convert-the-model-to-rknn/015测试Yolov8_rknn.py

240 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from rknn.api import RKNN
import time
class YOLOv8RKNN:
def __init__(self, model_path, input_size=(640, 640)):
self.model_path = model_path
self.input_size = input_size
self.rknn = RKNN()
# 类别名称根据你的2个类别修改
self.class_names = ['class1', 'class2'] # 请替换为你实际的类别名称
# 初始化模型
self.load_model()
def load_model(self):
"""加载RKNN模型"""
print("Loading RKNN model...")
ret = self.rknn.load_rknn(self.model_path)
if ret != 0:
print("Load RKNN model failed!")
return False
# 初始化运行时环境在RK3588设备上运行
print("Init RKNN runtime...")
ret = self.rknn.init_runtime(target='rk3588', device_id=None, perf_debug=False, eval_mem=False)
if ret != 0:
print("Init RKNN runtime failed!")
return False
print("RKNN model loaded successfully!")
return True
def preprocess(self, image):
"""图像预处理"""
# 获取原始图像尺寸
self.orig_height, self.orig_width = image.shape[:2]
# Resize到模型输入尺寸保持宽高比
scale = min(self.input_size[0]/self.orig_width, self.input_size[1]/self.orig_height)
new_width = int(self.orig_width * scale)
new_height = int(self.orig_height * scale)
# 缩放图像
resized = cv2.resize(image, (new_width, new_height))
# 创建输入图像(填充到目标尺寸)
input_image = np.full((self.input_size[1], self.input_size[0], 3), 114, dtype=np.uint8)
# 计算填充位置(居中)
y_offset = (self.input_size[1] - new_height) // 2
x_offset = (self.input_size[0] - new_width) // 2
# 将缩放后的图像放到中心位置
input_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width] = resized
# 保存缩放参数用于后处理
self.scale = scale
self.x_offset = x_offset
self.y_offset = y_offset
return input_image
def postprocess(self, outputs, conf_threshold=0.5, nms_threshold=0.4):
"""后处理解析YOLO输出并进行NMS"""
# YOLOv8输出格式: [batch, 84, 8400] (2个类别: 4+2+80=84但实际只有6维)
# 对于2类别: [x, y, w, h, conf_class1, conf_class2]
predictions = outputs[0][0] # 移除batch维度
# 转置为 [8400, 6] 格式
predictions = predictions.transpose()
boxes = []
scores = []
class_ids = []
for detection in predictions:
# 提取坐标和类别置信度
x, y, w, h = detection[:4]
class_confs = detection[4:6] # 2个类别的置信度
# 找到最大置信度的类别
class_id = np.argmax(class_confs)
max_conf = class_confs[class_id]
if max_conf >= conf_threshold:
# 转换坐标格式 (中心点 -> 左上角)
x1 = x - w/2
y1 = y - h/2
x2 = x + w/2
y2 = y + h/2
# 将坐标映射回原图尺寸
x1 = (x1 - self.x_offset) / self.scale
y1 = (y1 - self.y_offset) / self.scale
x2 = (x2 - self.x_offset) / self.scale
y2 = (y2 - self.y_offset) / self.scale
# 限制在图像边界内
x1 = max(0, min(x1, self.orig_width))
y1 = max(0, min(y1, self.orig_height))
x2 = max(0, min(x2, self.orig_width))
y2 = max(0, min(y2, self.orig_height))
boxes.append([x1, y1, x2, y2])
scores.append(max_conf)
class_ids.append(class_id)
# 执行NMS
if len(boxes) > 0:
boxes = np.array(boxes)
scores = np.array(scores)
class_ids = np.array(class_ids)
# OpenCV NMS
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, nms_threshold)
if len(indices) > 0:
indices = indices.flatten()
return boxes[indices], scores[indices], class_ids[indices]
return np.array([]), np.array([]), np.array([])
def detect(self, image, conf_threshold=0.5, nms_threshold=0.4):
"""执行检测"""
# 预处理
input_image = self.preprocess(image)
# 推理
start_time = time.time()
outputs = self.rknn.inference(inputs=[input_image])
inference_time = time.time() - start_time
# 后处理
boxes, scores, class_ids = self.postprocess(outputs, conf_threshold, nms_threshold)
return boxes, scores, class_ids, inference_time
def draw_detections(self, image, boxes, scores, class_ids):
"""在图像上绘制检测结果"""
for i in range(len(boxes)):
x1, y1, x2, y2 = boxes[i].astype(int)
score = scores[i]
class_id = int(class_ids[i])
# 绘制边界框
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 绘制标签
label = f"{self.class_names[class_id]}: {score:.2f}"
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(image, (x1, y1-label_size[1]-10),
(x1+label_size[0], y1), (0, 255, 0), -1)
cv2.putText(image, label, (x1, y1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
return image
def release(self):
"""释放资源"""
if self.rknn:
self.rknn.release()
def main():
# 初始化检测器
model_path = "/home/orangepi/Desktop/康达机器狗/model_rknn/yolov8_20250820.rknn"
detector = YOLOv8RKNN(model_path)
# 测试单张图片
def test_image(image_path):
image = cv2.imread(image_path)
if image is None:
print(f"Cannot load image: {image_path}")
return
# 执行检测
boxes, scores, class_ids, inference_time = detector.detect(image)
print(f"Inference time: {inference_time*1000:.2f}ms")
print(f"Detected {len(boxes)} objects")
# 绘制结果
result_image = detector.draw_detections(image, boxes, scores, class_ids)
# 显示结果
# cv2.imshow("Detection Result", result_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
cv2.imwrite("xxxxxxx.jpg", result_image)
# 测试摄像头实时检测
def test_camera():
cap = cv2.VideoCapture(0) # 使用默认摄像头
if not cap.isOpened():
print("Cannot open camera")
return
while True:
ret, frame = cap.read()
if not ret:
break
# 执行检测
boxes, scores, class_ids, inference_time = detector.detect(frame)
# 绘制结果
result_frame = detector.draw_detections(frame, boxes, scores, class_ids)
# 显示FPS
fps = 1.0 / inference_time if inference_time > 0 else 0
cv2.putText(result_frame, f"FPS: {fps:.1f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("Real-time Detection", result_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
# 选择测试模式
mode = input("选择模式 (1: 图片检测, 2: 摄像头实时检测): ")
if mode == "1":
image_path = input("输入图片路径: ")
test_image(image_path)
elif mode == "2":
test_camera()
else:
print("无效选择")
# 释放资源
detector.release()
if __name__ == "__main__":
main()