91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
"""检查 RKNN 模型结构和输出信息"""
|
|
|
|
import sys
|
|
|
|
try:
|
|
from rknn.api import RKNN
|
|
except ImportError:
|
|
print("错误: 未安装 rknn-toolkit2")
|
|
sys.exit(1)
|
|
|
|
def check_model(model_path):
|
|
rknn = RKNN(verbose=False)
|
|
|
|
# 加载模型
|
|
ret = rknn.load_rknn(model_path)
|
|
if ret != 0:
|
|
print(f"加载模型失败: {model_path}")
|
|
return
|
|
|
|
# 获取模型信息
|
|
print(f"\n=== 模型信息: {model_path} ===\n")
|
|
|
|
# 使用 rknn_query 获取信息
|
|
try:
|
|
# 获取输入输出数量
|
|
from ctypes import c_void_p, sizeof, Structure, c_int, c_uint32
|
|
|
|
class rknn_input_output_num(Structure):
|
|
_fields_ = [("n_input", c_uint32), ("n_output", c_uint32)]
|
|
|
|
io_num = rknn_input_output_num()
|
|
# 尝试获取输入输出数量
|
|
print(f"尝试分析模型结构...")
|
|
|
|
except Exception as e:
|
|
print(f"查询失败: {e}")
|
|
|
|
# 尝试用推理方式测试
|
|
print("\n尝试模拟推理查看输出形状...")
|
|
import numpy as np
|
|
|
|
# 创建假输入
|
|
dummy_input = np.zeros((1, 768, 768, 3), dtype=np.uint8)
|
|
|
|
ret = rknn.init_runtime(core_mask=RKNN.NPU_CORE_AUTO)
|
|
if ret != 0:
|
|
print("初始化 runtime 失败")
|
|
rknn.release()
|
|
return
|
|
|
|
# 推理
|
|
outputs = rknn.inference(inputs=[dummy_input], data_format=['nhwc'])
|
|
|
|
print(f"\n输出数量: {len(outputs)}")
|
|
for i, out in enumerate(outputs):
|
|
print(f" 输出[{i}]: shape={out.shape}, dtype={out.dtype}")
|
|
# 显示部分数据
|
|
flat = out.flatten()
|
|
print(f" 数据范围: [{flat.min():.4f}, {flat.max():.4f}]")
|
|
print(f" 前10个值: {flat[:10]}")
|
|
|
|
# 判断模型类型
|
|
print(f"\n=== 分析结果 ===")
|
|
if len(outputs) == 1:
|
|
shape = outputs[0].shape
|
|
print(f"模型类型: YOLOv8 (单输出)")
|
|
print(f"输出形状: {shape}")
|
|
if len(shape) == 3:
|
|
# YOLOv8 输出通常是 [1, 84, 8400] 或 [1, 15, 8400] 等
|
|
num_classes = shape[1] - 4 # 减去 x,y,w,h
|
|
num_boxes = shape[2]
|
|
print(f"检测框数量: {num_boxes}")
|
|
print(f"类别数: {num_classes}")
|
|
elif len(outputs) == 3:
|
|
print(f"模型类型: YOLOv5 (三输出)")
|
|
for i, out in enumerate(outputs):
|
|
print(f" 输出[{i}] shape={out.shape}")
|
|
else:
|
|
print(f"模型类型: 其他 ({len(outputs)} 个输出)")
|
|
|
|
rknn.release()
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2:
|
|
print("用法: python3 check_model.py <model_path>")
|
|
print("示例: python3 check_model.py models/ppe_det_yolov8_ppe11_768_rk3588.rknn")
|
|
sys.exit(1)
|
|
|
|
check_model(sys.argv[1])
|