#!/usr/bin/env python3 """检查 RKNN 模型结构和输出信息""" import sys try: from rknn.api import RKNN except ImportError: print("错误: 未安装 rknn-toolkit2") sys.exit(1) def check_model(model_path): rknn = RKNN(verbose=False) # 加载模型 ret = rknn.load_rknn(model_path) if ret != 0: print(f"加载模型失败: {model_path}") return # 获取模型信息 print(f"\n=== 模型信息: {model_path} ===\n") # 使用 rknn_query 获取信息 try: # 获取输入输出数量 from ctypes import c_void_p, sizeof, Structure, c_int, c_uint32 class rknn_input_output_num(Structure): _fields_ = [("n_input", c_uint32), ("n_output", c_uint32)] io_num = rknn_input_output_num() # 尝试获取输入输出数量 print(f"尝试分析模型结构...") except Exception as e: print(f"查询失败: {e}") # 尝试用推理方式测试 print("\n尝试模拟推理查看输出形状...") import numpy as np # 创建假输入 dummy_input = np.zeros((1, 768, 768, 3), dtype=np.uint8) ret = rknn.init_runtime(core_mask=RKNN.NPU_CORE_AUTO) if ret != 0: print("初始化 runtime 失败") rknn.release() return # 推理 outputs = rknn.inference(inputs=[dummy_input], data_format=['nhwc']) print(f"\n输出数量: {len(outputs)}") for i, out in enumerate(outputs): print(f" 输出[{i}]: shape={out.shape}, dtype={out.dtype}") # 显示部分数据 flat = out.flatten() print(f" 数据范围: [{flat.min():.4f}, {flat.max():.4f}]") print(f" 前10个值: {flat[:10]}") # 判断模型类型 print(f"\n=== 分析结果 ===") if len(outputs) == 1: shape = outputs[0].shape print(f"模型类型: YOLOv8 (单输出)") print(f"输出形状: {shape}") if len(shape) == 3: # YOLOv8 输出通常是 [1, 84, 8400] 或 [1, 15, 8400] 等 num_classes = shape[1] - 4 # 减去 x,y,w,h num_boxes = shape[2] print(f"检测框数量: {num_boxes}") print(f"类别数: {num_classes}") elif len(outputs) == 3: print(f"模型类型: YOLOv5 (三输出)") for i, out in enumerate(outputs): print(f" 输出[{i}] shape={out.shape}") else: print(f"模型类型: 其他 ({len(outputs)} 个输出)") rknn.release() if __name__ == "__main__": if len(sys.argv) < 2: print("用法: python3 check_model.py ") print("示例: python3 check_model.py models/best-768.rknn") sys.exit(1) check_model(sys.argv[1])