AddFaceTo3588/gallery_builder/aggregate.py
2026-01-08 13:46:50 +08:00

21 lines
586 B
Python

from __future__ import annotations
import numpy as np
def l2_normalize(v: np.ndarray, eps: float = 1e-12) -> np.ndarray:
v = np.asarray(v, dtype=np.float32)
n = float(np.linalg.norm(v))
if n < eps:
return v * 0.0
return v / n
def compute_centroid(embs: np.ndarray) -> np.ndarray:
"""embs: float32 shape (K,D), expected already L2-normalized per row."""
embs = np.asarray(embs, dtype=np.float32)
if embs.ndim != 2 or embs.shape[0] < 1:
raise ValueError("embs must be (K,D) with K>=1")
c = embs.mean(axis=0)
return l2_normalize(c)