test_api_server 距离估算完成
This commit is contained in:
parent
a03ccf9a22
commit
7f5e223d8c
@ -10,9 +10,16 @@ model:
|
||||
confidence_threshold: 0.5
|
||||
|
||||
distance_estimation:
|
||||
focal_length: 1000 # 摄像头焦距
|
||||
sensor_height: 3000 # 摄像头传感器高度(mm)
|
||||
average_person_height: 1700 # 平均人身高(mm)
|
||||
# focal_length: 1000 # 摄像头焦距
|
||||
# sensor_height: 3000 # 摄像头传感器高度(mm)
|
||||
# average_person_height: 1700 # 平均人身高(mm)
|
||||
focal_length_mm : 35 # 焦距(mm)
|
||||
sensor_width_mm : 23.6 # 传感器宽度(mm)
|
||||
sensor_height_mm : 15.6 # 传感器高度(mm)
|
||||
image_width_pixels : 640 # 图像宽度(像素)
|
||||
image_height_pixels : 640 # 图像高度(像素)
|
||||
camera_height_mm: 1700 # 摄像头安装高度(mm)
|
||||
camera_tilt_angle: 15 # 摄像头俯仰角(度)
|
||||
|
||||
api:
|
||||
host: "0.0.0.0"
|
||||
|
||||
21
main.py
21
main.py
@ -1,6 +1,6 @@
|
||||
import yaml
|
||||
import uvicorn
|
||||
from src. import RTSPCamera
|
||||
from src.camera_handler import RTSPCamera
|
||||
from src.person_detector import PersonDetector
|
||||
from src.distance_estimator import DistanceEstimator
|
||||
from src.api_server import DistanceAPI, app
|
||||
@ -26,10 +26,23 @@ def main():
|
||||
)
|
||||
|
||||
# 初始化距离估算器
|
||||
'''
|
||||
focal_length_mm : 35 # 焦距(mm)
|
||||
sensor_width_mm : 23.6 # 传感器宽度(mm)
|
||||
sensor_height_mm : 15.6 # 传感器高度(mm)
|
||||
image_width_pixels : 640 # 图像宽度(像素)
|
||||
image_height_pixels : 640 # 图像高度(像素)
|
||||
camera_height_mm: 1700 # 摄像头安装高度(mm)
|
||||
camera_tilt_angle : 15 # 摄像头俯仰角(度)
|
||||
'''
|
||||
estimator = DistanceEstimator(
|
||||
config['model']['distance_estimation']['focal_length'],
|
||||
config['model']['distance_estimation']['sensor_height'],
|
||||
config['model']['distance_estimation']['average_person_height']
|
||||
config['model']['distance_estimation']['focal_length_mm'],
|
||||
config['model']['distance_estimation']['sensor_width_mm'],
|
||||
config['model']['distance_estimation']['sensor_height_mm'],
|
||||
config['model']['distance_estimation']['image_width_pixels'],
|
||||
config['model']['distance_estimation']['image_height_pixels'],
|
||||
config['model']['distance_estimation']['camera_height_mm'],
|
||||
config['model']['distance_estimation']['camera_tilt_angle']
|
||||
)
|
||||
|
||||
# 初始化API
|
||||
|
||||
BIN
output/detection_20250113_165518_851152.jpg
Normal file
BIN
output/detection_20250113_165518_851152.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 85 KiB |
BIN
output/detection_20250113_165638_412391.jpg
Normal file
BIN
output/detection_20250113_165638_412391.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 82 KiB |
Binary file not shown.
Binary file not shown.
@ -1,35 +1,249 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from flask import Flask, request, jsonify, send_file
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict
|
||||
import cv2
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
import os
|
||||
import time
|
||||
|
||||
app = FastAPI()
|
||||
app = Flask(__name__)
|
||||
|
||||
class Distance(BaseModel):
|
||||
@dataclass
|
||||
class Distance:
|
||||
"""距离估计结果模型"""
|
||||
person_id: int
|
||||
distance_mm: float
|
||||
confidence: float
|
||||
bbox: List[int]
|
||||
timestamp: str
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'person_id': self.person_id,
|
||||
'distance_mm': self.distance_mm,
|
||||
'confidence': self.confidence,
|
||||
'bbox': self.bbox,
|
||||
'timestamp': self.timestamp
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class HealthCheck:
|
||||
"""健康检查响应模型"""
|
||||
status: str
|
||||
camera_connected: bool
|
||||
detector_loaded: bool
|
||||
estimator_initialized: bool
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'status': self.status,
|
||||
'camera_connected': self.camera_connected,
|
||||
'detector_loaded': self.detector_loaded,
|
||||
'estimator_initialized': self.estimator_initialized
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class CameraConfig:
|
||||
"""相机配置模型"""
|
||||
rtsp_url: str
|
||||
fps: int = 30
|
||||
width: int = 1920
|
||||
height: int = 1080
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
return cls(
|
||||
rtsp_url=data['rtsp_url'],
|
||||
fps=data.get('fps', 30),
|
||||
width=data.get('width', 1920),
|
||||
height=data.get('height', 1080)
|
||||
)
|
||||
|
||||
class DistanceAPI:
|
||||
def __init__(self, camera, detector, estimator):
|
||||
self.camera = camera
|
||||
self.camera.start()
|
||||
time.sleep(10)
|
||||
print('相机启动')
|
||||
|
||||
self.detector = detector
|
||||
self.estimator = estimator
|
||||
self.output_dir = "output"
|
||||
self.ensure_output_dir()
|
||||
|
||||
async def get_distances(self) -> List[Distance]:
|
||||
def ensure_output_dir(self):
|
||||
"""确保输出目录存在"""
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
def get_timestamp(self):
|
||||
"""获取当前时间戳"""
|
||||
return datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
def save_visualization(self, frame, distances):
|
||||
"""保存可视化结果"""
|
||||
timestamp = self.get_timestamp()
|
||||
output_path = os.path.join(self.output_dir, f"detection_{timestamp}.jpg")
|
||||
cv2.imwrite(output_path, frame)
|
||||
return output_path
|
||||
|
||||
def get_distances(self, save_visualization: bool = False) -> Dict:
|
||||
"""获取距离估计结果"""
|
||||
# 获取视频帧
|
||||
frame = self.camera.get_frame()
|
||||
if frame is None:
|
||||
raise HTTPException(status_code=500, detail="No frame available")
|
||||
return jsonify({'error': '无法获取视频帧'}), 500
|
||||
|
||||
persons = self.detector.detect(frame)
|
||||
distances = []
|
||||
# 检测人物
|
||||
persons, _ = self.detector.detect(frame)
|
||||
if not persons:
|
||||
return {
|
||||
"count": 0,
|
||||
"distances": [],
|
||||
"visualization_path": None,
|
||||
"timestamp": self.get_timestamp()
|
||||
}
|
||||
|
||||
for person in persons:
|
||||
# 估计距离
|
||||
distances = []
|
||||
vis_frame = frame.copy()
|
||||
|
||||
for i, person in enumerate(persons):
|
||||
bbox = person['bbox']
|
||||
person_height = bbox[3] - bbox[1] # y2 - y1
|
||||
distance = self.estimator.estimate_distance(person_height, frame.shape[0])
|
||||
distance, confidence = self.estimator.estimate_distance(frame, bbox)
|
||||
|
||||
# 可视化
|
||||
vis_frame = self.estimator.visualize_estimation(
|
||||
vis_frame, bbox, distance, confidence
|
||||
)
|
||||
|
||||
# 添加结果
|
||||
distances.append(Distance(
|
||||
person_id=i,
|
||||
distance_mm=distance,
|
||||
confidence=person['confidence']
|
||||
confidence=confidence,
|
||||
bbox=list(bbox),
|
||||
timestamp=self.get_timestamp()
|
||||
))
|
||||
|
||||
# 保存可视化结果
|
||||
visualization_path = None
|
||||
if save_visualization and distances:
|
||||
visualization_path = self.save_visualization(vis_frame, distances)
|
||||
|
||||
return distances
|
||||
return {
|
||||
"count": len(distances),
|
||||
"distances": [d.to_dict() for d in distances],
|
||||
"visualization_path": visualization_path,
|
||||
"timestamp": self.get_timestamp()
|
||||
}
|
||||
|
||||
def check_health(self) -> Dict:
|
||||
"""检查服务健康状态"""
|
||||
camera_ok = self.camera.get_frame() is not None
|
||||
|
||||
health = HealthCheck(
|
||||
status="healthy" if camera_ok else "degraded",
|
||||
camera_connected=camera_ok,
|
||||
detector_loaded=True,
|
||||
estimator_initialized=True
|
||||
)
|
||||
return health.to_dict()
|
||||
|
||||
def update_camera_config(self, config: CameraConfig) -> bool:
|
||||
"""更新相机配置"""
|
||||
try:
|
||||
# 停止当前相机
|
||||
self.camera.stop()
|
||||
|
||||
# 更新相机参数
|
||||
self.camera.rtsp_url = config.rtsp_url
|
||||
self.camera.fps = config.fps
|
||||
|
||||
# 重启相机
|
||||
self.camera.start()
|
||||
time.sleep(10) # 等待相机初始化
|
||||
|
||||
# 验证连接
|
||||
frame = self.camera.get_frame()
|
||||
if frame is None:
|
||||
raise Exception("无法连接到新的相机配置")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def create_routes(app: Flask, api: DistanceAPI):
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health_check():
|
||||
"""健康检查接口"""
|
||||
return jsonify(api.check_health())
|
||||
|
||||
@app.route('/distances', methods=['GET'])
|
||||
def get_distances():
|
||||
"""获取距离估计结果"""
|
||||
save_visualization = request.args.get('save_visualization', 'false').lower() == 'true'
|
||||
return jsonify(api.get_distances(save_visualization))
|
||||
|
||||
@app.route('/visualization/<filename>', methods=['GET'])
|
||||
def get_visualization(filename):
|
||||
"""获取可视化结果图片"""
|
||||
file_path = os.path.join(api.output_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return jsonify({'error': '图片不存在'}), 404
|
||||
return send_file(file_path, mimetype='image/jpeg')
|
||||
|
||||
@app.route('/camera/config', methods=['POST'])
|
||||
def update_camera_config():
|
||||
"""更新相机配置"""
|
||||
try:
|
||||
config_data = request.get_json()
|
||||
config = CameraConfig.from_dict(config_data)
|
||||
success = api.update_camera_config(config)
|
||||
return jsonify({"status": "success" if success else "failed"})
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@app.route('/camera/frame', methods=['GET'])
|
||||
def get_current_frame():
|
||||
"""获取当前视频帧"""
|
||||
frame = api.camera.get_frame()
|
||||
if frame is None:
|
||||
return jsonify({'error': '无法获取视频帧'}), 500
|
||||
|
||||
# 保存当前帧
|
||||
timestamp = api.get_timestamp()
|
||||
output_path = os.path.join(api.output_dir, f"frame_{timestamp}.jpg")
|
||||
cv2.imwrite(output_path, frame)
|
||||
|
||||
return send_file(output_path, mimetype='image/jpeg')
|
||||
|
||||
def create_app(camera, detector, estimator):
|
||||
"""创建Flask应用实例"""
|
||||
api = DistanceAPI(camera, detector, estimator)
|
||||
create_routes(app, api)
|
||||
return app
|
||||
|
||||
# 使用示例
|
||||
def main():
|
||||
from src.camera_handler import RTSPCamera
|
||||
from src.person_detector import PersonDetector
|
||||
from src.distance_estimator import DistanceEstimator
|
||||
|
||||
# 初始化组件
|
||||
camera = RTSPCamera("rtsp://10.0.0.17:8554/camera_test/2")
|
||||
detector = PersonDetector("yolov8n.pt")
|
||||
estimator = DistanceEstimator(
|
||||
focal_length_mm=35,
|
||||
sensor_width_mm=23.5,
|
||||
sensor_height_mm=15.6,
|
||||
image_width_pixels=1920,
|
||||
image_height_pixels=1080
|
||||
)
|
||||
|
||||
# 创建应用
|
||||
app = create_app(camera, detector, estimator)
|
||||
|
||||
# 启动服务
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
@ -20,10 +20,15 @@ class RTSPCamera:
|
||||
def _update_frame(self):
|
||||
cap = cv2.VideoCapture(self.rtsp_url)
|
||||
while self.running:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
self.frame = frame
|
||||
time.sleep(1/self.fps)
|
||||
if cap is None:
|
||||
time.sleep(5)
|
||||
cap = cv2.VideoCapture(self.rtsp_url)
|
||||
else:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
self.frame = frame
|
||||
# 若视频流是实时的 cv2捕获的视频帧是当前时刻的帧.
|
||||
time.sleep(1/self.fps)
|
||||
|
||||
def get_frame(self):
|
||||
return self.frame.copy() if self.frame is not None else None
|
||||
|
||||
3
test_api_server.py
Normal file
3
test_api_server.py
Normal file
@ -0,0 +1,3 @@
|
||||
from src.api_server import main
|
||||
|
||||
main()
|
||||
Binary file not shown.
BIN
tests/__pycache__/test_api_server.cpython-39.pyc
Normal file
BIN
tests/__pycache__/test_api_server.cpython-39.pyc
Normal file
Binary file not shown.
@ -2,7 +2,7 @@ import asyncio
|
||||
from .test_camera import test_rtsp_camera
|
||||
from .test_person_detector import test_person_detector
|
||||
from .test_distance_estimator import test_distance_estimator
|
||||
from .test_api import test_distance_api
|
||||
from .test_api_server import test_run
|
||||
|
||||
async def run_all_tests():
|
||||
print("开始运行所有测试...")
|
||||
@ -12,11 +12,11 @@ async def run_all_tests():
|
||||
# print("\n2. 测试人物检测模块")
|
||||
# test_person_detector()
|
||||
|
||||
print("\n3. 测试距离估算模块")
|
||||
test_distance_estimator()
|
||||
# print("\n3. 测试距离估算模块")
|
||||
# test_distance_estimator()
|
||||
|
||||
# print("\n4. 测试API模块")
|
||||
# await test_distance_api()
|
||||
print("\n4. 测试API模块")
|
||||
test_run()
|
||||
|
||||
print("\n所有测试完成!")
|
||||
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from src.api_server import DistanceAPI, Distance
|
||||
from src.camera_handler import RTSPCamera
|
||||
from src.person_detector import PersonDetector
|
||||
from src.distance_estimator import DistanceEstimator
|
||||
|
||||
async def test_distance_api():
|
||||
# 初始化所有组件
|
||||
camera = RTSPCamera(0) # 使用本地摄像头测试
|
||||
detector = PersonDetector("yolov8n.pt")
|
||||
estimator = DistanceEstimator(
|
||||
focal_length=35,
|
||||
sensor_height=24,
|
||||
avg_person_height=1700
|
||||
)
|
||||
|
||||
# 初始化API
|
||||
api = DistanceAPI(camera, detector, estimator)
|
||||
|
||||
print("测试API...")
|
||||
camera.start()
|
||||
await asyncio.sleep(2) # 等待摄像头初始化
|
||||
|
||||
try:
|
||||
distances = await api.get_distances()
|
||||
assert isinstance(distances, list), "返回结果应该是列表"
|
||||
for distance in distances:
|
||||
assert isinstance(distance, Distance), "返回结果应该是Distance对象"
|
||||
assert distance.distance_mm > 0, "距离应该大于0"
|
||||
assert 0 <= distance.confidence <= 1, "置信度应该在0-1之间"
|
||||
|
||||
print(f"检测到 {len(distances)} 个人物的距离")
|
||||
for i, d in enumerate(distances):
|
||||
print(f"人物 {i+1}: 距离 = {d.distance_mm:.2f}mm, 置信度 = {d.confidence:.2f}")
|
||||
|
||||
finally:
|
||||
camera.stop()
|
||||
|
||||
print("API测试完成!")
|
||||
@ -1,40 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
class Distance(BaseModel):
|
||||
distance_mm: float
|
||||
confidence: float
|
||||
detection_image_path: Optional[str] = None
|
||||
|
||||
class DistanceAPI:
|
||||
def __init__(self, camera, detector, estimator):
|
||||
self.camera = camera
|
||||
self.detector = detector
|
||||
self.estimator = estimator
|
||||
|
||||
async def get_distances(self) -> List[Distance]:
|
||||
frame = self.camera.get_frame()
|
||||
if frame is None:
|
||||
raise HTTPException(status_code=500, detail="No frame available")
|
||||
|
||||
# 检测人物并保存可视化结果
|
||||
persons, vis_path = self.detector.detect(
|
||||
frame,
|
||||
save_visualization=True,
|
||||
save_path="output/detections"
|
||||
)
|
||||
|
||||
distances = []
|
||||
for person in persons:
|
||||
bbox = person['bbox']
|
||||
person_height = bbox[3] - bbox[1] # y2 - y1
|
||||
distance = self.estimator.estimate_distance(person_height, frame.shape[0])
|
||||
|
||||
distances.append(Distance(
|
||||
distance_mm=distance,
|
||||
confidence=person['confidence'],
|
||||
detection_image_path=vis_path
|
||||
))
|
||||
|
||||
return distances
|
||||
Loading…
Reference in New Issue
Block a user