xcms_beixiaocai/002批量测试模型图片.py

63 lines
2.2 KiB
Python

from ultralytics import YOLO
import cv2
import os
def predict_and_visualize(model_path, image_path, output_path):
# 加载训练好的模型
model = YOLO(model_path)
all_image_file = os.listdir(image_path)
all_image_path = [os.path.join(image_path, t) for t in all_image_file]
for i in range(len(all_image_path)):
# 进行预测
results = model.predict(source=all_image_path[i], conf=0.25) # conf设置置信度阈值
# 读取原始图片
img = cv2.imread(all_image_path[i])
# 获取预测结果
boxes = results[0].boxes
class_names = model.names # 获取类别名称字典
# {0: 'head', 1: 'safehat'}
print(class_names)
# break
# 遍历每个检测结果
for box in boxes:
# 获取坐标和类别信息
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
cls_id = int(box.cls[0].item())
conf = box.conf[0].item()
# 绘制边界框
color = (0, 255, 0) # 绿色边框
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
# 准备显示文本
label = f"{class_names[cls_id]}: {conf:.2f}"
# 计算文本位置
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
# 绘制文本背景
cv2.rectangle(img, (x1, y1 - h - 5), (x1 + w, y1), color, -1)
# 绘制文本
cv2.putText(img, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)
# 保存结果
cv2.imwrite(output_path+f"{i}.jpg", img)
print(f"结果已保存至: {output_path}")
if __name__ == "__main__":
# 使用示例
model_path = "/home/admin-root/haotian/xcms/models/安全帽检测模型OpenVINO/best_s.xml" # 替换为你的模型路径
image_path = "/home/admin-root/haotian/锻8/tensorrtx/yolov8/images" # 替换为你的图片路径
output_path = "/home/admin-root/haotian/xcms/output/" # 输出文件名
predict_and_visualize(model_path, image_path, output_path)