Refactor code structure for improved readability and maintainability

This commit is contained in:
sladro 2026-02-26 13:08:37 +08:00
parent 1df19a1353
commit 48f5a44cb3
4 changed files with 236 additions and 1 deletions

View File

@ -120,4 +120,10 @@ ls -l ./rk3588-agent_linux_arm64
.\minio.exe server E:\minio\data --address ":9000" --console-address ":9001"
.\minio.exe server E:\minio\data --address ":9000" --console-address ":9001"
python scripts/pt2rknn_yolov8.py `
--pt yolov8n.pt `
--out models/yolov8n-640.rknn `
--imgsz 640

229
scripts/pt2rknn_yolov8.py Normal file
View File

@ -0,0 +1,229 @@
#!/usr/bin/env python3
import argparse
import os
import subprocess
import sys
from pathlib import Path
def run_cmd(cmd):
print("[CMD]", " ".join(cmd))
ret = subprocess.run(cmd, check=False)
if ret.returncode != 0:
raise RuntimeError(f"command failed ({ret.returncode}): {' '.join(cmd)}")
def export_pt_to_onnx(pt_path: Path, onnx_path: Path, imgsz: int, opset: int):
try:
from ultralytics import YOLO # type: ignore
print("[INFO] exporting via ultralytics python API")
model = YOLO(str(pt_path))
exported = model.export(
format="onnx",
imgsz=imgsz,
opset=opset,
simplify=True,
dynamic=False,
nms=False,
)
exported_path = Path(str(exported))
if exported_path.exists() and exported_path.resolve() != onnx_path.resolve():
onnx_path.write_bytes(exported_path.read_bytes())
except Exception as e:
print(f"[WARN] ultralytics API export failed: {e}")
print("[INFO] fallback to CLI: yolo export ...")
run_cmd(
[
"yolo",
"export",
f"model={pt_path}",
"format=onnx",
f"imgsz={imgsz}",
f"opset={opset}",
"simplify=True",
"dynamic=False",
"nms=False",
]
)
default_onnx = pt_path.with_suffix(".onnx")
if not default_onnx.exists():
raise RuntimeError(f"ONNX not found after export: {default_onnx}")
if default_onnx.resolve() != onnx_path.resolve():
onnx_path.write_bytes(default_onnx.read_bytes())
if not onnx_path.exists():
raise RuntimeError(f"ONNX export failed, file missing: {onnx_path}")
print(f"[OK] ONNX: {onnx_path}")
def _value_info_shape(value_info):
if value_info is None:
return []
dims = []
tt = value_info.type.tensor_type
for d in tt.shape.dim:
if d.HasField("dim_value"):
dims.append(int(d.dim_value))
else:
dims.append(-1)
return dims
def _find_value_info(model, name):
for vi in list(model.graph.value_info) + list(model.graph.output) + list(model.graph.input):
if vi.name == name:
return vi
return None
def ensure_v8_output_layout_cxn(in_onnx: Path, out_onnx: Path):
try:
import onnx
from onnx import TensorProto, helper, shape_inference
except Exception as e:
raise RuntimeError(f"missing dependency 'onnx': {e}")
model = onnx.load(str(in_onnx))
model = shape_inference.infer_shapes(model)
if len(model.graph.output) != 1:
raise RuntimeError(
f"expect single output for this media-server v8 path, got {len(model.graph.output)}"
)
out = model.graph.output[0]
out_name = out.name
shape = _value_info_shape(_find_value_info(model, out_name))
print(f"[INFO] ONNX output shape: {shape if shape else 'unknown'}")
# Expected CxN style: [1, 4+num_classes, num_boxes], e.g. [1,84,8400]
if len(shape) == 3 and shape[1] > 4 and shape[2] > shape[1]:
# already [1,84,8400] like
onnx.save(model, str(out_onnx))
print("[OK] output layout already CxN")
return
# Common Ultralytics style: [1, num_boxes, 4+num_classes], e.g. [1,8400,84]
if len(shape) == 3 and shape[2] > 4 and shape[1] > shape[2]:
transposed = out_name + "_cxn"
node = helper.make_node(
"Transpose",
inputs=[out_name],
outputs=[transposed],
perm=[0, 2, 1],
name="transpose_to_cxn",
)
model.graph.node.append(node)
model.graph.output.clear()
model.graph.output.extend(
[
helper.make_tensor_value_info(
transposed, TensorProto.FLOAT, [shape[0], shape[2], shape[1]]
)
]
)
onnx.checker.check_model(model)
onnx.save(model, str(out_onnx))
print("[OK] output layout fixed by transpose to CxN")
return
raise RuntimeError(
"unsupported output shape for YOLOv8 in this project; "
"need rank-3 tensor and channels=4+num_classes"
)
def build_rknn(
onnx_path: Path,
rknn_path: Path,
target: str,
imgsz: int,
quant: bool,
dataset: str,
mean: float,
std: float,
):
try:
from rknn.api import RKNN
except Exception as e:
raise RuntimeError(f"missing dependency 'rknn-toolkit2': {e}")
rknn = RKNN(verbose=True)
rknn.config(
target_platform=target,
mean_values=[[mean, mean, mean]],
std_values=[[std, std, std]],
)
ret = rknn.load_onnx(model=str(onnx_path), input_size_list=[[imgsz, imgsz, 3]])
if ret != 0:
raise RuntimeError(f"rknn.load_onnx failed: {ret}")
if quant:
if not dataset:
raise RuntimeError("--quant requires --dataset")
ret = rknn.build(do_quantization=True, dataset=dataset)
else:
ret = rknn.build(do_quantization=False)
if ret != 0:
raise RuntimeError(f"rknn.build failed: {ret}")
ret = rknn.export_rknn(str(rknn_path))
if ret != 0:
raise RuntimeError(f"rknn.export_rknn failed: {ret}")
rknn.release()
print(f"[OK] RKNN: {rknn_path}")
def main():
ap = argparse.ArgumentParser(description="YOLOv8 .pt -> .rknn for rk3588 media-server")
ap.add_argument("--pt", required=True, help="path to yolov8 .pt")
ap.add_argument("--onnx", default="", help="output onnx path")
ap.add_argument("--onnx_fixed", default="", help="onnx path after layout fix")
ap.add_argument("--out", required=True, help="output .rknn path")
ap.add_argument("--imgsz", type=int, default=640)
ap.add_argument("--opset", type=int, default=12)
ap.add_argument("--target", default="rk3588")
ap.add_argument("--quant", action="store_true", help="enable int8 quantization")
ap.add_argument("--dataset", default="", help="dataset txt for quantization")
ap.add_argument("--mean", type=float, default=0.0, help="default 0.0")
ap.add_argument("--std", type=float, default=255.0, help="default 255.0")
args = ap.parse_args()
pt_path = Path(args.pt).resolve()
if not pt_path.exists():
raise SystemExit(f"pt file not found: {pt_path}")
onnx_path = Path(args.onnx).resolve() if args.onnx else pt_path.with_suffix(".onnx").resolve()
onnx_fixed = (
Path(args.onnx_fixed).resolve()
if args.onnx_fixed
else onnx_path.with_name(onnx_path.stem + "_cxn.onnx").resolve()
)
out_path = Path(args.out).resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
export_pt_to_onnx(pt_path, onnx_path, args.imgsz, args.opset)
ensure_v8_output_layout_cxn(onnx_path, onnx_fixed)
build_rknn(
onnx_path=onnx_fixed,
rknn_path=out_path,
target=args.target,
imgsz=args.imgsz,
quant=args.quant,
dataset=args.dataset,
mean=args.mean,
std=args.std,
)
print(
"[NEXT] config ai_yolo with model_version='v8', and preprocess dst_w/dst_h == imgsz, keep_ratio=false"
)
if __name__ == "__main__":
try:
main()
except Exception as e:
print("[ERROR]", e)
sys.exit(1)

BIN
yolov8n.onnx Normal file

Binary file not shown.

BIN
yolov8n_cxn.onnx Normal file

Binary file not shown.