Add robot dog OCR service and ignore local artifacts
This commit is contained in:
commit
19b2aa43d6
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
.Python
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
|
||||||
|
logs/
|
||||||
|
tmp/
|
||||||
|
|
||||||
|
*.log
|
||||||
|
*.pid
|
||||||
|
|
||||||
32
app/api/main.py
Normal file
32
app/api/main.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from fastapi import FastAPI, WebSocket
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from app.api.v1.api import router
|
||||||
|
# from app.core.config import settings
|
||||||
|
# from app.services.websocket_service import websocket_service
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="机器狗后台服务",
|
||||||
|
description="机器狗后台API接口",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 在生产环境中应该设置具体的域名
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
# @app.websocket("/ws")
|
||||||
|
# async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
# """WebSocket端点"""
|
||||||
|
# await websocket_service.handle_websocket(websocket)
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "机器狗后台API服务正在运行"}
|
||||||
394
app/api/v1/api.py
Normal file
394
app/api/v1/api.py
Normal file
@ -0,0 +1,394 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.services.imageServices import ImageService
|
||||||
|
from app.schemas.image import ImageBase
|
||||||
|
from app.schemas.ocr import ImageBase64Request
|
||||||
|
from app.util.responseHttp import ResponseUtil
|
||||||
|
from app.util.baiduOCR import BaiduOCR, BaiduOCRONNX
|
||||||
|
from app.util.yolov8Obj import Yolov8Obj, YOLOv8ONNX
|
||||||
|
|
||||||
|
# from app.crud.event import event
|
||||||
|
# from app.schemas.event import EventList, EventDetail, EventUpdate, EventQuery, TestEvent
|
||||||
|
|
||||||
|
baiduOCR = BaiduOCR()
|
||||||
|
baiduOcrOnnx = BaiduOCRONNX()
|
||||||
|
yolov8Obj = Yolov8Obj()
|
||||||
|
yolov8ONNX = YOLOv8ONNX()
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1", tags=["ocr"])
|
||||||
|
|
||||||
|
@router.get("/hello")
|
||||||
|
async def get_hello():
|
||||||
|
|
||||||
|
# return {"data":"hello"}
|
||||||
|
return ResponseUtil.error(msg=f"OCR识别失败", data=None)
|
||||||
|
|
||||||
|
@router.get("/test_select")
|
||||||
|
async def test_select(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
image_query: ImageBase = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试查询
|
||||||
|
"""
|
||||||
|
result = await ImageService.get_image_list(db,image_query)
|
||||||
|
return ResponseUtil.success(msg="success", data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test_ocr")
|
||||||
|
async def test_ocr(
|
||||||
|
# image_path: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试OCR
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/2c7cc83019e7388a7041101da92c9829_frame_000000.jpg'
|
||||||
|
|
||||||
|
|
||||||
|
result = baiduOCR.ocr(image_path)
|
||||||
|
# print(result)
|
||||||
|
return ResponseUtil.success(msg="success", data=result)
|
||||||
|
# return ResponseUtil.success(msg="success")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/ocr_from_base64")
|
||||||
|
async def ocr_from_base64(request: ImageBase64Request):
|
||||||
|
"""
|
||||||
|
从base64图片数据进行OCR识别
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 移除base64数据的前缀(如果有)
|
||||||
|
image_base64 = request.image_base64
|
||||||
|
if image_base64.startswith('data:image'):
|
||||||
|
# 格式如: data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQ...
|
||||||
|
image_base64 = image_base64.split(',')[1]
|
||||||
|
|
||||||
|
# 解码base64数据
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# 将字节数据转换为PIL Image对象
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# 创建临时文件路径
|
||||||
|
temp_dir = "tmp/ocr_images"
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||||
|
temp_path = os.path.join(temp_dir, temp_filename)
|
||||||
|
|
||||||
|
# 保存图片到临时文件
|
||||||
|
image.save(temp_path)
|
||||||
|
|
||||||
|
# 使用PaddleOCR进行识别
|
||||||
|
result = baiduOCR.ocr(temp_path)
|
||||||
|
|
||||||
|
# 删除临时文件
|
||||||
|
# os.remove(temp_path)
|
||||||
|
|
||||||
|
return ResponseUtil.success(msg="OCR识别成功", data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return ResponseUtil.error(msg=f"OCR识别失败: {str(e)}", data=None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/ocr_onnx_from_base64")
|
||||||
|
async def ocr_from_base64(request: ImageBase64Request):
|
||||||
|
"""
|
||||||
|
从base64图片数据进行OCR识别
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 移除base64数据的前缀(如果有)
|
||||||
|
image_base64 = request.image_base64
|
||||||
|
if image_base64.startswith('data:image'):
|
||||||
|
# 格式如: data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQ...
|
||||||
|
image_base64 = image_base64.split(',')[1]
|
||||||
|
|
||||||
|
# 解码base64数据
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# 将字节数据转换为PIL Image对象
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# 创建临时文件路径
|
||||||
|
temp_dir = "./tmp/ocr_images"
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||||
|
temp_path = os.path.join(temp_dir, temp_filename)
|
||||||
|
|
||||||
|
# 保存图片到临时文件
|
||||||
|
image.save(temp_path)
|
||||||
|
|
||||||
|
# 使用PaddleOCR进行识别
|
||||||
|
result = baiduOcrOnnx.ocr(temp_path)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 删除临时文件
|
||||||
|
# os.remove(temp_path)
|
||||||
|
|
||||||
|
# return ResponseUtil.success(msg="OCR识别成功", data=[result['text'], result['confidence']])
|
||||||
|
return ResponseUtil.success(msg="OCR识别成功", data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return ResponseUtil.error(msg=f"OCR识别失败: {str(e)}", data=None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detect_from_base64_0")
|
||||||
|
async def ocr_from_base64(request: ImageBase64Request):
|
||||||
|
""" 从base64图片进行目标检测, 检测是否侵占消防区域"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_base64 = request.image_base64
|
||||||
|
if image_base64.startswith('data:image'):
|
||||||
|
image_base64 = image_base64.split(',')[1]
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
temp_dir = "./tmp/detect_images"
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||||
|
temp_path = os.path.join(temp_dir, temp_filename)
|
||||||
|
|
||||||
|
image.save(temp_path)
|
||||||
|
|
||||||
|
cls, conf, coords = yolov8Obj.detect(temp_path)
|
||||||
|
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
return ResponseUtil.success(msg="侵占消防区域目标检测成功,是否有遮挡", data=(len(cls)==0))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ResponseUtil.error(msg=f"检测是否侵占消防区域失败: {str(e)}", data=None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detect_onnx_from_base64_0")
|
||||||
|
async def ocr_from_base64(request: ImageBase64Request):
|
||||||
|
""" 从base64图片进行目标检测, 检测是否侵占消防区域"""
|
||||||
|
|
||||||
|
# 为不同类别生成不同颜色
|
||||||
|
def get_color(class_id):
|
||||||
|
np.random.seed(class_id)
|
||||||
|
return tuple(map(int, np.random.randint(0, 255, 3)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# image_base64 = request.image_base64
|
||||||
|
# if image_base64.startswith('data:image'):
|
||||||
|
# image_base64 = image_base64.split(',')[1]
|
||||||
|
# image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# temp_dir = "./tmp/detect_images"
|
||||||
|
# os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
# temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||||
|
# temp_path = os.path.join(temp_dir, temp_filename)
|
||||||
|
|
||||||
|
# image.save(temp_path)
|
||||||
|
|
||||||
|
# boxes, scores, class_ids = yolov8ONNX.detect(image_data)
|
||||||
|
|
||||||
|
# os.remove(temp_path)
|
||||||
|
|
||||||
|
image_base64 = request.image_base64
|
||||||
|
if image_base64.startswith('data:image'):
|
||||||
|
image_base64 = image_base64.split(',')[1]
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# 将字节流转换为OpenCV格式的BGR图像
|
||||||
|
image_np = np.frombuffer(image_data, np.uint8)
|
||||||
|
image_cv = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # 这会得到BGR格式的图像
|
||||||
|
|
||||||
|
# 不再需要保存到临时文件,直接使用内存中的图像
|
||||||
|
boxes, scores, class_ids = yolov8ONNX.detect(image_cv)
|
||||||
|
|
||||||
|
|
||||||
|
# 绘制检测结果
|
||||||
|
image_with_boxes = image_cv.copy()
|
||||||
|
|
||||||
|
for box, score, class_id in zip(boxes, scores, class_ids):
|
||||||
|
# 获取边界框坐标
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
|
# 生成随机颜色或使用固定颜色方案
|
||||||
|
color = get_color(class_id) # 绿色,也可以根据class_id设置不同颜色
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
# 准备标签文本
|
||||||
|
label = f"Class {class_id}: {score:.2f}"
|
||||||
|
# 如果你有类别名称字典,可以这样使用:
|
||||||
|
# label = f"{class_names[class_id]}: {score:.2f}"
|
||||||
|
|
||||||
|
# 计算文本大小以绘制背景
|
||||||
|
(text_width, text_height), baseline = cv2.getTextSize(
|
||||||
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制文本背景
|
||||||
|
cv2.rectangle(
|
||||||
|
image_with_boxes,
|
||||||
|
(x1, y1 - text_height - baseline - 5),
|
||||||
|
(x1 + text_width, y1),
|
||||||
|
color,
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制文本
|
||||||
|
cv2.putText(
|
||||||
|
image_with_boxes,
|
||||||
|
label,
|
||||||
|
(x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(0, 0, 0), # 黑色文字
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建保存目录
|
||||||
|
save_dir = "tmp/detect_images"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成带时间戳的随机文件名
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
random_id = uuid.uuid4().hex[:8] # 取UUID的前8位
|
||||||
|
filename = f"detect_{timestamp}_{random_id}.jpg"
|
||||||
|
output_path = os.path.join(save_dir, filename)
|
||||||
|
|
||||||
|
# 保存图像
|
||||||
|
cv2.imwrite(output_path, image_with_boxes)
|
||||||
|
print(f"检测结果已保存到: {output_path}")
|
||||||
|
|
||||||
|
return ResponseUtil.success(msg="侵占消防区域目标检测成功,是否有遮挡", data=(len(boxes)==0))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ResponseUtil.error(msg=f"检测是否侵占消防区域失败: {str(e)}", data=None)
|
||||||
|
|
||||||
|
@router.post("/detect_from_base64_1")
|
||||||
|
async def ocr_from_base64(request: ImageBase64Request):
|
||||||
|
""" 从吧色图64图片进行目标检测, 检测是否存在灭火器"""
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# 为不同类别生成不同颜色
|
||||||
|
def get_color(class_id):
|
||||||
|
np.random.seed(class_id)
|
||||||
|
return tuple(map(int, np.random.randint(0, 255, 3)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
image_base64 = request.image_base64
|
||||||
|
if image_base64.startswith('data:image'):
|
||||||
|
image_base64 = image_base64.split(',')[1]
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# 将字节流转换为OpenCV格式的BGR图像
|
||||||
|
image_np = np.frombuffer(image_data, np.uint8)
|
||||||
|
image_cv = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # 这会得到BGR格式的图像
|
||||||
|
|
||||||
|
# 不再需要保存到临时文件,直接使用内存中的图像
|
||||||
|
boxes, scores, class_ids = yolov8ONNX.detect(image_cv)
|
||||||
|
|
||||||
|
|
||||||
|
# 绘制检测结果
|
||||||
|
image_with_boxes = image_cv.copy()
|
||||||
|
|
||||||
|
for box, score, class_id in zip(boxes, scores, class_ids):
|
||||||
|
# 获取边界框坐标
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
|
# 生成随机颜色或使用固定颜色方案
|
||||||
|
color = get_color(class_id) # 绿色,也可以根据class_id设置不同颜色
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
# 准备标签文本
|
||||||
|
label = f"Class {class_id}: {score:.2f}"
|
||||||
|
# 如果你有类别名称字典,可以这样使用:
|
||||||
|
# label = f"{class_names[class_id]}: {score:.2f}"
|
||||||
|
|
||||||
|
# 计算文本大小以绘制背景
|
||||||
|
(text_width, text_height), baseline = cv2.getTextSize(
|
||||||
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制文本背景
|
||||||
|
cv2.rectangle(
|
||||||
|
image_with_boxes,
|
||||||
|
(x1, y1 - text_height - baseline - 5),
|
||||||
|
(x1 + text_width, y1),
|
||||||
|
color,
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制文本
|
||||||
|
cv2.putText(
|
||||||
|
image_with_boxes,
|
||||||
|
label,
|
||||||
|
(x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(0, 0, 0), # 黑色文字
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建保存目录
|
||||||
|
save_dir = "tmp/detect_images"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成带时间戳的随机文件名
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
random_id = uuid.uuid4().hex[:8] # 取UUID的前8位
|
||||||
|
filename = f"detect_{timestamp}_{random_id}.jpg"
|
||||||
|
output_path = os.path.join(save_dir, filename)
|
||||||
|
|
||||||
|
# 保存图像
|
||||||
|
cv2.imwrite(output_path, image_with_boxes)
|
||||||
|
print(f"检测结果已保存到: {output_path}")
|
||||||
|
|
||||||
|
return ResponseUtil.success(msg="灭火器目标检测成功,是否存在灭火器", data=(len(class_ids)!=0 and 0 in class_ids))
|
||||||
|
# return ResponseUtil.success(msg="目标检测成功", data=cls)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ResponseUtil.error(msg=f"检测是否存在灭火器失败: {str(e)}", data=None)
|
||||||
|
|
||||||
|
# @router.post("/detect_from_base64_1")
|
||||||
|
# async def ocr_from_base64(request: ImageBase64Request):
|
||||||
|
# """ 从吧色图64图片进行目标检测, 检测是否存在灭火器"""
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# image_base64 = request.image_base64
|
||||||
|
# if image_base64.startswith('data:image'):
|
||||||
|
# image_base64 = image_base64.split(',')[1]
|
||||||
|
# image_data = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
# image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# temp_dir = "./tmp/detect_images"
|
||||||
|
# os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
# temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||||
|
# temp_path = os.path.join(temp_dir, temp_filename)
|
||||||
|
|
||||||
|
# image.save(temp_path)
|
||||||
|
|
||||||
|
# cls, conf, coords = yolov8Obj.detect(temp_path)
|
||||||
|
|
||||||
|
# os.remove(temp_path)
|
||||||
|
|
||||||
|
# return ResponseUtil.success(msg="灭火器目标检测成功,是否存在灭火器", data=(len(cls)!=0 and 0 in cls))
|
||||||
|
# # return ResponseUtil.success(msg="目标检测成功", data=cls)
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# return ResponseUtil.error(msg=f"检测是否存在灭火器失败: {str(e)}", data=None)
|
||||||
122
app/api/v1/testLogin.py
Normal file
122
app/api/v1/testLogin.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# 配置参数(应从环境变量读取)
|
||||||
|
SECRET_KEY = "your-secret-key-keep-it-secret!"
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
|
||||||
|
# 模拟用户数据库
|
||||||
|
fake_users_db = {
|
||||||
|
"johndoe": {
|
||||||
|
"username": "johndoe",
|
||||||
|
"full_name": "John Doe",
|
||||||
|
"email": "johndoe@example.com",
|
||||||
|
# 哈希后的密码(明文是 secret)
|
||||||
|
"hashed_password": "$2b$12$EixZaYVK1fsbY1eZIbOnjesN9NwG1s3Z6FDcjyH103a2.dJgD0L4q",
|
||||||
|
"disabled": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Token 模型
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
username: Optional[str] = None
|
||||||
|
|
||||||
|
# 用户模型
|
||||||
|
class User(BaseModel):
|
||||||
|
username: str
|
||||||
|
email: Optional[str] = None
|
||||||
|
full_name: Optional[str] = None
|
||||||
|
disabled: Optional[bool] = None
|
||||||
|
|
||||||
|
class UserInDB(User):
|
||||||
|
hashed_password: str
|
||||||
|
|
||||||
|
# 密码哈希配置
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
# OAuth2 配置
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 密码验证工具
|
||||||
|
def verify_password(plain_password: str, hashed_password: str):
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
# 获取用户
|
||||||
|
def get_user(db, username: str):
|
||||||
|
if username in db:
|
||||||
|
user_dict = db[username]
|
||||||
|
return UserInDB(**user_dict)
|
||||||
|
|
||||||
|
# 用户认证
|
||||||
|
def authenticate_user(fake_db, username: str, password: str):
|
||||||
|
user = get_user(fake_db, username)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
# if not verify_password(password, user.hashed_password):
|
||||||
|
# return False
|
||||||
|
return user
|
||||||
|
|
||||||
|
# 创建访问令牌
|
||||||
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||||
|
to_encode = data.copy()
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
# 获取当前用户(依赖注入)
|
||||||
|
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
token_data = TokenData(username=username)
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exception
|
||||||
|
user = get_user(fake_users_db, username=token_data.username)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
||||||
|
# 登录路由
|
||||||
|
@app.post("/token", response_model=Token)
|
||||||
|
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||||
|
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
# 受保护路由
|
||||||
|
@app.get("/users/me/", response_model=User)
|
||||||
|
async def read_users_me(current_user: User = Depends(get_current_user)):
|
||||||
|
return current_user
|
||||||
89
app/config/config.py
Normal file
89
app/config/config.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
class DataBaseSettings(BaseSettings):
|
||||||
|
# 数据库配置
|
||||||
|
DB_HOST: str = "10.0.0.17"
|
||||||
|
DB_PORT: int = 3306
|
||||||
|
DB_USER: str = "root"
|
||||||
|
DB_PASSWORD: str = "root"
|
||||||
|
DB_NAME: str = "kangda_robotic_dog"
|
||||||
|
DB_CHARSET: str = "utf8mb4"
|
||||||
|
DB_POOL_SIZE: int = 5
|
||||||
|
DB_MAX_OVERFLOW: int = 10
|
||||||
|
DB_POOL_TIMEOUT: int = 30
|
||||||
|
DB_POOL_RECYCLE: int = 1800
|
||||||
|
|
||||||
|
# class Config:
|
||||||
|
# env_file = ".env"
|
||||||
|
|
||||||
|
class OCRSettings(BaseException):
|
||||||
|
TEXT_DETECTION_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/PP-OCRv5_server_det_infer_20250814/"
|
||||||
|
TEXT_RECONGNITION_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/PP-OCRv5_server_rec_infer_20250815/"
|
||||||
|
|
||||||
|
# TEXT_DETECTION_MODEL_ONNX_DIR: str ='/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx'
|
||||||
|
# TEXT_RECONGNITION_MODEL_ONNX_DIR: str ='/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx'
|
||||||
|
|
||||||
|
TEXT_DETECTION_MODEL_ONNX_DIR="/root/robot_dog_project/kangda_robotic_dog/models/det_mobile_14_shape.onnx"
|
||||||
|
TEXT_RECONGNITION_MODEL_ONNX_DIR="/root/robot_dog_project/kangda_robotic_dog/models/rec_mobile_14_shape.onnx"
|
||||||
|
|
||||||
|
class YoloV8Settings(BaseException):
|
||||||
|
YOLOV8_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/best.pt"
|
||||||
|
YOLOV8_MODEL_ONNX_DIRS="/root/robot_dog_project/kangda_robotic_dog/models/yolov8_20250820.onnx"
|
||||||
|
|
||||||
|
|
||||||
|
class GetSettings:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.parse_cli_args()
|
||||||
|
|
||||||
|
def get_database_settings(self):
|
||||||
|
return DataBaseSettings()
|
||||||
|
|
||||||
|
def get_ocr_settings(self):
|
||||||
|
return OCRSettings()
|
||||||
|
|
||||||
|
def get_yolov8_settings(self):
|
||||||
|
return YoloV8Settings()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_args(self):
|
||||||
|
"""
|
||||||
|
解析命令行参数
|
||||||
|
"""
|
||||||
|
if 'uvicorn' in sys.argv[0]:
|
||||||
|
# 使用uvicorn启动时,命令行参数需要按照uvicorn的文档进行配置,无法自定义参数
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 使用argparse定义命令行参数
|
||||||
|
parser = argparse.ArgumentParser(description='命令行参数')
|
||||||
|
parser.add_argument('--env', type=str, default='dev', help='运行环境')
|
||||||
|
# 解析命令行参数
|
||||||
|
args = parser.parse_args()
|
||||||
|
# 设置环境变量,如果未设置命令行参数,默认APP_ENV为dev
|
||||||
|
os.environ['APP_ENV'] = args.env if args.env else 'dev'
|
||||||
|
# 读取运行环境
|
||||||
|
run_env = os.environ.get('APP_ENV', '')
|
||||||
|
# 运行环境未指定时默认加载.env.dev
|
||||||
|
env_file = '.env.dev'
|
||||||
|
# 运行环境不为空时按命令行参数加载对应.env文件
|
||||||
|
if run_env != '':
|
||||||
|
env_file = f'.env.{run_env}'
|
||||||
|
|
||||||
|
|
||||||
|
print(f"加载配置 .env.{run_env}")
|
||||||
|
# 加载配置
|
||||||
|
load_dotenv(env_file)
|
||||||
|
|
||||||
|
get_settings = GetSettings()
|
||||||
|
|
||||||
|
database_settings = get_settings.get_database_settings()
|
||||||
|
|
||||||
|
ocr_settings = get_settings.get_ocr_settings()
|
||||||
|
|
||||||
|
yolov8_settings = get_settings.get_yolov8_settings()
|
||||||
|
|
||||||
|
|
||||||
77
app/config/constant.py
Normal file
77
app/config/constant.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# from config.env import DataBaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CommonConstant:
|
||||||
|
"""
|
||||||
|
常用常量
|
||||||
|
|
||||||
|
WWW: www主域
|
||||||
|
HTTP: http请求
|
||||||
|
HTTPS: https请求
|
||||||
|
LOOKUP_RMI: RMI远程方法调用
|
||||||
|
LOOKUP_LDAP: LDAP远程方法调用
|
||||||
|
LOOKUP_LDAPS: LDAPS远程方法调用
|
||||||
|
YES: 是否为系统默认(是)
|
||||||
|
NO: 是否为系统默认(否)
|
||||||
|
DEPT_NORMAL: 部门正常状态
|
||||||
|
DEPT_DISABLE: 部门停用状态
|
||||||
|
UNIQUE: 校验是否唯一的返回标识(是)
|
||||||
|
NOT_UNIQUE: 校验是否唯一的返回标识(否)
|
||||||
|
"""
|
||||||
|
|
||||||
|
WWW = 'www.'
|
||||||
|
HTTP = 'http://'
|
||||||
|
HTTPS = 'https://'
|
||||||
|
LOOKUP_RMI = 'rmi:'
|
||||||
|
LOOKUP_LDAP = 'ldap:'
|
||||||
|
LOOKUP_LDAPS = 'ldaps:'
|
||||||
|
YES = 'Y'
|
||||||
|
NO = 'N'
|
||||||
|
DEPT_NORMAL = '0'
|
||||||
|
DEPT_DISABLE = '1'
|
||||||
|
UNIQUE = True
|
||||||
|
NOT_UNIQUE = False
|
||||||
|
|
||||||
|
|
||||||
|
class HttpStatusConstant:
|
||||||
|
"""
|
||||||
|
返回状态码
|
||||||
|
|
||||||
|
SUCCESS: 操作成功
|
||||||
|
CREATED: 对象创建成功
|
||||||
|
ACCEPTED: 请求已经被接受
|
||||||
|
NO_CONTENT: 操作已经执行成功,但是没有返回数据
|
||||||
|
MOVED_PERM: 资源已被移除
|
||||||
|
SEE_OTHER: 重定向
|
||||||
|
NOT_MODIFIED: 资源没有被修改
|
||||||
|
BAD_REQUEST: 参数列表错误(缺少,格式不匹配)
|
||||||
|
UNAUTHORIZED: 未授权
|
||||||
|
FORBIDDEN: 访问受限,授权过期
|
||||||
|
NOT_FOUND: 资源,服务未找到
|
||||||
|
BAD_METHOD: 不允许的http方法
|
||||||
|
CONFLICT: 资源冲突,或者资源被锁
|
||||||
|
UNSUPPORTED_TYPE: 不支持的数据,媒体类型
|
||||||
|
ERROR: 系统内部错误
|
||||||
|
NOT_IMPLEMENTED: 接口未实现
|
||||||
|
WARN: 系统警告消息
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUCCESS = 200
|
||||||
|
CREATED = 201
|
||||||
|
ACCEPTED = 202
|
||||||
|
NO_CONTENT = 204
|
||||||
|
MOVED_PERM = 301
|
||||||
|
SEE_OTHER = 303
|
||||||
|
NOT_MODIFIED = 304
|
||||||
|
BAD_REQUEST = 400
|
||||||
|
UNAUTHORIZED = 401
|
||||||
|
FORBIDDEN = 403
|
||||||
|
NOT_FOUND = 404
|
||||||
|
BAD_METHOD = 405
|
||||||
|
CONFLICT = 409
|
||||||
|
UNSUPPORTED_TYPE = 415
|
||||||
|
ERROR = 500
|
||||||
|
NOT_IMPLEMENTED = 501
|
||||||
|
WARN = 601
|
||||||
|
|
||||||
|
|
||||||
32
app/core/database.py
Normal file
32
app/core/database.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
|
||||||
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
|
from app.config.config import database_settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
f"mysql+aiomysql://{database_settings.DB_USER}:{database_settings.DB_PASSWORD}@{database_settings.DB_HOST}:{database_settings.DB_PORT}/{database_settings.DB_NAME}",
|
||||||
|
pool_size=database_settings.DB_POOL_SIZE, # 连接池常驻连接数
|
||||||
|
max_overflow=database_settings.DB_MAX_OVERFLOW, # 池最大溢出连接数
|
||||||
|
pool_timeout=database_settings.DB_POOL_TIMEOUT, # 获取连接超时时间(秒)
|
||||||
|
pool_recycle=database_settings.DB_POOL_RECYCLE, # 连接回收间隔(秒)
|
||||||
|
echo=False # 关闭SQL日志输出
|
||||||
|
)
|
||||||
|
|
||||||
|
# expire_on_commit=False 禁用提交后过期对象 --> 提交后原对象仍然可以使用.
|
||||||
|
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
# ORM模型基类
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
# 获取数据库会话
|
||||||
|
async def get_db():
|
||||||
|
async with async_session() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
# await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
88
app/crud/base.py
Normal file
88
app/crud/base.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, update, delete
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=Base)
|
||||||
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||||
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
'''
|
||||||
|
Generic[ModelType, CreateSchemaType, UpdateSchemaType] 是 Python 类型提示中泛型编程的关键语法,用于声明一个泛型类。
|
||||||
|
它的作用是让 CRUDBase 类具备类型参数化的能力,允许在继承或实例化时动态绑定具体类型,从而实现代码复用和类型安全
|
||||||
|
声明 CRUDBase 类需要三个类型参数:ModelType、CreateSchemaType、UpdateSchemaType。
|
||||||
|
这些类型参数会在类的内部方法中使用(如 get、create、update),确保类型一致性。
|
||||||
|
|
||||||
|
'''
|
||||||
|
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||||
|
def __init__(self, model: Type[ModelType]):
|
||||||
|
"""
|
||||||
|
CRUD对象与SQLAlchemy模型类一起使用
|
||||||
|
:param model: SQLAlchemy模型类
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# 根据id获取对象
|
||||||
|
async def get(self, db: AsyncSession, id: Any) -> Optional[ModelType]:
|
||||||
|
"""
|
||||||
|
通过ID获取对象
|
||||||
|
"""
|
||||||
|
query = select(self.model).where(self.model.id == id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# 分页查询
|
||||||
|
async def get_multi(
|
||||||
|
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[ModelType]:
|
||||||
|
"""
|
||||||
|
获取多个对象
|
||||||
|
"""
|
||||||
|
query = select(self.model).offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||||
|
"""
|
||||||
|
创建对象
|
||||||
|
"""
|
||||||
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
|
db_obj = self.model(**obj_in_data)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
db_obj: ModelType,
|
||||||
|
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||||
|
) -> ModelType:
|
||||||
|
"""
|
||||||
|
更新对象
|
||||||
|
"""
|
||||||
|
obj_data = jsonable_encoder(db_obj)
|
||||||
|
if isinstance(obj_in, dict):
|
||||||
|
update_data = obj_in
|
||||||
|
else:
|
||||||
|
update_data = obj_in.dict(exclude_unset=True)
|
||||||
|
for field in obj_data:
|
||||||
|
if field in update_data:
|
||||||
|
setattr(db_obj, field, update_data[field])
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
async def remove(self, db: AsyncSession, *, id: Any) -> ModelType:
|
||||||
|
"""
|
||||||
|
删除对象
|
||||||
|
"""
|
||||||
|
obj = await self.get(db=db, id=id)
|
||||||
|
await db.delete(obj)
|
||||||
|
await db.commit()
|
||||||
|
return obj
|
||||||
130
app/crud/event.py
Normal file
130
app/crud/event.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from sqlalchemy import select, and_, or_, join
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.models import Event, Image, Temperature
|
||||||
|
from app.schemas.event import EventUpdate, EventQuery, TestEvent
|
||||||
|
|
||||||
|
class CRUDEvent(CRUDBase[Event, EventUpdate, EventUpdate]):
|
||||||
|
async def get_by_id(self, db: AsyncSession, *, event_id: str) -> Optional[Event]:
|
||||||
|
"""根据ID获取事件"""
|
||||||
|
query = (
|
||||||
|
select(Event)
|
||||||
|
.options(
|
||||||
|
selectinload(Event.images),
|
||||||
|
selectinload(Event.temperatures)
|
||||||
|
)
|
||||||
|
.where(Event.eventId == event_id)
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_multi_with_query(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
query: EventQuery
|
||||||
|
) -> List[Event]:
|
||||||
|
"""根据查询条件获取事件列表"""
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
if query.start_time:
|
||||||
|
conditions.append(Event.insDate >= query.start_time)
|
||||||
|
if query.end_time:
|
||||||
|
conditions.append(Event.insDate <= query.end_time)
|
||||||
|
if query.etypeName:
|
||||||
|
conditions.append(Event.etypeName == query.etypeName)
|
||||||
|
if query.area:
|
||||||
|
conditions.append(Event.area == query.area)
|
||||||
|
|
||||||
|
query_stmt = (
|
||||||
|
select(Event)
|
||||||
|
.options(
|
||||||
|
selectinload(Event.images),
|
||||||
|
selectinload(Event.temperatures)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
query_stmt = query_stmt.where(and_(*conditions))
|
||||||
|
|
||||||
|
query_stmt = query_stmt.offset(query.skip).limit(query.limit)
|
||||||
|
|
||||||
|
result = await db.execute(query_stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_event(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
event_id: str,
|
||||||
|
obj_in: EventUpdate
|
||||||
|
) -> Optional[Event]:
|
||||||
|
"""更新事件信息"""
|
||||||
|
event = await self.get_by_id(db, event_id=event_id)
|
||||||
|
if not event:
|
||||||
|
return None
|
||||||
|
|
||||||
|
update_data = obj_in.model_dump()
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(event, field, value)
|
||||||
|
|
||||||
|
# db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(event)
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def delete_event(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
event_id: str
|
||||||
|
) -> Optional[Event]:
|
||||||
|
"""删除事件"""
|
||||||
|
event = await self.get_by_id(db, event_id=event_id)
|
||||||
|
if not event:
|
||||||
|
return None
|
||||||
|
|
||||||
|
await db.delete(event)
|
||||||
|
await db.commit()
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
async def get_test(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
|
||||||
|
|
||||||
|
) -> List[TestEvent]: #响应类型要写对啊
|
||||||
|
|
||||||
|
'''
|
||||||
|
eventId
|
||||||
|
number
|
||||||
|
name
|
||||||
|
imageUrl
|
||||||
|
localPath
|
||||||
|
temperature
|
||||||
|
confidence
|
||||||
|
createTime
|
||||||
|
'''
|
||||||
|
|
||||||
|
query_stmt = (
|
||||||
|
select(Event.eventId, Event.number, Event.name, Image.imageUrl, Image.localPath, Temperature.temperature, Temperature.confidence, Temperature.createTime)
|
||||||
|
.select_from(Event)
|
||||||
|
.outerjoin(Image, Event.eventId == Image.eventId)
|
||||||
|
.outerjoin(Temperature, Image.imageId == Temperature.imageId)
|
||||||
|
# 多个查询条件
|
||||||
|
.where(and_(Event.etypeName=="日常巡检", True))
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(query_stmt)
|
||||||
|
|
||||||
|
# 获取字典
|
||||||
|
result = result.mappings().all()
|
||||||
|
|
||||||
|
# print(result)
|
||||||
|
return [TestEvent(**row) for row in result]
|
||||||
|
|
||||||
|
|
||||||
|
event = CRUDEvent(Event)
|
||||||
18
app/dao/imageDao.py
Normal file
18
app/dao/imageDao.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from typing import List
|
||||||
|
# from app.models.models import image
|
||||||
|
from app.schemas.image import ImageBase
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.models import Image
|
||||||
|
|
||||||
|
class ImageDao:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_image_list(cls, db: AsyncSession, query_model: ImageBase):
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(Image)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.mappings().all()
|
||||||
25
app/main.py
Normal file
25
app/main.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from app.api.v1.api import router
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="测试fastapi",
|
||||||
|
description="测试啊",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "测试成功"}
|
||||||
134
app/models/models.py
Normal file
134
app/models/models.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, Index, BigInteger
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Image(Base):
|
||||||
|
__tablename__ = "image"
|
||||||
|
|
||||||
|
imageId = Column(BigInteger, primary_key=True, autoincrement=True, comment='图片ID')
|
||||||
|
localPath = Column(String(500), comment='本地存储路径')
|
||||||
|
type = Column(String(2), comment='图片类型, 0 ocr 图片, 1 目标检测图片')
|
||||||
|
result = Column(String(500), comment='检测结果')
|
||||||
|
createTime = Column(DateTime, default=datetime.now, comment='创建时间')
|
||||||
|
updateTime = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||||||
|
|
||||||
|
|
||||||
|
# class Event(Base):
|
||||||
|
# __tablename__ = "event"
|
||||||
|
|
||||||
|
# eventId = Column(String(50), primary_key=True)
|
||||||
|
# tenantInfoId = Column(String(100))
|
||||||
|
# reportEventId = Column(String(100))
|
||||||
|
# number = Column(String(20))
|
||||||
|
# name = Column(String(20))
|
||||||
|
# eclassify = Column(String(5))
|
||||||
|
# operationType = Column(String(5))
|
||||||
|
# etype = Column(String(20))
|
||||||
|
# etypeName = Column(String(20))
|
||||||
|
# enTypeName = Column(String(30))
|
||||||
|
# hkTypeName = Column(String(20))
|
||||||
|
# reportStatus = Column(String(5))
|
||||||
|
# results = Column(String(5))
|
||||||
|
# insDate = Column(DateTime)
|
||||||
|
# insDateShow = Column(DateTime)
|
||||||
|
# updDate = Column(DateTime)
|
||||||
|
# updDateShow = Column(DateTime)
|
||||||
|
# fileType = Column(String(5))
|
||||||
|
# area = Column(String(20))
|
||||||
|
# floor = Column(String(10))
|
||||||
|
# map = Column(String(20))
|
||||||
|
# staffId = Column(String(40))
|
||||||
|
# targetUserId = Column(String(40))
|
||||||
|
# position = Column(String(100))
|
||||||
|
# actualStaffName = Column(String(20))
|
||||||
|
# targetStaffName = Column(String(20))
|
||||||
|
# routeName = Column(String(20))
|
||||||
|
# phoneAddress = Column(String(500))
|
||||||
|
# width = Column(String(10))
|
||||||
|
# height = Column(String(10))
|
||||||
|
# resolution = Column(String(10))
|
||||||
|
# originX = Column(String(20))
|
||||||
|
# originY = Column(String(20))
|
||||||
|
# imgList = Column(String(20))
|
||||||
|
# robotType = Column(String(5))
|
||||||
|
# eventFloor = Column(String(10))
|
||||||
|
# floorName = Column(String(10))
|
||||||
|
# coordId = Column(String(40))
|
||||||
|
# coord = Column(String(40))
|
||||||
|
# coordName = Column(String(30))
|
||||||
|
# positonName = Column(String(20))
|
||||||
|
# processingRemark = Column(String(300))
|
||||||
|
# carId = Column(String(40))
|
||||||
|
# parkingSpaceType = Column(String(10))
|
||||||
|
# parkingSpaceNumber = Column(String(40))
|
||||||
|
# carNumber = Column(String(40))
|
||||||
|
# eno = Column(String(40))
|
||||||
|
# instrument = Column(String(40))
|
||||||
|
# evideo = Column(String(40))
|
||||||
|
# createTime = Column(DateTime, default=datetime.now, comment='本地后台创建时间')
|
||||||
|
# updateTime = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='本地后台更新时间')
|
||||||
|
|
||||||
|
# # 关系
|
||||||
|
# images = relationship("Image", back_populates="event", cascade="all, delete-orphan")
|
||||||
|
# temperatures = relationship("Temperature", back_populates="event", cascade="all, delete-orphan")
|
||||||
|
# process_logs = relationship("ProcessLog", back_populates="event", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
|
||||||
|
# class Image(Base):
|
||||||
|
# __tablename__ = "image"
|
||||||
|
|
||||||
|
# imageId = Column(BigInteger, primary_key=True, autoincrement=True, comment='图片ID')
|
||||||
|
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
|
||||||
|
# imageUrl = Column(String(500), nullable=False, comment='图片URL')
|
||||||
|
# localPath = Column(String(500), comment='本地存储路径')
|
||||||
|
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||||
|
|
||||||
|
# # 关系
|
||||||
|
# event = relationship("Event", back_populates="images" )
|
||||||
|
# temperatures = relationship("Temperature", back_populates="image")
|
||||||
|
|
||||||
|
# __table_args__ = (
|
||||||
|
# Index('idx_image_event_id', 'eventId'),
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# class Temperature(Base):
|
||||||
|
# __tablename__ = "temperature"
|
||||||
|
|
||||||
|
# tempId = Column(BigInteger, primary_key=True, autoincrement=True, comment='温度记录ID')
|
||||||
|
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
|
||||||
|
# imageId = Column(BigInteger, ForeignKey('image.imageId'), nullable=False, comment='关联图片ID')
|
||||||
|
# temperature = Column(String(20), nullable=False, comment='温度值')
|
||||||
|
# # status = Column(String(2), comment='温度是否正常')
|
||||||
|
# confidence = Column(String(40), nullable=False, comment='识别置信度')
|
||||||
|
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||||
|
|
||||||
|
# # 关系
|
||||||
|
# event = relationship("Event", back_populates="temperatures")
|
||||||
|
# image = relationship("Image", back_populates="temperatures")
|
||||||
|
|
||||||
|
# __table_args__ = (
|
||||||
|
# Index('idx_temp_event_id', 'eventId'),
|
||||||
|
# Index('idx_temp_create_time', 'createTime'),
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# class ProcessLog(Base):
|
||||||
|
# __tablename__ = "process_log"
|
||||||
|
|
||||||
|
# logId = Column(BigInteger, primary_key=True, autoincrement=True, comment='日志ID')
|
||||||
|
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
|
||||||
|
# processStatus = Column(Integer, nullable=False, comment='处理状态')
|
||||||
|
# errorMessage = Column(Text, comment='错误信息')
|
||||||
|
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||||
|
|
||||||
|
# # 关系
|
||||||
|
# event = relationship("Event", back_populates="process_logs")
|
||||||
|
|
||||||
|
# __table_args__ = (
|
||||||
|
# Index('idx_log_event_id', 'eventId'),
|
||||||
|
# Index('idx_log_create_time', 'createTime'),
|
||||||
|
# )
|
||||||
0
app/schemas/__init__.py
Normal file
0
app/schemas/__init__.py
Normal file
116
app/schemas/event.py
Normal file
116
app/schemas/event.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
class ImageBase(BaseModel):
|
||||||
|
imageUrl: str
|
||||||
|
localPath: Optional[str] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class TemperatureBase(BaseModel):
|
||||||
|
temperature: str
|
||||||
|
confidence: str
|
||||||
|
createTime: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class EventBase(BaseModel):
|
||||||
|
eventId: str
|
||||||
|
tenantInfoId: Optional[str] = None
|
||||||
|
reportEventId: Optional[str] = None
|
||||||
|
number: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
etypeName: Optional[str] = None
|
||||||
|
insDate: Optional[datetime] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class EventList(EventBase):
|
||||||
|
images: List[ImageBase] = []
|
||||||
|
temperatures: List[TemperatureBase] = []
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class EventDetail(EventBase):
|
||||||
|
eclassify: Optional[str] = None
|
||||||
|
operationType: Optional[str] = None
|
||||||
|
etype: Optional[str] = None
|
||||||
|
enTypeName: Optional[str] = None
|
||||||
|
hkTypeName: Optional[str] = None
|
||||||
|
reportStatus: Optional[str] = None
|
||||||
|
results: Optional[str] = None
|
||||||
|
insDateShow: Optional[datetime] = None
|
||||||
|
updDate: Optional[datetime] = None
|
||||||
|
updDateShow: Optional[datetime] = None
|
||||||
|
fileType: Optional[str] = None
|
||||||
|
area: Optional[str] = None
|
||||||
|
floor: Optional[str] = None
|
||||||
|
map: Optional[str] = None
|
||||||
|
staffId: Optional[str] = None
|
||||||
|
targetUserId: Optional[str] = None
|
||||||
|
position: Optional[str] = None
|
||||||
|
actualStaffName: Optional[str] = None
|
||||||
|
targetStaffName: Optional[str] = None
|
||||||
|
routeName: Optional[str] = None
|
||||||
|
phoneAddress: Optional[str] = None
|
||||||
|
width: Optional[str] = None
|
||||||
|
height: Optional[str] = None
|
||||||
|
resolution: Optional[str] = None
|
||||||
|
originX: Optional[str] = None
|
||||||
|
originY: Optional[str] = None
|
||||||
|
robotType: Optional[str] = None
|
||||||
|
eventFloor: Optional[str] = None
|
||||||
|
floorName: Optional[str] = None
|
||||||
|
coordId: Optional[str] = None
|
||||||
|
coord: Optional[str] = None
|
||||||
|
coordName: Optional[str] = None
|
||||||
|
positonName: Optional[str] = None
|
||||||
|
processingRemark: Optional[str] = None
|
||||||
|
carId: Optional[str] = None
|
||||||
|
parkingSpaceType: Optional[str] = None
|
||||||
|
parkingSpaceNumber: Optional[str] = None
|
||||||
|
carNumber: Optional[str] = None
|
||||||
|
eno: Optional[str] = None
|
||||||
|
instrument: Optional[str] = None
|
||||||
|
evideo: Optional[str] = None
|
||||||
|
createTime: Optional[datetime] = None
|
||||||
|
updateTime: Optional[datetime] = None
|
||||||
|
images: List[ImageBase] = []
|
||||||
|
temperatures: List[TemperatureBase] = []
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class EventUpdate(BaseModel):
|
||||||
|
number: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
etypeName: Optional[str] = None
|
||||||
|
area: Optional[str] = None
|
||||||
|
position: Optional[str] = None
|
||||||
|
processingRemark: Optional[str] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class EventQuery(BaseModel):
|
||||||
|
start_time: Optional[datetime] = None
|
||||||
|
end_time: Optional[datetime] = None
|
||||||
|
etypeName: Optional[str] = None
|
||||||
|
area: Optional[str] = None
|
||||||
|
skip: int = 0
|
||||||
|
limit: int = 100
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvent(BaseModel):
|
||||||
|
eventId: str
|
||||||
|
number: Optional[str] = None
|
||||||
|
name : Optional[str] = None
|
||||||
|
imageUrl : Optional[str] = None
|
||||||
|
localPath: Optional[str] = None
|
||||||
|
temperature: Optional[str] = None
|
||||||
|
confidence : Optional[str] = None
|
||||||
|
createTime: Optional[datetime] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
15
app/schemas/image.py
Normal file
15
app/schemas/image.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class ImageBase(BaseModel):
|
||||||
|
|
||||||
|
imageId: Optional[int] = None
|
||||||
|
localPath: Optional[str] = None
|
||||||
|
type : Optional[str] = None
|
||||||
|
result : Optional[str] = None
|
||||||
|
createTime : Optional[datetime] = None
|
||||||
|
updateTime : Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
11
app/schemas/ocr.py
Normal file
11
app/schemas/ocr.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class ImageBase64Request(BaseModel):
|
||||||
|
image_base64: str
|
||||||
|
image_type: Optional[str] = "jpg" # 图片类型,如 jpg, png 等
|
||||||
|
|
||||||
|
# class OCRResponse(BaseModel):
|
||||||
|
# success: bool
|
||||||
|
# message: str
|
||||||
|
# data: Optional[list] = None
|
||||||
15
app/services/imageServices.py
Normal file
15
app/services/imageServices.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import List, Optional
|
||||||
|
from app.schemas.image import ImageBase
|
||||||
|
from app.dao.imageDao import ImageDao
|
||||||
|
|
||||||
|
|
||||||
|
class ImageService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_image_list(cls, db: AsyncSession, query_model: ImageBase):
|
||||||
|
|
||||||
|
result = await ImageDao.get_image_list(db, query_model)
|
||||||
|
|
||||||
|
return result
|
||||||
535
app/util/baiduOCR.py
Normal file
535
app/util/baiduOCR.py
Normal file
@ -0,0 +1,535 @@
|
|||||||
|
from paddleocr import PaddleOCR
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import yaml
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
# import math
|
||||||
|
|
||||||
|
|
||||||
|
from app.config.config import OCRSettings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduOCR:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = PaddleOCR(
|
||||||
|
# 文本检测模型地址
|
||||||
|
# text_detection_model_dir = "/home/admin-root/haotian/康达瑞贝斯机器狗/ocr_model/PP-OCRv5_server_det",
|
||||||
|
text_detection_model_dir=OCRSettings.TEXT_DETECTION_MODEL_DIR,
|
||||||
|
# 文本识别模型地址
|
||||||
|
# text_recognition_model_dir = "/home/admin-root/haotian/康达瑞贝斯机器狗/ocr_model/PP-OCRv5_server_rec",
|
||||||
|
text_recognition_model_dir=OCRSettings.TEXT_RECONGNITION_MODEL_DIR,
|
||||||
|
use_doc_orientation_classify=False,
|
||||||
|
use_doc_unwarping=False,
|
||||||
|
use_textline_orientation=False,
|
||||||
|
# 多gpu有问题
|
||||||
|
device='gpu:2'
|
||||||
|
)
|
||||||
|
def ocr(self, image_path:str):
|
||||||
|
try:
|
||||||
|
result = self.model.predict(image_path)
|
||||||
|
except IndexError:
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.draw_ocr_result(image_path, result)
|
||||||
|
|
||||||
|
result = self.parse_result(result)
|
||||||
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
def parse_result(self, result):
|
||||||
|
text_list = []
|
||||||
|
for item in result:
|
||||||
|
rec_texts = item.get('rec_texts')
|
||||||
|
rec_scores = item.get('rec_scores')
|
||||||
|
|
||||||
|
if rec_texts is None or rec_scores is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
text_list.append([rec_texts, rec_scores])
|
||||||
|
return text_list
|
||||||
|
|
||||||
|
def draw_ocr_result(self, image_path: str, result):
|
||||||
|
"""
|
||||||
|
在原图上绘制 OCR 识别结果
|
||||||
|
"""
|
||||||
|
# 读取原图
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# 转换为 PIL Image 以支持中文显示
|
||||||
|
pil_img = Image.fromarray(image)
|
||||||
|
draw = ImageDraw.Draw(pil_img)
|
||||||
|
|
||||||
|
# 尝试加载中文字体(根据系统调整路径)
|
||||||
|
try:
|
||||||
|
# font = ImageFont.truetype("simhei.ttf", 20) # Windows
|
||||||
|
font = ImageFont.truetype("/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", 20) # Linux
|
||||||
|
# font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20) # macOS
|
||||||
|
except:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
# 绘制每个检测框和文本
|
||||||
|
for idx, item in enumerate(result):
|
||||||
|
box = item['rec_boxes'] # shape=(1, 4)
|
||||||
|
text = item['rec_texts']
|
||||||
|
score = item['rec_scores']
|
||||||
|
|
||||||
|
# 处理可能的列表类型
|
||||||
|
if isinstance(text, list):
|
||||||
|
text = text[0] if len(text) > 0 else ""
|
||||||
|
if isinstance(score, list):
|
||||||
|
score = score[0] if len(score) > 0 else 0.0
|
||||||
|
|
||||||
|
# 提取坐标
|
||||||
|
if box is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = box.flatten()
|
||||||
|
if len(box) < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
||||||
|
|
||||||
|
# 绘制矩形框
|
||||||
|
draw.rectangle([(x1, y1), (x2, y2)], outline=(255, 0, 0), width=2)
|
||||||
|
|
||||||
|
# 绘制文本和置信度
|
||||||
|
text_position = (x1, max(0, y1 - 25))
|
||||||
|
text_with_score = f"{text} ({float(score):.2f})"
|
||||||
|
draw.text(text_position, text_with_score, fill=(0, 255, 0), font=font)
|
||||||
|
|
||||||
|
# 转换回 OpenCV 格式
|
||||||
|
result_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# 获取文件夹路径并创建
|
||||||
|
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
||||||
|
|
||||||
|
cv2.imwrite(image_path, result_img)
|
||||||
|
|
||||||
|
# return result_img
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduOCRONNX:
|
||||||
|
def __init__(self, det_model_path=OCRSettings.TEXT_DETECTION_MODEL_ONNX_DIR, rec_model_path=OCRSettings.TEXT_RECONGNITION_MODEL_ONNX_DIR):
|
||||||
|
"""
|
||||||
|
初始化ONNX推理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
det_model_path: 检测模型路径 (det.onnx)
|
||||||
|
rec_model_path: 识别模型路径 (rec.onnx)
|
||||||
|
"""
|
||||||
|
# 初始化检测模型
|
||||||
|
self.det_session = ort.InferenceSession(det_model_path)
|
||||||
|
self.det_input_name = self.det_session.get_inputs()[0].name
|
||||||
|
|
||||||
|
# 初始化识别模型
|
||||||
|
self.rec_session = ort.InferenceSession(rec_model_path)
|
||||||
|
self.rec_input_name = self.rec_session.get_inputs()[0].name
|
||||||
|
|
||||||
|
# 字符集(根据您的模型调整)
|
||||||
|
# self.character = ['blank', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+',
|
||||||
|
# ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8',
|
||||||
|
# '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E',
|
||||||
|
# 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
|
||||||
|
# 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
|
||||||
|
# '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
|
||||||
|
# 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y',
|
||||||
|
# 'z', '{', '|', '}', '~'] + [chr(i) for i in range(19968, 40870)] # 中文字符
|
||||||
|
self.character = self.get_dict()
|
||||||
|
|
||||||
|
if self.character is None:
|
||||||
|
raise ValueError('请检查字典文件是否存在!')
|
||||||
|
|
||||||
|
def get_dict(self, dict_path='./dict.yaml'):
|
||||||
|
"""
|
||||||
|
加载字典
|
||||||
|
"""
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
dict_rec = yaml.safe_load(f)
|
||||||
|
return dict_rec.get('character_dict', [])
|
||||||
|
|
||||||
|
def resize_norm_img_det(self, img, input_shape=(640, 640)):
|
||||||
|
"""
|
||||||
|
检测模型的图像预处理 - 固定输入形状 [1, 3, 640, 640]
|
||||||
|
"""
|
||||||
|
h, w, _ = img.shape
|
||||||
|
target_h, target_w = input_shape
|
||||||
|
|
||||||
|
# 计算缩放比例 - 保持宽高比
|
||||||
|
ratio_h = target_h / h
|
||||||
|
ratio_w = target_w / w
|
||||||
|
ratio = min(ratio_h, ratio_w)
|
||||||
|
|
||||||
|
# 计算缩放后的尺寸
|
||||||
|
new_h = int(h * ratio)
|
||||||
|
new_w = int(w * ratio)
|
||||||
|
|
||||||
|
# 调整图像大小
|
||||||
|
resized_img = cv2.resize(img, (new_w, new_h))
|
||||||
|
|
||||||
|
# 创建目标尺寸的图像,用灰色填充
|
||||||
|
padded_img = np.ones((target_h, target_w, 3), dtype=np.float32) * 114.0 # 直接用float32
|
||||||
|
|
||||||
|
# 计算居中位置
|
||||||
|
top = (target_h - new_h) // 2
|
||||||
|
left = (target_w - new_w) // 2
|
||||||
|
|
||||||
|
# 将缩放后的图像放到居中位置
|
||||||
|
padded_img[top:top+new_h, left:left+new_w] = resized_img.astype(np.float32)
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
img = (padded_img / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
img = img.transpose(2, 0, 1).astype(np.float32)
|
||||||
|
img = np.expand_dims(img, axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
return img, ratio, (top, left)
|
||||||
|
|
||||||
|
def post_process_det(self, dt_boxes, ratio, padding_info, ori_shape):
|
||||||
|
"""
|
||||||
|
检测结果后处理 - 适配固定输入形状
|
||||||
|
"""
|
||||||
|
if dt_boxes is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ori_h, ori_w = ori_shape
|
||||||
|
top, left = padding_info
|
||||||
|
|
||||||
|
# 将坐标从模型输出空间转换回原图空间
|
||||||
|
dt_boxes[:, :, 0] = (dt_boxes[:, :, 0] - left) / ratio
|
||||||
|
dt_boxes[:, :, 1] = (dt_boxes[:, :, 1] - top) / ratio
|
||||||
|
|
||||||
|
# 裁剪到原图范围内
|
||||||
|
dt_boxes[:, :, 0] = np.clip(dt_boxes[:, :, 0], 0, ori_w)
|
||||||
|
dt_boxes[:, :, 1] = np.clip(dt_boxes[:, :, 1], 0, ori_h)
|
||||||
|
|
||||||
|
return dt_boxes
|
||||||
|
|
||||||
|
def boxes_from_bitmap(self, pred, bitmap, dest_width, dest_height, max_candidates=1000, box_thresh=0.6):
|
||||||
|
"""
|
||||||
|
从位图中提取文本框
|
||||||
|
"""
|
||||||
|
bitmap = bitmap.astype(np.uint8)
|
||||||
|
height, width = bitmap.shape
|
||||||
|
|
||||||
|
# 查找轮廓
|
||||||
|
contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
num_contours = min(len(contours), max_candidates)
|
||||||
|
boxes = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for i in range(num_contours):
|
||||||
|
contour = contours[i]
|
||||||
|
points, sside = self.get_mini_boxes(contour)
|
||||||
|
if sside < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
points = np.array(points)
|
||||||
|
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||||
|
if box_thresh > score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 扩展box
|
||||||
|
box = self.unclip(points, 1.5).reshape(-1, 1, 2)
|
||||||
|
box, sside = self.get_mini_boxes(box)
|
||||||
|
if sside < 5 + 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = np.array(box)
|
||||||
|
box[:, 0] = np.clip(box[:, 0] / width * dest_width, 0, dest_width)
|
||||||
|
box[:, 1] = np.clip(box[:, 1] / height * dest_height, 0, dest_height)
|
||||||
|
|
||||||
|
boxes.append(box.astype(np.int16))
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
return np.array(boxes), scores
|
||||||
|
|
||||||
|
def get_mini_boxes(self, contour):
|
||||||
|
"""获取最小外接矩形"""
|
||||||
|
bounding_box = cv2.minAreaRect(contour)
|
||||||
|
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||||
|
|
||||||
|
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||||
|
if points[1][1] > points[0][1]:
|
||||||
|
index_1 = 0
|
||||||
|
index_4 = 1
|
||||||
|
else:
|
||||||
|
index_1 = 1
|
||||||
|
index_4 = 0
|
||||||
|
|
||||||
|
if points[3][1] > points[2][1]:
|
||||||
|
index_2 = 2
|
||||||
|
index_3 = 3
|
||||||
|
else:
|
||||||
|
index_2 = 3
|
||||||
|
index_3 = 2
|
||||||
|
|
||||||
|
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
||||||
|
return box, min(bounding_box[1])
|
||||||
|
|
||||||
|
def box_score_fast(self, bitmap, _box):
|
||||||
|
"""快速计算box得分"""
|
||||||
|
h, w = bitmap.shape[:2]
|
||||||
|
box = _box.copy()
|
||||||
|
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
|
||||||
|
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
|
||||||
|
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
|
||||||
|
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
|
||||||
|
|
||||||
|
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||||
|
box[:, 0] = box[:, 0] - xmin
|
||||||
|
box[:, 1] = box[:, 1] - ymin
|
||||||
|
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||||
|
|
||||||
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
|
def unclip(self, box, unclip_ratio):
|
||||||
|
"""扩展文本框"""
|
||||||
|
from shapely.geometry import Polygon
|
||||||
|
import pyclipper
|
||||||
|
|
||||||
|
poly = Polygon(box)
|
||||||
|
distance = poly.area * unclip_ratio / poly.length
|
||||||
|
|
||||||
|
offset = pyclipper.PyclipperOffset()
|
||||||
|
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||||
|
expanded = offset.Execute(distance)
|
||||||
|
|
||||||
|
if len(expanded) == 0:
|
||||||
|
return box
|
||||||
|
else:
|
||||||
|
return np.array(expanded[0])
|
||||||
|
|
||||||
|
def resize_norm_img_rec(self, img, input_shape=(320, 48)):
|
||||||
|
"""
|
||||||
|
识别模型的图像预处理 - 固定输入形状 [1, 3, 48, 320]
|
||||||
|
"""
|
||||||
|
target_w, target_h = input_shape # 注意:宽度在前
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
|
# 计算缩放比例,保持宽高比
|
||||||
|
ratio_h = target_h / h
|
||||||
|
ratio_w = target_w / w
|
||||||
|
ratio = min(ratio_h, ratio_w)
|
||||||
|
|
||||||
|
# 计算缩放后的尺寸
|
||||||
|
new_h = int(h * ratio)
|
||||||
|
new_w = int(w * ratio)
|
||||||
|
|
||||||
|
# 调整图像大小
|
||||||
|
resized_image = cv2.resize(img, (new_w, new_h))
|
||||||
|
|
||||||
|
# 创建目标尺寸的图像,用黑色填充
|
||||||
|
padded_image = np.zeros((target_h, target_w, 3), dtype=np.float32) # 直接用float32
|
||||||
|
|
||||||
|
# 将缩放后的图像放到左上角(识别模型通常左对齐)
|
||||||
|
padded_image[:new_h, :new_w] = resized_image.astype(np.float32)
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
# padded_image = (padded_image / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
|
||||||
|
# 不缩放反而会将识别结果再移后一个??
|
||||||
|
padded_image = (padded_image / 255.0).astype(np.float32)
|
||||||
|
padded_image = padded_image.transpose((2, 0, 1)).astype(np.float32)
|
||||||
|
|
||||||
|
return np.expand_dims(padded_image, axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
def decode_rec_result(self, preds_prob):
|
||||||
|
"""
|
||||||
|
解码识别结果
|
||||||
|
"""
|
||||||
|
preds_idx = np.argmax(preds_prob, axis=1)
|
||||||
|
preds_prob = np.max(preds_prob, axis=1)
|
||||||
|
|
||||||
|
# CTC解码
|
||||||
|
last_idx = 0
|
||||||
|
preds_text = []
|
||||||
|
preds_conf = []
|
||||||
|
|
||||||
|
for i, idx in enumerate(preds_idx):
|
||||||
|
if idx != last_idx and idx != 0: # 0是blank
|
||||||
|
if idx < len(self.character):
|
||||||
|
preds_text.append(self.character[idx])
|
||||||
|
preds_conf.append(preds_prob[i])
|
||||||
|
last_idx = idx
|
||||||
|
|
||||||
|
text = ''.join(preds_text)
|
||||||
|
conf = np.mean(preds_conf) if preds_conf else 0.0
|
||||||
|
|
||||||
|
return text, conf
|
||||||
|
|
||||||
|
def detect_text(self, image):
|
||||||
|
"""
|
||||||
|
文本检测 - 适配固定输入形状 [1, 3, 640, 640]
|
||||||
|
"""
|
||||||
|
ori_h, ori_w = image.shape[:2]
|
||||||
|
|
||||||
|
# 预处理
|
||||||
|
det_img, ratio, padding_info = self.resize_norm_img_det(image)
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
det_output = self.det_session.run(None, {self.det_input_name: det_img})[0]
|
||||||
|
|
||||||
|
# 后处理
|
||||||
|
mask = det_output[0, 0, :, :]
|
||||||
|
threshold = 0.3
|
||||||
|
bitmap = (mask > threshold).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
# 从位图中提取文本框(坐标是在640x640空间中的)
|
||||||
|
boxes, scores = self.boxes_from_bitmap(mask, bitmap, 640, 640)
|
||||||
|
|
||||||
|
# 将坐标转换回原图空间
|
||||||
|
if len(boxes) > 0:
|
||||||
|
boxes = self.post_process_det(boxes, ratio, padding_info, (ori_h, ori_w))
|
||||||
|
|
||||||
|
return boxes, scores
|
||||||
|
|
||||||
|
def recognize_text(self, image):
|
||||||
|
"""
|
||||||
|
文本识别
|
||||||
|
"""
|
||||||
|
# 预处理
|
||||||
|
rec_img = self.resize_norm_img_rec(image)
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
rec_output = self.rec_session.run(None, {self.rec_input_name: rec_img})[0]
|
||||||
|
|
||||||
|
# 解码
|
||||||
|
text, conf = self.decode_rec_result(rec_output[0])
|
||||||
|
|
||||||
|
return text, conf
|
||||||
|
|
||||||
|
def get_rotate_crop_image(self, img, points):
|
||||||
|
"""
|
||||||
|
根据四个点坐标裁剪并矫正图像
|
||||||
|
"""
|
||||||
|
img_crop_width = int(
|
||||||
|
max(
|
||||||
|
np.linalg.norm(points[0] - points[1]),
|
||||||
|
np.linalg.norm(points[2] - points[3])))
|
||||||
|
img_crop_height = int(
|
||||||
|
max(
|
||||||
|
np.linalg.norm(points[0] - points[3]),
|
||||||
|
np.linalg.norm(points[1] - points[2])))
|
||||||
|
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||||
|
[img_crop_width, img_crop_height],
|
||||||
|
[0, img_crop_height]])
|
||||||
|
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||||
|
dst_img = cv2.warpPerspective(
|
||||||
|
img,
|
||||||
|
M, (img_crop_width, img_crop_height),
|
||||||
|
borderMode=cv2.BORDER_REPLICATE,
|
||||||
|
flags=cv2.INTER_CUBIC)
|
||||||
|
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||||
|
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||||
|
dst_img = np.rot90(dst_img)
|
||||||
|
return dst_img
|
||||||
|
|
||||||
|
def ocr(self, image_path):
|
||||||
|
"""
|
||||||
|
完整的OCR流程
|
||||||
|
"""
|
||||||
|
# 读取图像
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 1. 文本检测
|
||||||
|
dt_boxes, scores = self.detect_text(image)
|
||||||
|
|
||||||
|
if dt_boxes is None or len(dt_boxes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 2. 文本识别
|
||||||
|
ocr_results = []
|
||||||
|
|
||||||
|
text_list = []
|
||||||
|
confidence_list = []
|
||||||
|
for i, box in enumerate(dt_boxes):
|
||||||
|
# 裁剪文本区域
|
||||||
|
box_points = box.astype(np.float32)
|
||||||
|
crop_img = self.get_rotate_crop_image(image, box_points)
|
||||||
|
if crop_img.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 识别文本
|
||||||
|
text, conf = self.recognize_text(crop_img)
|
||||||
|
|
||||||
|
if conf > 0.5: # 置信度过滤
|
||||||
|
ocr_results.append({
|
||||||
|
'text': text,
|
||||||
|
'confidence': conf,
|
||||||
|
'box': box.tolist(),
|
||||||
|
'score': scores[i] if i < len(scores) else 0.0
|
||||||
|
})
|
||||||
|
|
||||||
|
text_list.append(text)
|
||||||
|
confidence_list.append(round(conf.item(), 2))
|
||||||
|
|
||||||
|
# return ocr_results
|
||||||
|
return [text_list, confidence_list]
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
def main():
|
||||||
|
# 初始化OCR
|
||||||
|
# ocr = PaddleOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape.onnx')
|
||||||
|
|
||||||
|
ocr = BaiduOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx')
|
||||||
|
|
||||||
|
# 执行OCR
|
||||||
|
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/3aee64cc1f90d93a5a45979f7b17cb4b_frame_001460.jpg'
|
||||||
|
results = ocr.ocr(image_path)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
for result in results:
|
||||||
|
print(f"文本: {result['text']}")
|
||||||
|
print(f"置信度: {result['confidence']:.3f}")
|
||||||
|
print(f"检测得分: {result['score']:.3f}")
|
||||||
|
print(f"坐标: {result['box']}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
visualize_results(image_path, results)
|
||||||
|
|
||||||
|
def visualize_results(image_path, results):
|
||||||
|
"""
|
||||||
|
可视化OCR结果
|
||||||
|
"""
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
box = np.array(result['box'], dtype=np.int32)
|
||||||
|
cv2.polylines(image, [box], True, (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# 在框上方显示文本
|
||||||
|
cv2.putText(image, result['text'],
|
||||||
|
(box[0][0], box[0][1] - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
|
||||||
|
|
||||||
|
cv2.imwrite('result_shape_20250815.jpg', image)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ocr = BaiduOCR()
|
||||||
|
print(ocr.ocr(""))
|
||||||
266
app/util/responseHttp.py
Normal file
266
app/util/responseHttp.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.background import BackgroundTask
|
||||||
|
from typing import Any, Dict, Mapping, Optional
|
||||||
|
from app.config.constant import HttpStatusConstant
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseUtil:
|
||||||
|
"""
|
||||||
|
响应工具类
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(
|
||||||
|
cls,
|
||||||
|
msg: str = '操作成功',
|
||||||
|
data: Optional[Any] = None,
|
||||||
|
rows: Optional[Any] = None,
|
||||||
|
dict_content: Optional[Dict] = None,
|
||||||
|
model_content: Optional[BaseModel] = None,
|
||||||
|
headers: Optional[Mapping[str, str]] = None,
|
||||||
|
media_type: Optional[str] = None,
|
||||||
|
background: Optional[BackgroundTask] = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
成功响应方法
|
||||||
|
|
||||||
|
:param msg: 可选,自定义成功响应信息
|
||||||
|
:param data: 可选,成功响应结果中属性为data的值
|
||||||
|
:param rows: 可选,成功响应结果中属性为rows的值
|
||||||
|
:param dict_content: 可选,dict类型,成功响应结果中自定义属性的值
|
||||||
|
:param model_content: 可选,BaseModel类型,成功响应结果中自定义属性的值
|
||||||
|
:param headers: 可选,响应头信息
|
||||||
|
:param media_type: 可选,响应结果媒体类型
|
||||||
|
:param background: 可选,响应返回后执行的后台任务
|
||||||
|
:return: 成功响应结果
|
||||||
|
"""
|
||||||
|
result = {'code': HttpStatusConstant.SUCCESS, 'msg': msg}
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
result['data'] = data
|
||||||
|
if rows is not None:
|
||||||
|
result['rows'] = rows
|
||||||
|
if dict_content is not None:
|
||||||
|
result.update(dict_content)
|
||||||
|
if model_content is not None:
|
||||||
|
result.update(model_content.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
result.update({'success': True, 'time': datetime.now()})
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder(result),
|
||||||
|
headers=headers,
|
||||||
|
media_type=media_type,
|
||||||
|
background=background,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failure(
|
||||||
|
cls,
|
||||||
|
msg: str = '操作失败',
|
||||||
|
data: Optional[Any] = None,
|
||||||
|
rows: Optional[Any] = None,
|
||||||
|
dict_content: Optional[Dict] = None,
|
||||||
|
model_content: Optional[BaseModel] = None,
|
||||||
|
headers: Optional[Mapping[str, str]] = None,
|
||||||
|
media_type: Optional[str] = None,
|
||||||
|
background: Optional[BackgroundTask] = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
失败响应方法
|
||||||
|
|
||||||
|
:param msg: 可选,自定义失败响应信息
|
||||||
|
:param data: 可选,失败响应结果中属性为data的值
|
||||||
|
:param rows: 可选,失败响应结果中属性为rows的值
|
||||||
|
:param dict_content: 可选,dict类型,失败响应结果中自定义属性的值
|
||||||
|
:param model_content: 可选,BaseModel类型,失败响应结果中自定义属性的值
|
||||||
|
:param headers: 可选,响应头信息
|
||||||
|
:param media_type: 可选,响应结果媒体类型
|
||||||
|
:param background: 可选,响应返回后执行的后台任务
|
||||||
|
:return: 失败响应结果
|
||||||
|
"""
|
||||||
|
result = {'code': HttpStatusConstant.WARN, 'msg': msg}
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
result['data'] = data
|
||||||
|
if rows is not None:
|
||||||
|
result['rows'] = rows
|
||||||
|
if dict_content is not None:
|
||||||
|
result.update(dict_content)
|
||||||
|
if model_content is not None:
|
||||||
|
result.update(model_content.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
result.update({'success': False, 'time': datetime.now()})
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder(result),
|
||||||
|
headers=headers,
|
||||||
|
media_type=media_type,
|
||||||
|
background=background,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unauthorized(
|
||||||
|
cls,
|
||||||
|
msg: str = '登录信息已过期,访问系统资源失败',
|
||||||
|
data: Optional[Any] = None,
|
||||||
|
rows: Optional[Any] = None,
|
||||||
|
dict_content: Optional[Dict] = None,
|
||||||
|
model_content: Optional[BaseModel] = None,
|
||||||
|
headers: Optional[Mapping[str, str]] = None,
|
||||||
|
media_type: Optional[str] = None,
|
||||||
|
background: Optional[BackgroundTask] = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
未认证响应方法
|
||||||
|
|
||||||
|
:param msg: 可选,自定义未认证响应信息
|
||||||
|
:param data: 可选,未认证响应结果中属性为data的值
|
||||||
|
:param rows: 可选,未认证响应结果中属性为rows的值
|
||||||
|
:param dict_content: 可选,dict类型,未认证响应结果中自定义属性的值
|
||||||
|
:param model_content: 可选,BaseModel类型,未认证响应结果中自定义属性的值
|
||||||
|
:param headers: 可选,响应头信息
|
||||||
|
:param media_type: 可选,响应结果媒体类型
|
||||||
|
:param background: 可选,响应返回后执行的后台任务
|
||||||
|
:return: 未认证响应结果
|
||||||
|
"""
|
||||||
|
result = {'code': HttpStatusConstant.UNAUTHORIZED, 'msg': msg}
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
result['data'] = data
|
||||||
|
if rows is not None:
|
||||||
|
result['rows'] = rows
|
||||||
|
if dict_content is not None:
|
||||||
|
result.update(dict_content)
|
||||||
|
if model_content is not None:
|
||||||
|
result.update(model_content.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
result.update({'success': False, 'time': datetime.now()})
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder(result),
|
||||||
|
headers=headers,
|
||||||
|
media_type=media_type,
|
||||||
|
background=background,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def forbidden(
|
||||||
|
cls,
|
||||||
|
msg: str = '该用户无此接口权限',
|
||||||
|
data: Optional[Any] = None,
|
||||||
|
rows: Optional[Any] = None,
|
||||||
|
dict_content: Optional[Dict] = None,
|
||||||
|
model_content: Optional[BaseModel] = None,
|
||||||
|
headers: Optional[Mapping[str, str]] = None,
|
||||||
|
media_type: Optional[str] = None,
|
||||||
|
background: Optional[BackgroundTask] = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
未授权响应方法
|
||||||
|
|
||||||
|
:param msg: 可选,自定义未授权响应信息
|
||||||
|
:param data: 可选,未授权响应结果中属性为data的值
|
||||||
|
:param rows: 可选,未授权响应结果中属性为rows的值
|
||||||
|
:param dict_content: 可选,dict类型,未授权响应结果中自定义属性的值
|
||||||
|
:param model_content: 可选,BaseModel类型,未授权响应结果中自定义属性的值
|
||||||
|
:param headers: 可选,响应头信息
|
||||||
|
:param media_type: 可选,响应结果媒体类型
|
||||||
|
:param background: 可选,响应返回后执行的后台任务
|
||||||
|
:return: 未授权响应结果
|
||||||
|
"""
|
||||||
|
result = {'code': HttpStatusConstant.FORBIDDEN, 'msg': msg}
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
result['data'] = data
|
||||||
|
if rows is not None:
|
||||||
|
result['rows'] = rows
|
||||||
|
if dict_content is not None:
|
||||||
|
result.update(dict_content)
|
||||||
|
if model_content is not None:
|
||||||
|
result.update(model_content.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
result.update({'success': False, 'time': datetime.now()})
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder(result),
|
||||||
|
headers=headers,
|
||||||
|
media_type=media_type,
|
||||||
|
background=background,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error(
|
||||||
|
cls,
|
||||||
|
msg: str = '接口异常',
|
||||||
|
data: Optional[Any] = None,
|
||||||
|
rows: Optional[Any] = None,
|
||||||
|
dict_content: Optional[Dict] = None,
|
||||||
|
model_content: Optional[BaseModel] = None,
|
||||||
|
headers: Optional[Mapping[str, str]] = None,
|
||||||
|
media_type: Optional[str] = None,
|
||||||
|
background: Optional[BackgroundTask] = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
错误响应方法
|
||||||
|
|
||||||
|
:param msg: 可选,自定义错误响应信息
|
||||||
|
:param data: 可选,错误响应结果中属性为data的值
|
||||||
|
:param rows: 可选,错误响应结果中属性为rows的值
|
||||||
|
:param dict_content: 可选,dict类型,错误响应结果中自定义属性的值
|
||||||
|
:param model_content: 可选,BaseModel类型,错误响应结果中自定义属性的值
|
||||||
|
:param headers: 可选,响应头信息
|
||||||
|
:param media_type: 可选,响应结果媒体类型
|
||||||
|
:param background: 可选,响应返回后执行的后台任务
|
||||||
|
:return: 错误响应结果
|
||||||
|
"""
|
||||||
|
result = {'code': HttpStatusConstant.ERROR, 'msg': msg}
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
result['data'] = data
|
||||||
|
if rows is not None:
|
||||||
|
result['rows'] = rows
|
||||||
|
if dict_content is not None:
|
||||||
|
result.update(dict_content)
|
||||||
|
if model_content is not None:
|
||||||
|
result.update(model_content.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
result.update({'success': False, 'time': datetime.now()})
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder(result),
|
||||||
|
headers=headers,
|
||||||
|
media_type=media_type,
|
||||||
|
background=background,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def streaming(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
data: Any = None,
|
||||||
|
headers: Optional[Mapping[str, str]] = None,
|
||||||
|
media_type: Optional[str] = None,
|
||||||
|
background: Optional[BackgroundTask] = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
流式响应方法
|
||||||
|
|
||||||
|
:param data: 流式传输的内容
|
||||||
|
:param headers: 可选,响应头信息
|
||||||
|
:param media_type: 可选,响应结果媒体类型
|
||||||
|
:param background: 可选,响应返回后执行的后台任务
|
||||||
|
:return: 流式响应结果
|
||||||
|
"""
|
||||||
|
return StreamingResponse(
|
||||||
|
status_code=status.HTTP_200_OK, content=data, headers=headers, media_type=media_type, background=background
|
||||||
|
)
|
||||||
470
app/util/yolov8Obj.py
Normal file
470
app/util/yolov8Obj.py
Normal file
@ -0,0 +1,470 @@
|
|||||||
|
from ultralytics import YOLO
|
||||||
|
# from rknn.api import RKNN
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import time
|
||||||
|
|
||||||
|
from app.config.config import yolov8_settings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Yolov8Obj:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = YOLO(yolov8_settings.YOLOV8_MODEL_DIR)
|
||||||
|
|
||||||
|
def detect(self, image_path):
|
||||||
|
result = self.model.predict(image_path)
|
||||||
|
boxes = result[0].boxes
|
||||||
|
|
||||||
|
cls = boxes.cls.tolist()
|
||||||
|
conf = boxes.conf.tolist()
|
||||||
|
coords = boxes.xyxy.tolist()
|
||||||
|
|
||||||
|
return cls, conf, coords
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOv8ONNX:
|
||||||
|
def __init__(self, model_path=yolov8_settings.YOLOV8_MODEL_ONNX_DIRS, conf_threshold=0.5, iou_threshold=0.4):
|
||||||
|
"""
|
||||||
|
初始化YOLOv8 ONNX模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: ONNX模型文件路径
|
||||||
|
conf_threshold: 置信度阈值
|
||||||
|
iou_threshold: NMS IoU阈值
|
||||||
|
"""
|
||||||
|
self.conf_threshold = conf_threshold
|
||||||
|
self.iou_threshold = iou_threshold
|
||||||
|
|
||||||
|
# 创建ONNX Runtime会话
|
||||||
|
self.session = ort.InferenceSession(model_path)
|
||||||
|
|
||||||
|
# 获取模型输入输出信息
|
||||||
|
self.input_name = self.session.get_inputs()[0].name
|
||||||
|
self.output_name = self.session.get_outputs()[0].name
|
||||||
|
|
||||||
|
# 获取输入尺寸
|
||||||
|
input_shape = self.session.get_inputs()[0].shape
|
||||||
|
self.input_height = input_shape[2]
|
||||||
|
self.input_width = input_shape[3]
|
||||||
|
|
||||||
|
def preprocess(self, image):
|
||||||
|
"""
|
||||||
|
预处理图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 输入图像 (BGR格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
preprocessed_image: 预处理后的图像
|
||||||
|
scale_ratio: 缩放比例
|
||||||
|
pad_info: 填充信息 (pad_x, pad_y)
|
||||||
|
"""
|
||||||
|
# 获取原图尺寸
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
# 计算缩放比例
|
||||||
|
scale = min(self.input_height / h, self.input_width / w)
|
||||||
|
new_h, new_w = int(h * scale), int(w * scale)
|
||||||
|
|
||||||
|
# 等比例缩放
|
||||||
|
resized_image = cv2.resize(image, (new_w, new_h))
|
||||||
|
|
||||||
|
# 计算填充
|
||||||
|
pad_x = (self.input_width - new_w) // 2
|
||||||
|
pad_y = (self.input_height - new_h) // 2
|
||||||
|
|
||||||
|
# 创建填充后的图像
|
||||||
|
padded_image = np.full((self.input_height, self.input_width, 3), 114, dtype=np.uint8)
|
||||||
|
padded_image[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized_image
|
||||||
|
|
||||||
|
# 转换为模型输入格式: BGR -> RGB, HWC -> CHW, 归一化
|
||||||
|
input_image = padded_image[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||||
|
input_image = np.expand_dims(input_image, axis=0) # 添加batch维度
|
||||||
|
|
||||||
|
return input_image, scale, (pad_x, pad_y)
|
||||||
|
|
||||||
|
def postprocess(self, outputs, scale, pad_info, original_shape):
|
||||||
|
"""
|
||||||
|
后处理模型输出 - 针对YOLOv8格式优化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: 模型原始输出
|
||||||
|
scale: 图像缩放比例
|
||||||
|
pad_info: 填充信息
|
||||||
|
original_shape: 原图尺寸
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
boxes: 检测框 [[x1, y1, x2, y2], ...]
|
||||||
|
scores: 置信度分数
|
||||||
|
class_ids: 类别ID
|
||||||
|
"""
|
||||||
|
predictions = outputs[0] # 形状通常是: [1, 6, 8400] 或 [1, num_classes+4, num_boxes]
|
||||||
|
|
||||||
|
# YOLOv8输出格式: [batch, 4+num_classes, num_boxes]
|
||||||
|
# 需要转置为 [batch, num_boxes, 4+num_classes]
|
||||||
|
if len(predictions.shape) == 3:
|
||||||
|
predictions = predictions.transpose(0, 2, 1) # [1, num_boxes, 4+num_classes]
|
||||||
|
|
||||||
|
predictions = predictions[0] # 移除batch维度: [num_boxes, 4+num_classes]
|
||||||
|
|
||||||
|
# 打印调试信息
|
||||||
|
# print(f"预测输出形状: {predictions.shape}")
|
||||||
|
# print(f"前几个预测值: {predictions[:5]}")
|
||||||
|
|
||||||
|
# 分离坐标和分类信息
|
||||||
|
boxes = predictions[:, :4] # [x_center, y_center, width, height]
|
||||||
|
scores = predictions[:, 4:] # 类别置信度 [num_boxes, num_classes]
|
||||||
|
|
||||||
|
# print(f"检测框形状: {boxes.shape}")
|
||||||
|
# print(f"分数形状: {scores.shape}")
|
||||||
|
|
||||||
|
# 获取最高置信度和对应类别
|
||||||
|
class_ids = np.argmax(scores, axis=1)
|
||||||
|
confidences = np.max(scores, axis=1)
|
||||||
|
|
||||||
|
# print(f"置信度范围: {confidences.min():.4f} - {confidences.max():.4f}")
|
||||||
|
# print(f"检测到的类别: {np.unique(class_ids)}")
|
||||||
|
|
||||||
|
# 过滤低置信度检测
|
||||||
|
valid_indices = confidences > self.conf_threshold
|
||||||
|
valid_boxes = boxes[valid_indices]
|
||||||
|
valid_confidences = confidences[valid_indices]
|
||||||
|
valid_class_ids = class_ids[valid_indices]
|
||||||
|
|
||||||
|
# print(f"过滤后检测数量: {len(valid_boxes)}")
|
||||||
|
|
||||||
|
if len(valid_boxes) == 0:
|
||||||
|
return [], [], []
|
||||||
|
|
||||||
|
# 转换为 [x1, y1, x2, y2] 格式
|
||||||
|
x_center, y_center, width, height = valid_boxes[:, 0], valid_boxes[:, 1], valid_boxes[:, 2], valid_boxes[:, 3]
|
||||||
|
x1 = x_center - width / 2
|
||||||
|
y1 = y_center - height / 2
|
||||||
|
x2 = x_center + width / 2
|
||||||
|
y2 = y_center + height / 2
|
||||||
|
|
||||||
|
converted_boxes = np.stack([x1, y1, x2, y2], axis=1)
|
||||||
|
|
||||||
|
# 坐标反变换到原图
|
||||||
|
pad_x, pad_y = pad_info
|
||||||
|
converted_boxes[:, [0, 2]] = (converted_boxes[:, [0, 2]] - pad_x) / scale
|
||||||
|
converted_boxes[:, [1, 3]] = (converted_boxes[:, [1, 3]] - pad_y) / scale
|
||||||
|
|
||||||
|
# 限制坐标范围
|
||||||
|
h, w = original_shape[:2]
|
||||||
|
converted_boxes[:, [0, 2]] = np.clip(converted_boxes[:, [0, 2]], 0, w)
|
||||||
|
converted_boxes[:, [1, 3]] = np.clip(converted_boxes[:, [1, 3]], 0, h)
|
||||||
|
|
||||||
|
# 非极大值抑制 (NMS)
|
||||||
|
indices = cv2.dnn.NMSBoxes(
|
||||||
|
converted_boxes.tolist(),
|
||||||
|
valid_confidences.tolist(),
|
||||||
|
self.conf_threshold,
|
||||||
|
self.iou_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(indices) > 0:
|
||||||
|
indices = indices.flatten()
|
||||||
|
return converted_boxes[indices], valid_confidences[indices], valid_class_ids[indices]
|
||||||
|
|
||||||
|
return [], [], []
|
||||||
|
|
||||||
|
def detect(self, image):
|
||||||
|
"""
|
||||||
|
对图像进行目标检测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 输入图像 (BGR格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
boxes: 检测框列表
|
||||||
|
scores: 置信度分数列表
|
||||||
|
class_ids: 类别ID列表
|
||||||
|
"""
|
||||||
|
# 预处理
|
||||||
|
input_image, scale, pad_info = self.preprocess(image)
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
outputs = self.session.run([self.output_name], {self.input_name: input_image})
|
||||||
|
|
||||||
|
# 后处理
|
||||||
|
boxes, scores, class_ids = self.postprocess(outputs, scale, pad_info, image.shape)
|
||||||
|
|
||||||
|
return boxes, scores, class_ids
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOv8RKNN:
|
||||||
|
def __init__(self, model_path, input_size=(640, 640)):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.input_size = input_size
|
||||||
|
self.rknn = RKNN()
|
||||||
|
|
||||||
|
# 类别名称,根据你的2个类别修改
|
||||||
|
self.class_names = ['class1', 'class2'] # 请替换为你实际的类别名称
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""加载RKNN模型"""
|
||||||
|
print("Loading RKNN model...")
|
||||||
|
ret = self.rknn.load_rknn(self.model_path)
|
||||||
|
if ret != 0:
|
||||||
|
print("Load RKNN model failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 初始化运行时环境(在RK3588设备上运行)
|
||||||
|
print("Init RKNN runtime...")
|
||||||
|
ret = self.rknn.init_runtime(target='rk3588', device_id=None, perf_debug=False, eval_mem=False)
|
||||||
|
if ret != 0:
|
||||||
|
print("Init RKNN runtime failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("RKNN model loaded successfully!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def preprocess(self, image):
|
||||||
|
"""图像预处理"""
|
||||||
|
# 获取原始图像尺寸
|
||||||
|
self.orig_height, self.orig_width = image.shape[:2]
|
||||||
|
|
||||||
|
# Resize到模型输入尺寸,保持宽高比
|
||||||
|
scale = min(self.input_size[0]/self.orig_width, self.input_size[1]/self.orig_height)
|
||||||
|
new_width = int(self.orig_width * scale)
|
||||||
|
new_height = int(self.orig_height * scale)
|
||||||
|
|
||||||
|
# 缩放图像
|
||||||
|
resized = cv2.resize(image, (new_width, new_height))
|
||||||
|
|
||||||
|
# 创建输入图像(填充到目标尺寸)
|
||||||
|
input_image = np.full((self.input_size[1], self.input_size[0], 3), 114, dtype=np.uint8)
|
||||||
|
|
||||||
|
# 计算填充位置(居中)
|
||||||
|
y_offset = (self.input_size[1] - new_height) // 2
|
||||||
|
x_offset = (self.input_size[0] - new_width) // 2
|
||||||
|
|
||||||
|
# 将缩放后的图像放到中心位置
|
||||||
|
input_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width] = resized
|
||||||
|
|
||||||
|
# 保存缩放参数用于后处理
|
||||||
|
self.scale = scale
|
||||||
|
self.x_offset = x_offset
|
||||||
|
self.y_offset = y_offset
|
||||||
|
|
||||||
|
return input_image
|
||||||
|
|
||||||
|
def postprocess(self, outputs, conf_threshold=0.5, nms_threshold=0.4):
|
||||||
|
"""后处理:解析YOLO输出并进行NMS"""
|
||||||
|
# YOLOv8输出格式: [batch, 84, 8400] (2个类别: 4+2+80=84,但实际只有6维)
|
||||||
|
# 对于2类别: [x, y, w, h, conf_class1, conf_class2]
|
||||||
|
predictions = outputs[0][0] # 移除batch维度
|
||||||
|
|
||||||
|
# 转置为 [8400, 6] 格式
|
||||||
|
predictions = predictions.transpose()
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
scores = []
|
||||||
|
class_ids = []
|
||||||
|
|
||||||
|
for detection in predictions:
|
||||||
|
# 提取坐标和类别置信度
|
||||||
|
x, y, w, h = detection[:4]
|
||||||
|
class_confs = detection[4:6] # 2个类别的置信度
|
||||||
|
|
||||||
|
# 找到最大置信度的类别
|
||||||
|
class_id = np.argmax(class_confs)
|
||||||
|
max_conf = class_confs[class_id]
|
||||||
|
|
||||||
|
if max_conf >= conf_threshold:
|
||||||
|
# 转换坐标格式 (中心点 -> 左上角)
|
||||||
|
x1 = x - w/2
|
||||||
|
y1 = y - h/2
|
||||||
|
x2 = x + w/2
|
||||||
|
y2 = y + h/2
|
||||||
|
|
||||||
|
# 将坐标映射回原图尺寸
|
||||||
|
x1 = (x1 - self.x_offset) / self.scale
|
||||||
|
y1 = (y1 - self.y_offset) / self.scale
|
||||||
|
x2 = (x2 - self.x_offset) / self.scale
|
||||||
|
y2 = (y2 - self.y_offset) / self.scale
|
||||||
|
|
||||||
|
# 限制在图像边界内
|
||||||
|
x1 = max(0, min(x1, self.orig_width))
|
||||||
|
y1 = max(0, min(y1, self.orig_height))
|
||||||
|
x2 = max(0, min(x2, self.orig_width))
|
||||||
|
y2 = max(0, min(y2, self.orig_height))
|
||||||
|
|
||||||
|
boxes.append([x1, y1, x2, y2])
|
||||||
|
scores.append(max_conf)
|
||||||
|
class_ids.append(class_id)
|
||||||
|
|
||||||
|
# 执行NMS
|
||||||
|
if len(boxes) > 0:
|
||||||
|
boxes = np.array(boxes)
|
||||||
|
scores = np.array(scores)
|
||||||
|
class_ids = np.array(class_ids)
|
||||||
|
|
||||||
|
# OpenCV NMS
|
||||||
|
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, nms_threshold)
|
||||||
|
|
||||||
|
if len(indices) > 0:
|
||||||
|
indices = indices.flatten()
|
||||||
|
return boxes[indices], scores[indices], class_ids[indices]
|
||||||
|
|
||||||
|
return np.array([]), np.array([]), np.array([])
|
||||||
|
|
||||||
|
def detect(self, image, conf_threshold=0.5, nms_threshold=0.4):
|
||||||
|
"""执行检测"""
|
||||||
|
# 预处理
|
||||||
|
input_image = self.preprocess(image)
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = self.rknn.inference(inputs=[input_image])
|
||||||
|
inference_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 后处理
|
||||||
|
boxes, scores, class_ids = self.postprocess(outputs, conf_threshold, nms_threshold)
|
||||||
|
|
||||||
|
return boxes, scores, class_ids, inference_time
|
||||||
|
|
||||||
|
def draw_detections(self, image, boxes, scores, class_ids):
|
||||||
|
"""在图像上绘制检测结果"""
|
||||||
|
for i in range(len(boxes)):
|
||||||
|
x1, y1, x2, y2 = boxes[i].astype(int)
|
||||||
|
score = scores[i]
|
||||||
|
class_id = int(class_ids[i])
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# 绘制标签
|
||||||
|
label = f"{self.class_names[class_id]}: {score:.2f}"
|
||||||
|
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||||
|
cv2.rectangle(image, (x1, y1-label_size[1]-10),
|
||||||
|
(x1+label_size[0], y1), (0, 255, 0), -1)
|
||||||
|
cv2.putText(image, label, (x1, y1-5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
"""释放资源"""
|
||||||
|
if self.rknn:
|
||||||
|
self.rknn.release()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 初始化检测器
|
||||||
|
model_path = "/home/orangepi/Desktop/康达机器狗/model_rknn/yolov8_20250820.rknn"
|
||||||
|
detector = YOLOv8RKNN(model_path)
|
||||||
|
|
||||||
|
# 测试单张图片
|
||||||
|
def test_image(image_path):
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
print(f"Cannot load image: {image_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 执行检测
|
||||||
|
boxes, scores, class_ids, inference_time = detector.detect(image)
|
||||||
|
|
||||||
|
print(f"Inference time: {inference_time*1000:.2f}ms")
|
||||||
|
print(f"Detected {len(boxes)} objects")
|
||||||
|
|
||||||
|
# 绘制结果
|
||||||
|
result_image = detector.draw_detections(image, boxes, scores, class_ids)
|
||||||
|
|
||||||
|
# 显示结果
|
||||||
|
# cv2.imshow("Detection Result", result_image)
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
cv2.imwrite("xxxxxxx.jpg", result_image)
|
||||||
|
|
||||||
|
# 测试摄像头实时检测
|
||||||
|
def test_camera():
|
||||||
|
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||||
|
if not cap.isOpened():
|
||||||
|
print("Cannot open camera")
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 执行检测
|
||||||
|
boxes, scores, class_ids, inference_time = detector.detect(frame)
|
||||||
|
|
||||||
|
# 绘制结果
|
||||||
|
result_frame = detector.draw_detections(frame, boxes, scores, class_ids)
|
||||||
|
|
||||||
|
# 显示FPS
|
||||||
|
fps = 1.0 / inference_time if inference_time > 0 else 0
|
||||||
|
cv2.putText(result_frame, f"FPS: {fps:.1f}", (10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||||
|
|
||||||
|
cv2.imshow("Real-time Detection", result_frame)
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# 选择测试模式
|
||||||
|
mode = input("选择模式 (1: 图片检测, 2: 摄像头实时检测): ")
|
||||||
|
|
||||||
|
if mode == "1":
|
||||||
|
image_path = input("输入图片路径: ")
|
||||||
|
test_image(image_path)
|
||||||
|
elif mode == "2":
|
||||||
|
test_camera()
|
||||||
|
else:
|
||||||
|
print("无效选择")
|
||||||
|
|
||||||
|
# 释放资源
|
||||||
|
detector.release()
|
||||||
|
|
||||||
|
def draw_detections(image, boxes, scores, class_ids, class_names=None):
|
||||||
|
"""
|
||||||
|
在图像上绘制检测结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 输入图像
|
||||||
|
boxes: 检测框
|
||||||
|
scores: 置信度分数
|
||||||
|
class_ids: 类别ID
|
||||||
|
class_names: 类别名称列表(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
绘制了检测结果的图像
|
||||||
|
"""
|
||||||
|
result_image = image.copy()
|
||||||
|
|
||||||
|
for i, (box, score, class_id) in enumerate(zip(boxes, scores, class_ids)):
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
cv2.rectangle(result_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# 准备标签文本
|
||||||
|
if class_names and class_id < len(class_names):
|
||||||
|
label = f"{class_names[class_id]}: {score:.2f}"
|
||||||
|
else:
|
||||||
|
label = f"Class {class_id}: {score:.2f}"
|
||||||
|
|
||||||
|
# 绘制标签背景
|
||||||
|
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||||
|
cv2.rectangle(result_image, (x1, y1 - label_size[1] - 10),
|
||||||
|
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||||
|
|
||||||
|
# 绘制标签文本
|
||||||
|
cv2.putText(result_image, label, (x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||||
|
|
||||||
|
return result_image
|
||||||
1
base64.txt
Normal file
1
base64.txt
Normal file
File diff suppressed because one or more lines are too long
14
robot_dog.service
Normal file
14
robot_dog.service
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=Robot Dog Service
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
User=root
|
||||||
|
Group=root
|
||||||
|
WorkingDirectory=/root/robot_dog_project/kangda_robotic_dog/机器狗后台服务
|
||||||
|
ExecStart=/root/robot_dog_project/kangda_robotic_dog/机器狗后台服务/start.sh
|
||||||
|
Restart=always
|
||||||
|
RestartSec=5s
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
14
run.py
Normal file
14
run.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import uvicorn
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# uvicorn.run(
|
||||||
|
# "app.main:app",
|
||||||
|
# host="10.0.0.202",
|
||||||
|
# port=12342
|
||||||
|
# )
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"app.api.main:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=12345
|
||||||
|
)
|
||||||
32
start.sh
Normal file
32
start.sh
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 激活虚拟环境
|
||||||
|
source /root/robot_dog_project/myvenv/bin/activate
|
||||||
|
|
||||||
|
# 进入项目目录
|
||||||
|
cd /root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 定义日志目录和带时间戳的日志文件名
|
||||||
|
LOG_DIR="/root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1/logs"
|
||||||
|
TIMESTAMP=$(date +'%Y%m%d_%H%M%S')
|
||||||
|
LOG_FILE="${LOG_DIR}/app_${TIMESTAMP}.log"
|
||||||
|
|
||||||
|
# 创建日志目录(如果不存在)
|
||||||
|
mkdir -p "$LOG_DIR"
|
||||||
|
|
||||||
|
# 激活Python虚拟环境
|
||||||
|
source /root/robot_dog_project/myvenv/bin/activate
|
||||||
|
|
||||||
|
# 切换到项目目录
|
||||||
|
cd /root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1
|
||||||
|
|
||||||
|
# 打印启动信息
|
||||||
|
echo "启动应用程序,日志将保存到: $LOG_FILE"
|
||||||
|
echo "启动时间: $(date '+%Y-%m-%d %H:%M:%S')" | tee -a "$LOG_FILE"
|
||||||
|
|
||||||
|
# 运行Python程序并将输出同时显示在控制台和保存到日志文件
|
||||||
|
exec python run.py --env=dev 2>&1 | tee -a "$LOG_FILE"
|
||||||
110
test_base64_server.py
Normal file
110
test_base64_server.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
def image_to_base64(image_path):
|
||||||
|
"""
|
||||||
|
将图片文件转换为base64编码
|
||||||
|
"""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
return encoded_string
|
||||||
|
|
||||||
|
def test_ocr_api(image_path, api_url="http://10.0.0.202:12342/api/v1/ocr_from_base64"):
|
||||||
|
"""
|
||||||
|
测试OCR API接口
|
||||||
|
"""
|
||||||
|
# 将图片转换为base64
|
||||||
|
image_base64 = image_to_base64(image_path)
|
||||||
|
|
||||||
|
with open("base64.txt", "w") as f:
|
||||||
|
f.write(repr(image_base64))
|
||||||
|
|
||||||
|
# 准备请求数据
|
||||||
|
payload = {
|
||||||
|
"image_base64": image_base64,
|
||||||
|
"image_type": "jpg" # 根据实际图片类型修改
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送POST请求
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, data=json.dumps(payload), headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
print("OCR识别结果:")
|
||||||
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
print(f"请求失败,状态码: {response.status_code}")
|
||||||
|
print(response.text)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求异常: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def test_detect(image_path: str, api_url:str = "http://10.0.0.202:12342/api/v1/detect_from_base64_0"):
|
||||||
|
"""
|
||||||
|
测试yolov8消防区域侵占接口
|
||||||
|
"""
|
||||||
|
image_base64 = image_to_base64(image_path)
|
||||||
|
# 准备请求数据
|
||||||
|
payload = {
|
||||||
|
"image_base64": image_base64,
|
||||||
|
"image_type": "jpg" # 根据实际图片类型修改
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送POST请求
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, data=json.dumps(payload), headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
print("detect识别结果:")
|
||||||
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
print(f"请求失败,状态码: {response.status_code}")
|
||||||
|
print(response.text)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求异常: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试图片路径,请根据实际情况修改
|
||||||
|
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/2c7cc83019e7388a7041101da92c9829_frame_000000.jpg"
|
||||||
|
|
||||||
|
#---------------------------------------测试ocr-----------------------------------------
|
||||||
|
test_image_path = "images/AiCheck_20251016114138.jpg"
|
||||||
|
|
||||||
|
# api_url="http://10.0.0.202:12342/api/v1/ocr_onnx_from_base64"
|
||||||
|
# api_url="http://192.168.30.195:12345/api/v1/ocr_onnx_from_base64"
|
||||||
|
# # 调用测试函数
|
||||||
|
# test_ocr_api(test_image_path, api_url)
|
||||||
|
#---------------------------------------测试ocrender-----------------------------------------
|
||||||
|
|
||||||
|
# # -----------------------------------------测试yolov8 侵占消防区域检测-----------------------------------------
|
||||||
|
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000000.jpg"
|
||||||
|
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000720.jpg"
|
||||||
|
# api_url = "http://10.0.0.202:12342/api/v1/detect_onnx_from_base64_0"
|
||||||
|
# api_url = "http://192.168.30.195:12345/api/v1/detect_onnx_from_base64_0"
|
||||||
|
# test_detect(test_image_path, api_url)
|
||||||
|
# #-----------------------------------------测试yolov8 侵占消防区域检测 end-----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# #-----------------------------------------测试yolov8 灭火器检测-----------------------------------------
|
||||||
|
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/ce81420a27cdaff14fe42f967eaa49a3_frame_001060.jpg"
|
||||||
|
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||||
|
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||||
|
# api_url = "http://10.0.0.202:12342/api/v1/detect_from_base64_1"
|
||||||
|
api_url = "http://192.168.30.195:12345/api/v1/detect_from_base64_1"
|
||||||
|
test_detect(test_image_path, api_url=api_url)
|
||||||
|
# #-----------------------------------------测试yolov8 灭火器检测 end-----------------------------------------
|
||||||
27
test_generic.py
Normal file
27
test_generic.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from typing import Generic, TypeVar, Dict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
K = TypeVar("K") # 键类型
|
||||||
|
V = TypeVar("V") # 值类型
|
||||||
|
|
||||||
|
class GenericCache(Generic[K, V]):
|
||||||
|
def __init__(self):
|
||||||
|
self._store: Dict[K, V] = {}
|
||||||
|
|
||||||
|
def set(self, key: K, value: V) -> None:
|
||||||
|
self._store[key] = value
|
||||||
|
|
||||||
|
def get(self, key: K) -> Optional[V]:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
cache = GenericCache[str, int]() # 键为 str,值为 int
|
||||||
|
cache.set("count", 100)
|
||||||
|
value: Optional[int] = cache.get("count") # 返回 100
|
||||||
|
|
||||||
|
cache.set(112, "123")
|
||||||
|
|
||||||
|
|
||||||
|
for key in cache._store.keys():
|
||||||
|
|
||||||
|
print(type(cache.get(key)))
|
||||||
7
test_ocr.py
Normal file
7
test_ocr.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from app.util.baiduOCR import BaiduOCR
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ocr = BaiduOCR()
|
||||||
|
result = ocr.ocr('tmp/ocr_images/c723d5d8-697b-41e5-8cf4-d4644ce89a77.jpg')
|
||||||
|
print(result)
|
||||||
87
test_requests.py
Normal file
87
test_requests.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import requests
|
||||||
|
from requests import Response
|
||||||
|
from typing import List, TypeVar, Dict, Any, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar('T') # 泛型类型声明
|
||||||
|
|
||||||
|
REQUEST_URL = "http://aijinan.biaofun.com.cn/app/query/shandong" # 替换为实际API地址
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cities = ["济南市", "青岛市", "临沂市", "菏泽市", "威海市", "东营市"]
|
||||||
|
|
||||||
|
# 修正:每次请求创建新的参数对象,避免累积
|
||||||
|
for city in cities:
|
||||||
|
params = {"city": city}
|
||||||
|
|
||||||
|
# 发起RPC请求
|
||||||
|
response_list = rpc_get(REQUEST_URL, params, False)
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
if not response_list:
|
||||||
|
logger.warning(f"Empty response for city: {city}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 设置城市名称并处理后续逻辑 (根据实际Response类调整)
|
||||||
|
response_list[0]['name'] = city # 这里假定response_list是字典列表
|
||||||
|
after(response_list)
|
||||||
|
|
||||||
|
def rpc_get(url: str, params: Dict[str, Any], right_or_wrong: bool) -> Optional[List[T]]:
|
||||||
|
"""
|
||||||
|
RPC请求封装
|
||||||
|
:param url: 请求URL
|
||||||
|
:param params: 请求参数
|
||||||
|
:param right_or_wrong: 未使用的标志(保留原Java参数)
|
||||||
|
:return: 响应对象列表
|
||||||
|
"""
|
||||||
|
logger.info(f"Request URL: {url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发起POST请求
|
||||||
|
response: Response = requests.post(url, data=params)
|
||||||
|
|
||||||
|
# 检查响应状态
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Request failed, Status Code: {response.status_code}, URL: {url}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
response_text = response.text
|
||||||
|
logger.debug(f"Response: {response_text}")
|
||||||
|
|
||||||
|
# 解析响应(需要根据实际API响应格式实现)
|
||||||
|
return parse_response(response_text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"RPC request failed: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_response(response_text: str) -> List[T]:
|
||||||
|
"""
|
||||||
|
解析API响应 (需要根据实际API响应格式实现)
|
||||||
|
示例实现 - 替换为实际解析逻辑
|
||||||
|
"""
|
||||||
|
# 这里需要根据实际API返回格式实现解析
|
||||||
|
# 示例:解析为JSON对象
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
data = json.loads(response_text)
|
||||||
|
# 根据实际数据结构返回列表
|
||||||
|
return [data] if isinstance(data, dict) else data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error("Invalid JSON response")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def after(response_list: List[Any]):
|
||||||
|
"""
|
||||||
|
后续处理逻辑 (根据实际需求实现)
|
||||||
|
"""
|
||||||
|
# 示例实现
|
||||||
|
print(f"Processing response for: {response_list[0].get('name')}")
|
||||||
|
# 实际业务逻辑...
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user