调整目录结构, 拆分一些模块
This commit is contained in:
parent
e6a8571626
commit
23fb86f8bd
@ -8,11 +8,12 @@ from pathlib import Path
|
||||
import subprocess
|
||||
import logging
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
import yaml
|
||||
|
||||
|
||||
from util.yolo_util import YOLODetector
|
||||
from util.entity_utl import DetectionResult, AlarmConfig
|
||||
from util.log_util import TimeBasedDuplicateFilter
|
||||
|
||||
|
||||
@ -35,23 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
# 5为时间间隔,单位s. 5s内不会输出相同日志
|
||||
logger.addFilter(TimeBasedDuplicateFilter(5))
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""检测结果数据类"""
|
||||
boxes: np.ndarray
|
||||
confidences: np.ndarray
|
||||
class_ids: np.ndarray
|
||||
class_names: List[str]
|
||||
timestamp: datetime
|
||||
|
||||
@dataclass
|
||||
class AlarmConfig:
|
||||
"""告警配置"""
|
||||
target_classes: List[str] # 目标类别
|
||||
confidence_threshold: float = 0.5 # 置信度阈值
|
||||
alarm_duration: int = 10 # 告警录制时长(秒)
|
||||
cooldown_duration: int = 30 # 告警冷却时间(秒)
|
||||
save_path: str = "./alarm_videos" # 保存路径
|
||||
|
||||
class FrameBuffer:
|
||||
"""帧缓冲区,用于告警录制"""
|
||||
@ -74,34 +59,7 @@ class FrameBuffer:
|
||||
|
||||
return frames
|
||||
|
||||
class YOLODetector:
|
||||
"""YOLO检测器"""
|
||||
def __init__(self, model_path: str = "yolov8n.pt"):
|
||||
self.model = YOLO(model_path)
|
||||
self.class_names = self.model.names
|
||||
|
||||
def detect(self, frame: np.ndarray, confidence_threshold: float = 0.5) -> DetectionResult:
|
||||
"""执行目标检测"""
|
||||
results = self.model(frame, conf=confidence_threshold, verbose=False)
|
||||
|
||||
if len(results) > 0 and results[0].boxes is not None:
|
||||
boxes = results[0].boxes.xyxy.cpu().numpy()
|
||||
confidences = results[0].boxes.conf.cpu().numpy()
|
||||
class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
|
||||
class_names = [self.class_names[id] for id in class_ids]
|
||||
else:
|
||||
boxes = np.array([])
|
||||
confidences = np.array([])
|
||||
class_ids = np.array([])
|
||||
class_names = []
|
||||
|
||||
return DetectionResult(
|
||||
boxes=boxes,
|
||||
confidences=confidences,
|
||||
class_ids=class_ids,
|
||||
class_names=class_names,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
|
||||
class AlarmManager:
|
||||
"""告警管理器"""
|
||||
@ -425,6 +383,8 @@ class RTSPProcessor:
|
||||
def connect_to_rtsp_stream(self, url):
|
||||
""" 尝试连接到 RTSP 流 """
|
||||
cap = cv2.VideoCapture(url)
|
||||
# 设置缓冲区大小
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
# cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'HEVC'))
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Failed to connect to {url}")
|
||||
@ -435,9 +395,6 @@ class RTSPProcessor:
|
||||
"""捕获RTSP流帧"""
|
||||
cap = self.connect_to_rtsp_stream(self.rtsp_url)
|
||||
|
||||
# 设置缓冲区大小
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
|
||||
logger.info(f"开始捕获RTSP流: {self.rtsp_url}")
|
||||
|
||||
try:
|
||||
|
||||
22
util/entity_utl.py
Normal file
22
util/entity_utl.py
Normal file
@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""检测结果数据类"""
|
||||
boxes: np.ndarray
|
||||
confidences: np.ndarray
|
||||
class_ids: np.ndarray
|
||||
class_names: List[str]
|
||||
timestamp: datetime
|
||||
|
||||
@dataclass
|
||||
class AlarmConfig:
|
||||
"""告警配置"""
|
||||
target_classes: List[str] # 目标类别
|
||||
confidence_threshold: float = 0.5 # 置信度阈值
|
||||
alarm_duration: int = 10 # 告警录制时长(秒)
|
||||
cooldown_duration: int = 30 # 告警冷却时间(秒)
|
||||
save_path: str = "./alarm_videos" # 保存路径
|
||||
33
util/yolo_util.py
Normal file
33
util/yolo_util.py
Normal file
@ -0,0 +1,33 @@
|
||||
from ultralytics import YOLO
|
||||
import numpy as np
|
||||
from util.entity_utl import DetectionResult
|
||||
from datetime import datetime
|
||||
|
||||
class YOLODetector:
|
||||
"""YOLO检测器"""
|
||||
def __init__(self, model_path: str = "yolov8n.pt"):
|
||||
self.model = YOLO(model_path)
|
||||
self.class_names = self.model.names
|
||||
|
||||
def detect(self, frame: np.ndarray, confidence_threshold: float = 0.5) -> DetectionResult:
|
||||
"""执行目标检测"""
|
||||
results = self.model(frame, conf=confidence_threshold, verbose=False)
|
||||
|
||||
if len(results) > 0 and results[0].boxes is not None:
|
||||
boxes = results[0].boxes.xyxy.cpu().numpy()
|
||||
confidences = results[0].boxes.conf.cpu().numpy()
|
||||
class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
|
||||
class_names = [self.class_names[id] for id in class_ids]
|
||||
else:
|
||||
boxes = np.array([])
|
||||
confidences = np.array([])
|
||||
class_ids = np.array([])
|
||||
class_names = []
|
||||
|
||||
return DetectionResult(
|
||||
boxes=boxes,
|
||||
confidences=confidences,
|
||||
class_ids=class_ids,
|
||||
class_names=class_names,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user