OrangePi3588Media/scripts/pt2rknn_yolov8.py

264 lines
8.5 KiB
Python

#!/usr/bin/env python3
import argparse
import subprocess
import sys
from pathlib import Path
def run_cmd(cmd):
print("[CMD]", " ".join(cmd))
ret = subprocess.run(cmd, check=False)
if ret.returncode != 0:
raise RuntimeError(f"command failed ({ret.returncode}): {' '.join(cmd)}")
def export_pt_to_onnx(pt_path: Path, onnx_path: Path, imgsz: int, opset: int):
try:
import torch
from ultralytics import YOLO # type: ignore
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, ClassificationModel, PoseModel
# Fix for PyTorch 2.6+ weights_only default change
torch.serialization.add_safe_globals([DetectionModel, SegmentationModel, ClassificationModel, PoseModel])
print("[INFO] exporting via ultralytics python API")
model = YOLO(str(pt_path))
exported = model.export(
format="onnx",
imgsz=imgsz,
opset=opset,
simplify=True,
dynamic=False,
nms=False,
)
exported_path = Path(str(exported))
if exported_path.exists() and exported_path.resolve() != onnx_path.resolve():
onnx_path.write_bytes(exported_path.read_bytes())
except Exception as e:
print(f"[WARN] ultralytics API export failed: {e}")
print("[INFO] fallback to CLI: yolo export ...")
run_cmd(
[
"yolo",
"export",
f"model={pt_path}",
"format=onnx",
f"imgsz={imgsz}",
f"opset={opset}",
"simplify=True",
"dynamic=False",
"nms=False",
]
)
default_onnx = pt_path.with_suffix(".onnx")
if not default_onnx.exists():
raise RuntimeError(f"ONNX not found after export: {default_onnx}")
if default_onnx.resolve() != onnx_path.resolve():
onnx_path.write_bytes(default_onnx.read_bytes())
if not onnx_path.exists():
raise RuntimeError(f"ONNX export failed, file missing: {onnx_path}")
print(f"[OK] ONNX: {onnx_path}")
def _value_info_shape(value_info):
if value_info is None:
return []
dims = []
tt = value_info.type.tensor_type
for d in tt.shape.dim:
if d.HasField("dim_value"):
dims.append(int(d.dim_value))
else:
dims.append(-1)
return dims
def _find_value_info(model, name):
for vi in list(model.graph.value_info) + list(model.graph.output) + list(model.graph.input):
if vi.name == name:
return vi
return None
def ensure_v8_output_layout_cxn(in_onnx: Path, out_onnx: Path):
try:
import onnx
from onnx import TensorProto, helper, shape_inference
except Exception as e:
raise RuntimeError(f"missing dependency 'onnx': {e}")
model = onnx.load(str(in_onnx))
model = shape_inference.infer_shapes(model)
if len(model.graph.output) != 1:
raise RuntimeError(
f"expect single output for this media-server v8 path, got {len(model.graph.output)}"
)
out = model.graph.output[0]
out_name = out.name
shape = _value_info_shape(_find_value_info(model, out_name))
print(f"[INFO] ONNX output shape: {shape if shape else 'unknown'}")
# Expected CxN style: [1, 4+num_classes, num_boxes], e.g. [1,84,8400]
if len(shape) == 3 and shape[1] > 4 and shape[2] > shape[1]:
# already [1,84,8400] like
onnx.save(model, str(out_onnx))
print("[OK] output layout already CxN")
return
# Common Ultralytics style: [1, num_boxes, 4+num_classes], e.g. [1,8400,84]
if len(shape) == 3 and shape[2] > 4 and shape[1] > shape[2]:
transposed = out_name + "_cxn"
node = helper.make_node(
"Transpose",
inputs=[out_name],
outputs=[transposed],
perm=[0, 2, 1],
name="transpose_to_cxn",
)
model.graph.node.append(node)
model.graph.output.clear()
model.graph.output.extend(
[
helper.make_tensor_value_info(
transposed, TensorProto.FLOAT, [shape[0], shape[2], shape[1]]
)
]
)
onnx.checker.check_model(model)
onnx.save(model, str(out_onnx))
print("[OK] output layout fixed by transpose to CxN")
return
raise RuntimeError(
"unsupported output shape for YOLOv8 in this project; "
"need rank-3 tensor and channels=4+num_classes"
)
def build_rknn(
onnx_path: Path,
rknn_path: Path,
target: str,
imgsz: int,
quant: bool,
dataset: str,
mean: float,
std: float,
):
try:
from rknn.api import RKNN
except Exception as e:
raise RuntimeError(f"missing dependency 'rknn-toolkit2': {e}")
rknn = RKNN(verbose=True)
rknn.config(
target_platform=target,
mean_values=[[mean, mean, mean]],
std_values=[[std, std, std]],
)
ret = rknn.load_onnx(model=str(onnx_path), input_size_list=[[imgsz, imgsz, 3]])
if ret != 0:
raise RuntimeError(f"rknn.load_onnx failed: {ret}")
if quant:
if not dataset:
raise RuntimeError("--quant requires --dataset")
ret = rknn.build(do_quantization=True, dataset=dataset)
else:
ret = rknn.build(do_quantization=False)
if ret != 0:
raise RuntimeError(f"rknn.build failed: {ret}")
ret = rknn.export_rknn(str(rknn_path))
if ret != 0:
raise RuntimeError(f"rknn.export_rknn failed: {ret}")
rknn.release()
print(f"[OK] RKNN: {rknn_path}")
def cmd_pt2onnx(args):
pt_path = Path(args.pt).resolve()
if not pt_path.exists():
raise SystemExit(f"pt file not found: {pt_path}")
onnx_path = Path(args.onnx).resolve() if args.onnx else pt_path.with_suffix(".onnx").resolve()
export_pt_to_onnx(pt_path, onnx_path, args.imgsz, args.opset)
print(f"[DONE] pt2onnx -> {onnx_path}")
def cmd_onnx2rknn(args):
onnx_path = Path(args.onnx).resolve()
if not onnx_path.exists():
raise SystemExit(f"onnx file not found: {onnx_path}")
out_path = Path(args.out).resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
build_input_onnx = onnx_path
if not args.no_fix_layout:
onnx_fixed = (
Path(args.onnx_fixed).resolve()
if args.onnx_fixed
else onnx_path.with_name(onnx_path.stem + "_cxn.onnx").resolve()
)
ensure_v8_output_layout_cxn(onnx_path, onnx_fixed)
build_input_onnx = onnx_fixed
build_rknn(
onnx_path=build_input_onnx,
rknn_path=out_path,
target=args.target,
imgsz=args.imgsz,
quant=args.quant,
dataset=args.dataset,
mean=args.mean,
std=args.std,
)
print(f"[DONE] onnx2rknn -> {out_path}")
print("[NEXT] config ai_yolo: model_version='v8', preprocess dst_w/dst_h == imgsz, keep_ratio=false")
def build_parser():
ap = argparse.ArgumentParser(
description="YOLOv8 conversion for rk3588 media-server: step1 pt2onnx, step2 onnx2rknn"
)
sp = ap.add_subparsers(dest="cmd", required=True)
p1 = sp.add_parser("pt2onnx", help="step 1: export .pt to .onnx")
p1.add_argument("--pt", required=True, help="path to yolov8 .pt")
p1.add_argument("--onnx", default="", help="output onnx path (default: same name as pt)")
p1.add_argument("--imgsz", type=int, default=640)
p1.add_argument("--opset", type=int, default=12)
p1.set_defaults(func=cmd_pt2onnx)
p2 = sp.add_parser("onnx2rknn", help="step 2: convert .onnx to .rknn")
p2.add_argument("--onnx", required=True, help="input onnx path")
p2.add_argument("--onnx_fixed", default="", help="onnx path after layout fix")
p2.add_argument("--no_fix_layout", action="store_true", help="skip output CxN layout check/fix")
p2.add_argument("--out", required=True, help="output .rknn path")
p2.add_argument("--imgsz", type=int, default=640)
p2.add_argument("--target", default="rk3588")
p2.add_argument("--quant", action="store_true", help="enable int8 quantization")
p2.add_argument("--dataset", default="", help="dataset txt for quantization")
p2.add_argument("--mean", type=float, default=0.0, help="default 0.0")
p2.add_argument("--std", type=float, default=255.0, help="default 255.0")
p2.set_defaults(func=cmd_onnx2rknn)
return ap
def main():
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
try:
main()
except Exception as e:
print("[ERROR]", e)
sys.exit(1)