diff --git a/config/config.yaml b/config/config.yaml index 96cb058..e50d87b 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -10,9 +10,16 @@ model: confidence_threshold: 0.5 distance_estimation: - focal_length: 1000 # 摄像头焦距 - sensor_height: 3000 # 摄像头传感器高度(mm) - average_person_height: 1700 # 平均人身高(mm) + # focal_length: 1000 # 摄像头焦距 + # sensor_height: 3000 # 摄像头传感器高度(mm) + # average_person_height: 1700 # 平均人身高(mm) + focal_length_mm : 35 # 焦距(mm) + sensor_width_mm : 23.6 # 传感器宽度(mm) + sensor_height_mm : 15.6 # 传感器高度(mm) + image_width_pixels : 640 # 图像宽度(像素) + image_height_pixels : 640 # 图像高度(像素) + camera_height_mm: 1700 # 摄像头安装高度(mm) + camera_tilt_angle: 15 # 摄像头俯仰角(度) api: host: "0.0.0.0" diff --git a/main.py b/main.py index 9ed6263..881cded 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import yaml import uvicorn -from src. import RTSPCamera +from src.camera_handler import RTSPCamera from src.person_detector import PersonDetector from src.distance_estimator import DistanceEstimator from src.api_server import DistanceAPI, app @@ -26,10 +26,23 @@ def main(): ) # 初始化距离估算器 + ''' + focal_length_mm : 35 # 焦距(mm) + sensor_width_mm : 23.6 # 传感器宽度(mm) + sensor_height_mm : 15.6 # 传感器高度(mm) + image_width_pixels : 640 # 图像宽度(像素) + image_height_pixels : 640 # 图像高度(像素) + camera_height_mm: 1700 # 摄像头安装高度(mm) + camera_tilt_angle : 15 # 摄像头俯仰角(度) + ''' estimator = DistanceEstimator( - config['model']['distance_estimation']['focal_length'], - config['model']['distance_estimation']['sensor_height'], - config['model']['distance_estimation']['average_person_height'] + config['model']['distance_estimation']['focal_length_mm'], + config['model']['distance_estimation']['sensor_width_mm'], + config['model']['distance_estimation']['sensor_height_mm'], + config['model']['distance_estimation']['image_width_pixels'], + config['model']['distance_estimation']['image_height_pixels'], + config['model']['distance_estimation']['camera_height_mm'], + config['model']['distance_estimation']['camera_tilt_angle'] ) # 初始化API diff --git a/output/detection_20250113_165518_851152.jpg b/output/detection_20250113_165518_851152.jpg new file mode 100644 index 0000000..e092edf Binary files /dev/null and b/output/detection_20250113_165518_851152.jpg differ diff --git a/output/detection_20250113_165638_412391.jpg b/output/detection_20250113_165638_412391.jpg new file mode 100644 index 0000000..47fedab Binary files /dev/null and b/output/detection_20250113_165638_412391.jpg differ diff --git a/src/__pycache__/api_server.cpython-39.pyc b/src/__pycache__/api_server.cpython-39.pyc index 04898e4..546ad91 100644 Binary files a/src/__pycache__/api_server.cpython-39.pyc and b/src/__pycache__/api_server.cpython-39.pyc differ diff --git a/src/__pycache__/camera_handler.cpython-39.pyc b/src/__pycache__/camera_handler.cpython-39.pyc index 7b9355e..1e84446 100644 Binary files a/src/__pycache__/camera_handler.cpython-39.pyc and b/src/__pycache__/camera_handler.cpython-39.pyc differ diff --git a/src/api_server.py b/src/api_server.py index 32eb9e3..e44689f 100644 --- a/src/api_server.py +++ b/src/api_server.py @@ -1,35 +1,249 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import List, Optional +from flask import Flask, request, jsonify, send_file +from dataclasses import dataclass +from typing import List, Optional, Dict +import cv2 +import numpy as np +from datetime import datetime +import os +import time -app = FastAPI() +app = Flask(__name__) -class Distance(BaseModel): +@dataclass +class Distance: + """距离估计结果模型""" + person_id: int distance_mm: float confidence: float + bbox: List[int] + timestamp: str + + def to_dict(self): + return { + 'person_id': self.person_id, + 'distance_mm': self.distance_mm, + 'confidence': self.confidence, + 'bbox': self.bbox, + 'timestamp': self.timestamp + } + +@dataclass +class HealthCheck: + """健康检查响应模型""" + status: str + camera_connected: bool + detector_loaded: bool + estimator_initialized: bool + + def to_dict(self): + return { + 'status': self.status, + 'camera_connected': self.camera_connected, + 'detector_loaded': self.detector_loaded, + 'estimator_initialized': self.estimator_initialized + } + +@dataclass +class CameraConfig: + """相机配置模型""" + rtsp_url: str + fps: int = 30 + width: int = 1920 + height: int = 1080 + + @classmethod + def from_dict(cls, data): + return cls( + rtsp_url=data['rtsp_url'], + fps=data.get('fps', 30), + width=data.get('width', 1920), + height=data.get('height', 1080) + ) class DistanceAPI: def __init__(self, camera, detector, estimator): self.camera = camera + self.camera.start() + time.sleep(10) + print('相机启动') + self.detector = detector self.estimator = estimator + self.output_dir = "output" + self.ensure_output_dir() - async def get_distances(self) -> List[Distance]: + def ensure_output_dir(self): + """确保输出目录存在""" + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + + def get_timestamp(self): + """获取当前时间戳""" + return datetime.now().strftime("%Y%m%d_%H%M%S_%f") + + def save_visualization(self, frame, distances): + """保存可视化结果""" + timestamp = self.get_timestamp() + output_path = os.path.join(self.output_dir, f"detection_{timestamp}.jpg") + cv2.imwrite(output_path, frame) + return output_path + + def get_distances(self, save_visualization: bool = False) -> Dict: + """获取距离估计结果""" + # 获取视频帧 frame = self.camera.get_frame() if frame is None: - raise HTTPException(status_code=500, detail="No frame available") + return jsonify({'error': '无法获取视频帧'}), 500 - persons = self.detector.detect(frame) - distances = [] + # 检测人物 + persons, _ = self.detector.detect(frame) + if not persons: + return { + "count": 0, + "distances": [], + "visualization_path": None, + "timestamp": self.get_timestamp() + } - for person in persons: + # 估计距离 + distances = [] + vis_frame = frame.copy() + + for i, person in enumerate(persons): bbox = person['bbox'] - person_height = bbox[3] - bbox[1] # y2 - y1 - distance = self.estimator.estimate_distance(person_height, frame.shape[0]) + distance, confidence = self.estimator.estimate_distance(frame, bbox) + # 可视化 + vis_frame = self.estimator.visualize_estimation( + vis_frame, bbox, distance, confidence + ) + + # 添加结果 distances.append(Distance( + person_id=i, distance_mm=distance, - confidence=person['confidence'] + confidence=confidence, + bbox=list(bbox), + timestamp=self.get_timestamp() )) + + # 保存可视化结果 + visualization_path = None + if save_visualization and distances: + visualization_path = self.save_visualization(vis_frame, distances) - return distances \ No newline at end of file + return { + "count": len(distances), + "distances": [d.to_dict() for d in distances], + "visualization_path": visualization_path, + "timestamp": self.get_timestamp() + } + + def check_health(self) -> Dict: + """检查服务健康状态""" + camera_ok = self.camera.get_frame() is not None + + health = HealthCheck( + status="healthy" if camera_ok else "degraded", + camera_connected=camera_ok, + detector_loaded=True, + estimator_initialized=True + ) + return health.to_dict() + + def update_camera_config(self, config: CameraConfig) -> bool: + """更新相机配置""" + try: + # 停止当前相机 + self.camera.stop() + + # 更新相机参数 + self.camera.rtsp_url = config.rtsp_url + self.camera.fps = config.fps + + # 重启相机 + self.camera.start() + time.sleep(10) # 等待相机初始化 + + # 验证连接 + frame = self.camera.get_frame() + if frame is None: + raise Exception("无法连接到新的相机配置") + + return True + except Exception as e: + return False + +def create_routes(app: Flask, api: DistanceAPI): + @app.route('/health', methods=['GET']) + def health_check(): + """健康检查接口""" + return jsonify(api.check_health()) + + @app.route('/distances', methods=['GET']) + def get_distances(): + """获取距离估计结果""" + save_visualization = request.args.get('save_visualization', 'false').lower() == 'true' + return jsonify(api.get_distances(save_visualization)) + + @app.route('/visualization/', methods=['GET']) + def get_visualization(filename): + """获取可视化结果图片""" + file_path = os.path.join(api.output_dir, filename) + if not os.path.exists(file_path): + return jsonify({'error': '图片不存在'}), 404 + return send_file(file_path, mimetype='image/jpeg') + + @app.route('/camera/config', methods=['POST']) + def update_camera_config(): + """更新相机配置""" + try: + config_data = request.get_json() + config = CameraConfig.from_dict(config_data) + success = api.update_camera_config(config) + return jsonify({"status": "success" if success else "failed"}) + except Exception as e: + return jsonify({'error': str(e)}), 500 + + @app.route('/camera/frame', methods=['GET']) + def get_current_frame(): + """获取当前视频帧""" + frame = api.camera.get_frame() + if frame is None: + return jsonify({'error': '无法获取视频帧'}), 500 + + # 保存当前帧 + timestamp = api.get_timestamp() + output_path = os.path.join(api.output_dir, f"frame_{timestamp}.jpg") + cv2.imwrite(output_path, frame) + + return send_file(output_path, mimetype='image/jpeg') + +def create_app(camera, detector, estimator): + """创建Flask应用实例""" + api = DistanceAPI(camera, detector, estimator) + create_routes(app, api) + return app + +# 使用示例 +def main(): + from src.camera_handler import RTSPCamera + from src.person_detector import PersonDetector + from src.distance_estimator import DistanceEstimator + + # 初始化组件 + camera = RTSPCamera("rtsp://10.0.0.17:8554/camera_test/2") + detector = PersonDetector("yolov8n.pt") + estimator = DistanceEstimator( + focal_length_mm=35, + sensor_width_mm=23.5, + sensor_height_mm=15.6, + image_width_pixels=1920, + image_height_pixels=1080 + ) + + # 创建应用 + app = create_app(camera, detector, estimator) + + # 启动服务 + app.run(host='0.0.0.0', port=5000, debug=True) \ No newline at end of file diff --git a/src/camera_handler.py b/src/camera_handler.py index 18a3f15..d895492 100644 --- a/src/camera_handler.py +++ b/src/camera_handler.py @@ -20,10 +20,15 @@ class RTSPCamera: def _update_frame(self): cap = cv2.VideoCapture(self.rtsp_url) while self.running: - ret, frame = cap.read() - if ret: - self.frame = frame - time.sleep(1/self.fps) + if cap is None: + time.sleep(5) + cap = cv2.VideoCapture(self.rtsp_url) + else: + ret, frame = cap.read() + if ret: + self.frame = frame + # 若视频流是实时的 cv2捕获的视频帧是当前时刻的帧. + time.sleep(1/self.fps) def get_frame(self): return self.frame.copy() if self.frame is not None else None diff --git a/test_api_server.py b/test_api_server.py new file mode 100644 index 0000000..2ba5c98 --- /dev/null +++ b/test_api_server.py @@ -0,0 +1,3 @@ +from src.api_server import main + +main() \ No newline at end of file diff --git a/tests/__pycache__/run_tests.cpython-39.pyc b/tests/__pycache__/run_tests.cpython-39.pyc index 881861f..9ac6b74 100644 Binary files a/tests/__pycache__/run_tests.cpython-39.pyc and b/tests/__pycache__/run_tests.cpython-39.pyc differ diff --git a/tests/__pycache__/test_api_server.cpython-39.pyc b/tests/__pycache__/test_api_server.cpython-39.pyc new file mode 100644 index 0000000..4456b87 Binary files /dev/null and b/tests/__pycache__/test_api_server.cpython-39.pyc differ diff --git a/tests/run_tests.py b/tests/run_tests.py index 913fff4..b3fbf0e 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -2,7 +2,7 @@ import asyncio from .test_camera import test_rtsp_camera from .test_person_detector import test_person_detector from .test_distance_estimator import test_distance_estimator -from .test_api import test_distance_api +from .test_api_server import test_run async def run_all_tests(): print("开始运行所有测试...") @@ -12,11 +12,11 @@ async def run_all_tests(): # print("\n2. 测试人物检测模块") # test_person_detector() - print("\n3. 测试距离估算模块") - test_distance_estimator() + # print("\n3. 测试距离估算模块") + # test_distance_estimator() - # print("\n4. 测试API模块") - # await test_distance_api() + print("\n4. 测试API模块") + test_run() print("\n所有测试完成!") diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index 2ce3a88..0000000 --- a/tests/test_api.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import numpy as np -from src.api_server import DistanceAPI, Distance -from src.camera_handler import RTSPCamera -from src.person_detector import PersonDetector -from src.distance_estimator import DistanceEstimator - -async def test_distance_api(): - # 初始化所有组件 - camera = RTSPCamera(0) # 使用本地摄像头测试 - detector = PersonDetector("yolov8n.pt") - estimator = DistanceEstimator( - focal_length=35, - sensor_height=24, - avg_person_height=1700 - ) - - # 初始化API - api = DistanceAPI(camera, detector, estimator) - - print("测试API...") - camera.start() - await asyncio.sleep(2) # 等待摄像头初始化 - - try: - distances = await api.get_distances() - assert isinstance(distances, list), "返回结果应该是列表" - for distance in distances: - assert isinstance(distance, Distance), "返回结果应该是Distance对象" - assert distance.distance_mm > 0, "距离应该大于0" - assert 0 <= distance.confidence <= 1, "置信度应该在0-1之间" - - print(f"检测到 {len(distances)} 个人物的距离") - for i, d in enumerate(distances): - print(f"人物 {i+1}: 距离 = {d.distance_mm:.2f}mm, 置信度 = {d.confidence:.2f}") - - finally: - camera.stop() - - print("API测试完成!") \ No newline at end of file diff --git a/tests/test_api_server.py b/tests/test_api_server.py deleted file mode 100644 index 210888d..0000000 --- a/tests/test_api_server.py +++ /dev/null @@ -1,40 +0,0 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import List, Optional - -class Distance(BaseModel): - distance_mm: float - confidence: float - detection_image_path: Optional[str] = None - -class DistanceAPI: - def __init__(self, camera, detector, estimator): - self.camera = camera - self.detector = detector - self.estimator = estimator - - async def get_distances(self) -> List[Distance]: - frame = self.camera.get_frame() - if frame is None: - raise HTTPException(status_code=500, detail="No frame available") - - # 检测人物并保存可视化结果 - persons, vis_path = self.detector.detect( - frame, - save_visualization=True, - save_path="output/detections" - ) - - distances = [] - for person in persons: - bbox = person['bbox'] - person_height = bbox[3] - bbox[1] # y2 - y1 - distance = self.estimator.estimate_distance(person_height, frame.shape[0]) - - distances.append(Distance( - distance_mm=distance, - confidence=person['confidence'], - detection_image_path=vis_path - )) - - return distances \ No newline at end of file