85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import sqlite3
|
|
from typing import Iterable, Tuple
|
|
|
|
import numpy as np
|
|
|
|
|
|
_SCHEMA_SQL = """
|
|
CREATE TABLE IF NOT EXISTS person (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL UNIQUE
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS embedding (
|
|
person_id INTEGER NOT NULL,
|
|
emb BLOB NOT NULL,
|
|
FOREIGN KEY(person_id) REFERENCES person(id)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_embedding_person_id ON embedding(person_id);
|
|
"""
|
|
|
|
|
|
class GalleryDbWriter:
|
|
def __init__(self, db_path: str, expected_dim: int = 512) -> None:
|
|
self.db_path = os.path.abspath(db_path)
|
|
self.expected_dim = int(expected_dim)
|
|
|
|
def write(self, items: Iterable[Tuple[str, np.ndarray]]) -> None:
|
|
"""items: (person_name, centroid_float32(D,)). Overwrites existing db file."""
|
|
|
|
if os.path.exists(self.db_path):
|
|
os.remove(self.db_path)
|
|
|
|
os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True)
|
|
conn = sqlite3.connect(self.db_path)
|
|
try:
|
|
conn.executescript(_SCHEMA_SQL)
|
|
cur = conn.cursor()
|
|
cur.execute("BEGIN")
|
|
for name, emb in items:
|
|
emb = np.asarray(emb, dtype=np.float32).reshape(-1)
|
|
if emb.size != self.expected_dim:
|
|
raise ValueError(f"embedding dim mismatch for {name}: got {emb.size}, expected {self.expected_dim}")
|
|
blob = emb.astype(np.float32).tobytes()
|
|
if len(blob) != self.expected_dim * 4:
|
|
raise ValueError(f"embedding blob size mismatch for {name}: got {len(blob)} bytes")
|
|
|
|
cur.execute("INSERT INTO person(name) VALUES(?)", (name,))
|
|
person_id = cur.lastrowid
|
|
cur.execute(
|
|
"INSERT INTO embedding(person_id, emb) VALUES(?, ?) ",
|
|
(person_id, sqlite3.Binary(blob)),
|
|
)
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def db_selfcheck(db_path: str, expected_dim: int = 512, sample_n: int = 5) -> dict:
|
|
db_path = os.path.abspath(db_path)
|
|
conn = sqlite3.connect(db_path)
|
|
try:
|
|
cur = conn.cursor()
|
|
cur.execute("SELECT COUNT(*) FROM person")
|
|
person_cnt = int(cur.fetchone()[0])
|
|
cur.execute("SELECT COUNT(*) FROM embedding")
|
|
emb_cnt = int(cur.fetchone()[0])
|
|
cur.execute("SELECT length(emb) FROM embedding ORDER BY RANDOM() LIMIT ?", (int(sample_n),))
|
|
lengths = [int(r[0]) for r in cur.fetchall()]
|
|
ok_len = all(l == expected_dim * 4 for l in lengths) if lengths else True
|
|
return {
|
|
"person_count": person_cnt,
|
|
"embedding_count": emb_cnt,
|
|
"sample_lengths": lengths,
|
|
"sample_lengths_ok": ok_len,
|
|
}
|
|
finally:
|
|
conn.close()
|