OrangePi3588Media/scripts/onnx2rknn_prenorm.py
sladro 49699263e0
Some checks are pending
CI / host-build (push) Waiting to run
CI / rk3588-cross-build (push) Waiting to run
开始测试人脸识别,修改人脸识别模型
2026-01-08 14:11:41 +08:00

184 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import onnx
from onnx import utils as onnx_utils
from rknn.api import RKNN
PASSTHRU_OPS = {
"Identity",
"Reshape",
"Squeeze",
"Unsqueeze",
"Flatten",
"Transpose",
"Cast",
}
NORM_CHAIN_OPS = {
"ReduceL2",
"Sqrt",
"Rsqrt",
"Reciprocal",
"Pow",
"Clip",
"Max",
"Min",
"Add",
"Sub",
"Mul",
"Div",
"ReduceSum",
"ReduceMean",
"Expand",
"Reshape",
"Squeeze",
"Unsqueeze",
"Cast",
}
def _build_producer_map(model: onnx.ModelProto):
prod = {}
for node in model.graph.node:
for o in node.output:
if o:
prod[o] = node
return prod
def _is_reduce_l2_chain(tensor_name: str, prod_map) -> bool:
cur = tensor_name
for _ in range(50):
node = prod_map.get(cur)
if node is None:
return False
if node.op_type == "ReduceL2":
return True
if node.op_type not in NORM_CHAIN_OPS:
return False
if not node.input:
return False
cur = node.input[0]
return False
def _find_prenorm_tensor(model: onnx.ModelProto, output_tensor: str) -> str | None:
prod = _build_producer_map(model)
cur = output_tensor
for _ in range(100):
node = prod.get(cur)
if node is None:
return None
if node.op_type in PASSTHRU_OPS and node.input:
cur = node.input[0]
continue
if node.op_type == "Div" and len(node.input) >= 2:
a, b = node.input[0], node.input[1]
a_is = _is_reduce_l2_chain(a, prod)
b_is = _is_reduce_l2_chain(b, prod)
if b_is and not a_is:
return a
if a_is and not b_is:
return b
return None
if node.op_type == "Mul" and len(node.input) >= 2:
a, b = node.input[0], node.input[1]
a_is = _is_reduce_l2_chain(a, prod)
b_is = _is_reduce_l2_chain(b, prod)
if b_is and not a_is:
return a
if a_is and not b_is:
return b
return None
return None
return None
def _get_real_input_names(model: onnx.ModelProto):
init_names = {i.name for i in model.graph.initializer}
inputs = []
for i in model.graph.input:
if i.name and i.name not in init_names:
inputs.append(i.name)
return inputs
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--onnx", default="mobilefacenet_arcface_bs1.onnx")
ap.add_argument("--out", default="mobilefacenet_arcface_prenorm.rknn")
ap.add_argument("--target", default="rk3588")
ap.add_argument("--onnx_out", default="mobilefacenet_arcface_prenorm.onnx")
ap.add_argument(
"--pre_norm_tensor",
default="",
help="手动指定“归一化前 embedding”的张量名自动识别失败时用 Netron 查到后填这里)",
)
args = ap.parse_args()
model = onnx.load(args.onnx)
onnx.checker.check_model(model)
if not model.graph.output:
raise SystemExit("ERROR: ONNX has no graph outputs")
orig_out = model.graph.output[0].name
pre_norm = args.pre_norm_tensor.strip()
if not pre_norm:
pre_norm = _find_prenorm_tensor(model, orig_out) or ""
if not pre_norm:
raise SystemExit(
"ERROR: 自动寻找 pre-norm embedding 失败。\n"
"请用 Netron 打开 ONNX找到输出归一化(L2Norm/ReduceL2)之前的 512D 张量名,"
"然后用 --pre_norm_tensor <name> 重新运行。"
)
input_names = _get_real_input_names(model)
if not input_names:
raise SystemExit("ERROR: ONNX has no real inputs")
in0 = input_names[0]
onnx_utils.extract_model(args.onnx, args.onnx_out, [in0], [pre_norm])
print(f"[OK] Extracted pre-norm ONNX: {args.onnx_out}")
print(f" input={in0}")
print(f" output(pre_norm)={pre_norm}")
rknn = RKNN(verbose=True)
rknn.config(
target_platform=args.target,
mean_values=[[127.5, 127.5, 127.5]],
std_values=[[128.0, 128.0, 128.0]],
)
ret = rknn.load_onnx(
model=args.onnx_out,
input_size_list=[[112, 112, 3]],
)
if ret != 0:
raise SystemExit(f"ERROR: load_onnx failed ret={ret}")
ret = rknn.build(do_quantization=False)
if ret != 0:
raise SystemExit(f"ERROR: build failed ret={ret}")
ret = rknn.export_rknn(args.out)
if ret != 0:
raise SystemExit(f"ERROR: export_rknn failed ret={ret}")
rknn.release()
print("OK:", args.out)
if __name__ == "__main__":
main()