Store multiple embeddings per person

This commit is contained in:
tian 2026-04-16 10:41:26 +08:00
parent 3ea3f9069f
commit fd7e004d47
4 changed files with 84 additions and 32 deletions

View File

@ -1,8 +1,8 @@
PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python+ONNX
PRD全离线生成 `face_gallery.db`(每人多条 embeddingWindows+Python+ONNX
1. 背景与目标
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成 1 条 512D
人脸特征(centroid),写入 SQLite 数据库 face_gallery.db。
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成多条 512D
人脸特征(每张有效照片 1 条 embedding),写入 SQLite 数据库 face_gallery.db。
• RK3588 运行时 ai_face_recog 以 gallery.backend=sqlite
读取该库,实现识别(无需改线上阈值/配置逻辑)。
@ -14,7 +14,6 @@ PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python
• 数据集扫描(按人目录)
• 人脸检测(含 5 点关键点)+ 对齐到 112×112
• ArcFace/MobileFaceNet ONNX 推理生成 512D embedding
• 每人多图聚合为 centroidL2 normalize 后平均,再 L2 normalize
• 生成 SQLiteperson、embedding 两表
• 生成构建报告(统计、异常、可选相似度抽检)
@ -94,7 +93,7 @@ PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python
──────────────────────────────────────────
6. 特征生成与聚合(每人 1 条 centroid
6. 特征生成与入库(每人多条 embedding
对每张有效图片:
1. 检测 → 选择人脸(默认:最大脸)
@ -104,10 +103,8 @@ PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python
对每个人:
• 取最多 max_imgs_per_person 张有效样本 embedding不足也可
• centroid 计算:
• centroid = mean(emb_i)(对已归一化 emb 做均值)
• centroid = L2_normalize(centroid)
• 只写入 1 条 embeddingcentroid
• 每张有效样本保留 1 条 embedding
• 同一 person 可写入多条 embedding
──────────────────────────────────────────
@ -134,9 +131,9 @@ PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python
写入规则
• person.name = 文件夹名
• embedding.emb = centroid 向量的 float32 原始字节BLOB
• embedding.emb = 单张样本 embedding 的 float32 原始字节BLOB
• 长度必须为 expected_dim * 4512→2048 bytes
• 每个 person 仅 1 行 embedding
• 每个 person 可对应多行 embedding
──────────────────────────────────────────
@ -184,7 +181,7 @@ PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python
loaded: n=<入库人数> dim=512
2. SQLite 校验:
• SELECT COUNT(*) FROM person; == 入库人数
• SELECT COUNT(*) FROM embedding; == 入库人数
• SELECT COUNT(*) FROM embedding; >= 入库人数
• SELECT length(emb) FROM embedding LIMIT 5; 全为 2048
3. 实测:对至少 3 人,每人 3 张注册图 + 1
条现场图,识别能输出对应姓名(阈值用现有配置不强制修改)
@ -195,5 +192,4 @@ PRD全离线生成 `face_gallery.db`(每人 1 条 centroidWindows+Python
• gallery.backend="sqlite"
• gallery.path 指向拷贝后的 face_gallery.db
• gallery.expected_dim=512
• threshold.margin 保持现有值即可(因为每人只有 1 条 centroid不会出现“同人占
top2 导致 margin 过小”的问题)
• 设备端检索需要按 person 聚合,避免同人多条 embedding 互相占据 top2

View File

@ -7,7 +7,6 @@ from typing import Any, Dict, List, Tuple
import numpy as np
from gallery_builder.aggregate import compute_centroid
from gallery_builder.align import align_face_5pts
from gallery_builder.dataset import DatasetScanner
from gallery_builder.db import GalleryDbWriter, db_selfcheck
@ -56,7 +55,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
)
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
enrolled: List[Tuple[str, np.ndarray]] = []
enrolled: List[Tuple[str, List[np.ndarray]]] = []
for person in scanner.iter_persons():
used_embs: List[np.ndarray] = []
@ -107,8 +106,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
report.add_success(person.name)
if used_embs:
centroid = compute_centroid(np.stack(used_embs, axis=0))
enrolled.append((person.name, centroid))
enrolled.append((person.name, list(used_embs)))
else:
report.skipped_persons.append(person.name)

View File

@ -28,8 +28,23 @@ class GalleryDbWriter:
self.db_path = os.path.abspath(db_path)
self.expected_dim = int(expected_dim)
def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None:
"""items: (person_name, centroid_float32(D,)). Overwrites existing db file."""
def _iter_embeddings(self, emb_or_embs) -> Iterable[np.ndarray]:
if isinstance(emb_or_embs, np.ndarray):
arr = np.asarray(emb_or_embs, dtype=np.float32)
if arr.ndim == 1:
yield arr.reshape(-1)
return
if arr.ndim == 2:
for row in arr:
yield np.asarray(row, dtype=np.float32).reshape(-1)
return
raise ValueError(f"unsupported embedding ndarray ndim: {arr.ndim}")
for emb in emb_or_embs:
yield np.asarray(emb, dtype=np.float32).reshape(-1)
def write(self, items: Iterable[Tuple[str, object]]) -> None:
"""items: (person_name, embedding_or_embeddings). Overwrites existing db file."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
@ -40,20 +55,23 @@ class GalleryDbWriter:
conn.executescript(_SCHEMA_SQL)
cur = conn.cursor()
cur.execute("BEGIN")
for name, emb in items:
emb = np.asarray(emb, dtype=np.float32).reshape(-1)
if emb.size != self.expected_dim:
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
blob = emb.astype(np.float32).tobytes()
if len(blob) != self.expected_dim * 4:
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
for name, emb_or_embs in items:
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
person_id = cur.lastrowid
cur.execute(
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
(person_id, sqlite3.Binary(blob)),
)
insert_count = 0
for emb in self._iter_embeddings(emb_or_embs):
if emb.size != self.expected_dim:
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
blob = emb.astype(np.float32).tobytes()
if len(blob) != self.expected_dim * 4:
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
cur.execute(
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
(person_id, sqlite3.Binary(blob)),
)
insert_count += 1
if insert_count == 0:
raise ValueError(f"no embeddings provided for {name}")
conn.commit()
except Exception:
conn.rollback()

40
tests/test_db_writer.py Normal file
View File

@ -0,0 +1,40 @@
import sqlite3
import tempfile
import unittest
from pathlib import Path
import numpy as np
from gallery_builder.db import GalleryDbWriter
class GalleryDbWriterTest(unittest.TestCase):
def test_writes_multiple_embeddings_for_same_person(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
db_path = Path(tmp) / "face_gallery.db"
writer = GalleryDbWriter(str(db_path), expected_dim=4)
writer.write(
[
("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]),
("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]),
]
)
conn = sqlite3.connect(db_path)
try:
cur = conn.cursor()
cur.execute("select count(*) from person")
self.assertEqual(cur.fetchone()[0], 2)
cur.execute("select count(*) from embedding")
self.assertEqual(cur.fetchone()[0], 3)
cur.execute(
"select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name"
)
self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)])
finally:
conn.close()
if __name__ == "__main__":
unittest.main()