AddFaceTo3588/gallery_builder/recognizer.py
2026-01-08 13:46:50 +08:00

50 lines
1.8 KiB
Python

from __future__ import annotations
from typing import Optional
import numpy as np
from .aggregate import l2_normalize
class OnnxFaceRecognizer:
def __init__(self, model_path: str, expected_dim: int = 512) -> None:
self.model_path = model_path
self.expected_dim = int(expected_dim)
self._sess = None
self._input_name: Optional[str] = None
self._output_name: Optional[str] = None
def _ensure_session(self) -> None:
if self._sess is not None:
return
try:
import onnxruntime as ort
except Exception as e: # pragma: no cover
raise RuntimeError("onnxruntime is required for recognition") from e
self._sess = ort.InferenceSession(self.model_path, providers=["CPUExecutionProvider"])
self._input_name = self._sess.get_inputs()[0].name
self._output_name = self._sess.get_outputs()[0].name
def embed_aligned_rgb112(self, aligned_rgb112: np.ndarray) -> np.ndarray:
"""aligned_rgb112: uint8 RGB 112x112x3. Return float32 (D,) L2-normalized."""
self._ensure_session()
x = np.asarray(aligned_rgb112)
if x.shape[:2] != (112, 112) or x.ndim != 3 or x.shape[2] != 3:
raise ValueError("aligned image must be 112x112x3 RGB")
if x.dtype != np.uint8:
x = x.astype(np.uint8)
x = x.astype(np.float32)
x = (x - 127.5) / 128.0
x = np.transpose(x, (2, 0, 1)) # CHW
x = np.expand_dims(x, axis=0) # NCHW
out = self._sess.run([self._output_name], {self._input_name: x})[0]
out = np.asarray(out, dtype=np.float32).reshape(-1)
if out.size != self.expected_dim:
raise ValueError(f"unexpected embedding dim: got {out.size}, expected {self.expected_dim}")
return l2_normalize(out)