AddFaceTo3588/gallery_builder/db.py
2026-01-08 13:46:50 +08:00

85 lines
2.8 KiB
Python

from __future__ import annotations
import os
import sqlite3
from typing import Iterable, Tuple
import numpy as np
_SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS person (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS embedding (
person_id INTEGER NOT NULL,
emb BLOB NOT NULL,
FOREIGN KEY(person_id) REFERENCES person(id)
);
CREATE INDEX IF NOT EXISTS idx_embedding_person_id ON embedding(person_id);
"""
class GalleryDbWriter:
def __init__(self, db_path: str, expected_dim: int = 512) -> None:
self.db_path = os.path.abspath(db_path)
self.expected_dim = int(expected_dim)
def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None:
"""items: (person_name, centroid_float32(D,)). Overwrites existing db file."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True)
conn = sqlite3.connect(self.db_path)
try:
conn.executescript(_SCHEMA_SQL)
cur = conn.cursor()
cur.execute("BEGIN")
for name, emb in items:
emb = np.asarray(emb, dtype=np.float32).reshape(-1)
if emb.size != self.expected_dim:
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
blob = emb.astype(np.float32).tobytes()
if len(blob) != self.expected_dim * 4:
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
person_id = cur.lastrowid
cur.execute(
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
(person_id, sqlite3.Binary(blob)),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def db_selfcheck(db_path: str, expected_dim: int = 512, sample_n: int = 5) -> dict:
db_path = os.path.abspath(db_path)
conn = sqlite3.connect(db_path)
try:
cur = conn.cursor()
cur.execute("SELECT COUNT(*) FROM person")
person_cnt = int(cur.fetchone()[0])
cur.execute("SELECT COUNT(*) FROM embedding")
emb_cnt = int(cur.fetchone()[0])
cur.execute("SELECT length(emb) FROM embedding ORDER BY RANDOM() LIMIT ?", (int(sample_n),))
lengths = [int(r[0]) for r in cur.fetchall()]
ok_len = all(l == expected_dim * 4 for l in lengths) if lengths else True
return {
"person_count": person_cnt,
"embedding_count": emb_cnt,
"sample_lengths": lengths,
"sample_lengths_ok": ok_len,
}
finally:
conn.close()