From fd7e004d4778c0ee3ac4d2065771c39d2e652d4f Mon Sep 17 00:00:00 2001 From: tian <11429339@qq.com> Date: Thu, 16 Apr 2026 10:41:26 +0800 Subject: [PATCH] Store multiple embeddings per person --- Readme.md | 24 +++++++++------------ build_gallery.py | 6 ++---- gallery_builder/db.py | 46 ++++++++++++++++++++++++++++------------- tests/test_db_writer.py | 40 +++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 32 deletions(-) create mode 100644 tests/test_db_writer.py diff --git a/Readme.md b/Readme.md index d14200d..a70bca3 100644 --- a/Readme.md +++ b/Readme.md @@ -1,8 +1,8 @@ -PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python+ONNX) +PRD:全离线生成 `face_gallery.db`(每人多条 embedding,Windows+Python+ONNX) 1. 背景与目标 - • 在 Windows 电脑上离线处理“注册照片”,为每个人生成 1 条 512D - 人脸特征(centroid),写入 SQLite 数据库 face_gallery.db。 + • 在 Windows 电脑上离线处理“注册照片”,为每个人生成多条 512D + 人脸特征(每张有效照片 1 条 embedding),写入 SQLite 数据库 face_gallery.db。 • RK3588 运行时 ai_face_recog 以 gallery.backend=sqlite 读取该库,实现识别(无需改线上阈值/配置逻辑)。 @@ -14,7 +14,6 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python • 数据集扫描(按人目录) • 人脸检测(含 5 点关键点)+ 对齐到 112×112 • ArcFace/MobileFaceNet ONNX 推理生成 512D embedding - • 每人多图聚合为 centroid(L2 normalize 后平均,再 L2 normalize) • 生成 SQLite:person、embedding 两表 • 生成构建报告(统计、异常、可选相似度抽检) @@ -94,7 +93,7 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python ────────────────────────────────────────── - 6. 特征生成与聚合(每人 1 条 centroid) + 6. 特征生成与入库(每人多条 embedding) 对每张有效图片: 1. 检测 → 选择人脸(默认:最大脸) @@ -104,10 +103,8 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python 对每个人: • 取最多 max_imgs_per_person 张有效样本 embedding(不足也可) - • centroid 计算: - • centroid = mean(emb_i)(对已归一化 emb 做均值) - • centroid = L2_normalize(centroid) - • 只写入 1 条 embedding(centroid) + • 每张有效样本保留 1 条 embedding + • 同一 person 可写入多条 embedding ────────────────────────────────────────── @@ -134,9 +131,9 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python 写入规则 • person.name = 文件夹名 - • embedding.emb = centroid 向量的 float32 原始字节(BLOB) + • embedding.emb = 单张样本 embedding 的 float32 原始字节(BLOB) • 长度必须为 expected_dim * 4(512→2048 bytes) - • 每个 person 仅 1 行 embedding + • 每个 person 可对应多行 embedding ────────────────────────────────────────── @@ -184,7 +181,7 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python loaded: n=<入库人数> dim=512 2. SQLite 校验: • SELECT COUNT(*) FROM person; == 入库人数 - • SELECT COUNT(*) FROM embedding; == 入库人数 + • SELECT COUNT(*) FROM embedding; >= 入库人数 • SELECT length(emb) FROM embedding LIMIT 5; 全为 2048 3. 实测:对至少 3 人,每人 3 张注册图 + 1 条现场图,识别能输出对应姓名(阈值用现有配置不强制修改) @@ -195,5 +192,4 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python • gallery.backend="sqlite" • gallery.path 指向拷贝后的 face_gallery.db • gallery.expected_dim=512 - • threshold.margin 保持现有值即可(因为每人只有 1 条 centroid,不会出现“同人占 - top2 导致 margin 过小”的问题) \ No newline at end of file + • 设备端检索需要按 person 聚合,避免同人多条 embedding 互相占据 top2 diff --git a/build_gallery.py b/build_gallery.py index 90d21a5..753b34f 100644 --- a/build_gallery.py +++ b/build_gallery.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Tuple import numpy as np -from gallery_builder.aggregate import compute_centroid from gallery_builder.align import align_face_5pts from gallery_builder.dataset import DatasetScanner from gallery_builder.db import GalleryDbWriter, db_selfcheck @@ -56,7 +55,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]: ) recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim) - enrolled: List[Tuple[str, np.ndarray]] = [] + enrolled: List[Tuple[str, List[np.ndarray]]] = [] for person in scanner.iter_persons(): used_embs: List[np.ndarray] = [] @@ -107,8 +106,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]: report.add_success(person.name) if used_embs: - centroid = compute_centroid(np.stack(used_embs, axis=0)) - enrolled.append((person.name, centroid)) + enrolled.append((person.name, list(used_embs))) else: report.skipped_persons.append(person.name) diff --git a/gallery_builder/db.py b/gallery_builder/db.py index b0f6fb6..9e855cd 100644 --- a/gallery_builder/db.py +++ b/gallery_builder/db.py @@ -28,8 +28,23 @@ class GalleryDbWriter: self.db_path = os.path.abspath(db_path) self.expected_dim = int(expected_dim) - def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None: - """items: (person_name, centroid_float32(D,)). Overwrites existing db file.""" + def _iter_embeddings(self, emb_or_embs) -> Iterable[np.ndarray]: + if isinstance(emb_or_embs, np.ndarray): + arr = np.asarray(emb_or_embs, dtype=np.float32) + if arr.ndim == 1: + yield arr.reshape(-1) + return + if arr.ndim == 2: + for row in arr: + yield np.asarray(row, dtype=np.float32).reshape(-1) + return + raise ValueError(f"unsupported embedding ndarray ndim: {arr.ndim}") + + for emb in emb_or_embs: + yield np.asarray(emb, dtype=np.float32).reshape(-1) + + def write(self, items: Iterable[Tuple[str, object]]) -> None: + """items: (person_name, embedding_or_embeddings). Overwrites existing db file.""" if os.path.exists(self.db_path): os.remove(self.db_path) @@ -40,20 +55,23 @@ class GalleryDbWriter: conn.executescript(_SCHEMA_SQL) cur = conn.cursor() cur.execute("BEGIN") - for name, emb in items: - emb = np.asarray(emb, dtype=np.float32).reshape(-1) - if emb.size != self.expected_dim: - raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}") - blob = emb.astype(np.float32).tobytes() - if len(blob) != self.expected_dim * 4: - raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes") - + for name, emb_or_embs in items: cur.execute("INSERT INTO person(name) VALUES(?)", (name,)) person_id = cur.lastrowid - cur.execute( - "INSERT INTO embedding(person_id, emb) VALUES(?, ?) ", - (person_id, sqlite3.Binary(blob)), - ) + insert_count = 0 + for emb in self._iter_embeddings(emb_or_embs): + if emb.size != self.expected_dim: + raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}") + blob = emb.astype(np.float32).tobytes() + if len(blob) != self.expected_dim * 4: + raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes") + cur.execute( + "INSERT INTO embedding(person_id, emb) VALUES(?, ?) ", + (person_id, sqlite3.Binary(blob)), + ) + insert_count += 1 + if insert_count == 0: + raise ValueError(f"no embeddings provided for {name}") conn.commit() except Exception: conn.rollback() diff --git a/tests/test_db_writer.py b/tests/test_db_writer.py new file mode 100644 index 0000000..9f2959c --- /dev/null +++ b/tests/test_db_writer.py @@ -0,0 +1,40 @@ +import sqlite3 +import tempfile +import unittest +from pathlib import Path + +import numpy as np + +from gallery_builder.db import GalleryDbWriter + + +class GalleryDbWriterTest(unittest.TestCase): + def test_writes_multiple_embeddings_for_same_person(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + db_path = Path(tmp) / "face_gallery.db" + writer = GalleryDbWriter(str(db_path), expected_dim=4) + + writer.write( + [ + ("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]), + ("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]), + ] + ) + + conn = sqlite3.connect(db_path) + try: + cur = conn.cursor() + cur.execute("select count(*) from person") + self.assertEqual(cur.fetchone()[0], 2) + cur.execute("select count(*) from embedding") + self.assertEqual(cur.fetchone()[0], 3) + cur.execute( + "select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name" + ) + self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)]) + finally: + conn.close() + + +if __name__ == "__main__": + unittest.main()