xcms_beixiaocai/001测试模型.py

51 lines
1.7 KiB
Python

from ultralytics import YOLO
import cv2
def predict_and_visualize(model_path, image_path, output_path):
# 加载训练好的模型
model = YOLO(model_path)
# 进行预测
results = model.predict(source=image_path, conf=0.25) # conf设置置信度阈值
# 读取原始图片
img = cv2.imread(image_path)
# 获取预测结果
boxes = results[0].boxes
class_names = model.names # 获取类别名称字典
# 遍历每个检测结果
for box in boxes:
# 获取坐标和类别信息
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
cls_id = int(box.cls[0].item())
conf = box.conf[0].item()
# 绘制边界框
color = (0, 255, 0) # 绿色边框
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
# 准备显示文本
label = f"{class_names[cls_id]}: {conf:.2f}"
# 计算文本位置
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
# 绘制文本背景
cv2.rectangle(img, (x1, y1 - h - 5), (x1 + w, y1), color, -1)
# 绘制文本
cv2.putText(img, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)
# 保存结果
cv2.imwrite(output_path, img)
print(f"结果已保存至: {output_path}")
if __name__ == "__main__":
# 使用示例
model_path = "models/安全帽检测模型/yolo11n_safehat.pt" # 替换为你的模型路径
image_path = "images/mp4_509.jpg" # 替换为你的图片路径
output_path = "output/mp4_509.jpg" # 输出文件名
predict_and_visualize(model_path, image_path, output_path)