from __future__ import annotations import os import sqlite3 from typing import Iterable, Tuple import numpy as np _SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS person ( id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE ); CREATE TABLE IF NOT EXISTS embedding ( person_id INTEGER NOT NULL, emb BLOB NOT NULL, FOREIGN KEY(person_id) REFERENCES person(id) ); CREATE INDEX IF NOT EXISTS idx_embedding_person_id ON embedding(person_id); """ class GalleryDbWriter: def __init__(self, db_path: str, expected_dim: int = 512) -> None: self.db_path = os.path.abspath(db_path) self.expected_dim = int(expected_dim) def _iter_embeddings(self, emb_or_embs) -> Iterable[np.ndarray]: if isinstance(emb_or_embs, np.ndarray): arr = np.asarray(emb_or_embs, dtype=np.float32) if arr.ndim == 1: yield arr.reshape(-1) return if arr.ndim == 2: for row in arr: yield np.asarray(row, dtype=np.float32).reshape(-1) return raise ValueError(f"unsupported embedding ndarray ndim: {arr.ndim}") for emb in emb_or_embs: yield np.asarray(emb, dtype=np.float32).reshape(-1) def write(self, items: Iterable[Tuple[str, object]]) -> None: """items: (person_name, embedding_or_embeddings). Overwrites existing db file.""" if os.path.exists(self.db_path): os.remove(self.db_path) os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True) conn = sqlite3.connect(self.db_path) try: conn.executescript(_SCHEMA_SQL) cur = conn.cursor() cur.execute("BEGIN") for name, emb_or_embs in items: cur.execute("INSERT INTO person(name) VALUES(?)", (name,)) person_id = cur.lastrowid insert_count = 0 for emb in self._iter_embeddings(emb_or_embs): if emb.size != self.expected_dim: raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}") blob = emb.astype(np.float32).tobytes() if len(blob) != self.expected_dim * 4: raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes") cur.execute( "INSERT INTO embedding(person_id, emb) VALUES(?, ?) ", (person_id, sqlite3.Binary(blob)), ) insert_count += 1 if insert_count == 0: raise ValueError(f"no embeddings provided for {name}") conn.commit() except Exception: conn.rollback() raise finally: conn.close() def db_selfcheck(db_path: str, expected_dim: int = 512, sample_n: int = 5) -> dict: db_path = os.path.abspath(db_path) conn = sqlite3.connect(db_path) try: cur = conn.cursor() cur.execute("SELECT COUNT(*) FROM person") person_cnt = int(cur.fetchone()[0]) cur.execute("SELECT COUNT(*) FROM embedding") emb_cnt = int(cur.fetchone()[0]) cur.execute("SELECT length(emb) FROM embedding ORDER BY RANDOM() LIMIT ?", (int(sample_n),)) lengths = [int(r[0]) for r in cur.fetchall()] ok_len = all(l == expected_dim * 4 for l in lengths) if lengths else True return { "person_count": person_cnt, "embedding_count": emb_cnt, "sample_lengths": lengths, "sample_lengths_ok": ok_len, } finally: conn.close()