OrangePi3588Media/scripts/check_model.py

91 lines
2.7 KiB
Python

#!/usr/bin/env python3
"""检查 RKNN 模型结构和输出信息"""
import sys
try:
from rknn.api import RKNN
except ImportError:
print("错误: 未安装 rknn-toolkit2")
sys.exit(1)
def check_model(model_path):
rknn = RKNN(verbose=False)
# 加载模型
ret = rknn.load_rknn(model_path)
if ret != 0:
print(f"加载模型失败: {model_path}")
return
# 获取模型信息
print(f"\n=== 模型信息: {model_path} ===\n")
# 使用 rknn_query 获取信息
try:
# 获取输入输出数量
from ctypes import c_void_p, sizeof, Structure, c_int, c_uint32
class rknn_input_output_num(Structure):
_fields_ = [("n_input", c_uint32), ("n_output", c_uint32)]
io_num = rknn_input_output_num()
# 尝试获取输入输出数量
print(f"尝试分析模型结构...")
except Exception as e:
print(f"查询失败: {e}")
# 尝试用推理方式测试
print("\n尝试模拟推理查看输出形状...")
import numpy as np
# 创建假输入
dummy_input = np.zeros((1, 768, 768, 3), dtype=np.uint8)
ret = rknn.init_runtime(core_mask=RKNN.NPU_CORE_AUTO)
if ret != 0:
print("初始化 runtime 失败")
rknn.release()
return
# 推理
outputs = rknn.inference(inputs=[dummy_input], data_format=['nhwc'])
print(f"\n输出数量: {len(outputs)}")
for i, out in enumerate(outputs):
print(f" 输出[{i}]: shape={out.shape}, dtype={out.dtype}")
# 显示部分数据
flat = out.flatten()
print(f" 数据范围: [{flat.min():.4f}, {flat.max():.4f}]")
print(f" 前10个值: {flat[:10]}")
# 判断模型类型
print(f"\n=== 分析结果 ===")
if len(outputs) == 1:
shape = outputs[0].shape
print(f"模型类型: YOLOv8 (单输出)")
print(f"输出形状: {shape}")
if len(shape) == 3:
# YOLOv8 输出通常是 [1, 84, 8400] 或 [1, 15, 8400] 等
num_classes = shape[1] - 4 # 减去 x,y,w,h
num_boxes = shape[2]
print(f"检测框数量: {num_boxes}")
print(f"类别数: {num_classes}")
elif len(outputs) == 3:
print(f"模型类型: YOLOv5 (三输出)")
for i, out in enumerate(outputs):
print(f" 输出[{i}] shape={out.shape}")
else:
print(f"模型类型: 其他 ({len(outputs)} 个输出)")
rknn.release()
if __name__ == "__main__":
if len(sys.argv) < 2:
print("用法: python3 check_model.py <model_path>")
print("示例: python3 check_model.py models/ppe_det_yolov8_ppe11_768_rk3588.rknn")
sys.exit(1)
check_model(sys.argv[1])