添加使用yolov8_rknn模型脚本
This commit is contained in:
parent
21a2b16a26
commit
fdb821fc67
240
009使用rknn模型.py
Normal file
240
009使用rknn模型.py
Normal file
@ -0,0 +1,240 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from rknn.api import RKNN
|
||||
import time
|
||||
|
||||
class YOLOv8RKNN:
|
||||
def __init__(self, model_path, input_size=(640, 640)):
|
||||
self.model_path = model_path
|
||||
self.input_size = input_size
|
||||
self.rknn = RKNN()
|
||||
|
||||
# 类别名称,根据你的2个类别修改
|
||||
self.class_names = ['class1', 'class2'] # 请替换为你实际的类别名称
|
||||
|
||||
# 初始化模型
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""加载RKNN模型"""
|
||||
print("Loading RKNN model...")
|
||||
ret = self.rknn.load_rknn(self.model_path)
|
||||
if ret != 0:
|
||||
print("Load RKNN model failed!")
|
||||
return False
|
||||
|
||||
# 初始化运行时环境(在RK3588设备上运行)
|
||||
print("Init RKNN runtime...")
|
||||
ret = self.rknn.init_runtime(target='rk3588', device_id=None, perf_debug=False, eval_mem=False)
|
||||
if ret != 0:
|
||||
print("Init RKNN runtime failed!")
|
||||
return False
|
||||
|
||||
print("RKNN model loaded successfully!")
|
||||
return True
|
||||
|
||||
def preprocess(self, image):
|
||||
"""图像预处理"""
|
||||
# 获取原始图像尺寸
|
||||
self.orig_height, self.orig_width = image.shape[:2]
|
||||
|
||||
# Resize到模型输入尺寸,保持宽高比
|
||||
scale = min(self.input_size[0]/self.orig_width, self.input_size[1]/self.orig_height)
|
||||
new_width = int(self.orig_width * scale)
|
||||
new_height = int(self.orig_height * scale)
|
||||
|
||||
# 缩放图像
|
||||
resized = cv2.resize(image, (new_width, new_height))
|
||||
|
||||
# 创建输入图像(填充到目标尺寸)
|
||||
input_image = np.full((self.input_size[1], self.input_size[0], 3), 114, dtype=np.uint8)
|
||||
|
||||
# 计算填充位置(居中)
|
||||
y_offset = (self.input_size[1] - new_height) // 2
|
||||
x_offset = (self.input_size[0] - new_width) // 2
|
||||
|
||||
# 将缩放后的图像放到中心位置
|
||||
input_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width] = resized
|
||||
|
||||
# 保存缩放参数用于后处理
|
||||
self.scale = scale
|
||||
self.x_offset = x_offset
|
||||
self.y_offset = y_offset
|
||||
|
||||
return input_image
|
||||
|
||||
def postprocess(self, outputs, conf_threshold=0.5, nms_threshold=0.4):
|
||||
"""后处理:解析YOLO输出并进行NMS"""
|
||||
# YOLOv8输出格式: [batch, 84, 8400] (2个类别: 4+2+80=84,但实际只有6维)
|
||||
# 对于2类别: [x, y, w, h, conf_class1, conf_class2]
|
||||
predictions = outputs[0][0] # 移除batch维度
|
||||
|
||||
# 转置为 [8400, 6] 格式
|
||||
predictions = predictions.transpose()
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
class_ids = []
|
||||
|
||||
for detection in predictions:
|
||||
# 提取坐标和类别置信度
|
||||
x, y, w, h = detection[:4]
|
||||
class_confs = detection[4:6] # 2个类别的置信度
|
||||
|
||||
# 找到最大置信度的类别
|
||||
class_id = np.argmax(class_confs)
|
||||
max_conf = class_confs[class_id]
|
||||
|
||||
if max_conf >= conf_threshold:
|
||||
# 转换坐标格式 (中心点 -> 左上角)
|
||||
x1 = x - w/2
|
||||
y1 = y - h/2
|
||||
x2 = x + w/2
|
||||
y2 = y + h/2
|
||||
|
||||
# 将坐标映射回原图尺寸
|
||||
x1 = (x1 - self.x_offset) / self.scale
|
||||
y1 = (y1 - self.y_offset) / self.scale
|
||||
x2 = (x2 - self.x_offset) / self.scale
|
||||
y2 = (y2 - self.y_offset) / self.scale
|
||||
|
||||
# 限制在图像边界内
|
||||
x1 = max(0, min(x1, self.orig_width))
|
||||
y1 = max(0, min(y1, self.orig_height))
|
||||
x2 = max(0, min(x2, self.orig_width))
|
||||
y2 = max(0, min(y2, self.orig_height))
|
||||
|
||||
boxes.append([x1, y1, x2, y2])
|
||||
scores.append(max_conf)
|
||||
class_ids.append(class_id)
|
||||
|
||||
# 执行NMS
|
||||
if len(boxes) > 0:
|
||||
boxes = np.array(boxes)
|
||||
scores = np.array(scores)
|
||||
class_ids = np.array(class_ids)
|
||||
|
||||
# OpenCV NMS
|
||||
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, nms_threshold)
|
||||
|
||||
if len(indices) > 0:
|
||||
indices = indices.flatten()
|
||||
return boxes[indices], scores[indices], class_ids[indices]
|
||||
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
def detect(self, image, conf_threshold=0.5, nms_threshold=0.4):
|
||||
"""执行检测"""
|
||||
# 预处理
|
||||
input_image = self.preprocess(image)
|
||||
|
||||
# 推理
|
||||
start_time = time.time()
|
||||
outputs = self.rknn.inference(inputs=[input_image])
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# 后处理
|
||||
boxes, scores, class_ids = self.postprocess(outputs, conf_threshold, nms_threshold)
|
||||
|
||||
return boxes, scores, class_ids, inference_time
|
||||
|
||||
def draw_detections(self, image, boxes, scores, class_ids):
|
||||
"""在图像上绘制检测结果"""
|
||||
for i in range(len(boxes)):
|
||||
x1, y1, x2, y2 = boxes[i].astype(int)
|
||||
score = scores[i]
|
||||
class_id = int(class_ids[i])
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# 绘制标签
|
||||
label = f"{self.class_names[class_id]}: {score:.2f}"
|
||||
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||
cv2.rectangle(image, (x1, y1-label_size[1]-10),
|
||||
(x1+label_size[0], y1), (0, 255, 0), -1)
|
||||
cv2.putText(image, label, (x1, y1-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||
|
||||
return image
|
||||
|
||||
def release(self):
|
||||
"""释放资源"""
|
||||
if self.rknn:
|
||||
self.rknn.release()
|
||||
|
||||
def main():
|
||||
# 初始化检测器
|
||||
model_path = "/home/orangepi/Desktop/康达机器狗/model_rknn/yolov8_20250820.rknn"
|
||||
detector = YOLOv8RKNN(model_path)
|
||||
|
||||
# 测试单张图片
|
||||
def test_image(image_path):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f"Cannot load image: {image_path}")
|
||||
return
|
||||
|
||||
# 执行检测
|
||||
boxes, scores, class_ids, inference_time = detector.detect(image)
|
||||
|
||||
print(f"Inference time: {inference_time*1000:.2f}ms")
|
||||
print(f"Detected {len(boxes)} objects")
|
||||
|
||||
# 绘制结果
|
||||
result_image = detector.draw_detections(image, boxes, scores, class_ids)
|
||||
|
||||
# 显示结果
|
||||
# cv2.imshow("Detection Result", result_image)
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
cv2.imwrite("xxxxxxx.jpg", result_image)
|
||||
|
||||
# 测试摄像头实时检测
|
||||
def test_camera():
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
if not cap.isOpened():
|
||||
print("Cannot open camera")
|
||||
return
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 执行检测
|
||||
boxes, scores, class_ids, inference_time = detector.detect(frame)
|
||||
|
||||
# 绘制结果
|
||||
result_frame = detector.draw_detections(frame, boxes, scores, class_ids)
|
||||
|
||||
# 显示FPS
|
||||
fps = 1.0 / inference_time if inference_time > 0 else 0
|
||||
cv2.putText(result_frame, f"FPS: {fps:.1f}", (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow("Real-time Detection", result_frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 选择测试模式
|
||||
mode = input("选择模式 (1: 图片检测, 2: 摄像头实时检测): ")
|
||||
|
||||
if mode == "1":
|
||||
image_path = input("输入图片路径: ")
|
||||
test_image(image_path)
|
||||
elif mode == "2":
|
||||
test_camera()
|
||||
else:
|
||||
print("无效选择")
|
||||
|
||||
# 释放资源
|
||||
detector.release()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -29,7 +29,8 @@ router = APIRouter(prefix="/api/v1", tags=["ocr"])
|
||||
@router.get("/hello")
|
||||
async def get_hello():
|
||||
|
||||
return {"data":"hello"}
|
||||
# return {"data":"hello"}
|
||||
return ResponseUtil.error(msg=f"OCR识别失败", data=None)
|
||||
|
||||
@router.get("/test_select")
|
||||
async def test_select(
|
||||
|
||||
@ -6,6 +6,8 @@ import onnxruntime as ort
|
||||
from app.config.config import yolov8_settings
|
||||
|
||||
|
||||
|
||||
|
||||
class Yolov8Obj:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@ -90,18 +90,18 @@ if __name__ == "__main__":
|
||||
# test_ocr_api(test_image_path, api_url)
|
||||
# #---------------------------------------测试ocrender-----------------------------------------
|
||||
|
||||
#-----------------------------------------测试yolov8 侵占消防区域检测-----------------------------------------
|
||||
test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000000.jpg"
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000720.jpg"
|
||||
api_url = "http://10.0.0.202:12342/api/v1/detect_onnx_from_base64_0"
|
||||
test_detect(test_image_path)
|
||||
#-----------------------------------------测试yolov8 侵占消防区域检测 end-----------------------------------------
|
||||
# # -----------------------------------------测试yolov8 侵占消防区域检测-----------------------------------------
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000000.jpg"
|
||||
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000720.jpg"
|
||||
# api_url = "http://10.0.0.202:12342/api/v1/detect_onnx_from_base64_0"
|
||||
# test_detect(test_image_path)
|
||||
# #-----------------------------------------测试yolov8 侵占消防区域检测 end-----------------------------------------
|
||||
|
||||
|
||||
# #-----------------------------------------测试yolov8 灭火器检测-----------------------------------------
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/ce81420a27cdaff14fe42f967eaa49a3_frame_001060.jpg"
|
||||
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||
# api_url = "http://10.0.0.202:12342/api/v1/detect_from_base64_1"
|
||||
# test_detect(test_image_path, api_url=api_url)
|
||||
# #-----------------------------------------测试yolov8 灭火器检测 end-----------------------------------------
|
||||
#-----------------------------------------测试yolov8 灭火器检测-----------------------------------------
|
||||
test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/ce81420a27cdaff14fe42f967eaa49a3_frame_001060.jpg"
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||
api_url = "http://10.0.0.202:12342/api/v1/detect_from_base64_1"
|
||||
test_detect(test_image_path, api_url=api_url)
|
||||
#-----------------------------------------测试yolov8 灭火器检测 end-----------------------------------------
|
||||
Loading…
Reference in New Issue
Block a user