41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import sqlite3
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from gallery_builder.db import GalleryDbWriter
|
|
|
|
|
|
class GalleryDbWriterTest(unittest.TestCase):
|
|
def test_writes_multiple_embeddings_for_same_person(self) -> None:
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
db_path = Path(tmp) / "face_gallery.db"
|
|
writer = GalleryDbWriter(str(db_path), expected_dim=4)
|
|
|
|
writer.write(
|
|
[
|
|
("alice", [np.array([1, 0, 0, 0], dtype=np.float32), np.array([0, 1, 0, 0], dtype=np.float32)]),
|
|
("bob", [np.array([0, 0, 1, 0], dtype=np.float32)]),
|
|
]
|
|
)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
try:
|
|
cur = conn.cursor()
|
|
cur.execute("select count(*) from person")
|
|
self.assertEqual(cur.fetchone()[0], 2)
|
|
cur.execute("select count(*) from embedding")
|
|
self.assertEqual(cur.fetchone()[0], 3)
|
|
cur.execute(
|
|
"select p.name, count(*) from embedding e join person p on p.id = e.person_id group by p.name order by p.name"
|
|
)
|
|
self.assertEqual(cur.fetchall(), [("alice", 2), ("bob", 1)])
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|