AddFaceTo3588/tests/test_db_writer.py

41 lines
1.3 KiB
Python

import sqlite3
import tempfile
import unittest
from pathlib import Path
import numpy as np
from gallery_builder.db import GalleryDbWriter
class GalleryDbWriterTest(unittest.TestCase):
def test_writes_multiple_embeddings_for_same_person(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
db_path = Path(tmp) / "face_gallery.db"
writer = GalleryDbWriter(str(db_path), expected_dim=4)
writer.write(
[
("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]),
("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]),
]
)
conn = sqlite3.connect(db_path)
try:
cur = conn.cursor()
cur.execute("select count(*) from person")
self.assertEqual(cur.fetchone()[0], 2)
cur.execute("select count(*) from embedding")
self.assertEqual(cur.fetchone()[0], 3)
cur.execute(
"select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name"
)
self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)])
finally:
conn.close()
if __name__ == "__main__":
unittest.main()