Store multiple embeddings per person
This commit is contained in:
parent
3ea3f9069f
commit
fd7e004d47
24
Readme.md
24
Readme.md
@ -1,8 +1,8 @@
|
||||
PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python+ONNX)
|
||||
PRD:全离线生成 `face_gallery.db`(每人多条 embedding,Windows+Python+ONNX)
|
||||
|
||||
1. 背景与目标
|
||||
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成 1 条 512D
|
||||
人脸特征(centroid),写入 SQLite 数据库 face_gallery.db。
|
||||
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成多条 512D
|
||||
人脸特征(每张有效照片 1 条 embedding),写入 SQLite 数据库 face_gallery.db。
|
||||
• RK3588 运行时 ai_face_recog 以 gallery.backend=sqlite
|
||||
读取该库,实现识别(无需改线上阈值/配置逻辑)。
|
||||
|
||||
@ -14,7 +14,6 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
||||
• 数据集扫描(按人目录)
|
||||
• 人脸检测(含 5 点关键点)+ 对齐到 112×112
|
||||
• ArcFace/MobileFaceNet ONNX 推理生成 512D embedding
|
||||
• 每人多图聚合为 centroid(L2 normalize 后平均,再 L2 normalize)
|
||||
• 生成 SQLite:person、embedding 两表
|
||||
• 生成构建报告(统计、异常、可选相似度抽检)
|
||||
|
||||
@ -94,7 +93,7 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
||||
|
||||
──────────────────────────────────────────
|
||||
|
||||
6. 特征生成与聚合(每人 1 条 centroid)
|
||||
6. 特征生成与入库(每人多条 embedding)
|
||||
|
||||
对每张有效图片:
|
||||
1. 检测 → 选择人脸(默认:最大脸)
|
||||
@ -104,10 +103,8 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
||||
|
||||
对每个人:
|
||||
• 取最多 max_imgs_per_person 张有效样本 embedding(不足也可)
|
||||
• centroid 计算:
|
||||
• centroid = mean(emb_i)(对已归一化 emb 做均值)
|
||||
• centroid = L2_normalize(centroid)
|
||||
• 只写入 1 条 embedding(centroid)
|
||||
• 每张有效样本保留 1 条 embedding
|
||||
• 同一 person 可写入多条 embedding
|
||||
|
||||
──────────────────────────────────────────
|
||||
|
||||
@ -134,9 +131,9 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
||||
|
||||
写入规则
|
||||
• person.name = 文件夹名
|
||||
• embedding.emb = centroid 向量的 float32 原始字节(BLOB)
|
||||
• embedding.emb = 单张样本 embedding 的 float32 原始字节(BLOB)
|
||||
• 长度必须为 expected_dim * 4(512→2048 bytes)
|
||||
• 每个 person 仅 1 行 embedding
|
||||
• 每个 person 可对应多行 embedding
|
||||
|
||||
──────────────────────────────────────────
|
||||
|
||||
@ -184,7 +181,7 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
||||
loaded: n=<入库人数> dim=512
|
||||
2. SQLite 校验:
|
||||
• SELECT COUNT(*) FROM person; == 入库人数
|
||||
• SELECT COUNT(*) FROM embedding; == 入库人数
|
||||
• SELECT COUNT(*) FROM embedding; >= 入库人数
|
||||
• SELECT length(emb) FROM embedding LIMIT 5; 全为 2048
|
||||
3. 实测:对至少 3 人,每人 3 张注册图 + 1
|
||||
条现场图,识别能输出对应姓名(阈值用现有配置不强制修改)
|
||||
@ -195,5 +192,4 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
||||
• gallery.backend="sqlite"
|
||||
• gallery.path 指向拷贝后的 face_gallery.db
|
||||
• gallery.expected_dim=512
|
||||
• threshold.margin 保持现有值即可(因为每人只有 1 条 centroid,不会出现“同人占
|
||||
top2 导致 margin 过小”的问题)
|
||||
• 设备端检索需要按 person 聚合,避免同人多条 embedding 互相占据 top2
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gallery_builder.aggregate import compute_centroid
|
||||
from gallery_builder.align import align_face_5pts
|
||||
from gallery_builder.dataset import DatasetScanner
|
||||
from gallery_builder.db import GalleryDbWriter, db_selfcheck
|
||||
@ -56,7 +55,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
|
||||
)
|
||||
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
|
||||
|
||||
enrolled: List[Tuple[str, np.ndarray]] = []
|
||||
enrolled: List[Tuple[str, List[np.ndarray]]] = []
|
||||
|
||||
for person in scanner.iter_persons():
|
||||
used_embs: List[np.ndarray] = []
|
||||
@ -107,8 +106,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
|
||||
report.add_success(person.name)
|
||||
|
||||
if used_embs:
|
||||
centroid = compute_centroid(np.stack(used_embs, axis=0))
|
||||
enrolled.append((person.name, centroid))
|
||||
enrolled.append((person.name, list(used_embs)))
|
||||
else:
|
||||
report.skipped_persons.append(person.name)
|
||||
|
||||
|
||||
@ -28,8 +28,23 @@ class GalleryDbWriter:
|
||||
self.db_path = os.path.abspath(db_path)
|
||||
self.expected_dim = int(expected_dim)
|
||||
|
||||
def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None:
|
||||
"""items: (person_name, centroid_float32(D,)). Overwrites existing db file."""
|
||||
def _iter_embeddings(self, emb_or_embs) -> Iterable[np.ndarray]:
|
||||
if isinstance(emb_or_embs, np.ndarray):
|
||||
arr = np.asarray(emb_or_embs, dtype=np.float32)
|
||||
if arr.ndim == 1:
|
||||
yield arr.reshape(-1)
|
||||
return
|
||||
if arr.ndim == 2:
|
||||
for row in arr:
|
||||
yield np.asarray(row, dtype=np.float32).reshape(-1)
|
||||
return
|
||||
raise ValueError(f"unsupported embedding ndarray ndim: {arr.ndim}")
|
||||
|
||||
for emb in emb_or_embs:
|
||||
yield np.asarray(emb, dtype=np.float32).reshape(-1)
|
||||
|
||||
def write(self, items: Iterable[Tuple[str, object]]) -> None:
|
||||
"""items: (person_name, embedding_or_embeddings). Overwrites existing db file."""
|
||||
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
@ -40,20 +55,23 @@ class GalleryDbWriter:
|
||||
conn.executescript(_SCHEMA_SQL)
|
||||
cur = conn.cursor()
|
||||
cur.execute("BEGIN")
|
||||
for name, emb in items:
|
||||
emb = np.asarray(emb, dtype=np.float32).reshape(-1)
|
||||
if emb.size != self.expected_dim:
|
||||
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
|
||||
blob = emb.astype(np.float32).tobytes()
|
||||
if len(blob) != self.expected_dim * 4:
|
||||
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
|
||||
|
||||
for name, emb_or_embs in items:
|
||||
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
|
||||
person_id = cur.lastrowid
|
||||
cur.execute(
|
||||
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
|
||||
(person_id, sqlite3.Binary(blob)),
|
||||
)
|
||||
insert_count = 0
|
||||
for emb in self._iter_embeddings(emb_or_embs):
|
||||
if emb.size != self.expected_dim:
|
||||
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
|
||||
blob = emb.astype(np.float32).tobytes()
|
||||
if len(blob) != self.expected_dim * 4:
|
||||
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
|
||||
cur.execute(
|
||||
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
|
||||
(person_id, sqlite3.Binary(blob)),
|
||||
)
|
||||
insert_count += 1
|
||||
if insert_count == 0:
|
||||
raise ValueError(f"no embeddings provided for {name}")
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
|
||||
40
tests/test_db_writer.py
Normal file
40
tests/test_db_writer.py
Normal file
@ -0,0 +1,40 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gallery_builder.db import GalleryDbWriter
|
||||
|
||||
|
||||
class GalleryDbWriterTest(unittest.TestCase):
|
||||
def test_writes_multiple_embeddings_for_same_person(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
db_path = Path(tmp) / "face_gallery.db"
|
||||
writer = GalleryDbWriter(str(db_path), expected_dim=4)
|
||||
|
||||
writer.write(
|
||||
[
|
||||
("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]),
|
||||
("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]),
|
||||
]
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("select count(*) from person")
|
||||
self.assertEqual(cur.fetchone()[0], 2)
|
||||
cur.execute("select count(*) from embedding")
|
||||
self.assertEqual(cur.fetchone()[0], 3)
|
||||
cur.execute(
|
||||
"select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name"
|
||||
)
|
||||
self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)])
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user