initial commit
This commit is contained in:
commit
e02380c201
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
/faces
|
||||||
|
/models
|
||||||
199
Readme.md
Normal file
199
Readme.md
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python+ONNX)
|
||||||
|
|
||||||
|
1. 背景与目标
|
||||||
|
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成 1 条 512D
|
||||||
|
人脸特征(centroid),写入 SQLite 数据库 face_gallery.db。
|
||||||
|
• RK3588 运行时 ai_face_recog 以 gallery.backend=sqlite
|
||||||
|
读取该库,实现识别(无需改线上阈值/配置逻辑)。
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
2. 范围(In/Out)
|
||||||
|
|
||||||
|
In Scope
|
||||||
|
• 数据集扫描(按人目录)
|
||||||
|
• 人脸检测(含 5 点关键点)+ 对齐到 112×112
|
||||||
|
• ArcFace/MobileFaceNet ONNX 推理生成 512D embedding
|
||||||
|
• 每人多图聚合为 centroid(L2 normalize 后平均,再 L2 normalize)
|
||||||
|
• 生成 SQLite:person、embedding 两表
|
||||||
|
• 生成构建报告(统计、异常、可选相似度抽检)
|
||||||
|
|
||||||
|
Out of Scope
|
||||||
|
• 在线注册/写库接口
|
||||||
|
• RKNN 推理/设备端工具
|
||||||
|
• 人脸库增量热更新(只生成 db 文件)
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
3. 用户与使用方式
|
||||||
|
|
||||||
|
目标用户
|
||||||
|
• 开发/测试人员(在 Windows 上离线准备人脸库)
|
||||||
|
|
||||||
|
使用方式(CLI)
|
||||||
|
|
||||||
|
提供脚本:build_gallery.py
|
||||||
|
示例:
|
||||||
|
|
||||||
|
bash
|
||||||
|
python build_gallery.py ^
|
||||||
|
--dataset "D:\faces\dataset" ^
|
||||||
|
--db_out "D:\faces\face_gallery.db" ^
|
||||||
|
--det_model "D:\models\face_det.onnx" ^
|
||||||
|
--recog_model "D:\models\mobilefacenet_arcface_bs1.onnx" ^
|
||||||
|
--expected_dim 512 ^
|
||||||
|
--max_imgs_per_person 10 ^
|
||||||
|
--pick_face largest ^
|
||||||
|
--min_face_size 80 ^
|
||||||
|
--fail_on_empty true
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
4. 输入数据规范
|
||||||
|
|
||||||
|
目录结构
|
||||||
|
|
||||||
|
dataset/
|
||||||
|
张三/
|
||||||
|
001.jpg
|
||||||
|
002.jpg
|
||||||
|
李四/
|
||||||
|
a.png
|
||||||
|
b.png
|
||||||
|
|
||||||
|
• 文件夹名作为 person.name(UTF-8)
|
||||||
|
• 每人建议 3~10 张图(允许混合:标准大头照 + 摄像头位照片)
|
||||||
|
|
||||||
|
图片支持格式
|
||||||
|
• jpg/jpeg/png/bmp(由 OpenCV 读取)
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
5. 模型与前处理规范(必须严格一致)
|
||||||
|
|
||||||
|
识别模型(已确定)
|
||||||
|
• 输入:float32 [1,3,112,112],输出:float32 [1,512]
|
||||||
|
• 输入图像为 对齐后的 112×112 RGB
|
||||||
|
• 归一化:(x - 127.5) / 128.0,其中 x 为 uint8 0..255
|
||||||
|
• HWC→CHW→NCHW
|
||||||
|
|
||||||
|
检测模型(你不要求我管来源,但脚本必须满足)
|
||||||
|
• det ONNX 推理输出至少包含:bbox + 5 landmarks
|
||||||
|
• 必须明确 landmark 顺序(开发在配置/代码中固定)
|
||||||
|
|
||||||
|
对齐目标点(112×112,必须固定为以下值)
|
||||||
|
• (38.2946, 51.6963)
|
||||||
|
• (73.5318, 51.5014)
|
||||||
|
• (56.0252, 71.7366)
|
||||||
|
• (41.5493, 92.3655)
|
||||||
|
• (70.7299, 92.2041)
|
||||||
|
|
||||||
|
对齐方法:
|
||||||
|
• 基于 5 点求相似变换(Similarity / affine partial)→ warpAffine 到
|
||||||
|
112×112,双线性插值
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
6. 特征生成与聚合(每人 1 条 centroid)
|
||||||
|
|
||||||
|
对每张有效图片:
|
||||||
|
1. 检测 → 选择人脸(默认:最大脸)
|
||||||
|
2. 5 点对齐 → 112×112 RGB
|
||||||
|
3. 预处理 + ONNX 推理 → 512D embedding
|
||||||
|
4. L2 normalize(单样本)
|
||||||
|
|
||||||
|
对每个人:
|
||||||
|
• 取最多 max_imgs_per_person 张有效样本 embedding(不足也可)
|
||||||
|
• centroid 计算:
|
||||||
|
• centroid = mean(emb_i)(对已归一化 emb 做均值)
|
||||||
|
• centroid = L2_normalize(centroid)
|
||||||
|
• 只写入 1 条 embedding(centroid)
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
7. SQLite 数据库规范(必须与现有插件兼容)
|
||||||
|
|
||||||
|
文件
|
||||||
|
• 输出文件名:face_gallery.db(路径由 --db_out 指定)
|
||||||
|
|
||||||
|
表结构(最小要求)
|
||||||
|
|
||||||
|
sql
|
||||||
|
CREATE TABLE IF NOT EXISTS person (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL UNIQUE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS embedding (
|
||||||
|
person_id INTEGER NOT NULL,
|
||||||
|
emb BLOB NOT NULL,
|
||||||
|
FOREIGN KEY(person_id) REFERENCES person(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embedding_person_id ON embedding(person_id);
|
||||||
|
|
||||||
|
写入规则
|
||||||
|
• person.name = 文件夹名
|
||||||
|
• embedding.emb = centroid 向量的 float32 原始字节(BLOB)
|
||||||
|
• 长度必须为 expected_dim * 4(512→2048 bytes)
|
||||||
|
• 每个 person 仅 1 行 embedding
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
8. 质量控制与异常处理
|
||||||
|
|
||||||
|
过滤规则(可配置)
|
||||||
|
• min_face_size:bbox 宽或高小于阈值则丢弃
|
||||||
|
• 检测不到脸:记录为失败样本
|
||||||
|
• 多脸:默认选最大脸;可配置 --pick_face(largest/first/highest_score)
|
||||||
|
|
||||||
|
容错策略(可配置)
|
||||||
|
• 单人所有图片都失败:该 person 不写入库,并在报告中列出
|
||||||
|
• --fail_on_empty true:若最终 person 数为 0 或 embedding 数为 0,则脚本返回非
|
||||||
|
0 退出码
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
9. 输出报告(必须)
|
||||||
|
|
||||||
|
脚本运行结束输出:
|
||||||
|
• 总人数、成功入库人数、总图片数、成功样本数、失败样本数
|
||||||
|
• 每人使用的样本数量
|
||||||
|
• 失败原因统计(no_face / small_face / align_fail / infer_fail 等)
|
||||||
|
• 数据库自检:
|
||||||
|
• COUNT(person)、COUNT(embedding)
|
||||||
|
• 随机抽查 N 条 length(emb)==2048
|
||||||
|
|
||||||
|
(可选增强)打印每人 centroid 的 L2 norm(应接近 1.0)
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
10. 依赖与运行环境
|
||||||
|
• Windows 10/11
|
||||||
|
• Python 3.9+
|
||||||
|
• 依赖包:
|
||||||
|
• numpy
|
||||||
|
• opencv-python
|
||||||
|
• onnxruntime
|
||||||
|
• sqlite3:使用 Python 标准库
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
11. 验收标准(可直接验)
|
||||||
|
1. 生成的 face_gallery.db 能被你当前 ai_face_recog 加载,设备日志出现:gallery
|
||||||
|
loaded: n=<入库人数> dim=512
|
||||||
|
2. SQLite 校验:
|
||||||
|
• SELECT COUNT(*) FROM person; == 入库人数
|
||||||
|
• SELECT COUNT(*) FROM embedding; == 入库人数
|
||||||
|
• SELECT length(emb) FROM embedding LIMIT 5; 全为 2048
|
||||||
|
3. 实测:对至少 3 人,每人 3 张注册图 + 1
|
||||||
|
条现场图,识别能输出对应姓名(阈值用现有配置不强制修改)
|
||||||
|
|
||||||
|
──────────────────────────────────────────
|
||||||
|
|
||||||
|
12. 与设备端配置的对接要求(提醒开发)
|
||||||
|
• gallery.backend="sqlite"
|
||||||
|
• gallery.path 指向拷贝后的 face_gallery.db
|
||||||
|
• gallery.expected_dim=512
|
||||||
|
• threshold.margin 保持现有值即可(因为每人只有 1 条 centroid,不会出现“同人占
|
||||||
|
top2 导致 margin 过小”的问题)
|
||||||
167
build_gallery.py
Normal file
167
build_gallery.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gallery_builder.aggregate import compute_centroid
|
||||||
|
from gallery_builder.align import align_face_5pts
|
||||||
|
from gallery_builder.dataset import DatasetScanner
|
||||||
|
from gallery_builder.db import GalleryDbWriter, db_selfcheck
|
||||||
|
from gallery_builder.detector import OnnxFaceDetector, load_det_outputs_config
|
||||||
|
from gallery_builder.recognizer import OnnxFaceRecognizer
|
||||||
|
from gallery_builder.types import BuildReport, FailureReason
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bool(s: str) -> bool:
|
||||||
|
if isinstance(s, bool):
|
||||||
|
return s
|
||||||
|
v = str(s).strip().lower()
|
||||||
|
if v in ("1", "true", "t", "yes", "y", "on"):
|
||||||
|
return True
|
||||||
|
if v in ("0", "false", "f", "no", "n", "off"):
|
||||||
|
return False
|
||||||
|
raise argparse.ArgumentTypeError(f"invalid bool: {s}")
|
||||||
|
|
||||||
|
|
||||||
|
def _bbox_wh(b: np.ndarray) -> Tuple[float, float]:
|
||||||
|
return float(b[2] - b[0]), float(b[3] - b[1])
|
||||||
|
|
||||||
|
|
||||||
|
def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("opencv-python is required") from e
|
||||||
|
|
||||||
|
det_cfg = load_det_outputs_config(args.det_outputs_config)
|
||||||
|
if args.det_input_rgb is not None:
|
||||||
|
det_cfg.setdefault("input", {})
|
||||||
|
det_cfg["input"]["color"] = "RGB" if args.det_input_rgb else "BGR"
|
||||||
|
|
||||||
|
scanner = DatasetScanner(args.dataset)
|
||||||
|
report = BuildReport()
|
||||||
|
persons_total, images_total = scanner.summary()
|
||||||
|
report.total_person_dirs = persons_total
|
||||||
|
report.total_images = images_total
|
||||||
|
|
||||||
|
detector = OnnxFaceDetector(
|
||||||
|
model_path=args.det_model,
|
||||||
|
det_outputs_config=det_cfg,
|
||||||
|
score_thresh=args.det_score_thresh,
|
||||||
|
pick_face=args.pick_face,
|
||||||
|
)
|
||||||
|
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
|
||||||
|
|
||||||
|
enrolled: List[Tuple[str, np.ndarray]] = []
|
||||||
|
|
||||||
|
for person in scanner.iter_persons():
|
||||||
|
used_embs: List[np.ndarray] = []
|
||||||
|
if not person.image_paths:
|
||||||
|
report.skipped_persons.append(person.name)
|
||||||
|
for img_path in person.image_paths:
|
||||||
|
if len(used_embs) >= args.max_imgs_per_person:
|
||||||
|
break
|
||||||
|
img_bgr = cv2.imread(img_path)
|
||||||
|
if img_bgr is None:
|
||||||
|
report.add_failure(person.name, img_path, FailureReason.read_fail, "cv2.imread returned None")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
det = detector.detect_one(img_bgr)
|
||||||
|
except Exception as e:
|
||||||
|
report.add_failure(person.name, img_path, FailureReason.det_fail, repr(e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if det is None:
|
||||||
|
report.add_failure(person.name, img_path, FailureReason.no_face, "no detection")
|
||||||
|
continue
|
||||||
|
|
||||||
|
bw, bh = _bbox_wh(det.bbox_xyxy)
|
||||||
|
if bw < args.min_face_size or bh < args.min_face_size:
|
||||||
|
report.add_failure(
|
||||||
|
person.name,
|
||||||
|
img_path,
|
||||||
|
FailureReason.small_face,
|
||||||
|
f"bbox too small: w={bw:.1f} h={bh:.1f} < {args.min_face_size}",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
aligned = align_face_5pts(img_rgb, det.landmarks5, out_size=(112, 112))
|
||||||
|
except Exception as e:
|
||||||
|
report.add_failure(person.name, img_path, FailureReason.align_fail, repr(e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
emb = recognizer.embed_aligned_rgb112(aligned)
|
||||||
|
except Exception as e:
|
||||||
|
report.add_failure(person.name, img_path, FailureReason.infer_fail, repr(e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
used_embs.append(emb)
|
||||||
|
report.add_success(person.name)
|
||||||
|
|
||||||
|
if used_embs:
|
||||||
|
centroid = compute_centroid(np.stack(used_embs, axis=0))
|
||||||
|
enrolled.append((person.name, centroid))
|
||||||
|
else:
|
||||||
|
report.skipped_persons.append(person.name)
|
||||||
|
|
||||||
|
report.enrolled_persons = len(enrolled)
|
||||||
|
|
||||||
|
if args.fail_on_empty and report.enrolled_persons == 0:
|
||||||
|
return 2, report
|
||||||
|
|
||||||
|
writer = GalleryDbWriter(args.db_out, expected_dim=args.expected_dim)
|
||||||
|
writer.write(enrolled)
|
||||||
|
|
||||||
|
selfcheck = db_selfcheck(args.db_out, expected_dim=args.expected_dim, sample_n=args.selfcheck_samples)
|
||||||
|
from gallery_builder.report import format_report
|
||||||
|
|
||||||
|
print(format_report(report, selfcheck, show_per_person=True))
|
||||||
|
if args.print_centroid_norm:
|
||||||
|
for name, c in enrolled:
|
||||||
|
print(f"centroid_norm {name}: {float(np.linalg.norm(c)):.6f}")
|
||||||
|
|
||||||
|
if args.fail_on_empty and (selfcheck["person_count"] == 0 or selfcheck["embedding_count"] == 0):
|
||||||
|
return 3, report
|
||||||
|
if selfcheck["person_count"] != selfcheck["embedding_count"]:
|
||||||
|
return 4, report
|
||||||
|
if not selfcheck["sample_lengths_ok"]:
|
||||||
|
return 5, report
|
||||||
|
return 0, report
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: List[str]) -> int:
|
||||||
|
p = argparse.ArgumentParser(description="Build offline face_gallery.db (SQLite) from per-person image folders")
|
||||||
|
p.add_argument("--dataset", required=True, help="dataset root: dataset/person_name/*.jpg")
|
||||||
|
p.add_argument("--db_out", required=True, help="output sqlite db path")
|
||||||
|
p.add_argument("--det_model", required=True, help="face detection onnx")
|
||||||
|
p.add_argument("--recog_model", required=True, help="face recognition onnx")
|
||||||
|
p.add_argument("--det_outputs_config", required=True, help="JSON string or path to JSON file for detection output mapping")
|
||||||
|
|
||||||
|
p.add_argument("--expected_dim", type=int, default=512)
|
||||||
|
p.add_argument("--max_imgs_per_person", type=int, default=10)
|
||||||
|
p.add_argument("--pick_face", choices=["largest", "first", "highest_score"], default="largest")
|
||||||
|
p.add_argument("--min_face_size", type=int, default=80)
|
||||||
|
p.add_argument("--fail_on_empty", type=_parse_bool, default=False)
|
||||||
|
|
||||||
|
p.add_argument("--det_score_thresh", type=float, default=0.0)
|
||||||
|
p.add_argument("--det_input_rgb", type=_parse_bool, default=None, help="override det config input.color (true=RGB,false=BGR)")
|
||||||
|
p.add_argument("--selfcheck_samples", type=int, default=5)
|
||||||
|
p.add_argument("--print_centroid_norm", type=_parse_bool, default=False)
|
||||||
|
|
||||||
|
args = p.parse_args(argv)
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(args.db_out)) or ".", exist_ok=True)
|
||||||
|
|
||||||
|
code, _report = build_gallery(args)
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main(sys.argv[1:]))
|
||||||
15
gallery_builder/__init__.py
Normal file
15
gallery_builder/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
__all__ = [
|
||||||
|
"DatasetScanner",
|
||||||
|
"OnnxFaceDetector",
|
||||||
|
"OnnxFaceRecognizer",
|
||||||
|
"align_face_5pts",
|
||||||
|
"compute_centroid",
|
||||||
|
"GalleryDbWriter",
|
||||||
|
]
|
||||||
|
|
||||||
|
from .dataset import DatasetScanner
|
||||||
|
from .detector import OnnxFaceDetector
|
||||||
|
from .recognizer import OnnxFaceRecognizer
|
||||||
|
from .align import align_face_5pts
|
||||||
|
from .aggregate import compute_centroid
|
||||||
|
from .db import GalleryDbWriter
|
||||||
BIN
gallery_builder/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/aggregate.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/aggregate.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/align.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/align.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/dataset.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/db.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/db.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/detector.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/detector.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/recognizer.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/recognizer.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/report.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/report.cpython-310.pyc
Normal file
Binary file not shown.
BIN
gallery_builder/__pycache__/types.cpython-310.pyc
Normal file
BIN
gallery_builder/__pycache__/types.cpython-310.pyc
Normal file
Binary file not shown.
20
gallery_builder/aggregate.py
Normal file
20
gallery_builder/aggregate.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def l2_normalize(v: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
||||||
|
v = np.asarray(v, dtype=np.float32)
|
||||||
|
n = float(np.linalg.norm(v))
|
||||||
|
if n < eps:
|
||||||
|
return v * 0.0
|
||||||
|
return v / n
|
||||||
|
|
||||||
|
|
||||||
|
def compute_centroid(embs: np.ndarray) -> np.ndarray:
|
||||||
|
"""embs: float32 shape (K,D), expected already L2-normalized per row."""
|
||||||
|
embs = np.asarray(embs, dtype=np.float32)
|
||||||
|
if embs.ndim != 2 or embs.shape[0] < 1:
|
||||||
|
raise ValueError("embs must be (K,D) with K>=1")
|
||||||
|
c = embs.mean(axis=0)
|
||||||
|
return l2_normalize(c)
|
||||||
44
gallery_builder/align.py
Normal file
44
gallery_builder/align.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
_DST_5PTS_112 = np.array(
|
||||||
|
[
|
||||||
|
[38.2946, 51.6963],
|
||||||
|
[73.5318, 51.5014],
|
||||||
|
[56.0252, 71.7366],
|
||||||
|
[41.5493, 92.3655],
|
||||||
|
[70.7299, 92.2041],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def align_face_5pts(img_rgb: np.ndarray, landmarks5: np.ndarray, out_size: Tuple[int, int] = (112, 112)) -> np.ndarray:
|
||||||
|
"""Return aligned RGB image (H,W,3) uint8."""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
if img_rgb is None or img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
|
||||||
|
raise ValueError("img_rgb must be HxWx3")
|
||||||
|
lmk = np.asarray(landmarks5, dtype=np.float32)
|
||||||
|
if lmk.shape != (5, 2):
|
||||||
|
raise ValueError("landmarks5 must be shape (5,2)")
|
||||||
|
|
||||||
|
dst = _DST_5PTS_112.copy()
|
||||||
|
if out_size != (112, 112):
|
||||||
|
sx = out_size[0] / 112.0
|
||||||
|
sy = out_size[1] / 112.0
|
||||||
|
dst[:, 0] *= sx
|
||||||
|
dst[:, 1] *= sy
|
||||||
|
|
||||||
|
M, inliers = cv2.estimateAffinePartial2D(lmk, dst, method=cv2.LMEDS)
|
||||||
|
if M is None or not np.isfinite(M).all():
|
||||||
|
raise ValueError("estimateAffinePartial2D failed")
|
||||||
|
|
||||||
|
w, h = out_size
|
||||||
|
aligned = cv2.warpAffine(img_rgb, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
|
||||||
|
return aligned
|
||||||
48
gallery_builder/dataset.py
Normal file
48
gallery_builder/dataset.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
_IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PersonSamples:
|
||||||
|
name: str
|
||||||
|
image_paths: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetScanner:
|
||||||
|
def __init__(self, dataset_root: str) -> None:
|
||||||
|
self.dataset_root = os.path.abspath(dataset_root)
|
||||||
|
|
||||||
|
def iter_persons(self) -> Iterable[PersonSamples]:
|
||||||
|
if not os.path.isdir(self.dataset_root):
|
||||||
|
raise FileNotFoundError(f"dataset root not found: {self.dataset_root}")
|
||||||
|
|
||||||
|
with os.scandir(self.dataset_root) as it:
|
||||||
|
person_dirs = [e for e in it if e.is_dir()]
|
||||||
|
|
||||||
|
person_dirs.sort(key=lambda e: e.name)
|
||||||
|
for e in person_dirs:
|
||||||
|
name = e.name
|
||||||
|
img_paths: List[str] = []
|
||||||
|
with os.scandir(e.path) as it2:
|
||||||
|
for f in it2:
|
||||||
|
if not f.is_file():
|
||||||
|
continue
|
||||||
|
ext = os.path.splitext(f.name)[1].lower()
|
||||||
|
if ext in _IMG_EXTS:
|
||||||
|
img_paths.append(os.path.abspath(f.path))
|
||||||
|
img_paths.sort()
|
||||||
|
yield PersonSamples(name=name, image_paths=img_paths)
|
||||||
|
|
||||||
|
def summary(self) -> Tuple[int, int]:
|
||||||
|
persons = 0
|
||||||
|
images = 0
|
||||||
|
for p in self.iter_persons():
|
||||||
|
persons += 1
|
||||||
|
images += len(p.image_paths)
|
||||||
|
return persons, images
|
||||||
84
gallery_builder/db.py
Normal file
84
gallery_builder/db.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
_SCHEMA_SQL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS person (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL UNIQUE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS embedding (
|
||||||
|
person_id INTEGER NOT NULL,
|
||||||
|
emb BLOB NOT NULL,
|
||||||
|
FOREIGN KEY(person_id) REFERENCES person(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embedding_person_id ON embedding(person_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GalleryDbWriter:
|
||||||
|
def __init__(self, db_path: str, expected_dim: int = 512) -> None:
|
||||||
|
self.db_path = os.path.abspath(db_path)
|
||||||
|
self.expected_dim = int(expected_dim)
|
||||||
|
|
||||||
|
def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None:
|
||||||
|
"""items: (person_name, centroid_float32(D,)). Overwrites existing db file."""
|
||||||
|
|
||||||
|
if os.path.exists(self.db_path):
|
||||||
|
os.remove(self.db_path)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True)
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
try:
|
||||||
|
conn.executescript(_SCHEMA_SQL)
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("BEGIN")
|
||||||
|
for name, emb in items:
|
||||||
|
emb = np.asarray(emb, dtype=np.float32).reshape(-1)
|
||||||
|
if emb.size != self.expected_dim:
|
||||||
|
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
|
||||||
|
blob = emb.astype(np.float32).tobytes()
|
||||||
|
if len(blob) != self.expected_dim * 4:
|
||||||
|
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
|
||||||
|
|
||||||
|
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
|
||||||
|
person_id = cur.lastrowid
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
|
||||||
|
(person_id, sqlite3.Binary(blob)),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def db_selfcheck(db_path: str, expected_dim: int = 512, sample_n: int = 5) -> dict:
|
||||||
|
db_path = os.path.abspath(db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SELECT COUNT(*) FROM person")
|
||||||
|
person_cnt = int(cur.fetchone()[0])
|
||||||
|
cur.execute("SELECT COUNT(*) FROM embedding")
|
||||||
|
emb_cnt = int(cur.fetchone()[0])
|
||||||
|
cur.execute("SELECT length(emb) FROM embedding ORDER BY RANDOM() LIMIT ?", (int(sample_n),))
|
||||||
|
lengths = [int(r[0]) for r in cur.fetchall()]
|
||||||
|
ok_len = all(l == expected_dim * 4 for l in lengths) if lengths else True
|
||||||
|
return {
|
||||||
|
"person_count": person_cnt,
|
||||||
|
"embedding_count": emb_cnt,
|
||||||
|
"sample_lengths": lengths,
|
||||||
|
"sample_lengths_ok": ok_len,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
514
gallery_builder/detector.py
Normal file
514
gallery_builder/detector.py
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .types import Detection
|
||||||
|
|
||||||
|
|
||||||
|
_CANONICAL_LMK_ORDER = ["left_eye", "right_eye", "nose", "left_mouth", "right_mouth"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _ResizeMeta:
|
||||||
|
orig_w: int
|
||||||
|
orig_h: int
|
||||||
|
in_w: int
|
||||||
|
in_h: int
|
||||||
|
mode: str # none|stretch|keep_ratio
|
||||||
|
scale_x: float
|
||||||
|
scale_y: float
|
||||||
|
pad_x: float
|
||||||
|
pad_y: float
|
||||||
|
|
||||||
|
|
||||||
|
def load_det_outputs_config(s: str) -> Dict[str, Any]:
|
||||||
|
"""Accept JSON string or a JSON file path."""
|
||||||
|
if s is None:
|
||||||
|
raise ValueError("det_outputs_config is required (Option B)")
|
||||||
|
p = os.path.abspath(s)
|
||||||
|
if os.path.isfile(p):
|
||||||
|
with open(p, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
return json.loads(s)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxFaceDetector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
det_outputs_config: Dict[str, Any],
|
||||||
|
score_thresh: float = 0.0,
|
||||||
|
pick_face: str = "largest",
|
||||||
|
) -> None:
|
||||||
|
self.model_path = model_path
|
||||||
|
self.cfg = det_outputs_config
|
||||||
|
self.score_thresh = float(score_thresh)
|
||||||
|
self.pick_face = pick_face
|
||||||
|
self._sess = None
|
||||||
|
self._input_name: Optional[str] = None
|
||||||
|
self._output_names: Optional[List[str]] = None
|
||||||
|
|
||||||
|
if pick_face not in ("largest", "first", "highest_score"):
|
||||||
|
raise ValueError("pick_face must be one of: largest|first|highest_score")
|
||||||
|
|
||||||
|
def _ensure_session(self) -> None:
|
||||||
|
if self._sess is not None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
raise RuntimeError("onnxruntime is required for detection") from e
|
||||||
|
|
||||||
|
self._sess = ort.InferenceSession(self.model_path, providers=["CPUExecutionProvider"])
|
||||||
|
self._input_name = self._sess.get_inputs()[0].name
|
||||||
|
self._output_names = [o.name for o in self._sess.get_outputs()]
|
||||||
|
|
||||||
|
def detect_one(self, img_bgr: np.ndarray) -> Optional[Detection]:
|
||||||
|
dets = self.detect_all(img_bgr)
|
||||||
|
if not dets:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.pick_face == "first":
|
||||||
|
return dets[0]
|
||||||
|
if self.pick_face == "highest_score":
|
||||||
|
return max(dets, key=lambda d: float(d.score))
|
||||||
|
return max(dets, key=lambda d: float((d.bbox_xyxy[2] - d.bbox_xyxy[0]) * (d.bbox_xyxy[3] - d.bbox_xyxy[1])))
|
||||||
|
|
||||||
|
def detect_all(self, img_bgr: np.ndarray) -> List[Detection]:
|
||||||
|
"""Return detections in original image coords."""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
self._ensure_session()
|
||||||
|
if img_bgr is None or img_bgr.ndim != 3 or img_bgr.shape[2] != 3:
|
||||||
|
raise ValueError("img_bgr must be HxWx3")
|
||||||
|
|
||||||
|
inp_cfg = self.cfg.get("input", {})
|
||||||
|
color = str(inp_cfg.get("color", "BGR")).upper()
|
||||||
|
img = img_bgr
|
||||||
|
if color == "RGB":
|
||||||
|
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
elif color != "BGR":
|
||||||
|
raise ValueError(f"unsupported input.color: {color}")
|
||||||
|
|
||||||
|
inp, meta = self._preprocess(img, inp_cfg)
|
||||||
|
outputs = self._sess.run(None, {self._input_name: inp})
|
||||||
|
out_by_name = {name: val for name, val in zip(self._output_names, outputs)}
|
||||||
|
dets = self._parse_outputs(out_by_name, meta)
|
||||||
|
if self.score_thresh > 0:
|
||||||
|
dets = [d for d in dets if float(d.score) >= self.score_thresh]
|
||||||
|
return dets
|
||||||
|
|
||||||
|
def _preprocess(self, img_hwc: np.ndarray, inp_cfg: Dict[str, Any]) -> Tuple[np.ndarray, _ResizeMeta]:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
h, w = int(img_hwc.shape[0]), int(img_hwc.shape[1])
|
||||||
|
resize_cfg = inp_cfg.get("resize", None)
|
||||||
|
if not resize_cfg:
|
||||||
|
in_w, in_h = w, h
|
||||||
|
meta = _ResizeMeta(orig_w=w, orig_h=h, in_w=w, in_h=h, mode="none", scale_x=1.0, scale_y=1.0, pad_x=0.0, pad_y=0.0)
|
||||||
|
resized = img_hwc
|
||||||
|
else:
|
||||||
|
size = resize_cfg.get("size")
|
||||||
|
if not (isinstance(size, (list, tuple)) and len(size) == 2):
|
||||||
|
raise ValueError("input.resize.size must be [w,h]")
|
||||||
|
in_w, in_h = int(size[0]), int(size[1])
|
||||||
|
mode = str(resize_cfg.get("mode", "stretch")).lower()
|
||||||
|
if mode == "stretch":
|
||||||
|
resized = cv2.resize(img_hwc, (in_w, in_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
meta = _ResizeMeta(
|
||||||
|
orig_w=w,
|
||||||
|
orig_h=h,
|
||||||
|
in_w=in_w,
|
||||||
|
in_h=in_h,
|
||||||
|
mode="stretch",
|
||||||
|
scale_x=in_w / float(w),
|
||||||
|
scale_y=in_h / float(h),
|
||||||
|
pad_x=0.0,
|
||||||
|
pad_y=0.0,
|
||||||
|
)
|
||||||
|
elif mode == "keep_ratio":
|
||||||
|
scale = min(in_w / float(w), in_h / float(h))
|
||||||
|
new_w = int(round(w * scale))
|
||||||
|
new_h = int(round(h * scale))
|
||||||
|
resized_small = cv2.resize(img_hwc, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
canvas = np.zeros((in_h, in_w, 3), dtype=resized_small.dtype)
|
||||||
|
pad_x = (in_w - new_w) // 2
|
||||||
|
pad_y = (in_h - new_h) // 2
|
||||||
|
canvas[pad_y : pad_y + new_h, pad_x : pad_x + new_w] = resized_small
|
||||||
|
resized = canvas
|
||||||
|
meta = _ResizeMeta(
|
||||||
|
orig_w=w,
|
||||||
|
orig_h=h,
|
||||||
|
in_w=in_w,
|
||||||
|
in_h=in_h,
|
||||||
|
mode="keep_ratio",
|
||||||
|
scale_x=scale,
|
||||||
|
scale_y=scale,
|
||||||
|
pad_x=float(pad_x),
|
||||||
|
pad_y=float(pad_y),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("input.resize.mode must be stretch|keep_ratio")
|
||||||
|
|
||||||
|
dtype = str(inp_cfg.get("dtype", "float32")).lower()
|
||||||
|
x = resized
|
||||||
|
if dtype in ("float32", "fp32"):
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
elif dtype in ("uint8",):
|
||||||
|
x = x.astype(np.uint8)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported input.dtype: {dtype}")
|
||||||
|
|
||||||
|
norm = inp_cfg.get("normalize", None)
|
||||||
|
if norm and dtype in ("float32", "fp32"):
|
||||||
|
scale = float(norm.get("scale", 1.0))
|
||||||
|
mean = norm.get("mean", [0.0, 0.0, 0.0])
|
||||||
|
std = norm.get("std", [1.0, 1.0, 1.0])
|
||||||
|
mean = np.asarray(mean, dtype=np.float32).reshape(1, 1, 3)
|
||||||
|
std = np.asarray(std, dtype=np.float32).reshape(1, 1, 3)
|
||||||
|
x = x * scale
|
||||||
|
x = (x - mean) / std
|
||||||
|
|
||||||
|
layout = str(inp_cfg.get("layout", "NCHW")).upper()
|
||||||
|
if layout == "NHWC":
|
||||||
|
x = np.expand_dims(x, axis=0)
|
||||||
|
elif layout == "NCHW":
|
||||||
|
x = np.transpose(x, (2, 0, 1))
|
||||||
|
x = np.expand_dims(x, axis=0)
|
||||||
|
else:
|
||||||
|
raise ValueError("input.layout must be NCHW|NHWC")
|
||||||
|
|
||||||
|
return x, meta
|
||||||
|
|
||||||
|
def _parse_outputs(self, out_by_name: Dict[str, Any], meta: _ResizeMeta) -> List[Detection]:
|
||||||
|
decoder_cfg = self.cfg.get("decoder")
|
||||||
|
if decoder_cfg and str(decoder_cfg.get("type", "")).lower() == "retinaface":
|
||||||
|
return self._parse_outputs_retinaface(out_by_name, meta, decoder_cfg)
|
||||||
|
|
||||||
|
out_cfg = self.cfg.get("outputs", {})
|
||||||
|
bbox_cfg = out_cfg.get("bbox")
|
||||||
|
lmk_cfg = out_cfg.get("landmarks")
|
||||||
|
score_cfg = out_cfg.get("score")
|
||||||
|
if not bbox_cfg or not lmk_cfg:
|
||||||
|
raise ValueError("det_outputs_config must include either decoder.type=retinaface OR outputs.bbox+outputs.landmarks")
|
||||||
|
|
||||||
|
bbox_arr = self._select_output(out_by_name, bbox_cfg)
|
||||||
|
lmk_arr = self._select_output(out_by_name, lmk_cfg)
|
||||||
|
score_arr = self._select_output(out_by_name, score_cfg) if score_cfg else None
|
||||||
|
|
||||||
|
bbox = self._to_Nx4(bbox_arr)
|
||||||
|
lmks = self._to_landmarks(lmk_arr, lmk_cfg)
|
||||||
|
|
||||||
|
if score_arr is None:
|
||||||
|
scores = np.ones((bbox.shape[0],), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
scores = np.asarray(score_arr, dtype=np.float32)
|
||||||
|
scores = scores.reshape(-1)
|
||||||
|
if scores.size == bbox.shape[0] * 1:
|
||||||
|
scores = scores[: bbox.shape[0]]
|
||||||
|
elif scores.size != bbox.shape[0]:
|
||||||
|
raise ValueError(f"score count mismatch: scores={scores.size}, bbox={bbox.shape[0]}")
|
||||||
|
|
||||||
|
bbox_format = str(bbox_cfg.get("format", "xyxy")).lower()
|
||||||
|
bbox_norm = bool(bbox_cfg.get("normalized", False))
|
||||||
|
lmk_norm = bool(lmk_cfg.get("normalized", False))
|
||||||
|
|
||||||
|
if bbox_norm:
|
||||||
|
bbox = bbox.copy()
|
||||||
|
bbox[:, [0, 2]] *= float(meta.in_w)
|
||||||
|
bbox[:, [1, 3]] *= float(meta.in_h)
|
||||||
|
if bbox_format == "xywh":
|
||||||
|
bbox = bbox.copy()
|
||||||
|
bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
|
||||||
|
bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
|
||||||
|
elif bbox_format != "xyxy":
|
||||||
|
raise ValueError("outputs.bbox.format must be xyxy|xywh")
|
||||||
|
|
||||||
|
if lmk_norm:
|
||||||
|
lmks = lmks.copy()
|
||||||
|
lmks[:, :, 0] *= float(meta.in_w)
|
||||||
|
lmks[:, :, 1] *= float(meta.in_h)
|
||||||
|
|
||||||
|
bbox, lmks = self._map_to_original(bbox, lmks, meta)
|
||||||
|
bbox = self._clip_bbox(bbox, meta.orig_w, meta.orig_h)
|
||||||
|
|
||||||
|
dets: List[Detection] = []
|
||||||
|
for i in range(bbox.shape[0]):
|
||||||
|
dets.append(
|
||||||
|
Detection(
|
||||||
|
bbox_xyxy=bbox[i].astype(np.float32),
|
||||||
|
landmarks5=lmks[i].astype(np.float32),
|
||||||
|
score=float(scores[i]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return dets
|
||||||
|
|
||||||
|
def _parse_outputs_retinaface(self, out_by_name: Dict[str, Any], meta: _ResizeMeta, decoder_cfg: Dict[str, Any]) -> List[Detection]:
|
||||||
|
out_cfg = self.cfg.get("outputs", {})
|
||||||
|
loc_spec = out_cfg.get("loc") or out_cfg.get("bbox")
|
||||||
|
conf_spec = out_cfg.get("conf") or out_cfg.get("score")
|
||||||
|
lmk_spec = out_cfg.get("landmarks")
|
||||||
|
if not loc_spec or not conf_spec or not lmk_spec:
|
||||||
|
raise ValueError("retinaface decoder requires outputs.loc, outputs.conf, outputs.landmarks")
|
||||||
|
|
||||||
|
loc = np.asarray(self._select_output(out_by_name, loc_spec), dtype=np.float32)
|
||||||
|
conf = np.asarray(self._select_output(out_by_name, conf_spec), dtype=np.float32)
|
||||||
|
landms = np.asarray(self._select_output(out_by_name, lmk_spec), dtype=np.float32)
|
||||||
|
|
||||||
|
if loc.ndim == 3 and loc.shape[0] == 1:
|
||||||
|
loc = loc[0]
|
||||||
|
if conf.ndim == 3 and conf.shape[0] == 1:
|
||||||
|
conf = conf[0]
|
||||||
|
if landms.ndim == 3 and landms.shape[0] == 1:
|
||||||
|
landms = landms[0]
|
||||||
|
|
||||||
|
if loc.ndim != 2 or loc.shape[1] != 4:
|
||||||
|
raise ValueError(f"retinaface loc must be [N,4] (or [1,N,4]); got {loc.shape}")
|
||||||
|
if conf.ndim != 2 or conf.shape[1] != 2:
|
||||||
|
raise ValueError(f"retinaface conf must be [N,2] (or [1,N,2]); got {conf.shape}")
|
||||||
|
if landms.ndim != 2 or landms.shape[1] != 10:
|
||||||
|
raise ValueError(f"retinaface landmarks must be [N,10] (or [1,N,10]); got {landms.shape}")
|
||||||
|
|
||||||
|
steps = decoder_cfg.get("steps", [8, 16, 32])
|
||||||
|
min_sizes = decoder_cfg.get("min_sizes", [[16, 32], [64, 128], [256, 512]])
|
||||||
|
variances = decoder_cfg.get("variances", [0.1, 0.2])
|
||||||
|
score_index = int(decoder_cfg.get("score_index", 1))
|
||||||
|
nms_iou = float(decoder_cfg.get("nms_iou_thresh", 0.4))
|
||||||
|
top_k = int(decoder_cfg.get("top_k", 5000))
|
||||||
|
keep_top_k = int(decoder_cfg.get("keep_top_k", 750))
|
||||||
|
prob_mode = str(decoder_cfg.get("conf_mode", "auto")).lower() # auto|prob|logits
|
||||||
|
|
||||||
|
priors = self._retinaface_priors(meta.in_w, meta.in_h, steps=steps, min_sizes=min_sizes)
|
||||||
|
if priors.shape[0] != loc.shape[0]:
|
||||||
|
raise ValueError(f"prior count mismatch: priors={priors.shape[0]} loc={loc.shape[0]}")
|
||||||
|
|
||||||
|
scores = self._retinaface_scores(conf, score_index=score_index, mode=prob_mode)
|
||||||
|
|
||||||
|
# filter
|
||||||
|
keep = np.where(scores >= float(self.score_thresh))[0] if self.score_thresh > 0 else np.arange(scores.size)
|
||||||
|
if keep.size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if top_k > 0 and keep.size > top_k:
|
||||||
|
idx = np.argsort(scores[keep])[::-1][:top_k]
|
||||||
|
keep = keep[idx]
|
||||||
|
|
||||||
|
pri = priors[keep]
|
||||||
|
loc_k = loc[keep]
|
||||||
|
lmk_k = landms[keep].reshape(-1, 5, 2)
|
||||||
|
sc_k = scores[keep]
|
||||||
|
|
||||||
|
bbox_in, lmks_in = self._retinaface_decode(pri, loc_k, lmk_k, meta.in_w, meta.in_h, variances=variances)
|
||||||
|
|
||||||
|
order = np.argsort(sc_k)[::-1]
|
||||||
|
bbox_in = bbox_in[order]
|
||||||
|
lmks_in = lmks_in[order]
|
||||||
|
sc_k = sc_k[order]
|
||||||
|
|
||||||
|
keep_nms = self._nms_xyxy(bbox_in, sc_k, iou_thresh=nms_iou)
|
||||||
|
if keep_top_k > 0:
|
||||||
|
keep_nms = keep_nms[:keep_top_k]
|
||||||
|
|
||||||
|
bbox_in = bbox_in[keep_nms]
|
||||||
|
lmks_in = lmks_in[keep_nms]
|
||||||
|
sc_k = sc_k[keep_nms]
|
||||||
|
|
||||||
|
bbox, lmks = self._map_to_original(bbox_in, lmks_in, meta)
|
||||||
|
bbox = self._clip_bbox(bbox, meta.orig_w, meta.orig_h)
|
||||||
|
|
||||||
|
dets: List[Detection] = []
|
||||||
|
for i in range(bbox.shape[0]):
|
||||||
|
dets.append(Detection(bbox_xyxy=bbox[i].astype(np.float32), landmarks5=lmks[i].astype(np.float32), score=float(sc_k[i])))
|
||||||
|
return dets
|
||||||
|
|
||||||
|
def _retinaface_priors(self, in_w: int, in_h: int, steps: Sequence[int], min_sizes: Sequence[Sequence[int]]) -> np.ndarray:
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
priors: List[List[float]] = []
|
||||||
|
for k, step in enumerate(steps):
|
||||||
|
fm_h = int(np.ceil(in_h / float(step)))
|
||||||
|
fm_w = int(np.ceil(in_w / float(step)))
|
||||||
|
for i, j in product(range(fm_h), range(fm_w)):
|
||||||
|
for ms in min_sizes[k]:
|
||||||
|
s_kx = ms / float(in_w)
|
||||||
|
s_ky = ms / float(in_h)
|
||||||
|
cx = (j + 0.5) * step / float(in_w)
|
||||||
|
cy = (i + 0.5) * step / float(in_h)
|
||||||
|
priors.append([cx, cy, s_kx, s_ky])
|
||||||
|
return np.asarray(priors, dtype=np.float32)
|
||||||
|
|
||||||
|
def _retinaface_scores(self, conf: np.ndarray, score_index: int, mode: str) -> np.ndarray:
|
||||||
|
x = conf.astype(np.float32)
|
||||||
|
if mode == "prob":
|
||||||
|
prob = x
|
||||||
|
elif mode == "logits":
|
||||||
|
prob = self._softmax(x, axis=1)
|
||||||
|
else: # auto
|
||||||
|
row_sum = x.sum(axis=1)
|
||||||
|
looks_prob = (x.min() >= 0.0) and (x.max() <= 1.0) and (np.mean(np.abs(row_sum - 1.0)) < 1e-2)
|
||||||
|
prob = x if looks_prob else self._softmax(x, axis=1)
|
||||||
|
|
||||||
|
if score_index < 0 or score_index >= prob.shape[1]:
|
||||||
|
raise ValueError(f"score_index out of range: {score_index}")
|
||||||
|
return prob[:, score_index]
|
||||||
|
|
||||||
|
def _retinaface_decode(
|
||||||
|
self,
|
||||||
|
priors: np.ndarray,
|
||||||
|
loc: np.ndarray,
|
||||||
|
landms: np.ndarray,
|
||||||
|
in_w: int,
|
||||||
|
in_h: int,
|
||||||
|
variances: Sequence[float],
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
v0 = float(variances[0])
|
||||||
|
v1 = float(variances[1])
|
||||||
|
|
||||||
|
pri_c = priors[:, 0:2]
|
||||||
|
pri_s = priors[:, 2:4]
|
||||||
|
|
||||||
|
boxes_c = pri_c + loc[:, 0:2] * v0 * pri_s
|
||||||
|
boxes_s = pri_s * np.exp(loc[:, 2:4] * v1)
|
||||||
|
boxes = np.concatenate([boxes_c - boxes_s / 2.0, boxes_c + boxes_s / 2.0], axis=1)
|
||||||
|
boxes[:, [0, 2]] *= float(in_w)
|
||||||
|
boxes[:, [1, 3]] *= float(in_h)
|
||||||
|
|
||||||
|
lm = pri_c[:, None, :] + landms * v0 * pri_s[:, None, :]
|
||||||
|
lm[:, :, 0] *= float(in_w)
|
||||||
|
lm[:, :, 1] *= float(in_h)
|
||||||
|
return boxes.astype(np.float32), lm.astype(np.float32)
|
||||||
|
|
||||||
|
def _softmax(self, x: np.ndarray, axis: int = -1) -> np.ndarray:
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
m = np.max(x, axis=axis, keepdims=True)
|
||||||
|
e = np.exp(x - m)
|
||||||
|
s = np.sum(e, axis=axis, keepdims=True)
|
||||||
|
return e / s
|
||||||
|
|
||||||
|
def _nms_xyxy(self, boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> List[int]:
|
||||||
|
b = boxes.astype(np.float32)
|
||||||
|
s = scores.astype(np.float32)
|
||||||
|
x1 = b[:, 0]
|
||||||
|
y1 = b[:, 1]
|
||||||
|
x2 = b[:, 2]
|
||||||
|
y2 = b[:, 3]
|
||||||
|
areas = np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1)
|
||||||
|
|
||||||
|
order = np.argsort(s)[::-1]
|
||||||
|
keep: List[int] = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = int(order[0])
|
||||||
|
keep.append(i)
|
||||||
|
if order.size == 1:
|
||||||
|
break
|
||||||
|
rest = order[1:]
|
||||||
|
|
||||||
|
xx1 = np.maximum(x1[i], x1[rest])
|
||||||
|
yy1 = np.maximum(y1[i], y1[rest])
|
||||||
|
xx2 = np.minimum(x2[i], x2[rest])
|
||||||
|
yy2 = np.minimum(y2[i], y2[rest])
|
||||||
|
|
||||||
|
w = np.maximum(0.0, xx2 - xx1)
|
||||||
|
h = np.maximum(0.0, yy2 - yy1)
|
||||||
|
inter = w * h
|
||||||
|
union = areas[i] + areas[rest] - inter
|
||||||
|
iou = np.where(union > 0, inter / union, 0.0)
|
||||||
|
|
||||||
|
inds = np.where(iou <= float(iou_thresh))[0]
|
||||||
|
order = rest[inds]
|
||||||
|
return keep
|
||||||
|
|
||||||
|
def _select_output(self, out_by_name: Dict[str, Any], spec: Optional[Dict[str, Any]]) -> Any:
|
||||||
|
if spec is None:
|
||||||
|
return None
|
||||||
|
if "name" in spec:
|
||||||
|
name = spec["name"]
|
||||||
|
if name not in out_by_name:
|
||||||
|
raise KeyError(f"output not found: {name}")
|
||||||
|
return out_by_name[name]
|
||||||
|
if "index" in spec:
|
||||||
|
idx = int(spec["index"])
|
||||||
|
keys = list(out_by_name.keys())
|
||||||
|
if idx < 0 or idx >= len(keys):
|
||||||
|
raise IndexError(f"output index out of range: {idx}")
|
||||||
|
return out_by_name[keys[idx]]
|
||||||
|
raise ValueError("output spec must include name or index")
|
||||||
|
|
||||||
|
def _to_Nx4(self, arr: Any) -> np.ndarray:
|
||||||
|
x = np.asarray(arr)
|
||||||
|
if x.ndim == 3 and x.shape[0] == 1:
|
||||||
|
x = x[0]
|
||||||
|
if x.ndim != 2 or x.shape[1] != 4:
|
||||||
|
raise ValueError(f"bbox output must be [N,4] (or [1,N,4]); got {x.shape}")
|
||||||
|
return x.astype(np.float32)
|
||||||
|
|
||||||
|
def _to_landmarks(self, arr: Any, lmk_cfg: Dict[str, Any]) -> np.ndarray:
|
||||||
|
x = np.asarray(arr)
|
||||||
|
if x.ndim == 4 and x.shape[0] == 1:
|
||||||
|
x = x[0]
|
||||||
|
if x.ndim == 3 and x.shape[0] == 1:
|
||||||
|
x = x[0]
|
||||||
|
|
||||||
|
layout = str(lmk_cfg.get("layout", "flat10")).lower()
|
||||||
|
if layout == "flat10":
|
||||||
|
if x.ndim != 2 or x.shape[1] != 10:
|
||||||
|
raise ValueError(f"landmarks flat10 must be [N,10]; got {x.shape}")
|
||||||
|
x = x.reshape(-1, 5, 2)
|
||||||
|
elif layout in ("5x2", "five_two"):
|
||||||
|
if x.ndim != 3 or x.shape[1:] != (5, 2):
|
||||||
|
raise ValueError(f"landmarks 5x2 must be [N,5,2]; got {x.shape}")
|
||||||
|
else:
|
||||||
|
raise ValueError("outputs.landmarks.layout must be flat10|5x2")
|
||||||
|
|
||||||
|
order = lmk_cfg.get("order")
|
||||||
|
if order:
|
||||||
|
x = self._reorder_landmarks(x, order)
|
||||||
|
return x.astype(np.float32)
|
||||||
|
|
||||||
|
def _reorder_landmarks(self, lmks: np.ndarray, order: Sequence[str]) -> np.ndarray:
|
||||||
|
order = [str(o) for o in order]
|
||||||
|
if sorted(order) != sorted(_CANONICAL_LMK_ORDER):
|
||||||
|
raise ValueError(f"outputs.landmarks.order must be a permutation of {_CANONICAL_LMK_ORDER}")
|
||||||
|
idx = {name: i for i, name in enumerate(order)}
|
||||||
|
take = [idx[name] for name in _CANONICAL_LMK_ORDER]
|
||||||
|
return lmks[:, take, :]
|
||||||
|
|
||||||
|
def _map_to_original(self, bbox_xyxy_in: np.ndarray, lmks_in: np.ndarray, meta: _ResizeMeta) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
if meta.mode == "none":
|
||||||
|
return bbox_xyxy_in, lmks_in
|
||||||
|
if meta.mode == "stretch":
|
||||||
|
sx = meta.scale_x
|
||||||
|
sy = meta.scale_y
|
||||||
|
bbox = bbox_xyxy_in.copy()
|
||||||
|
bbox[:, [0, 2]] /= sx
|
||||||
|
bbox[:, [1, 3]] /= sy
|
||||||
|
lmks = lmks_in.copy()
|
||||||
|
lmks[:, :, 0] /= sx
|
||||||
|
lmks[:, :, 1] /= sy
|
||||||
|
return bbox, lmks
|
||||||
|
if meta.mode == "keep_ratio":
|
||||||
|
s = meta.scale_x
|
||||||
|
px = meta.pad_x
|
||||||
|
py = meta.pad_y
|
||||||
|
bbox = bbox_xyxy_in.copy()
|
||||||
|
bbox[:, [0, 2]] = (bbox[:, [0, 2]] - px) / s
|
||||||
|
bbox[:, [1, 3]] = (bbox[:, [1, 3]] - py) / s
|
||||||
|
lmks = lmks_in.copy()
|
||||||
|
lmks[:, :, 0] = (lmks[:, :, 0] - px) / s
|
||||||
|
lmks[:, :, 1] = (lmks[:, :, 1] - py) / s
|
||||||
|
return bbox, lmks
|
||||||
|
raise ValueError(f"unknown resize mode: {meta.mode}")
|
||||||
|
|
||||||
|
def _clip_bbox(self, bbox: np.ndarray, w: int, h: int) -> np.ndarray:
|
||||||
|
b = bbox.copy()
|
||||||
|
b[:, 0] = np.clip(b[:, 0], 0, w - 1)
|
||||||
|
b[:, 1] = np.clip(b[:, 1], 0, h - 1)
|
||||||
|
b[:, 2] = np.clip(b[:, 2], 0, w - 1)
|
||||||
|
b[:, 3] = np.clip(b[:, 3], 0, h - 1)
|
||||||
|
return b
|
||||||
49
gallery_builder/recognizer.py
Normal file
49
gallery_builder/recognizer.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .aggregate import l2_normalize
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxFaceRecognizer:
|
||||||
|
def __init__(self, model_path: str, expected_dim: int = 512) -> None:
|
||||||
|
self.model_path = model_path
|
||||||
|
self.expected_dim = int(expected_dim)
|
||||||
|
self._sess = None
|
||||||
|
self._input_name: Optional[str] = None
|
||||||
|
self._output_name: Optional[str] = None
|
||||||
|
|
||||||
|
def _ensure_session(self) -> None:
|
||||||
|
if self._sess is not None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
raise RuntimeError("onnxruntime is required for recognition") from e
|
||||||
|
|
||||||
|
self._sess = ort.InferenceSession(self.model_path, providers=["CPUExecutionProvider"])
|
||||||
|
self._input_name = self._sess.get_inputs()[0].name
|
||||||
|
self._output_name = self._sess.get_outputs()[0].name
|
||||||
|
|
||||||
|
def embed_aligned_rgb112(self, aligned_rgb112: np.ndarray) -> np.ndarray:
|
||||||
|
"""aligned_rgb112: uint8 RGB 112x112x3. Return float32 (D,) L2-normalized."""
|
||||||
|
|
||||||
|
self._ensure_session()
|
||||||
|
x = np.asarray(aligned_rgb112)
|
||||||
|
if x.shape[:2] != (112, 112) or x.ndim != 3 or x.shape[2] != 3:
|
||||||
|
raise ValueError("aligned image must be 112x112x3 RGB")
|
||||||
|
if x.dtype != np.uint8:
|
||||||
|
x = x.astype(np.uint8)
|
||||||
|
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
x = (x - 127.5) / 128.0
|
||||||
|
x = np.transpose(x, (2, 0, 1)) # CHW
|
||||||
|
x = np.expand_dims(x, axis=0) # NCHW
|
||||||
|
|
||||||
|
out = self._sess.run([self._output_name], {self._input_name: x})[0]
|
||||||
|
out = np.asarray(out, dtype=np.float32).reshape(-1)
|
||||||
|
if out.size != self.expected_dim:
|
||||||
|
raise ValueError(f"unexpected embedding dim: got {out.size}, expected {self.expected_dim}")
|
||||||
|
return l2_normalize(out)
|
||||||
45
gallery_builder/report.py
Normal file
45
gallery_builder/report.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from .types import BuildReport
|
||||||
|
|
||||||
|
|
||||||
|
def format_report(report: BuildReport, db_selfcheck: Dict[str, Any], show_per_person: bool = True) -> str:
|
||||||
|
lines = []
|
||||||
|
lines.append("==== build report ====")
|
||||||
|
lines.append(f"total_person_dirs: {report.total_person_dirs}")
|
||||||
|
lines.append(f"enrolled_persons: {report.enrolled_persons}")
|
||||||
|
lines.append(f"total_images: {report.total_images}")
|
||||||
|
lines.append(f"processed_images: {report.processed_images}")
|
||||||
|
lines.append(f"ok_images: {report.ok_images}")
|
||||||
|
lines.append(f"failed_images: {report.failed_images}")
|
||||||
|
lines.append("\n-- failure reasons --")
|
||||||
|
if report.failure_reasons:
|
||||||
|
for k in sorted(report.failure_reasons.keys()):
|
||||||
|
lines.append(f"{k}: {report.failure_reasons[k]}")
|
||||||
|
else:
|
||||||
|
lines.append("(none)")
|
||||||
|
|
||||||
|
if show_per_person:
|
||||||
|
lines.append("\n-- per person used samples --")
|
||||||
|
for name in sorted(report.per_person_used.keys()):
|
||||||
|
lines.append(f"{name}: {report.per_person_used[name]}")
|
||||||
|
|
||||||
|
if report.skipped_persons:
|
||||||
|
lines.append("\n-- persons skipped (no valid embeddings) --")
|
||||||
|
for name in sorted(set(report.skipped_persons)):
|
||||||
|
fails = report.per_person_failed.get(name, [])
|
||||||
|
lines.append(f"{name}: failed_images={len(fails)}")
|
||||||
|
if fails:
|
||||||
|
reason_cnt: Dict[str, int] = {}
|
||||||
|
for _path, r, _detail in fails:
|
||||||
|
reason_cnt[str(r)] = reason_cnt.get(str(r), 0) + 1
|
||||||
|
lines.append(" reasons: " + ", ".join([f"{k}={reason_cnt[k]}" for k in sorted(reason_cnt.keys())]))
|
||||||
|
for path, r, detail in fails[:3]:
|
||||||
|
lines.append(f" sample: {r} {path} | {detail}")
|
||||||
|
|
||||||
|
lines.append("\n-- db selfcheck --")
|
||||||
|
lines.append(json.dumps(db_selfcheck, ensure_ascii=False, indent=2))
|
||||||
|
return "\n".join(lines)
|
||||||
67
gallery_builder/types.py
Normal file
67
gallery_builder/types.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class FailureReason(str, Enum):
|
||||||
|
no_face = "no_face"
|
||||||
|
small_face = "small_face"
|
||||||
|
align_fail = "align_fail"
|
||||||
|
infer_fail = "infer_fail"
|
||||||
|
read_fail = "read_fail"
|
||||||
|
det_fail = "det_fail"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Detection:
|
||||||
|
bbox_xyxy: np.ndarray # float32 shape (4,) in original image coords
|
||||||
|
landmarks5: np.ndarray # float32 shape (5,2) in original image coords
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ImageResult:
|
||||||
|
ok: bool
|
||||||
|
reason: Optional[FailureReason]
|
||||||
|
detail: str
|
||||||
|
emb: Optional[np.ndarray] = None # float32 (D,) L2-normalized
|
||||||
|
det: Optional[Detection] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BuildReport:
|
||||||
|
total_person_dirs: int = 0
|
||||||
|
enrolled_persons: int = 0
|
||||||
|
total_images: int = 0
|
||||||
|
processed_images: int = 0
|
||||||
|
ok_images: int = 0
|
||||||
|
failed_images: int = 0
|
||||||
|
per_person_used: Dict[str, int] = None
|
||||||
|
per_person_failed: Dict[str, List[Tuple[str, FailureReason, str]]] = None
|
||||||
|
failure_reasons: Dict[str, int] = None
|
||||||
|
skipped_persons: List[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.per_person_used is None:
|
||||||
|
self.per_person_used = {}
|
||||||
|
if self.per_person_failed is None:
|
||||||
|
self.per_person_failed = {}
|
||||||
|
if self.failure_reasons is None:
|
||||||
|
self.failure_reasons = {}
|
||||||
|
if self.skipped_persons is None:
|
||||||
|
self.skipped_persons = []
|
||||||
|
|
||||||
|
def add_failure(self, person: str, img_path: str, reason: FailureReason, detail: str) -> None:
|
||||||
|
self.processed_images += 1
|
||||||
|
self.failed_images += 1
|
||||||
|
self.failure_reasons[str(reason)] = self.failure_reasons.get(str(reason), 0) + 1
|
||||||
|
self.per_person_failed.setdefault(person, []).append((img_path, reason, detail))
|
||||||
|
|
||||||
|
def add_success(self, person: str) -> None:
|
||||||
|
self.processed_images += 1
|
||||||
|
self.ok_images += 1
|
||||||
|
self.per_person_used[person] = self.per_person_used.get(person, 0) + 1
|
||||||
187
inspect_onnx.py
Normal file
187
inspect_onnx.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _np_dtype_from_ort(ort_type: str):
|
||||||
|
t = ort_type.lower()
|
||||||
|
if "float" in t:
|
||||||
|
return np.float32
|
||||||
|
if "uint8" in t:
|
||||||
|
return np.uint8
|
||||||
|
if "int8" in t:
|
||||||
|
return np.int8
|
||||||
|
if "int32" in t:
|
||||||
|
return np.int32
|
||||||
|
if "int64" in t:
|
||||||
|
return np.int64
|
||||||
|
return np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def _print_io(sess) -> None:
|
||||||
|
print("=== 模型信息 ===")
|
||||||
|
for inp in sess.get_inputs():
|
||||||
|
print(f"Input: {inp.name}, shape: {inp.shape}, dtype: {inp.type}")
|
||||||
|
for out in sess.get_outputs():
|
||||||
|
print(f"Output: {out.name}, shape: {out.shape}, dtype: {out.type}")
|
||||||
|
|
||||||
|
|
||||||
|
def _describe_array(name: str, arr: np.ndarray, max_show: int = 8) -> None:
|
||||||
|
a = np.asarray(arr)
|
||||||
|
flat = a.reshape(-1)
|
||||||
|
if flat.size:
|
||||||
|
mn = float(flat.min())
|
||||||
|
mx = float(flat.max())
|
||||||
|
else:
|
||||||
|
mn = float("nan")
|
||||||
|
mx = float("nan")
|
||||||
|
print(f"- {name}: dtype={a.dtype}, shape={list(a.shape)}, min={mn:.6g}, max={mx:.6g}")
|
||||||
|
if flat.size:
|
||||||
|
head = flat[:max_show]
|
||||||
|
print(f" head[{len(head)}]: {np.array2string(head, precision=6, separator=', ')}")
|
||||||
|
|
||||||
|
|
||||||
|
def _guess_det_outputs(out_by_name: Dict[str, Any]) -> None:
|
||||||
|
print("\n=== 输出候选分析(用于写 det_outputs_config) ===")
|
||||||
|
for name, a in out_by_name.items():
|
||||||
|
x = np.asarray(a)
|
||||||
|
shp = list(x.shape)
|
||||||
|
|
||||||
|
is_bbox = False
|
||||||
|
is_lmk10 = False
|
||||||
|
is_lmk5x2 = False
|
||||||
|
is_score = False
|
||||||
|
|
||||||
|
if (x.ndim == 2 and shp[1] == 4) or (x.ndim == 3 and shp[-1] == 4):
|
||||||
|
is_bbox = True
|
||||||
|
if (x.ndim == 2 and shp[1] == 10) or (x.ndim == 3 and shp[-1] == 10):
|
||||||
|
is_lmk10 = True
|
||||||
|
if x.ndim >= 3 and shp[-2:] == [5, 2]:
|
||||||
|
is_lmk5x2 = True
|
||||||
|
if (x.ndim == 2 and shp[1] == 1) or (x.ndim == 3 and shp[-1] == 1) or (x.ndim == 1):
|
||||||
|
is_score = True
|
||||||
|
|
||||||
|
tags: List[str] = []
|
||||||
|
if is_bbox:
|
||||||
|
tags.append("bbox?")
|
||||||
|
if is_lmk10:
|
||||||
|
tags.append("lmk10?")
|
||||||
|
if is_lmk5x2:
|
||||||
|
tags.append("lmk5x2?")
|
||||||
|
if is_score:
|
||||||
|
tags.append("score?")
|
||||||
|
tag_str = (" " + ",".join(tags)) if tags else ""
|
||||||
|
print(f"* {name}: shape={shp}{tag_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_image_input(
|
||||||
|
inp_meta,
|
||||||
|
image_path: str,
|
||||||
|
color: str,
|
||||||
|
size_wh: Optional[Tuple[int, int]],
|
||||||
|
norm: str,
|
||||||
|
) -> np.ndarray:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is None:
|
||||||
|
raise FileNotFoundError(f"cv2.imread failed: {image_path}")
|
||||||
|
|
||||||
|
color = color.upper()
|
||||||
|
if color == "RGB":
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
elif color != "BGR":
|
||||||
|
raise ValueError("--color must be BGR or RGB")
|
||||||
|
|
||||||
|
shp = inp_meta.shape
|
||||||
|
if len(shp) != 4:
|
||||||
|
raise ValueError(f"unsupported input rank={len(shp)}: shape={shp}")
|
||||||
|
|
||||||
|
layout = None
|
||||||
|
if shp[1] == 3:
|
||||||
|
layout = "NCHW"
|
||||||
|
h = shp[2] if isinstance(shp[2], int) else None
|
||||||
|
w = shp[3] if isinstance(shp[3], int) else None
|
||||||
|
elif shp[3] == 3:
|
||||||
|
layout = "NHWC"
|
||||||
|
h = shp[1] if isinstance(shp[1], int) else None
|
||||||
|
w = shp[2] if isinstance(shp[2], int) else None
|
||||||
|
else:
|
||||||
|
layout = "NCHW"
|
||||||
|
h = shp[2] if isinstance(shp[2], int) else None
|
||||||
|
w = shp[3] if isinstance(shp[3], int) else None
|
||||||
|
|
||||||
|
if size_wh is not None:
|
||||||
|
w, h = int(size_wh[0]), int(size_wh[1])
|
||||||
|
else:
|
||||||
|
if h is None or w is None:
|
||||||
|
raise ValueError(f"dynamic H/W input; please pass --size W H. got shape={shp}")
|
||||||
|
w, h = int(w), int(h)
|
||||||
|
|
||||||
|
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
np_dtype = _np_dtype_from_ort(inp_meta.type)
|
||||||
|
x = img.astype(np_dtype)
|
||||||
|
|
||||||
|
if np_dtype == np.float32:
|
||||||
|
if norm == "none":
|
||||||
|
pass
|
||||||
|
elif norm == "div255":
|
||||||
|
x = x / 255.0
|
||||||
|
elif norm == "arcface_128":
|
||||||
|
x = (x - 127.5) / 128.0
|
||||||
|
elif norm == "arcface_1275":
|
||||||
|
x = (x - 127.5) / 127.5
|
||||||
|
else:
|
||||||
|
raise ValueError("--norm must be none|div255|arcface_128|arcface_1275")
|
||||||
|
|
||||||
|
if layout == "NCHW":
|
||||||
|
x = np.transpose(x, (2, 0, 1))
|
||||||
|
x = np.expand_dims(x, axis=0)
|
||||||
|
else:
|
||||||
|
x = np.expand_dims(x, axis=0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
ap = argparse.ArgumentParser(description="Inspect ONNX model IO + run one inference and print outputs")
|
||||||
|
ap.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mobilefacenet_arcface_bs1.onnx",
|
||||||
|
help="onnx path (default: mobilefacenet_arcface_bs1.onnx)",
|
||||||
|
)
|
||||||
|
ap.add_argument("--image", help="optional image path for one dry-run inference")
|
||||||
|
ap.add_argument("--size", nargs=2, type=int, metavar=("W", "H"), help="override input resize W H for dynamic shapes")
|
||||||
|
ap.add_argument("--color", default="RGB", choices=["BGR", "RGB"], help="input color order (default RGB)")
|
||||||
|
ap.add_argument("--norm", default="arcface_128", choices=["none", "div255", "arcface_128", "arcface_1275"], help="float input normalization")
|
||||||
|
ap.add_argument("--providers", default="CPUExecutionProvider")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
sess = ort.InferenceSession(args.model, providers=[args.providers])
|
||||||
|
_print_io(sess)
|
||||||
|
|
||||||
|
if not args.image:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
inp0 = sess.get_inputs()[0]
|
||||||
|
size_wh = (int(args.size[0]), int(args.size[1])) if args.size else None
|
||||||
|
x = _make_image_input(inp0, args.image, args.color, size_wh=size_wh, norm=args.norm)
|
||||||
|
|
||||||
|
print("\n=== 单次推理输出(用于看输出名/shape/范围) ===")
|
||||||
|
outputs = sess.run(None, {inp0.name: x})
|
||||||
|
out_names = [o.name for o in sess.get_outputs()]
|
||||||
|
out_by_name = {n: v for n, v in zip(out_names, outputs)}
|
||||||
|
for n in out_names:
|
||||||
|
_describe_array(n, out_by_name[n])
|
||||||
|
|
||||||
|
_guess_det_outputs(out_by_name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Loading…
Reference in New Issue
Block a user