264 lines
8.5 KiB
Python
264 lines
8.5 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def run_cmd(cmd):
|
|
print("[CMD]", " ".join(cmd))
|
|
ret = subprocess.run(cmd, check=False)
|
|
if ret.returncode != 0:
|
|
raise RuntimeError(f"command failed ({ret.returncode}): {' '.join(cmd)}")
|
|
|
|
|
|
def export_pt_to_onnx(pt_path: Path, onnx_path: Path, imgsz: int, opset: int):
|
|
try:
|
|
import torch
|
|
from ultralytics import YOLO # type: ignore
|
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, ClassificationModel, PoseModel
|
|
|
|
# Fix for PyTorch 2.6+ weights_only default change
|
|
torch.serialization.add_safe_globals([DetectionModel, SegmentationModel, ClassificationModel, PoseModel])
|
|
|
|
print("[INFO] exporting via ultralytics python API")
|
|
model = YOLO(str(pt_path))
|
|
exported = model.export(
|
|
format="onnx",
|
|
imgsz=imgsz,
|
|
opset=opset,
|
|
simplify=True,
|
|
dynamic=False,
|
|
nms=False,
|
|
)
|
|
exported_path = Path(str(exported))
|
|
if exported_path.exists() and exported_path.resolve() != onnx_path.resolve():
|
|
onnx_path.write_bytes(exported_path.read_bytes())
|
|
except Exception as e:
|
|
print(f"[WARN] ultralytics API export failed: {e}")
|
|
print("[INFO] fallback to CLI: yolo export ...")
|
|
run_cmd(
|
|
[
|
|
"yolo",
|
|
"export",
|
|
f"model={pt_path}",
|
|
"format=onnx",
|
|
f"imgsz={imgsz}",
|
|
f"opset={opset}",
|
|
"simplify=True",
|
|
"dynamic=False",
|
|
"nms=False",
|
|
]
|
|
)
|
|
default_onnx = pt_path.with_suffix(".onnx")
|
|
if not default_onnx.exists():
|
|
raise RuntimeError(f"ONNX not found after export: {default_onnx}")
|
|
if default_onnx.resolve() != onnx_path.resolve():
|
|
onnx_path.write_bytes(default_onnx.read_bytes())
|
|
|
|
if not onnx_path.exists():
|
|
raise RuntimeError(f"ONNX export failed, file missing: {onnx_path}")
|
|
print(f"[OK] ONNX: {onnx_path}")
|
|
|
|
|
|
def _value_info_shape(value_info):
|
|
if value_info is None:
|
|
return []
|
|
dims = []
|
|
tt = value_info.type.tensor_type
|
|
for d in tt.shape.dim:
|
|
if d.HasField("dim_value"):
|
|
dims.append(int(d.dim_value))
|
|
else:
|
|
dims.append(-1)
|
|
return dims
|
|
|
|
|
|
def _find_value_info(model, name):
|
|
for vi in list(model.graph.value_info) + list(model.graph.output) + list(model.graph.input):
|
|
if vi.name == name:
|
|
return vi
|
|
return None
|
|
|
|
|
|
def ensure_v8_output_layout_cxn(in_onnx: Path, out_onnx: Path):
|
|
try:
|
|
import onnx
|
|
from onnx import TensorProto, helper, shape_inference
|
|
except Exception as e:
|
|
raise RuntimeError(f"missing dependency 'onnx': {e}")
|
|
|
|
model = onnx.load(str(in_onnx))
|
|
model = shape_inference.infer_shapes(model)
|
|
if len(model.graph.output) != 1:
|
|
raise RuntimeError(
|
|
f"expect single output for this media-server v8 path, got {len(model.graph.output)}"
|
|
)
|
|
|
|
out = model.graph.output[0]
|
|
out_name = out.name
|
|
shape = _value_info_shape(_find_value_info(model, out_name))
|
|
print(f"[INFO] ONNX output shape: {shape if shape else 'unknown'}")
|
|
|
|
# Expected CxN style: [1, 4+num_classes, num_boxes], e.g. [1,84,8400]
|
|
if len(shape) == 3 and shape[1] > 4 and shape[2] > shape[1]:
|
|
# already [1,84,8400] like
|
|
onnx.save(model, str(out_onnx))
|
|
print("[OK] output layout already CxN")
|
|
return
|
|
|
|
# Common Ultralytics style: [1, num_boxes, 4+num_classes], e.g. [1,8400,84]
|
|
if len(shape) == 3 and shape[2] > 4 and shape[1] > shape[2]:
|
|
transposed = out_name + "_cxn"
|
|
node = helper.make_node(
|
|
"Transpose",
|
|
inputs=[out_name],
|
|
outputs=[transposed],
|
|
perm=[0, 2, 1],
|
|
name="transpose_to_cxn",
|
|
)
|
|
model.graph.node.append(node)
|
|
model.graph.output.clear()
|
|
model.graph.output.extend(
|
|
[
|
|
helper.make_tensor_value_info(
|
|
transposed, TensorProto.FLOAT, [shape[0], shape[2], shape[1]]
|
|
)
|
|
]
|
|
)
|
|
onnx.checker.check_model(model)
|
|
onnx.save(model, str(out_onnx))
|
|
print("[OK] output layout fixed by transpose to CxN")
|
|
return
|
|
|
|
raise RuntimeError(
|
|
"unsupported output shape for YOLOv8 in this project; "
|
|
"need rank-3 tensor and channels=4+num_classes"
|
|
)
|
|
|
|
|
|
def build_rknn(
|
|
onnx_path: Path,
|
|
rknn_path: Path,
|
|
target: str,
|
|
imgsz: int,
|
|
quant: bool,
|
|
dataset: str,
|
|
mean: float,
|
|
std: float,
|
|
):
|
|
try:
|
|
from rknn.api import RKNN
|
|
except Exception as e:
|
|
raise RuntimeError(f"missing dependency 'rknn-toolkit2': {e}")
|
|
|
|
rknn = RKNN(verbose=True)
|
|
rknn.config(
|
|
target_platform=target,
|
|
mean_values=[[mean, mean, mean]],
|
|
std_values=[[std, std, std]],
|
|
)
|
|
|
|
ret = rknn.load_onnx(model=str(onnx_path), input_size_list=[[imgsz, imgsz, 3]])
|
|
if ret != 0:
|
|
raise RuntimeError(f"rknn.load_onnx failed: {ret}")
|
|
|
|
if quant:
|
|
if not dataset:
|
|
raise RuntimeError("--quant requires --dataset")
|
|
ret = rknn.build(do_quantization=True, dataset=dataset)
|
|
else:
|
|
ret = rknn.build(do_quantization=False)
|
|
if ret != 0:
|
|
raise RuntimeError(f"rknn.build failed: {ret}")
|
|
|
|
ret = rknn.export_rknn(str(rknn_path))
|
|
if ret != 0:
|
|
raise RuntimeError(f"rknn.export_rknn failed: {ret}")
|
|
rknn.release()
|
|
print(f"[OK] RKNN: {rknn_path}")
|
|
|
|
|
|
def cmd_pt2onnx(args):
|
|
pt_path = Path(args.pt).resolve()
|
|
if not pt_path.exists():
|
|
raise SystemExit(f"pt file not found: {pt_path}")
|
|
|
|
onnx_path = Path(args.onnx).resolve() if args.onnx else pt_path.with_suffix(".onnx").resolve()
|
|
export_pt_to_onnx(pt_path, onnx_path, args.imgsz, args.opset)
|
|
print(f"[DONE] pt2onnx -> {onnx_path}")
|
|
|
|
|
|
def cmd_onnx2rknn(args):
|
|
onnx_path = Path(args.onnx).resolve()
|
|
if not onnx_path.exists():
|
|
raise SystemExit(f"onnx file not found: {onnx_path}")
|
|
|
|
out_path = Path(args.out).resolve()
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
build_input_onnx = onnx_path
|
|
if not args.no_fix_layout:
|
|
onnx_fixed = (
|
|
Path(args.onnx_fixed).resolve()
|
|
if args.onnx_fixed
|
|
else onnx_path.with_name(onnx_path.stem + "_cxn.onnx").resolve()
|
|
)
|
|
ensure_v8_output_layout_cxn(onnx_path, onnx_fixed)
|
|
build_input_onnx = onnx_fixed
|
|
|
|
build_rknn(
|
|
onnx_path=build_input_onnx,
|
|
rknn_path=out_path,
|
|
target=args.target,
|
|
imgsz=args.imgsz,
|
|
quant=args.quant,
|
|
dataset=args.dataset,
|
|
mean=args.mean,
|
|
std=args.std,
|
|
)
|
|
print(f"[DONE] onnx2rknn -> {out_path}")
|
|
print("[NEXT] config ai_yolo: model_version='v8', preprocess dst_w/dst_h == imgsz, keep_ratio=false")
|
|
|
|
|
|
def build_parser():
|
|
ap = argparse.ArgumentParser(
|
|
description="YOLOv8 conversion for rk3588 media-server: step1 pt2onnx, step2 onnx2rknn"
|
|
)
|
|
sp = ap.add_subparsers(dest="cmd", required=True)
|
|
|
|
p1 = sp.add_parser("pt2onnx", help="step 1: export .pt to .onnx")
|
|
p1.add_argument("--pt", required=True, help="path to yolov8 .pt")
|
|
p1.add_argument("--onnx", default="", help="output onnx path (default: same name as pt)")
|
|
p1.add_argument("--imgsz", type=int, default=640)
|
|
p1.add_argument("--opset", type=int, default=12)
|
|
p1.set_defaults(func=cmd_pt2onnx)
|
|
|
|
p2 = sp.add_parser("onnx2rknn", help="step 2: convert .onnx to .rknn")
|
|
p2.add_argument("--onnx", required=True, help="input onnx path")
|
|
p2.add_argument("--onnx_fixed", default="", help="onnx path after layout fix")
|
|
p2.add_argument("--no_fix_layout", action="store_true", help="skip output CxN layout check/fix")
|
|
p2.add_argument("--out", required=True, help="output .rknn path")
|
|
p2.add_argument("--imgsz", type=int, default=640)
|
|
p2.add_argument("--target", default="rk3588")
|
|
p2.add_argument("--quant", action="store_true", help="enable int8 quantization")
|
|
p2.add_argument("--dataset", default="", help="dataset txt for quantization")
|
|
p2.add_argument("--mean", type=float, default=0.0, help="default 0.0")
|
|
p2.add_argument("--std", type=float, default=255.0, help="default 255.0")
|
|
p2.set_defaults(func=cmd_onnx2rknn)
|
|
return ap
|
|
|
|
|
|
def main():
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
print("[ERROR]", e)
|
|
sys.exit(1)
|