import sqlite3 import tempfile import unittest from pathlib import Path import numpy as np from gallery_builder.db import GalleryDbWriter class GalleryDbWriterTest(unittest.TestCase): def test_writes_multiple_embeddings_for_same_person(self) -> None: with tempfile.TemporaryDirectory() as tmp: db_path = Path(tmp) / "face_gallery.db" writer = GalleryDbWriter(str(db_path), expected_dim=4) writer.write( [ ("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]), ("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]), ] ) conn = sqlite3.connect(db_path) try: cur = conn.cursor() cur.execute("select count(*) from person") self.assertEqual(cur.fetchone()[0], 2) cur.execute("select count(*) from embedding") self.assertEqual(cur.fetchone()[0], 3) cur.execute( "select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name" ) self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)]) finally: conn.close() if __name__ == "__main__": unittest.main()