77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Convert RetinaFace ONNX to RKNN for RK3588"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# Try to import rknn
|
|
try:
|
|
from rknn.api import RKNN
|
|
except ImportError:
|
|
print("Error: rknn-toolkit2 not installed")
|
|
print("Install with: pip install rknn-toolkit2")
|
|
sys.exit(1)
|
|
|
|
ONNX_MODEL = 'face_det_retinaface_mobile320.onnx'
|
|
RKNN_MODEL = 'face_det_retinaface_mobile320_rk3588.rknn'
|
|
|
|
def convert():
|
|
print(f"Converting {ONNX_MODEL} to {RKNN_MODEL}...")
|
|
|
|
# Create RKNN object
|
|
rknn = RKNN(verbose=True)
|
|
|
|
# Pre-process config
|
|
print("Configuring model...")
|
|
rknn.config(
|
|
target_platform='rk3588',
|
|
mean_values=[[0, 0, 0]], # No mean subtraction
|
|
std_values=[[255, 255, 255]], # Normalize to 0-1
|
|
quantized_dtype='w8a8',
|
|
optimization_level=2
|
|
)
|
|
|
|
# Load ONNX model (auto-detect inputs/outputs)
|
|
print("Loading ONNX model...")
|
|
ret = rknn.load_onnx(model=ONNX_MODEL)
|
|
if ret != 0:
|
|
print("Failed to load ONNX model")
|
|
return False
|
|
|
|
# Build RKNN model
|
|
print("Building RKNN model (this may take a while)...")
|
|
ret = rknn.build(
|
|
do_quantization=True,
|
|
dataset='./dataset.txt' # Need calibration dataset
|
|
)
|
|
if ret != 0:
|
|
print("Failed to build RKNN model")
|
|
return False
|
|
|
|
# Export RKNN model
|
|
print(f"Exporting to {RKNN_MODEL}...")
|
|
ret = rknn.export_rknn(RKNN_MODEL)
|
|
if ret != 0:
|
|
print("Failed to export RKNN model")
|
|
return False
|
|
|
|
print("Conversion successful!")
|
|
rknn.release()
|
|
return True
|
|
|
|
if __name__ == '__main__':
|
|
# Create a dummy dataset file for calibration
|
|
with open('dataset.txt', 'w') as f:
|
|
f.write('calibration.jpg\n')
|
|
|
|
if not os.path.exists('calibration.jpg'):
|
|
print("Warning: calibration.jpg not found, creating dummy...")
|
|
# Create a dummy image for calibration
|
|
import numpy as np
|
|
from PIL import Image
|
|
dummy = np.random.randint(0, 255, (320, 320, 3), dtype=np.uint8)
|
|
Image.fromarray(dummy).save('calibration.jpg')
|
|
|
|
success = convert()
|
|
sys.exit(0 if success else 1)
|