184 lines
4.6 KiB
Python
184 lines
4.6 KiB
Python
import argparse
|
||
|
||
import onnx
|
||
from onnx import utils as onnx_utils
|
||
from rknn.api import RKNN
|
||
|
||
|
||
PASSTHRU_OPS = {
|
||
"Identity",
|
||
"Reshape",
|
||
"Squeeze",
|
||
"Unsqueeze",
|
||
"Flatten",
|
||
"Transpose",
|
||
"Cast",
|
||
}
|
||
|
||
|
||
NORM_CHAIN_OPS = {
|
||
"ReduceL2",
|
||
"Sqrt",
|
||
"Rsqrt",
|
||
"Reciprocal",
|
||
"Pow",
|
||
"Clip",
|
||
"Max",
|
||
"Min",
|
||
"Add",
|
||
"Sub",
|
||
"Mul",
|
||
"Div",
|
||
"ReduceSum",
|
||
"ReduceMean",
|
||
"Expand",
|
||
"Reshape",
|
||
"Squeeze",
|
||
"Unsqueeze",
|
||
"Cast",
|
||
}
|
||
|
||
|
||
def _build_producer_map(model: onnx.ModelProto):
|
||
prod = {}
|
||
for node in model.graph.node:
|
||
for o in node.output:
|
||
if o:
|
||
prod[o] = node
|
||
return prod
|
||
|
||
|
||
def _is_reduce_l2_chain(tensor_name: str, prod_map) -> bool:
|
||
cur = tensor_name
|
||
for _ in range(50):
|
||
node = prod_map.get(cur)
|
||
if node is None:
|
||
return False
|
||
if node.op_type == "ReduceL2":
|
||
return True
|
||
if node.op_type not in NORM_CHAIN_OPS:
|
||
return False
|
||
if not node.input:
|
||
return False
|
||
cur = node.input[0]
|
||
return False
|
||
|
||
|
||
def _find_prenorm_tensor(model: onnx.ModelProto, output_tensor: str) -> str | None:
|
||
prod = _build_producer_map(model)
|
||
|
||
cur = output_tensor
|
||
for _ in range(100):
|
||
node = prod.get(cur)
|
||
if node is None:
|
||
return None
|
||
|
||
if node.op_type in PASSTHRU_OPS and node.input:
|
||
cur = node.input[0]
|
||
continue
|
||
|
||
if node.op_type == "Div" and len(node.input) >= 2:
|
||
a, b = node.input[0], node.input[1]
|
||
a_is = _is_reduce_l2_chain(a, prod)
|
||
b_is = _is_reduce_l2_chain(b, prod)
|
||
if b_is and not a_is:
|
||
return a
|
||
if a_is and not b_is:
|
||
return b
|
||
return None
|
||
|
||
if node.op_type == "Mul" and len(node.input) >= 2:
|
||
a, b = node.input[0], node.input[1]
|
||
a_is = _is_reduce_l2_chain(a, prod)
|
||
b_is = _is_reduce_l2_chain(b, prod)
|
||
if b_is and not a_is:
|
||
return a
|
||
if a_is and not b_is:
|
||
return b
|
||
return None
|
||
|
||
return None
|
||
|
||
return None
|
||
|
||
|
||
def _get_real_input_names(model: onnx.ModelProto):
|
||
init_names = {i.name for i in model.graph.initializer}
|
||
inputs = []
|
||
for i in model.graph.input:
|
||
if i.name and i.name not in init_names:
|
||
inputs.append(i.name)
|
||
return inputs
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--onnx", default="mobilefacenet_arcface_bs1.onnx")
|
||
ap.add_argument("--out", default="mobilefacenet_arcface_prenorm.rknn")
|
||
ap.add_argument("--target", default="rk3588")
|
||
ap.add_argument("--onnx_out", default="mobilefacenet_arcface_prenorm.onnx")
|
||
ap.add_argument(
|
||
"--pre_norm_tensor",
|
||
default="",
|
||
help="手动指定“归一化前 embedding”的张量名(自动识别失败时用 Netron 查到后填这里)",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
model = onnx.load(args.onnx)
|
||
onnx.checker.check_model(model)
|
||
|
||
if not model.graph.output:
|
||
raise SystemExit("ERROR: ONNX has no graph outputs")
|
||
|
||
orig_out = model.graph.output[0].name
|
||
pre_norm = args.pre_norm_tensor.strip()
|
||
if not pre_norm:
|
||
pre_norm = _find_prenorm_tensor(model, orig_out) or ""
|
||
|
||
if not pre_norm:
|
||
raise SystemExit(
|
||
"ERROR: 自动寻找 pre-norm embedding 失败。\n"
|
||
"请用 Netron 打开 ONNX,找到输出归一化(L2Norm/ReduceL2)之前的 512D 张量名,"
|
||
"然后用 --pre_norm_tensor <name> 重新运行。"
|
||
)
|
||
|
||
input_names = _get_real_input_names(model)
|
||
if not input_names:
|
||
raise SystemExit("ERROR: ONNX has no real inputs")
|
||
in0 = input_names[0]
|
||
|
||
onnx_utils.extract_model(args.onnx, args.onnx_out, [in0], [pre_norm])
|
||
print(f"[OK] Extracted pre-norm ONNX: {args.onnx_out}")
|
||
print(f" input={in0}")
|
||
print(f" output(pre_norm)={pre_norm}")
|
||
|
||
rknn = RKNN(verbose=True)
|
||
|
||
rknn.config(
|
||
target_platform=args.target,
|
||
mean_values=[[127.5, 127.5, 127.5]],
|
||
std_values=[[128.0, 128.0, 128.0]],
|
||
)
|
||
|
||
ret = rknn.load_onnx(
|
||
model=args.onnx_out,
|
||
input_size_list=[[112, 112, 3]],
|
||
)
|
||
if ret != 0:
|
||
raise SystemExit(f"ERROR: load_onnx failed ret={ret}")
|
||
|
||
ret = rknn.build(do_quantization=False)
|
||
if ret != 0:
|
||
raise SystemExit(f"ERROR: build failed ret={ret}")
|
||
|
||
ret = rknn.export_rknn(args.out)
|
||
if ret != 0:
|
||
raise SystemExit(f"ERROR: export_rknn failed ret={ret}")
|
||
|
||
rknn.release()
|
||
print("OK:", args.out)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|