from __future__ import annotations import argparse from typing import Any, Dict, List, Optional, Tuple import numpy as np def _np_dtype_from_ort(ort_type: str): t = ort_type.lower() if "float" in t: return np.float32 if "uint8" in t: return np.uint8 if "int8" in t: return np.int8 if "int32" in t: return np.int32 if "int64" in t: return np.int64 return np.float32 def _print_io(sess) -> None: print("=== 模型信息 ===") for inp in sess.get_inputs(): print(f"Input: {inp.name}, shape: {inp.shape}, dtype: {inp.type}") for out in sess.get_outputs(): print(f"Output: {out.name}, shape: {out.shape}, dtype: {out.type}") def _describe_array(name: str, arr: np.ndarray, max_show: int = 8) -> None: a = np.asarray(arr) flat = a.reshape(-1) if flat.size: mn = float(flat.min()) mx = float(flat.max()) else: mn = float("nan") mx = float("nan") print(f"- {name}: dtype={a.dtype}, shape={list(a.shape)}, min={mn:.6g}, max={mx:.6g}") if flat.size: head = flat[:max_show] print(f" head[{len(head)}]: {np.array2string(head, precision=6, separator=', ')}") def _guess_det_outputs(out_by_name: Dict[str, Any]) -> None: print("\n=== 输出候选分析(用于写 det_outputs_config) ===") for name, a in out_by_name.items(): x = np.asarray(a) shp = list(x.shape) is_bbox = False is_lmk10 = False is_lmk5x2 = False is_score = False if (x.ndim == 2 and shp[1] == 4) or (x.ndim == 3 and shp[-1] == 4): is_bbox = True if (x.ndim == 2 and shp[1] == 10) or (x.ndim == 3 and shp[-1] == 10): is_lmk10 = True if x.ndim >= 3 and shp[-2:] == [5, 2]: is_lmk5x2 = True if (x.ndim == 2 and shp[1] == 1) or (x.ndim == 3 and shp[-1] == 1) or (x.ndim == 1): is_score = True tags: List[str] = [] if is_bbox: tags.append("bbox?") if is_lmk10: tags.append("lmk10?") if is_lmk5x2: tags.append("lmk5x2?") if is_score: tags.append("score?") tag_str = (" " + ",".join(tags)) if tags else "" print(f"* {name}: shape={shp}{tag_str}") def _make_image_input( inp_meta, image_path: str, color: str, size_wh: Optional[Tuple[int, int]], norm: str, ) -> np.ndarray: import cv2 img = cv2.imread(image_path) if img is None: raise FileNotFoundError(f"cv2.imread failed: {image_path}") color = color.upper() if color == "RGB": img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) elif color != "BGR": raise ValueError("--color must be BGR or RGB") shp = inp_meta.shape if len(shp) != 4: raise ValueError(f"unsupported input rank={len(shp)}: shape={shp}") layout = None if shp[1] == 3: layout = "NCHW" h = shp[2] if isinstance(shp[2], int) else None w = shp[3] if isinstance(shp[3], int) else None elif shp[3] == 3: layout = "NHWC" h = shp[1] if isinstance(shp[1], int) else None w = shp[2] if isinstance(shp[2], int) else None else: layout = "NCHW" h = shp[2] if isinstance(shp[2], int) else None w = shp[3] if isinstance(shp[3], int) else None if size_wh is not None: w, h = int(size_wh[0]), int(size_wh[1]) else: if h is None or w is None: raise ValueError(f"dynamic H/W input; please pass --size W H. got shape={shp}") w, h = int(w), int(h) img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) np_dtype = _np_dtype_from_ort(inp_meta.type) x = img.astype(np_dtype) if np_dtype == np.float32: if norm == "none": pass elif norm == "div255": x = x / 255.0 elif norm == "arcface_128": x = (x - 127.5) / 128.0 elif norm == "arcface_1275": x = (x - 127.5) / 127.5 else: raise ValueError("--norm must be none|div255|arcface_128|arcface_1275") if layout == "NCHW": x = np.transpose(x, (2, 0, 1)) x = np.expand_dims(x, axis=0) else: x = np.expand_dims(x, axis=0) return x def main() -> int: ap = argparse.ArgumentParser(description="Inspect ONNX model IO + run one inference and print outputs") ap.add_argument( "--model", default="mobilefacenet_arcface_bs1.onnx", help="onnx path (default: mobilefacenet_arcface_bs1.onnx)", ) ap.add_argument("--image", help="optional image path for one dry-run inference") ap.add_argument("--size", nargs=2, type=int, metavar=("W", "H"), help="override input resize W H for dynamic shapes") ap.add_argument("--color", default="RGB", choices=["BGR", "RGB"], help="input color order (default RGB)") ap.add_argument("--norm", default="arcface_128", choices=["none", "div255", "arcface_128", "arcface_1275"], help="float input normalization") ap.add_argument("--providers", default="CPUExecutionProvider") args = ap.parse_args() import onnxruntime as ort sess = ort.InferenceSession(args.model, providers=[args.providers]) _print_io(sess) if not args.image: return 0 inp0 = sess.get_inputs()[0] size_wh = (int(args.size[0]), int(args.size[1])) if args.size else None x = _make_image_input(inp0, args.image, args.color, size_wh=size_wh, norm=args.norm) print("\n=== 单次推理输出(用于看输出名/shape/范围) ===") outputs = sess.run(None, {inp0.name: x}) out_names = [o.name for o in sess.get_outputs()] out_by_name = {n: v for n, v in zip(out_names, outputs)} for n in out_names: _describe_array(n, out_by_name[n]) _guess_det_outputs(out_by_name) return 0 if __name__ == "__main__": raise SystemExit(main())