188 lines
5.8 KiB
Python
188 lines
5.8 KiB
Python
from __future__ import annotations
|
||
|
||
import argparse
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
|
||
|
||
def _np_dtype_from_ort(ort_type: str):
|
||
t = ort_type.lower()
|
||
if "float" in t:
|
||
return np.float32
|
||
if "uint8" in t:
|
||
return np.uint8
|
||
if "int8" in t:
|
||
return np.int8
|
||
if "int32" in t:
|
||
return np.int32
|
||
if "int64" in t:
|
||
return np.int64
|
||
return np.float32
|
||
|
||
|
||
def _print_io(sess) -> None:
|
||
print("=== 模型信息 ===")
|
||
for inp in sess.get_inputs():
|
||
print(f"Input: {inp.name}, shape: {inp.shape}, dtype: {inp.type}")
|
||
for out in sess.get_outputs():
|
||
print(f"Output: {out.name}, shape: {out.shape}, dtype: {out.type}")
|
||
|
||
|
||
def _describe_array(name: str, arr: np.ndarray, max_show: int = 8) -> None:
|
||
a = np.asarray(arr)
|
||
flat = a.reshape(-1)
|
||
if flat.size:
|
||
mn = float(flat.min())
|
||
mx = float(flat.max())
|
||
else:
|
||
mn = float("nan")
|
||
mx = float("nan")
|
||
print(f"- {name}: dtype={a.dtype}, shape={list(a.shape)}, min={mn:.6g}, max={mx:.6g}")
|
||
if flat.size:
|
||
head = flat[:max_show]
|
||
print(f" head[{len(head)}]: {np.array2string(head, precision=6, separator=', ')}")
|
||
|
||
|
||
def _guess_det_outputs(out_by_name: Dict[str, Any]) -> None:
|
||
print("\n=== 输出候选分析(用于写 det_outputs_config) ===")
|
||
for name, a in out_by_name.items():
|
||
x = np.asarray(a)
|
||
shp = list(x.shape)
|
||
|
||
is_bbox = False
|
||
is_lmk10 = False
|
||
is_lmk5x2 = False
|
||
is_score = False
|
||
|
||
if (x.ndim == 2 and shp[1] == 4) or (x.ndim == 3 and shp[-1] == 4):
|
||
is_bbox = True
|
||
if (x.ndim == 2 and shp[1] == 10) or (x.ndim == 3 and shp[-1] == 10):
|
||
is_lmk10 = True
|
||
if x.ndim >= 3 and shp[-2:] == [5, 2]:
|
||
is_lmk5x2 = True
|
||
if (x.ndim == 2 and shp[1] == 1) or (x.ndim == 3 and shp[-1] == 1) or (x.ndim == 1):
|
||
is_score = True
|
||
|
||
tags: List[str] = []
|
||
if is_bbox:
|
||
tags.append("bbox?")
|
||
if is_lmk10:
|
||
tags.append("lmk10?")
|
||
if is_lmk5x2:
|
||
tags.append("lmk5x2?")
|
||
if is_score:
|
||
tags.append("score?")
|
||
tag_str = (" " + ",".join(tags)) if tags else ""
|
||
print(f"* {name}: shape={shp}{tag_str}")
|
||
|
||
|
||
def _make_image_input(
|
||
inp_meta,
|
||
image_path: str,
|
||
color: str,
|
||
size_wh: Optional[Tuple[int, int]],
|
||
norm: str,
|
||
) -> np.ndarray:
|
||
import cv2
|
||
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
raise FileNotFoundError(f"cv2.imread failed: {image_path}")
|
||
|
||
color = color.upper()
|
||
if color == "RGB":
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
elif color != "BGR":
|
||
raise ValueError("--color must be BGR or RGB")
|
||
|
||
shp = inp_meta.shape
|
||
if len(shp) != 4:
|
||
raise ValueError(f"unsupported input rank={len(shp)}: shape={shp}")
|
||
|
||
layout = None
|
||
if shp[1] == 3:
|
||
layout = "NCHW"
|
||
h = shp[2] if isinstance(shp[2], int) else None
|
||
w = shp[3] if isinstance(shp[3], int) else None
|
||
elif shp[3] == 3:
|
||
layout = "NHWC"
|
||
h = shp[1] if isinstance(shp[1], int) else None
|
||
w = shp[2] if isinstance(shp[2], int) else None
|
||
else:
|
||
layout = "NCHW"
|
||
h = shp[2] if isinstance(shp[2], int) else None
|
||
w = shp[3] if isinstance(shp[3], int) else None
|
||
|
||
if size_wh is not None:
|
||
w, h = int(size_wh[0]), int(size_wh[1])
|
||
else:
|
||
if h is None or w is None:
|
||
raise ValueError(f"dynamic H/W input; please pass --size W H. got shape={shp}")
|
||
w, h = int(w), int(h)
|
||
|
||
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR)
|
||
np_dtype = _np_dtype_from_ort(inp_meta.type)
|
||
x = img.astype(np_dtype)
|
||
|
||
if np_dtype == np.float32:
|
||
if norm == "none":
|
||
pass
|
||
elif norm == "div255":
|
||
x = x / 255.0
|
||
elif norm == "arcface_128":
|
||
x = (x - 127.5) / 128.0
|
||
elif norm == "arcface_1275":
|
||
x = (x - 127.5) / 127.5
|
||
else:
|
||
raise ValueError("--norm must be none|div255|arcface_128|arcface_1275")
|
||
|
||
if layout == "NCHW":
|
||
x = np.transpose(x, (2, 0, 1))
|
||
x = np.expand_dims(x, axis=0)
|
||
else:
|
||
x = np.expand_dims(x, axis=0)
|
||
|
||
return x
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser(description="Inspect ONNX model IO + run one inference and print outputs")
|
||
ap.add_argument(
|
||
"--model",
|
||
default="mobilefacenet_arcface_bs1.onnx",
|
||
help="onnx path (default: mobilefacenet_arcface_bs1.onnx)",
|
||
)
|
||
ap.add_argument("--image", help="optional image path for one dry-run inference")
|
||
ap.add_argument("--size", nargs=2, type=int, metavar=("W", "H"), help="override input resize W H for dynamic shapes")
|
||
ap.add_argument("--color", default="RGB", choices=["BGR", "RGB"], help="input color order (default RGB)")
|
||
ap.add_argument("--norm", default="arcface_128", choices=["none", "div255", "arcface_128", "arcface_1275"], help="float input normalization")
|
||
ap.add_argument("--providers", default="CPUExecutionProvider")
|
||
args = ap.parse_args()
|
||
|
||
import onnxruntime as ort
|
||
|
||
sess = ort.InferenceSession(args.model, providers=[args.providers])
|
||
_print_io(sess)
|
||
|
||
if not args.image:
|
||
return 0
|
||
|
||
inp0 = sess.get_inputs()[0]
|
||
size_wh = (int(args.size[0]), int(args.size[1])) if args.size else None
|
||
x = _make_image_input(inp0, args.image, args.color, size_wh=size_wh, norm=args.norm)
|
||
|
||
print("\n=== 单次推理输出(用于看输出名/shape/范围) ===")
|
||
outputs = sess.run(None, {inp0.name: x})
|
||
out_names = [o.name for o in sess.get_outputs()]
|
||
out_by_name = {n: v for n, v in zip(out_names, outputs)}
|
||
for n in out_names:
|
||
_describe_array(n, out_by_name[n])
|
||
|
||
_guess_det_outputs(out_by_name)
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|