Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
1df19a1353
commit
48f5a44cb3
@ -120,4 +120,10 @@ ls -l ./rk3588-agent_linux_arm64
|
||||
|
||||
|
||||
|
||||
.\minio.exe server E:\minio\data --address ":9000" --console-address ":9001"
|
||||
.\minio.exe server E:\minio\data --address ":9000" --console-address ":9001"
|
||||
|
||||
|
||||
python scripts/pt2rknn_yolov8.py `
|
||||
--pt yolov8n.pt `
|
||||
--out models/yolov8n-640.rknn `
|
||||
--imgsz 640
|
||||
|
||||
229
scripts/pt2rknn_yolov8.py
Normal file
229
scripts/pt2rknn_yolov8.py
Normal file
@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_cmd(cmd):
|
||||
print("[CMD]", " ".join(cmd))
|
||||
ret = subprocess.run(cmd, check=False)
|
||||
if ret.returncode != 0:
|
||||
raise RuntimeError(f"command failed ({ret.returncode}): {' '.join(cmd)}")
|
||||
|
||||
|
||||
def export_pt_to_onnx(pt_path: Path, onnx_path: Path, imgsz: int, opset: int):
|
||||
try:
|
||||
from ultralytics import YOLO # type: ignore
|
||||
|
||||
print("[INFO] exporting via ultralytics python API")
|
||||
model = YOLO(str(pt_path))
|
||||
exported = model.export(
|
||||
format="onnx",
|
||||
imgsz=imgsz,
|
||||
opset=opset,
|
||||
simplify=True,
|
||||
dynamic=False,
|
||||
nms=False,
|
||||
)
|
||||
exported_path = Path(str(exported))
|
||||
if exported_path.exists() and exported_path.resolve() != onnx_path.resolve():
|
||||
onnx_path.write_bytes(exported_path.read_bytes())
|
||||
except Exception as e:
|
||||
print(f"[WARN] ultralytics API export failed: {e}")
|
||||
print("[INFO] fallback to CLI: yolo export ...")
|
||||
run_cmd(
|
||||
[
|
||||
"yolo",
|
||||
"export",
|
||||
f"model={pt_path}",
|
||||
"format=onnx",
|
||||
f"imgsz={imgsz}",
|
||||
f"opset={opset}",
|
||||
"simplify=True",
|
||||
"dynamic=False",
|
||||
"nms=False",
|
||||
]
|
||||
)
|
||||
default_onnx = pt_path.with_suffix(".onnx")
|
||||
if not default_onnx.exists():
|
||||
raise RuntimeError(f"ONNX not found after export: {default_onnx}")
|
||||
if default_onnx.resolve() != onnx_path.resolve():
|
||||
onnx_path.write_bytes(default_onnx.read_bytes())
|
||||
|
||||
if not onnx_path.exists():
|
||||
raise RuntimeError(f"ONNX export failed, file missing: {onnx_path}")
|
||||
print(f"[OK] ONNX: {onnx_path}")
|
||||
|
||||
|
||||
def _value_info_shape(value_info):
|
||||
if value_info is None:
|
||||
return []
|
||||
dims = []
|
||||
tt = value_info.type.tensor_type
|
||||
for d in tt.shape.dim:
|
||||
if d.HasField("dim_value"):
|
||||
dims.append(int(d.dim_value))
|
||||
else:
|
||||
dims.append(-1)
|
||||
return dims
|
||||
|
||||
|
||||
def _find_value_info(model, name):
|
||||
for vi in list(model.graph.value_info) + list(model.graph.output) + list(model.graph.input):
|
||||
if vi.name == name:
|
||||
return vi
|
||||
return None
|
||||
|
||||
|
||||
def ensure_v8_output_layout_cxn(in_onnx: Path, out_onnx: Path):
|
||||
try:
|
||||
import onnx
|
||||
from onnx import TensorProto, helper, shape_inference
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"missing dependency 'onnx': {e}")
|
||||
|
||||
model = onnx.load(str(in_onnx))
|
||||
model = shape_inference.infer_shapes(model)
|
||||
if len(model.graph.output) != 1:
|
||||
raise RuntimeError(
|
||||
f"expect single output for this media-server v8 path, got {len(model.graph.output)}"
|
||||
)
|
||||
|
||||
out = model.graph.output[0]
|
||||
out_name = out.name
|
||||
shape = _value_info_shape(_find_value_info(model, out_name))
|
||||
print(f"[INFO] ONNX output shape: {shape if shape else 'unknown'}")
|
||||
|
||||
# Expected CxN style: [1, 4+num_classes, num_boxes], e.g. [1,84,8400]
|
||||
if len(shape) == 3 and shape[1] > 4 and shape[2] > shape[1]:
|
||||
# already [1,84,8400] like
|
||||
onnx.save(model, str(out_onnx))
|
||||
print("[OK] output layout already CxN")
|
||||
return
|
||||
|
||||
# Common Ultralytics style: [1, num_boxes, 4+num_classes], e.g. [1,8400,84]
|
||||
if len(shape) == 3 and shape[2] > 4 and shape[1] > shape[2]:
|
||||
transposed = out_name + "_cxn"
|
||||
node = helper.make_node(
|
||||
"Transpose",
|
||||
inputs=[out_name],
|
||||
outputs=[transposed],
|
||||
perm=[0, 2, 1],
|
||||
name="transpose_to_cxn",
|
||||
)
|
||||
model.graph.node.append(node)
|
||||
model.graph.output.clear()
|
||||
model.graph.output.extend(
|
||||
[
|
||||
helper.make_tensor_value_info(
|
||||
transposed, TensorProto.FLOAT, [shape[0], shape[2], shape[1]]
|
||||
)
|
||||
]
|
||||
)
|
||||
onnx.checker.check_model(model)
|
||||
onnx.save(model, str(out_onnx))
|
||||
print("[OK] output layout fixed by transpose to CxN")
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
"unsupported output shape for YOLOv8 in this project; "
|
||||
"need rank-3 tensor and channels=4+num_classes"
|
||||
)
|
||||
|
||||
|
||||
def build_rknn(
|
||||
onnx_path: Path,
|
||||
rknn_path: Path,
|
||||
target: str,
|
||||
imgsz: int,
|
||||
quant: bool,
|
||||
dataset: str,
|
||||
mean: float,
|
||||
std: float,
|
||||
):
|
||||
try:
|
||||
from rknn.api import RKNN
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"missing dependency 'rknn-toolkit2': {e}")
|
||||
|
||||
rknn = RKNN(verbose=True)
|
||||
rknn.config(
|
||||
target_platform=target,
|
||||
mean_values=[[mean, mean, mean]],
|
||||
std_values=[[std, std, std]],
|
||||
)
|
||||
|
||||
ret = rknn.load_onnx(model=str(onnx_path), input_size_list=[[imgsz, imgsz, 3]])
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"rknn.load_onnx failed: {ret}")
|
||||
|
||||
if quant:
|
||||
if not dataset:
|
||||
raise RuntimeError("--quant requires --dataset")
|
||||
ret = rknn.build(do_quantization=True, dataset=dataset)
|
||||
else:
|
||||
ret = rknn.build(do_quantization=False)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"rknn.build failed: {ret}")
|
||||
|
||||
ret = rknn.export_rknn(str(rknn_path))
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"rknn.export_rknn failed: {ret}")
|
||||
rknn.release()
|
||||
print(f"[OK] RKNN: {rknn_path}")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="YOLOv8 .pt -> .rknn for rk3588 media-server")
|
||||
ap.add_argument("--pt", required=True, help="path to yolov8 .pt")
|
||||
ap.add_argument("--onnx", default="", help="output onnx path")
|
||||
ap.add_argument("--onnx_fixed", default="", help="onnx path after layout fix")
|
||||
ap.add_argument("--out", required=True, help="output .rknn path")
|
||||
ap.add_argument("--imgsz", type=int, default=640)
|
||||
ap.add_argument("--opset", type=int, default=12)
|
||||
ap.add_argument("--target", default="rk3588")
|
||||
ap.add_argument("--quant", action="store_true", help="enable int8 quantization")
|
||||
ap.add_argument("--dataset", default="", help="dataset txt for quantization")
|
||||
ap.add_argument("--mean", type=float, default=0.0, help="default 0.0")
|
||||
ap.add_argument("--std", type=float, default=255.0, help="default 255.0")
|
||||
args = ap.parse_args()
|
||||
|
||||
pt_path = Path(args.pt).resolve()
|
||||
if not pt_path.exists():
|
||||
raise SystemExit(f"pt file not found: {pt_path}")
|
||||
|
||||
onnx_path = Path(args.onnx).resolve() if args.onnx else pt_path.with_suffix(".onnx").resolve()
|
||||
onnx_fixed = (
|
||||
Path(args.onnx_fixed).resolve()
|
||||
if args.onnx_fixed
|
||||
else onnx_path.with_name(onnx_path.stem + "_cxn.onnx").resolve()
|
||||
)
|
||||
out_path = Path(args.out).resolve()
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
export_pt_to_onnx(pt_path, onnx_path, args.imgsz, args.opset)
|
||||
ensure_v8_output_layout_cxn(onnx_path, onnx_fixed)
|
||||
build_rknn(
|
||||
onnx_path=onnx_fixed,
|
||||
rknn_path=out_path,
|
||||
target=args.target,
|
||||
imgsz=args.imgsz,
|
||||
quant=args.quant,
|
||||
dataset=args.dataset,
|
||||
mean=args.mean,
|
||||
std=args.std,
|
||||
)
|
||||
|
||||
print(
|
||||
"[NEXT] config ai_yolo with model_version='v8', and preprocess dst_w/dst_h == imgsz, keep_ratio=false"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print("[ERROR]", e)
|
||||
sys.exit(1)
|
||||
BIN
yolov8n.onnx
Normal file
BIN
yolov8n.onnx
Normal file
Binary file not shown.
BIN
yolov8n_cxn.onnx
Normal file
BIN
yolov8n_cxn.onnx
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user