diff --git a/build_gallery.py b/build_gallery.py index 753b34f..b52d6e1 100644 --- a/build_gallery.py +++ b/build_gallery.py @@ -30,6 +30,26 @@ def _bbox_wh(b: np.ndarray) -> Tuple[float, float]: return float(b[2] - b[0]), float(b[3] - b[1]) +def _selfcheck_exit_code( + selfcheck: Dict[str, Any], + report: BuildReport, + *, + fail_on_empty: bool, +) -> int: + person_count = int(selfcheck["person_count"]) + embedding_count = int(selfcheck["embedding_count"]) + + if fail_on_empty and (person_count == 0 or embedding_count == 0): + return 3 + if embedding_count < person_count: + return 4 + if embedding_count != report.ok_images: + return 4 + if not selfcheck["sample_lengths_ok"]: + return 5 + return 0 + + def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]: try: import cv2 @@ -126,13 +146,7 @@ def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]: for name, c in enrolled: print(f"centroid_norm {name}: {float(np.linalg.norm(c)):.6f}") - if args.fail_on_empty and (selfcheck["person_count"] == 0 or selfcheck["embedding_count"] == 0): - return 3, report - if selfcheck["person_count"] != selfcheck["embedding_count"]: - return 4, report - if not selfcheck["sample_lengths_ok"]: - return 5, report - return 0, report + return _selfcheck_exit_code(selfcheck, report, fail_on_empty=args.fail_on_empty), report def main(argv: List[str]) -> int: diff --git a/tests/test_build_gallery_selfcheck.py b/tests/test_build_gallery_selfcheck.py new file mode 100644 index 0000000..ec50d8f --- /dev/null +++ b/tests/test_build_gallery_selfcheck.py @@ -0,0 +1,34 @@ +import unittest + +from build_gallery import _selfcheck_exit_code +from gallery_builder.types import BuildReport + + +class BuildGallerySelfcheckTest(unittest.TestCase): + def test_accepts_multiple_embeddings_per_person(self) -> None: + report = BuildReport(ok_images=3) + selfcheck = {"person_count": 2, "embedding_count": 3, "sample_lengths_ok": True} + + self.assertEqual(_selfcheck_exit_code(selfcheck, report, fail_on_empty=True), 0) + + def test_rejects_missing_embedding_for_person(self) -> None: + report = BuildReport(ok_images=1) + selfcheck = {"person_count": 2, "embedding_count": 1, "sample_lengths_ok": True} + + self.assertEqual(_selfcheck_exit_code(selfcheck, report, fail_on_empty=True), 4) + + def test_rejects_embedding_count_that_does_not_match_successful_images(self) -> None: + report = BuildReport(ok_images=2) + selfcheck = {"person_count": 2, "embedding_count": 3, "sample_lengths_ok": True} + + self.assertEqual(_selfcheck_exit_code(selfcheck, report, fail_on_empty=False), 4) + + def test_rejects_bad_embedding_lengths(self) -> None: + report = BuildReport(ok_images=2) + selfcheck = {"person_count": 2, "embedding_count": 2, "sample_lengths_ok": False} + + self.assertEqual(_selfcheck_exit_code(selfcheck, report, fail_on_empty=False), 5) + + +if __name__ == "__main__": + unittest.main()