AddFaceTo3588/build_gallery.py

166 lines
6.2 KiB
Python

from __future__ import annotations
import argparse
import os
import sys
from typing import Any, Dict, List, Tuple
import numpy as np
from gallery_builder.align import align_face_5pts
from gallery_builder.dataset import DatasetScanner
from gallery_builder.db import GalleryDbWriter, db_selfcheck
from gallery_builder.detector import OnnxFaceDetector, load_det_outputs_config
from gallery_builder.recognizer import OnnxFaceRecognizer
from gallery_builder.types import BuildReport, FailureReason
def _parse_bool(s: str) -> bool:
if isinstance(s, bool):
return s
v = str(s).strip().lower()
if v in ("1", "true", "t", "yes", "y", "on"):
return True
if v in ("0", "false", "f", "no", "n", "off"):
return False
raise argparse.ArgumentTypeError(f"invalid bool: {s}")
def _bbox_wh(b: np.ndarray) -> Tuple[float, float]:
return float(b[2] - b[0]), float(b[3] - b[1])
def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
try:
import cv2
except Exception as e:
raise RuntimeError("opencv-python is required") from e
det_cfg = load_det_outputs_config(args.det_outputs_config)
if args.det_input_rgb is not None:
det_cfg.setdefault("input", {})
det_cfg["input"]["color"] = "RGB" if args.det_input_rgb else "BGR"
scanner = DatasetScanner(args.dataset)
report = BuildReport()
persons_total, images_total = scanner.summary()
report.total_person_dirs = persons_total
report.total_images = images_total
detector = OnnxFaceDetector(
model_path=args.det_model,
det_outputs_config=det_cfg,
score_thresh=args.det_score_thresh,
pick_face=args.pick_face,
)
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
enrolled: List[Tuple[str, List[np.ndarray]]] = []
for person in scanner.iter_persons():
used_embs: List[np.ndarray] = []
if not person.image_paths:
report.skipped_persons.append(person.name)
for img_path in person.image_paths:
if len(used_embs) >= args.max_imgs_per_person:
break
img_bgr = cv2.imread(img_path)
if img_bgr is None:
report.add_failure(person.name, img_path, FailureReason.read_fail, "cv2.imread returned None")
continue
try:
det = detector.detect_one(img_bgr)
except Exception as e:
report.add_failure(person.name, img_path, FailureReason.det_fail, repr(e))
continue
if det is None:
report.add_failure(person.name, img_path, FailureReason.no_face, "no detection")
continue
bw, bh = _bbox_wh(det.bbox_xyxy)
if bw < args.min_face_size or bh < args.min_face_size:
report.add_failure(
person.name,
img_path,
FailureReason.small_face,
f"bbox too small: w={bw:.1f} h={bh:.1f} < {args.min_face_size}",
)
continue
try:
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
aligned = align_face_5pts(img_rgb, det.landmarks5, out_size=(112, 112))
except Exception as e:
report.add_failure(person.name, img_path, FailureReason.align_fail, repr(e))
continue
try:
emb = recognizer.embed_aligned_rgb112(aligned)
except Exception as e:
report.add_failure(person.name, img_path, FailureReason.infer_fail, repr(e))
continue
used_embs.append(emb)
report.add_success(person.name)
if used_embs:
enrolled.append((person.name, list(used_embs)))
else:
report.skipped_persons.append(person.name)
report.enrolled_persons = len(enrolled)
if args.fail_on_empty and report.enrolled_persons == 0:
return 2, report
writer = GalleryDbWriter(args.db_out, expected_dim=args.expected_dim)
writer.write(enrolled)
selfcheck = db_selfcheck(args.db_out, expected_dim=args.expected_dim, sample_n=args.selfcheck_samples)
from gallery_builder.report import format_report
print(format_report(report, selfcheck, show_per_person=True))
if args.print_centroid_norm:
for name, c in enrolled:
print(f"centroid_norm {name}: {float(np.linalg.norm(c)):.6f}")
if args.fail_on_empty and (selfcheck["person_count"] == 0 or selfcheck["embedding_count"] == 0):
return 3, report
if selfcheck["person_count"] != selfcheck["embedding_count"]:
return 4, report
if not selfcheck["sample_lengths_ok"]:
return 5, report
return 0, report
def main(argv: List[str]) -> int:
p = argparse.ArgumentParser(description="Build offline face_gallery.db (SQLite) from per-person image folders")
p.add_argument("--dataset", required=True, help="dataset root: dataset/person_name/*.jpg")
p.add_argument("--db_out", required=True, help="output sqlite db path")
p.add_argument("--det_model", required=True, help="face detection onnx")
p.add_argument("--recog_model", required=True, help="face recognition onnx")
p.add_argument("--det_outputs_config", required=True, help="JSON string or path to JSON file for detection output mapping")
p.add_argument("--expected_dim", type=int, default=512)
p.add_argument("--max_imgs_per_person", type=int, default=10)
p.add_argument("--pick_face", choices=["largest", "first", "highest_score"], default="largest")
p.add_argument("--min_face_size", type=int, default=80)
p.add_argument("--fail_on_empty", type=_parse_bool, default=False)
p.add_argument("--det_score_thresh", type=float, default=0.0)
p.add_argument("--det_input_rgb", type=_parse_bool, default=None, help="override det config input.color (true=RGB,false=BGR)")
p.add_argument("--selfcheck_samples", type=int, default=5)
p.add_argument("--print_centroid_norm", type=_parse_bool, default=False)
args = p.parse_args(argv)
os.makedirs(os.path.dirname(os.path.abspath(args.db_out)) or ".", exist_ok=True)
code, _report = build_gallery(args)
return code
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))