33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
from ultralytics import YOLO
|
|
import numpy as np
|
|
from util.entity_utl import DetectionResult
|
|
from datetime import datetime
|
|
|
|
class YOLODetector:
|
|
"""YOLO检测器"""
|
|
def __init__(self, model_path: str = "yolov8n.pt"):
|
|
self.model = YOLO(model_path)
|
|
self.class_names = self.model.names
|
|
|
|
def detect(self, frame: np.ndarray, confidence_threshold: float = 0.5) -> DetectionResult:
|
|
"""执行目标检测"""
|
|
results = self.model(frame, conf=confidence_threshold, verbose=False)
|
|
|
|
if len(results) > 0 and results[0].boxes is not None:
|
|
boxes = results[0].boxes.xyxy.cpu().numpy()
|
|
confidences = results[0].boxes.conf.cpu().numpy()
|
|
class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
|
|
class_names = [self.class_names[id] for id in class_ids]
|
|
else:
|
|
boxes = np.array([])
|
|
confidences = np.array([])
|
|
class_ids = np.array([])
|
|
class_names = []
|
|
|
|
return DetectionResult(
|
|
boxes=boxes,
|
|
confidences=confidences,
|
|
class_ids=class_ids,
|
|
class_names=class_names,
|
|
timestamp=datetime.now()
|
|
) |