test_api_server 距离估算完成

This commit is contained in:
haotian 2025-01-14 14:15:06 +08:00
parent a03ccf9a22
commit 7f5e223d8c
14 changed files with 272 additions and 110 deletions

View File

@ -10,9 +10,16 @@ model:
confidence_threshold: 0.5
distance_estimation:
focal_length: 1000 # 摄像头焦距
sensor_height: 3000 # 摄像头传感器高度(mm)
average_person_height: 1700 # 平均人身高(mm)
# focal_length: 1000 # 摄像头焦距
# sensor_height: 3000 # 摄像头传感器高度(mm)
# average_person_height: 1700 # 平均人身高(mm)
focal_length_mm : 35 # 焦距(mm)
sensor_width_mm : 23.6 # 传感器宽度(mm)
sensor_height_mm : 15.6 # 传感器高度(mm)
image_width_pixels : 640 # 图像宽度(像素)
image_height_pixels : 640 # 图像高度(像素)
camera_height_mm: 1700 # 摄像头安装高度(mm)
camera_tilt_angle: 15 # 摄像头俯仰角(度)
api:
host: "0.0.0.0"

21
main.py
View File

@ -1,6 +1,6 @@
import yaml
import uvicorn
from src. import RTSPCamera
from src.camera_handler import RTSPCamera
from src.person_detector import PersonDetector
from src.distance_estimator import DistanceEstimator
from src.api_server import DistanceAPI, app
@ -26,10 +26,23 @@ def main():
)
# 初始化距离估算器
'''
focal_length_mm : 35 # 焦距(mm)
sensor_width_mm : 23.6 # 传感器宽度(mm)
sensor_height_mm : 15.6 # 传感器高度(mm)
image_width_pixels : 640 # 图像宽度(像素)
image_height_pixels : 640 # 图像高度(像素)
camera_height_mm: 1700 # 摄像头安装高度(mm)
camera_tilt_angle : 15 # 摄像头俯仰角(度)
'''
estimator = DistanceEstimator(
config['model']['distance_estimation']['focal_length'],
config['model']['distance_estimation']['sensor_height'],
config['model']['distance_estimation']['average_person_height']
config['model']['distance_estimation']['focal_length_mm'],
config['model']['distance_estimation']['sensor_width_mm'],
config['model']['distance_estimation']['sensor_height_mm'],
config['model']['distance_estimation']['image_width_pixels'],
config['model']['distance_estimation']['image_height_pixels'],
config['model']['distance_estimation']['camera_height_mm'],
config['model']['distance_estimation']['camera_tilt_angle']
)
# 初始化API

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@ -1,35 +1,249 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from flask import Flask, request, jsonify, send_file
from dataclasses import dataclass
from typing import List, Optional, Dict
import cv2
import numpy as np
from datetime import datetime
import os
import time
app = FastAPI()
app = Flask(__name__)
class Distance(BaseModel):
@dataclass
class Distance:
"""距离估计结果模型"""
person_id: int
distance_mm: float
confidence: float
bbox: List[int]
timestamp: str
def to_dict(self):
return {
'person_id': self.person_id,
'distance_mm': self.distance_mm,
'confidence': self.confidence,
'bbox': self.bbox,
'timestamp': self.timestamp
}
@dataclass
class HealthCheck:
"""健康检查响应模型"""
status: str
camera_connected: bool
detector_loaded: bool
estimator_initialized: bool
def to_dict(self):
return {
'status': self.status,
'camera_connected': self.camera_connected,
'detector_loaded': self.detector_loaded,
'estimator_initialized': self.estimator_initialized
}
@dataclass
class CameraConfig:
"""相机配置模型"""
rtsp_url: str
fps: int = 30
width: int = 1920
height: int = 1080
@classmethod
def from_dict(cls, data):
return cls(
rtsp_url=data['rtsp_url'],
fps=data.get('fps', 30),
width=data.get('width', 1920),
height=data.get('height', 1080)
)
class DistanceAPI:
def __init__(self, camera, detector, estimator):
self.camera = camera
self.camera.start()
time.sleep(10)
print('相机启动')
self.detector = detector
self.estimator = estimator
self.output_dir = "output"
self.ensure_output_dir()
async def get_distances(self) -> List[Distance]:
def ensure_output_dir(self):
"""确保输出目录存在"""
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
def get_timestamp(self):
"""获取当前时间戳"""
return datetime.now().strftime("%Y%m%d_%H%M%S_%f")
def save_visualization(self, frame, distances):
"""保存可视化结果"""
timestamp = self.get_timestamp()
output_path = os.path.join(self.output_dir, f"detection_{timestamp}.jpg")
cv2.imwrite(output_path, frame)
return output_path
def get_distances(self, save_visualization: bool = False) -> Dict:
"""获取距离估计结果"""
# 获取视频帧
frame = self.camera.get_frame()
if frame is None:
raise HTTPException(status_code=500, detail="No frame available")
return jsonify({'error': '无法获取视频帧'}), 500
persons = self.detector.detect(frame)
distances = []
# 检测人物
persons, _ = self.detector.detect(frame)
if not persons:
return {
"count": 0,
"distances": [],
"visualization_path": None,
"timestamp": self.get_timestamp()
}
for person in persons:
# 估计距离
distances = []
vis_frame = frame.copy()
for i, person in enumerate(persons):
bbox = person['bbox']
person_height = bbox[3] - bbox[1] # y2 - y1
distance = self.estimator.estimate_distance(person_height, frame.shape[0])
distance, confidence = self.estimator.estimate_distance(frame, bbox)
# 可视化
vis_frame = self.estimator.visualize_estimation(
vis_frame, bbox, distance, confidence
)
# 添加结果
distances.append(Distance(
person_id=i,
distance_mm=distance,
confidence=person['confidence']
confidence=confidence,
bbox=list(bbox),
timestamp=self.get_timestamp()
))
# 保存可视化结果
visualization_path = None
if save_visualization and distances:
visualization_path = self.save_visualization(vis_frame, distances)
return distances
return {
"count": len(distances),
"distances": [d.to_dict() for d in distances],
"visualization_path": visualization_path,
"timestamp": self.get_timestamp()
}
def check_health(self) -> Dict:
"""检查服务健康状态"""
camera_ok = self.camera.get_frame() is not None
health = HealthCheck(
status="healthy" if camera_ok else "degraded",
camera_connected=camera_ok,
detector_loaded=True,
estimator_initialized=True
)
return health.to_dict()
def update_camera_config(self, config: CameraConfig) -> bool:
"""更新相机配置"""
try:
# 停止当前相机
self.camera.stop()
# 更新相机参数
self.camera.rtsp_url = config.rtsp_url
self.camera.fps = config.fps
# 重启相机
self.camera.start()
time.sleep(10) # 等待相机初始化
# 验证连接
frame = self.camera.get_frame()
if frame is None:
raise Exception("无法连接到新的相机配置")
return True
except Exception as e:
return False
def create_routes(app: Flask, api: DistanceAPI):
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查接口"""
return jsonify(api.check_health())
@app.route('/distances', methods=['GET'])
def get_distances():
"""获取距离估计结果"""
save_visualization = request.args.get('save_visualization', 'false').lower() == 'true'
return jsonify(api.get_distances(save_visualization))
@app.route('/visualization/<filename>', methods=['GET'])
def get_visualization(filename):
"""获取可视化结果图片"""
file_path = os.path.join(api.output_dir, filename)
if not os.path.exists(file_path):
return jsonify({'error': '图片不存在'}), 404
return send_file(file_path, mimetype='image/jpeg')
@app.route('/camera/config', methods=['POST'])
def update_camera_config():
"""更新相机配置"""
try:
config_data = request.get_json()
config = CameraConfig.from_dict(config_data)
success = api.update_camera_config(config)
return jsonify({"status": "success" if success else "failed"})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/camera/frame', methods=['GET'])
def get_current_frame():
"""获取当前视频帧"""
frame = api.camera.get_frame()
if frame is None:
return jsonify({'error': '无法获取视频帧'}), 500
# 保存当前帧
timestamp = api.get_timestamp()
output_path = os.path.join(api.output_dir, f"frame_{timestamp}.jpg")
cv2.imwrite(output_path, frame)
return send_file(output_path, mimetype='image/jpeg')
def create_app(camera, detector, estimator):
"""创建Flask应用实例"""
api = DistanceAPI(camera, detector, estimator)
create_routes(app, api)
return app
# 使用示例
def main():
from src.camera_handler import RTSPCamera
from src.person_detector import PersonDetector
from src.distance_estimator import DistanceEstimator
# 初始化组件
camera = RTSPCamera("rtsp://10.0.0.17:8554/camera_test/2")
detector = PersonDetector("yolov8n.pt")
estimator = DistanceEstimator(
focal_length_mm=35,
sensor_width_mm=23.5,
sensor_height_mm=15.6,
image_width_pixels=1920,
image_height_pixels=1080
)
# 创建应用
app = create_app(camera, detector, estimator)
# 启动服务
app.run(host='0.0.0.0', port=5000, debug=True)

View File

@ -20,10 +20,15 @@ class RTSPCamera:
def _update_frame(self):
cap = cv2.VideoCapture(self.rtsp_url)
while self.running:
ret, frame = cap.read()
if ret:
self.frame = frame
time.sleep(1/self.fps)
if cap is None:
time.sleep(5)
cap = cv2.VideoCapture(self.rtsp_url)
else:
ret, frame = cap.read()
if ret:
self.frame = frame
# 若视频流是实时的 cv2捕获的视频帧是当前时刻的帧.
time.sleep(1/self.fps)
def get_frame(self):
return self.frame.copy() if self.frame is not None else None

3
test_api_server.py Normal file
View File

@ -0,0 +1,3 @@
from src.api_server import main
main()

Binary file not shown.

View File

@ -2,7 +2,7 @@ import asyncio
from .test_camera import test_rtsp_camera
from .test_person_detector import test_person_detector
from .test_distance_estimator import test_distance_estimator
from .test_api import test_distance_api
from .test_api_server import test_run
async def run_all_tests():
print("开始运行所有测试...")
@ -12,11 +12,11 @@ async def run_all_tests():
# print("\n2. 测试人物检测模块")
# test_person_detector()
print("\n3. 测试距离估算模块")
test_distance_estimator()
# print("\n3. 测试距离估算模块")
# test_distance_estimator()
# print("\n4. 测试API模块")
# await test_distance_api()
print("\n4. 测试API模块")
test_run()
print("\n所有测试完成!")

View File

@ -1,40 +0,0 @@
import asyncio
import numpy as np
from src.api_server import DistanceAPI, Distance
from src.camera_handler import RTSPCamera
from src.person_detector import PersonDetector
from src.distance_estimator import DistanceEstimator
async def test_distance_api():
# 初始化所有组件
camera = RTSPCamera(0) # 使用本地摄像头测试
detector = PersonDetector("yolov8n.pt")
estimator = DistanceEstimator(
focal_length=35,
sensor_height=24,
avg_person_height=1700
)
# 初始化API
api = DistanceAPI(camera, detector, estimator)
print("测试API...")
camera.start()
await asyncio.sleep(2) # 等待摄像头初始化
try:
distances = await api.get_distances()
assert isinstance(distances, list), "返回结果应该是列表"
for distance in distances:
assert isinstance(distance, Distance), "返回结果应该是Distance对象"
assert distance.distance_mm > 0, "距离应该大于0"
assert 0 <= distance.confidence <= 1, "置信度应该在0-1之间"
print(f"检测到 {len(distances)} 个人物的距离")
for i, d in enumerate(distances):
print(f"人物 {i+1}: 距离 = {d.distance_mm:.2f}mm, 置信度 = {d.confidence:.2f}")
finally:
camera.stop()
print("API测试完成!")

View File

@ -1,40 +0,0 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
class Distance(BaseModel):
distance_mm: float
confidence: float
detection_image_path: Optional[str] = None
class DistanceAPI:
def __init__(self, camera, detector, estimator):
self.camera = camera
self.detector = detector
self.estimator = estimator
async def get_distances(self) -> List[Distance]:
frame = self.camera.get_frame()
if frame is None:
raise HTTPException(status_code=500, detail="No frame available")
# 检测人物并保存可视化结果
persons, vis_path = self.detector.detect(
frame,
save_visualization=True,
save_path="output/detections"
)
distances = []
for person in persons:
bbox = person['bbox']
person_height = bbox[3] - bbox[1] # y2 - y1
distance = self.estimator.estimate_distance(person_height, frame.shape[0])
distances.append(Distance(
distance_mm=distance,
confidence=person['confidence'],
detection_image_path=vis_path
))
return distances