AddFaceTo3588/inspect_onnx.py
2026-01-08 13:46:50 +08:00

188 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import argparse
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
def _np_dtype_from_ort(ort_type: str):
t = ort_type.lower()
if "float" in t:
return np.float32
if "uint8" in t:
return np.uint8
if "int8" in t:
return np.int8
if "int32" in t:
return np.int32
if "int64" in t:
return np.int64
return np.float32
def _print_io(sess) -> None:
print("=== 模型信息 ===")
for inp in sess.get_inputs():
print(f"Input: {inp.name}, shape: {inp.shape}, dtype: {inp.type}")
for out in sess.get_outputs():
print(f"Output: {out.name}, shape: {out.shape}, dtype: {out.type}")
def _describe_array(name: str, arr: np.ndarray, max_show: int = 8) -> None:
a = np.asarray(arr)
flat = a.reshape(-1)
if flat.size:
mn = float(flat.min())
mx = float(flat.max())
else:
mn = float("nan")
mx = float("nan")
print(f"- {name}: dtype={a.dtype}, shape={list(a.shape)}, min={mn:.6g}, max={mx:.6g}")
if flat.size:
head = flat[:max_show]
print(f" head[{len(head)}]: {np.array2string(head, precision=6, separator=', ')}")
def _guess_det_outputs(out_by_name: Dict[str, Any]) -> None:
print("\n=== 输出候选分析(用于写 det_outputs_config ===")
for name, a in out_by_name.items():
x = np.asarray(a)
shp = list(x.shape)
is_bbox = False
is_lmk10 = False
is_lmk5x2 = False
is_score = False
if (x.ndim == 2 and shp[1] == 4) or (x.ndim == 3 and shp[-1] == 4):
is_bbox = True
if (x.ndim == 2 and shp[1] == 10) or (x.ndim == 3 and shp[-1] == 10):
is_lmk10 = True
if x.ndim >= 3 and shp[-2:] == [5, 2]:
is_lmk5x2 = True
if (x.ndim == 2 and shp[1] == 1) or (x.ndim == 3 and shp[-1] == 1) or (x.ndim == 1):
is_score = True
tags: List[str] = []
if is_bbox:
tags.append("bbox?")
if is_lmk10:
tags.append("lmk10?")
if is_lmk5x2:
tags.append("lmk5x2?")
if is_score:
tags.append("score?")
tag_str = (" " + ",".join(tags)) if tags else ""
print(f"* {name}: shape={shp}{tag_str}")
def _make_image_input(
inp_meta,
image_path: str,
color: str,
size_wh: Optional[Tuple[int, int]],
norm: str,
) -> np.ndarray:
import cv2
img = cv2.imread(image_path)
if img is None:
raise FileNotFoundError(f"cv2.imread failed: {image_path}")
color = color.upper()
if color == "RGB":
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif color != "BGR":
raise ValueError("--color must be BGR or RGB")
shp = inp_meta.shape
if len(shp) != 4:
raise ValueError(f"unsupported input rank={len(shp)}: shape={shp}")
layout = None
if shp[1] == 3:
layout = "NCHW"
h = shp[2] if isinstance(shp[2], int) else None
w = shp[3] if isinstance(shp[3], int) else None
elif shp[3] == 3:
layout = "NHWC"
h = shp[1] if isinstance(shp[1], int) else None
w = shp[2] if isinstance(shp[2], int) else None
else:
layout = "NCHW"
h = shp[2] if isinstance(shp[2], int) else None
w = shp[3] if isinstance(shp[3], int) else None
if size_wh is not None:
w, h = int(size_wh[0]), int(size_wh[1])
else:
if h is None or w is None:
raise ValueError(f"dynamic H/W input; please pass --size W H. got shape={shp}")
w, h = int(w), int(h)
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR)
np_dtype = _np_dtype_from_ort(inp_meta.type)
x = img.astype(np_dtype)
if np_dtype == np.float32:
if norm == "none":
pass
elif norm == "div255":
x = x / 255.0
elif norm == "arcface_128":
x = (x - 127.5) / 128.0
elif norm == "arcface_1275":
x = (x - 127.5) / 127.5
else:
raise ValueError("--norm must be none|div255|arcface_128|arcface_1275")
if layout == "NCHW":
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, axis=0)
else:
x = np.expand_dims(x, axis=0)
return x
def main() -> int:
ap = argparse.ArgumentParser(description="Inspect ONNX model IO + run one inference and print outputs")
ap.add_argument(
"--model",
default="mobilefacenet_arcface_bs1.onnx",
help="onnx path (default: mobilefacenet_arcface_bs1.onnx)",
)
ap.add_argument("--image", help="optional image path for one dry-run inference")
ap.add_argument("--size", nargs=2, type=int, metavar=("W", "H"), help="override input resize W H for dynamic shapes")
ap.add_argument("--color", default="RGB", choices=["BGR", "RGB"], help="input color order (default RGB)")
ap.add_argument("--norm", default="arcface_128", choices=["none", "div255", "arcface_128", "arcface_1275"], help="float input normalization")
ap.add_argument("--providers", default="CPUExecutionProvider")
args = ap.parse_args()
import onnxruntime as ort
sess = ort.InferenceSession(args.model, providers=[args.providers])
_print_io(sess)
if not args.image:
return 0
inp0 = sess.get_inputs()[0]
size_wh = (int(args.size[0]), int(args.size[1])) if args.size else None
x = _make_image_input(inp0, args.image, args.color, size_wh=size_wh, norm=args.norm)
print("\n=== 单次推理输出(用于看输出名/shape/范围) ===")
outputs = sess.run(None, {inp0.name: x})
out_names = [o.name for o in sess.get_outputs()]
out_by_name = {n: v for n, v in zip(out_names, outputs)}
for n in out_names:
_describe_array(n, out_by_name[n])
_guess_det_outputs(out_by_name)
return 0
if __name__ == "__main__":
raise SystemExit(main())