Store multiple embeddings per person
This commit is contained in:
parent
3ea3f9069f
commit
fd7e004d47
24
Readme.md
24
Readme.md
@ -1,8 +1,8 @@
|
|||||||
PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python+ONNX)
|
PRD:全离线生成 `face_gallery.db`(每人多条 embedding,Windows+Python+ONNX)
|
||||||
|
|
||||||
1. 背景与目标
|
1. 背景与目标
|
||||||
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成 1 条 512D
|
• 在 Windows 电脑上离线处理“注册照片”,为每个人生成多条 512D
|
||||||
人脸特征(centroid),写入 SQLite 数据库 face_gallery.db。
|
人脸特征(每张有效照片 1 条 embedding),写入 SQLite 数据库 face_gallery.db。
|
||||||
• RK3588 运行时 ai_face_recog 以 gallery.backend=sqlite
|
• RK3588 运行时 ai_face_recog 以 gallery.backend=sqlite
|
||||||
读取该库,实现识别(无需改线上阈值/配置逻辑)。
|
读取该库,实现识别(无需改线上阈值/配置逻辑)。
|
||||||
|
|
||||||
@ -14,7 +14,6 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
|||||||
• 数据集扫描(按人目录)
|
• 数据集扫描(按人目录)
|
||||||
• 人脸检测(含 5 点关键点)+ 对齐到 112×112
|
• 人脸检测(含 5 点关键点)+ 对齐到 112×112
|
||||||
• ArcFace/MobileFaceNet ONNX 推理生成 512D embedding
|
• ArcFace/MobileFaceNet ONNX 推理生成 512D embedding
|
||||||
• 每人多图聚合为 centroid(L2 normalize 后平均,再 L2 normalize)
|
|
||||||
• 生成 SQLite:person、embedding 两表
|
• 生成 SQLite:person、embedding 两表
|
||||||
• 生成构建报告(统计、异常、可选相似度抽检)
|
• 生成构建报告(统计、异常、可选相似度抽检)
|
||||||
|
|
||||||
@ -94,7 +93,7 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
|||||||
|
|
||||||
──────────────────────────────────────────
|
──────────────────────────────────────────
|
||||||
|
|
||||||
6. 特征生成与聚合(每人 1 条 centroid)
|
6. 特征生成与入库(每人多条 embedding)
|
||||||
|
|
||||||
对每张有效图片:
|
对每张有效图片:
|
||||||
1. 检测 → 选择人脸(默认:最大脸)
|
1. 检测 → 选择人脸(默认:最大脸)
|
||||||
@ -104,10 +103,8 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
|||||||
|
|
||||||
对每个人:
|
对每个人:
|
||||||
• 取最多 max_imgs_per_person 张有效样本 embedding(不足也可)
|
• 取最多 max_imgs_per_person 张有效样本 embedding(不足也可)
|
||||||
• centroid 计算:
|
• 每张有效样本保留 1 条 embedding
|
||||||
• centroid = mean(emb_i)(对已归一化 emb 做均值)
|
• 同一 person 可写入多条 embedding
|
||||||
• centroid = L2_normalize(centroid)
|
|
||||||
• 只写入 1 条 embedding(centroid)
|
|
||||||
|
|
||||||
──────────────────────────────────────────
|
──────────────────────────────────────────
|
||||||
|
|
||||||
@ -134,9 +131,9 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
|||||||
|
|
||||||
写入规则
|
写入规则
|
||||||
• person.name = 文件夹名
|
• person.name = 文件夹名
|
||||||
• embedding.emb = centroid 向量的 float32 原始字节(BLOB)
|
• embedding.emb = 单张样本 embedding 的 float32 原始字节(BLOB)
|
||||||
• 长度必须为 expected_dim * 4(512→2048 bytes)
|
• 长度必须为 expected_dim * 4(512→2048 bytes)
|
||||||
• 每个 person 仅 1 行 embedding
|
• 每个 person 可对应多行 embedding
|
||||||
|
|
||||||
──────────────────────────────────────────
|
──────────────────────────────────────────
|
||||||
|
|
||||||
@ -184,7 +181,7 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
|||||||
loaded: n=<入库人数> dim=512
|
loaded: n=<入库人数> dim=512
|
||||||
2. SQLite 校验:
|
2. SQLite 校验:
|
||||||
• SELECT COUNT(*) FROM person; == 入库人数
|
• SELECT COUNT(*) FROM person; == 入库人数
|
||||||
• SELECT COUNT(*) FROM embedding; == 入库人数
|
• SELECT COUNT(*) FROM embedding; >= 入库人数
|
||||||
• SELECT length(emb) FROM embedding LIMIT 5; 全为 2048
|
• SELECT length(emb) FROM embedding LIMIT 5; 全为 2048
|
||||||
3. 实测:对至少 3 人,每人 3 张注册图 + 1
|
3. 实测:对至少 3 人,每人 3 张注册图 + 1
|
||||||
条现场图,识别能输出对应姓名(阈值用现有配置不强制修改)
|
条现场图,识别能输出对应姓名(阈值用现有配置不强制修改)
|
||||||
@ -195,5 +192,4 @@ PRD:全离线生成 `face_gallery.db`(每人 1 条 centroid,Windows+Python
|
|||||||
• gallery.backend="sqlite"
|
• gallery.backend="sqlite"
|
||||||
• gallery.path 指向拷贝后的 face_gallery.db
|
• gallery.path 指向拷贝后的 face_gallery.db
|
||||||
• gallery.expected_dim=512
|
• gallery.expected_dim=512
|
||||||
• threshold.margin 保持现有值即可(因为每人只有 1 条 centroid,不会出现“同人占
|
• 设备端检索需要按 person 聚合,避免同人多条 embedding 互相占据 top2
|
||||||
top2 导致 margin 过小”的问题)
|
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import Any, Dict, List, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gallery_builder.aggregate import compute_centroid
|
|
||||||
from gallery_builder.align import align_face_5pts
|
from gallery_builder.align import align_face_5pts
|
||||||
from gallery_builder.dataset import DatasetScanner
|
from gallery_builder.dataset import DatasetScanner
|
||||||
from gallery_builder.db import GalleryDbWriter, db_selfcheck
|
from gallery_builder.db import GalleryDbWriter, db_selfcheck
|
||||||
@ -56,7 +55,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
|
|||||||
)
|
)
|
||||||
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
|
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
|
||||||
|
|
||||||
enrolled: List[Tuple[str, np.ndarray]] = []
|
enrolled: List[Tuple[str, List[np.ndarray]]] = []
|
||||||
|
|
||||||
for person in scanner.iter_persons():
|
for person in scanner.iter_persons():
|
||||||
used_embs: List[np.ndarray] = []
|
used_embs: List[np.ndarray] = []
|
||||||
@ -107,8 +106,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
|
|||||||
report.add_success(person.name)
|
report.add_success(person.name)
|
||||||
|
|
||||||
if used_embs:
|
if used_embs:
|
||||||
centroid = compute_centroid(np.stack(used_embs, axis=0))
|
enrolled.append((person.name, list(used_embs)))
|
||||||
enrolled.append((person.name, centroid))
|
|
||||||
else:
|
else:
|
||||||
report.skipped_persons.append(person.name)
|
report.skipped_persons.append(person.name)
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,23 @@ class GalleryDbWriter:
|
|||||||
self.db_path = os.path.abspath(db_path)
|
self.db_path = os.path.abspath(db_path)
|
||||||
self.expected_dim = int(expected_dim)
|
self.expected_dim = int(expected_dim)
|
||||||
|
|
||||||
def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None:
|
def _iter_embeddings(self, emb_or_embs) -> Iterable[np.ndarray]:
|
||||||
"""items: (person_name, centroid_float32(D,)). Overwrites existing db file."""
|
if isinstance(emb_or_embs, np.ndarray):
|
||||||
|
arr = np.asarray(emb_or_embs, dtype=np.float32)
|
||||||
|
if arr.ndim == 1:
|
||||||
|
yield arr.reshape(-1)
|
||||||
|
return
|
||||||
|
if arr.ndim == 2:
|
||||||
|
for row in arr:
|
||||||
|
yield np.asarray(row, dtype=np.float32).reshape(-1)
|
||||||
|
return
|
||||||
|
raise ValueError(f"unsupported embedding ndarray ndim: {arr.ndim}")
|
||||||
|
|
||||||
|
for emb in emb_or_embs:
|
||||||
|
yield np.asarray(emb, dtype=np.float32).reshape(-1)
|
||||||
|
|
||||||
|
def write(self, items: Iterable[Tuple[str, object]]) -> None:
|
||||||
|
"""items: (person_name, embedding_or_embeddings). Overwrites existing db file."""
|
||||||
|
|
||||||
if os.path.exists(self.db_path):
|
if os.path.exists(self.db_path):
|
||||||
os.remove(self.db_path)
|
os.remove(self.db_path)
|
||||||
@ -40,20 +55,23 @@ class GalleryDbWriter:
|
|||||||
conn.executescript(_SCHEMA_SQL)
|
conn.executescript(_SCHEMA_SQL)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("BEGIN")
|
cur.execute("BEGIN")
|
||||||
for name, emb in items:
|
for name, emb_or_embs in items:
|
||||||
emb = np.asarray(emb, dtype=np.float32).reshape(-1)
|
|
||||||
if emb.size != self.expected_dim:
|
|
||||||
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
|
|
||||||
blob = emb.astype(np.float32).tobytes()
|
|
||||||
if len(blob) != self.expected_dim * 4:
|
|
||||||
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
|
|
||||||
|
|
||||||
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
|
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
|
||||||
person_id = cur.lastrowid
|
person_id = cur.lastrowid
|
||||||
cur.execute(
|
insert_count = 0
|
||||||
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
|
for emb in self._iter_embeddings(emb_or_embs):
|
||||||
(person_id, sqlite3.Binary(blob)),
|
if emb.size != self.expected_dim:
|
||||||
)
|
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
|
||||||
|
blob = emb.astype(np.float32).tobytes()
|
||||||
|
if len(blob) != self.expected_dim * 4:
|
||||||
|
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
|
||||||
|
(person_id, sqlite3.Binary(blob)),
|
||||||
|
)
|
||||||
|
insert_count += 1
|
||||||
|
if insert_count == 0:
|
||||||
|
raise ValueError(f"no embeddings provided for {name}")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
|
|||||||
40
tests/test_db_writer.py
Normal file
40
tests/test_db_writer.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gallery_builder.db import GalleryDbWriter
|
||||||
|
|
||||||
|
|
||||||
|
class GalleryDbWriterTest(unittest.TestCase):
|
||||||
|
def test_writes_multiple_embeddings_for_same_person(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
db_path = Path(tmp) / "face_gallery.db"
|
||||||
|
writer = GalleryDbWriter(str(db_path), expected_dim=4)
|
||||||
|
|
||||||
|
writer.write(
|
||||||
|
[
|
||||||
|
("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]),
|
||||||
|
("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("select count(*) from person")
|
||||||
|
self.assertEqual(cur.fetchone()[0], 2)
|
||||||
|
cur.execute("select count(*) from embedding")
|
||||||
|
self.assertEqual(cur.fetchone()[0], 3)
|
||||||
|
cur.execute(
|
||||||
|
"select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name"
|
||||||
|
)
|
||||||
|
self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)])
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user