21 lines
586 B
Python
21 lines
586 B
Python
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
|
|
def l2_normalize(v: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
|
v = np.asarray(v, dtype=np.float32)
|
|
n = float(np.linalg.norm(v))
|
|
if n < eps:
|
|
return v * 0.0
|
|
return v / n
|
|
|
|
|
|
def compute_centroid(embs: np.ndarray) -> np.ndarray:
|
|
"""embs: float32 shape (K,D), expected already L2-normalized per row."""
|
|
embs = np.asarray(embs, dtype=np.float32)
|
|
if embs.ndim != 2 or embs.shape[0] < 1:
|
|
raise ValueError("embs must be (K,D) with K>=1")
|
|
c = embs.mean(axis=0)
|
|
return l2_normalize(c)
|