240 lines
8.2 KiB
Python
240 lines
8.2 KiB
Python
import cv2
|
||
import numpy as np
|
||
from rknn.api import RKNN
|
||
import time
|
||
|
||
class YOLOv8RKNN:
|
||
def __init__(self, model_path, input_size=(640, 640)):
|
||
self.model_path = model_path
|
||
self.input_size = input_size
|
||
self.rknn = RKNN()
|
||
|
||
# 类别名称,根据你的2个类别修改
|
||
self.class_names = ['class1', 'class2'] # 请替换为你实际的类别名称
|
||
|
||
# 初始化模型
|
||
self.load_model()
|
||
|
||
def load_model(self):
|
||
"""加载RKNN模型"""
|
||
print("Loading RKNN model...")
|
||
ret = self.rknn.load_rknn(self.model_path)
|
||
if ret != 0:
|
||
print("Load RKNN model failed!")
|
||
return False
|
||
|
||
# 初始化运行时环境(在RK3588设备上运行)
|
||
print("Init RKNN runtime...")
|
||
ret = self.rknn.init_runtime(target='rk3588', device_id=None, perf_debug=False, eval_mem=False)
|
||
if ret != 0:
|
||
print("Init RKNN runtime failed!")
|
||
return False
|
||
|
||
print("RKNN model loaded successfully!")
|
||
return True
|
||
|
||
def preprocess(self, image):
|
||
"""图像预处理"""
|
||
# 获取原始图像尺寸
|
||
self.orig_height, self.orig_width = image.shape[:2]
|
||
|
||
# Resize到模型输入尺寸,保持宽高比
|
||
scale = min(self.input_size[0]/self.orig_width, self.input_size[1]/self.orig_height)
|
||
new_width = int(self.orig_width * scale)
|
||
new_height = int(self.orig_height * scale)
|
||
|
||
# 缩放图像
|
||
resized = cv2.resize(image, (new_width, new_height))
|
||
|
||
# 创建输入图像(填充到目标尺寸)
|
||
input_image = np.full((self.input_size[1], self.input_size[0], 3), 114, dtype=np.uint8)
|
||
|
||
# 计算填充位置(居中)
|
||
y_offset = (self.input_size[1] - new_height) // 2
|
||
x_offset = (self.input_size[0] - new_width) // 2
|
||
|
||
# 将缩放后的图像放到中心位置
|
||
input_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width] = resized
|
||
|
||
# 保存缩放参数用于后处理
|
||
self.scale = scale
|
||
self.x_offset = x_offset
|
||
self.y_offset = y_offset
|
||
|
||
return input_image
|
||
|
||
def postprocess(self, outputs, conf_threshold=0.5, nms_threshold=0.4):
|
||
"""后处理:解析YOLO输出并进行NMS"""
|
||
# YOLOv8输出格式: [batch, 84, 8400] (2个类别: 4+2+80=84,但实际只有6维)
|
||
# 对于2类别: [x, y, w, h, conf_class1, conf_class2]
|
||
predictions = outputs[0][0] # 移除batch维度
|
||
|
||
# 转置为 [8400, 6] 格式
|
||
predictions = predictions.transpose()
|
||
|
||
boxes = []
|
||
scores = []
|
||
class_ids = []
|
||
|
||
for detection in predictions:
|
||
# 提取坐标和类别置信度
|
||
x, y, w, h = detection[:4]
|
||
class_confs = detection[4:6] # 2个类别的置信度
|
||
|
||
# 找到最大置信度的类别
|
||
class_id = np.argmax(class_confs)
|
||
max_conf = class_confs[class_id]
|
||
|
||
if max_conf >= conf_threshold:
|
||
# 转换坐标格式 (中心点 -> 左上角)
|
||
x1 = x - w/2
|
||
y1 = y - h/2
|
||
x2 = x + w/2
|
||
y2 = y + h/2
|
||
|
||
# 将坐标映射回原图尺寸
|
||
x1 = (x1 - self.x_offset) / self.scale
|
||
y1 = (y1 - self.y_offset) / self.scale
|
||
x2 = (x2 - self.x_offset) / self.scale
|
||
y2 = (y2 - self.y_offset) / self.scale
|
||
|
||
# 限制在图像边界内
|
||
x1 = max(0, min(x1, self.orig_width))
|
||
y1 = max(0, min(y1, self.orig_height))
|
||
x2 = max(0, min(x2, self.orig_width))
|
||
y2 = max(0, min(y2, self.orig_height))
|
||
|
||
boxes.append([x1, y1, x2, y2])
|
||
scores.append(max_conf)
|
||
class_ids.append(class_id)
|
||
|
||
# 执行NMS
|
||
if len(boxes) > 0:
|
||
boxes = np.array(boxes)
|
||
scores = np.array(scores)
|
||
class_ids = np.array(class_ids)
|
||
|
||
# OpenCV NMS
|
||
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, nms_threshold)
|
||
|
||
if len(indices) > 0:
|
||
indices = indices.flatten()
|
||
return boxes[indices], scores[indices], class_ids[indices]
|
||
|
||
return np.array([]), np.array([]), np.array([])
|
||
|
||
def detect(self, image, conf_threshold=0.5, nms_threshold=0.4):
|
||
"""执行检测"""
|
||
# 预处理
|
||
input_image = self.preprocess(image)
|
||
|
||
# 推理
|
||
start_time = time.time()
|
||
outputs = self.rknn.inference(inputs=[input_image])
|
||
inference_time = time.time() - start_time
|
||
|
||
# 后处理
|
||
boxes, scores, class_ids = self.postprocess(outputs, conf_threshold, nms_threshold)
|
||
|
||
return boxes, scores, class_ids, inference_time
|
||
|
||
def draw_detections(self, image, boxes, scores, class_ids):
|
||
"""在图像上绘制检测结果"""
|
||
for i in range(len(boxes)):
|
||
x1, y1, x2, y2 = boxes[i].astype(int)
|
||
score = scores[i]
|
||
class_id = int(class_ids[i])
|
||
|
||
# 绘制边界框
|
||
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
|
||
# 绘制标签
|
||
label = f"{self.class_names[class_id]}: {score:.2f}"
|
||
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||
cv2.rectangle(image, (x1, y1-label_size[1]-10),
|
||
(x1+label_size[0], y1), (0, 255, 0), -1)
|
||
cv2.putText(image, label, (x1, y1-5),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||
|
||
return image
|
||
|
||
def release(self):
|
||
"""释放资源"""
|
||
if self.rknn:
|
||
self.rknn.release()
|
||
|
||
def main():
|
||
# 初始化检测器
|
||
model_path = "/home/orangepi/Desktop/康达机器狗/model_rknn/yolov8_20250820.rknn"
|
||
detector = YOLOv8RKNN(model_path)
|
||
|
||
# 测试单张图片
|
||
def test_image(image_path):
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
print(f"Cannot load image: {image_path}")
|
||
return
|
||
|
||
# 执行检测
|
||
boxes, scores, class_ids, inference_time = detector.detect(image)
|
||
|
||
print(f"Inference time: {inference_time*1000:.2f}ms")
|
||
print(f"Detected {len(boxes)} objects")
|
||
|
||
# 绘制结果
|
||
result_image = detector.draw_detections(image, boxes, scores, class_ids)
|
||
|
||
# 显示结果
|
||
# cv2.imshow("Detection Result", result_image)
|
||
# cv2.waitKey(0)
|
||
# cv2.destroyAllWindows()
|
||
|
||
cv2.imwrite("xxxxxxx.jpg", result_image)
|
||
|
||
# 测试摄像头实时检测
|
||
def test_camera():
|
||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||
if not cap.isOpened():
|
||
print("Cannot open camera")
|
||
return
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# 执行检测
|
||
boxes, scores, class_ids, inference_time = detector.detect(frame)
|
||
|
||
# 绘制结果
|
||
result_frame = detector.draw_detections(frame, boxes, scores, class_ids)
|
||
|
||
# 显示FPS
|
||
fps = 1.0 / inference_time if inference_time > 0 else 0
|
||
cv2.putText(result_frame, f"FPS: {fps:.1f}", (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||
|
||
cv2.imshow("Real-time Detection", result_frame)
|
||
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
break
|
||
|
||
cap.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
# 选择测试模式
|
||
mode = input("选择模式 (1: 图片检测, 2: 摄像头实时检测): ")
|
||
|
||
if mode == "1":
|
||
image_path = input("输入图片路径: ")
|
||
test_image(image_path)
|
||
elif mode == "2":
|
||
test_camera()
|
||
else:
|
||
print("无效选择")
|
||
|
||
# 释放资源
|
||
detector.release()
|
||
|
||
if __name__ == "__main__":
|
||
main() |