168 lines
6.3 KiB
Python
168 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from gallery_builder.aggregate import compute_centroid
|
|
from gallery_builder.align import align_face_5pts
|
|
from gallery_builder.dataset import DatasetScanner
|
|
from gallery_builder.db import GalleryDbWriter, db_selfcheck
|
|
from gallery_builder.detector import OnnxFaceDetector, load_det_outputs_config
|
|
from gallery_builder.recognizer import OnnxFaceRecognizer
|
|
from gallery_builder.types import BuildReport, FailureReason
|
|
|
|
|
|
def _parse_bool(s: str) -> bool:
|
|
if isinstance(s, bool):
|
|
return s
|
|
v = str(s).strip().lower()
|
|
if v in ("1", "true", "t", "yes", "y", "on"):
|
|
return True
|
|
if v in ("0", "false", "f", "no", "n", "off"):
|
|
return False
|
|
raise argparse.ArgumentTypeError(f"invalid bool: {s}")
|
|
|
|
|
|
def _bbox_wh(b: np.ndarray) -> Tuple[float, float]:
|
|
return float(b[2] - b[0]), float(b[3] - b[1])
|
|
|
|
|
|
def build_gallery(args: argparse.Namespace) -> Tuple[int, BuildReport]:
|
|
try:
|
|
import cv2
|
|
except Exception as e:
|
|
raise RuntimeError("opencv-python is required") from e
|
|
|
|
det_cfg = load_det_outputs_config(args.det_outputs_config)
|
|
if args.det_input_rgb is not None:
|
|
det_cfg.setdefault("input", {})
|
|
det_cfg["input"]["color"] = "RGB" if args.det_input_rgb else "BGR"
|
|
|
|
scanner = DatasetScanner(args.dataset)
|
|
report = BuildReport()
|
|
persons_total, images_total = scanner.summary()
|
|
report.total_person_dirs = persons_total
|
|
report.total_images = images_total
|
|
|
|
detector = OnnxFaceDetector(
|
|
model_path=args.det_model,
|
|
det_outputs_config=det_cfg,
|
|
score_thresh=args.det_score_thresh,
|
|
pick_face=args.pick_face,
|
|
)
|
|
recognizer = OnnxFaceRecognizer(args.recog_model, expected_dim=args.expected_dim)
|
|
|
|
enrolled: List[Tuple[str, np.ndarray]] = []
|
|
|
|
for person in scanner.iter_persons():
|
|
used_embs: List[np.ndarray] = []
|
|
if not person.image_paths:
|
|
report.skipped_persons.append(person.name)
|
|
for img_path in person.image_paths:
|
|
if len(used_embs) >= args.max_imgs_per_person:
|
|
break
|
|
img_bgr = cv2.imread(img_path)
|
|
if img_bgr is None:
|
|
report.add_failure(person.name, img_path, FailureReason.read_fail, "cv2.imread returned None")
|
|
continue
|
|
|
|
try:
|
|
det = detector.detect_one(img_bgr)
|
|
except Exception as e:
|
|
report.add_failure(person.name, img_path, FailureReason.det_fail, repr(e))
|
|
continue
|
|
|
|
if det is None:
|
|
report.add_failure(person.name, img_path, FailureReason.no_face, "no detection")
|
|
continue
|
|
|
|
bw, bh = _bbox_wh(det.bbox_xyxy)
|
|
if bw < args.min_face_size or bh < args.min_face_size:
|
|
report.add_failure(
|
|
person.name,
|
|
img_path,
|
|
FailureReason.small_face,
|
|
f"bbox too small: w={bw:.1f} h={bh:.1f} < {args.min_face_size}",
|
|
)
|
|
continue
|
|
|
|
try:
|
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
aligned = align_face_5pts(img_rgb, det.landmarks5, out_size=(112, 112))
|
|
except Exception as e:
|
|
report.add_failure(person.name, img_path, FailureReason.align_fail, repr(e))
|
|
continue
|
|
|
|
try:
|
|
emb = recognizer.embed_aligned_rgb112(aligned)
|
|
except Exception as e:
|
|
report.add_failure(person.name, img_path, FailureReason.infer_fail, repr(e))
|
|
continue
|
|
|
|
used_embs.append(emb)
|
|
report.add_success(person.name)
|
|
|
|
if used_embs:
|
|
centroid = compute_centroid(np.stack(used_embs, axis=0))
|
|
enrolled.append((person.name, centroid))
|
|
else:
|
|
report.skipped_persons.append(person.name)
|
|
|
|
report.enrolled_persons = len(enrolled)
|
|
|
|
if args.fail_on_empty and report.enrolled_persons == 0:
|
|
return 2, report
|
|
|
|
writer = GalleryDbWriter(args.db_out, expected_dim=args.expected_dim)
|
|
writer.write(enrolled)
|
|
|
|
selfcheck = db_selfcheck(args.db_out, expected_dim=args.expected_dim, sample_n=args.selfcheck_samples)
|
|
from gallery_builder.report import format_report
|
|
|
|
print(format_report(report, selfcheck, show_per_person=True))
|
|
if args.print_centroid_norm:
|
|
for name, c in enrolled:
|
|
print(f"centroid_norm {name}: {float(np.linalg.norm(c)):.6f}")
|
|
|
|
if args.fail_on_empty and (selfcheck["person_count"] == 0 or selfcheck["embedding_count"] == 0):
|
|
return 3, report
|
|
if selfcheck["person_count"] != selfcheck["embedding_count"]:
|
|
return 4, report
|
|
if not selfcheck["sample_lengths_ok"]:
|
|
return 5, report
|
|
return 0, report
|
|
|
|
|
|
def main(argv: List[str]) -> int:
|
|
p = argparse.ArgumentParser(description="Build offline face_gallery.db (SQLite) from per-person image folders")
|
|
p.add_argument("--dataset", required=True, help="dataset root: dataset/person_name/*.jpg")
|
|
p.add_argument("--db_out", required=True, help="output sqlite db path")
|
|
p.add_argument("--det_model", required=True, help="face detection onnx")
|
|
p.add_argument("--recog_model", required=True, help="face recognition onnx")
|
|
p.add_argument("--det_outputs_config", required=True, help="JSON string or path to JSON file for detection output mapping")
|
|
|
|
p.add_argument("--expected_dim", type=int, default=512)
|
|
p.add_argument("--max_imgs_per_person", type=int, default=10)
|
|
p.add_argument("--pick_face", choices=["largest", "first", "highest_score"], default="largest")
|
|
p.add_argument("--min_face_size", type=int, default=80)
|
|
p.add_argument("--fail_on_empty", type=_parse_bool, default=False)
|
|
|
|
p.add_argument("--det_score_thresh", type=float, default=0.0)
|
|
p.add_argument("--det_input_rgb", type=_parse_bool, default=None, help="override det config input.color (true=RGB,false=BGR)")
|
|
p.add_argument("--selfcheck_samples", type=int, default=5)
|
|
p.add_argument("--print_centroid_norm", type=_parse_bool, default=False)
|
|
|
|
args = p.parse_args(argv)
|
|
os.makedirs(os.path.dirname(os.path.abspath(args.db_out)) or ".", exist_ok=True)
|
|
|
|
code, _report = build_gallery(args)
|
|
return code
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(sys.argv[1:]))
|