50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
from .aggregate import l2_normalize
|
|
|
|
|
|
class OnnxFaceRecognizer:
|
|
def __init__(self, model_path: str, expected_dim: int = 512) -> None:
|
|
self.model_path = model_path
|
|
self.expected_dim = int(expected_dim)
|
|
self._sess = None
|
|
self._input_name: Optional[str] = None
|
|
self._output_name: Optional[str] = None
|
|
|
|
def _ensure_session(self) -> None:
|
|
if self._sess is not None:
|
|
return
|
|
try:
|
|
import onnxruntime as ort
|
|
except Exception as e: # pragma: no cover
|
|
raise RuntimeError("onnxruntime is required for recognition") from e
|
|
|
|
self._sess = ort.InferenceSession(self.model_path, providers=["CPUExecutionProvider"])
|
|
self._input_name = self._sess.get_inputs()[0].name
|
|
self._output_name = self._sess.get_outputs()[0].name
|
|
|
|
def embed_aligned_rgb112(self, aligned_rgb112: np.ndarray) -> np.ndarray:
|
|
"""aligned_rgb112: uint8 RGB 112x112x3. Return float32 (D,) L2-normalized."""
|
|
|
|
self._ensure_session()
|
|
x = np.asarray(aligned_rgb112)
|
|
if x.shape[:2] != (112, 112) or x.ndim != 3 or x.shape[2] != 3:
|
|
raise ValueError("aligned image must be 112x112x3 RGB")
|
|
if x.dtype != np.uint8:
|
|
x = x.astype(np.uint8)
|
|
|
|
x = x.astype(np.float32)
|
|
x = (x - 127.5) / 128.0
|
|
x = np.transpose(x, (2, 0, 1)) # CHW
|
|
x = np.expand_dims(x, axis=0) # NCHW
|
|
|
|
out = self._sess.run([self._output_name], {self._input_name: x})[0]
|
|
out = np.asarray(out, dtype=np.float32).reshape(-1)
|
|
if out.size != self.expected_dim:
|
|
raise ValueError(f"unexpected embedding dim: got {out.size}, expected {self.expected_dim}")
|
|
return l2_normalize(out)
|