Add robot dog OCR service and ignore local artifacts
This commit is contained in:
commit
19b2aa43d6
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
.Python
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
.env
|
||||
.env.*
|
||||
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.coverage
|
||||
.coverage.*
|
||||
htmlcov/
|
||||
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
.eggs/
|
||||
|
||||
logs/
|
||||
tmp/
|
||||
|
||||
*.log
|
||||
*.pid
|
||||
|
||||
32
app/api/main.py
Normal file
32
app/api/main.py
Normal file
@ -0,0 +1,32 @@
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api.v1.api import router
|
||||
# from app.core.config import settings
|
||||
# from app.services.websocket_service import websocket_service
|
||||
|
||||
app = FastAPI(
|
||||
title="机器狗后台服务",
|
||||
description="机器狗后台API接口",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 在生产环境中应该设置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(router)
|
||||
|
||||
# @app.websocket("/ws")
|
||||
# async def websocket_endpoint(websocket: WebSocket):
|
||||
# """WebSocket端点"""
|
||||
# await websocket_service.handle_websocket(websocket)
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "机器狗后台API服务正在运行"}
|
||||
394
app/api/v1/api.py
Normal file
394
app/api/v1/api.py
Normal file
@ -0,0 +1,394 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import base64
|
||||
import io
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.imageServices import ImageService
|
||||
from app.schemas.image import ImageBase
|
||||
from app.schemas.ocr import ImageBase64Request
|
||||
from app.util.responseHttp import ResponseUtil
|
||||
from app.util.baiduOCR import BaiduOCR, BaiduOCRONNX
|
||||
from app.util.yolov8Obj import Yolov8Obj, YOLOv8ONNX
|
||||
|
||||
# from app.crud.event import event
|
||||
# from app.schemas.event import EventList, EventDetail, EventUpdate, EventQuery, TestEvent
|
||||
|
||||
baiduOCR = BaiduOCR()
|
||||
baiduOcrOnnx = BaiduOCRONNX()
|
||||
yolov8Obj = Yolov8Obj()
|
||||
yolov8ONNX = YOLOv8ONNX()
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["ocr"])
|
||||
|
||||
@router.get("/hello")
|
||||
async def get_hello():
|
||||
|
||||
# return {"data":"hello"}
|
||||
return ResponseUtil.error(msg=f"OCR识别失败", data=None)
|
||||
|
||||
@router.get("/test_select")
|
||||
async def test_select(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
image_query: ImageBase = None
|
||||
):
|
||||
"""
|
||||
测试查询
|
||||
"""
|
||||
result = await ImageService.get_image_list(db,image_query)
|
||||
return ResponseUtil.success(msg="success", data=result)
|
||||
|
||||
|
||||
@router.get("/test_ocr")
|
||||
async def test_ocr(
|
||||
# image_path: str = None
|
||||
):
|
||||
"""
|
||||
测试OCR
|
||||
"""
|
||||
|
||||
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/2c7cc83019e7388a7041101da92c9829_frame_000000.jpg'
|
||||
|
||||
|
||||
result = baiduOCR.ocr(image_path)
|
||||
# print(result)
|
||||
return ResponseUtil.success(msg="success", data=result)
|
||||
# return ResponseUtil.success(msg="success")
|
||||
|
||||
|
||||
@router.post("/ocr_from_base64")
|
||||
async def ocr_from_base64(request: ImageBase64Request):
|
||||
"""
|
||||
从base64图片数据进行OCR识别
|
||||
"""
|
||||
try:
|
||||
# 移除base64数据的前缀(如果有)
|
||||
image_base64 = request.image_base64
|
||||
if image_base64.startswith('data:image'):
|
||||
# 格式如: data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQ...
|
||||
image_base64 = image_base64.split(',')[1]
|
||||
|
||||
# 解码base64数据
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
# 将字节数据转换为PIL Image对象
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 创建临时文件路径
|
||||
temp_dir = "tmp/ocr_images"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||
temp_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
# 保存图片到临时文件
|
||||
image.save(temp_path)
|
||||
|
||||
# 使用PaddleOCR进行识别
|
||||
result = baiduOCR.ocr(temp_path)
|
||||
|
||||
# 删除临时文件
|
||||
# os.remove(temp_path)
|
||||
|
||||
return ResponseUtil.success(msg="OCR识别成功", data=result)
|
||||
except Exception as e:
|
||||
return ResponseUtil.error(msg=f"OCR识别失败: {str(e)}", data=None)
|
||||
|
||||
|
||||
@router.post("/ocr_onnx_from_base64")
|
||||
async def ocr_from_base64(request: ImageBase64Request):
|
||||
"""
|
||||
从base64图片数据进行OCR识别
|
||||
"""
|
||||
try:
|
||||
# 移除base64数据的前缀(如果有)
|
||||
image_base64 = request.image_base64
|
||||
if image_base64.startswith('data:image'):
|
||||
# 格式如: data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQ...
|
||||
image_base64 = image_base64.split(',')[1]
|
||||
|
||||
# 解码base64数据
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
# 将字节数据转换为PIL Image对象
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 创建临时文件路径
|
||||
temp_dir = "./tmp/ocr_images"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||
temp_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
# 保存图片到临时文件
|
||||
image.save(temp_path)
|
||||
|
||||
# 使用PaddleOCR进行识别
|
||||
result = baiduOcrOnnx.ocr(temp_path)
|
||||
print(result)
|
||||
|
||||
# 删除临时文件
|
||||
# os.remove(temp_path)
|
||||
|
||||
# return ResponseUtil.success(msg="OCR识别成功", data=[result['text'], result['confidence']])
|
||||
return ResponseUtil.success(msg="OCR识别成功", data=result)
|
||||
except Exception as e:
|
||||
return ResponseUtil.error(msg=f"OCR识别失败: {str(e)}", data=None)
|
||||
|
||||
|
||||
@router.post("/detect_from_base64_0")
|
||||
async def ocr_from_base64(request: ImageBase64Request):
|
||||
""" 从base64图片进行目标检测, 检测是否侵占消防区域"""
|
||||
|
||||
try:
|
||||
image_base64 = request.image_base64
|
||||
if image_base64.startswith('data:image'):
|
||||
image_base64 = image_base64.split(',')[1]
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
temp_dir = "./tmp/detect_images"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||
temp_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
image.save(temp_path)
|
||||
|
||||
cls, conf, coords = yolov8Obj.detect(temp_path)
|
||||
|
||||
os.remove(temp_path)
|
||||
|
||||
return ResponseUtil.success(msg="侵占消防区域目标检测成功,是否有遮挡", data=(len(cls)==0))
|
||||
|
||||
except Exception as e:
|
||||
return ResponseUtil.error(msg=f"检测是否侵占消防区域失败: {str(e)}", data=None)
|
||||
|
||||
|
||||
@router.post("/detect_onnx_from_base64_0")
|
||||
async def ocr_from_base64(request: ImageBase64Request):
|
||||
""" 从base64图片进行目标检测, 检测是否侵占消防区域"""
|
||||
|
||||
# 为不同类别生成不同颜色
|
||||
def get_color(class_id):
|
||||
np.random.seed(class_id)
|
||||
return tuple(map(int, np.random.randint(0, 255, 3)))
|
||||
|
||||
try:
|
||||
# image_base64 = request.image_base64
|
||||
# if image_base64.startswith('data:image'):
|
||||
# image_base64 = image_base64.split(',')[1]
|
||||
# image_data = base64.b64decode(image_base64)
|
||||
|
||||
# image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# temp_dir = "./tmp/detect_images"
|
||||
# os.makedirs(temp_dir, exist_ok=True)
|
||||
# temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||
# temp_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
# image.save(temp_path)
|
||||
|
||||
# boxes, scores, class_ids = yolov8ONNX.detect(image_data)
|
||||
|
||||
# os.remove(temp_path)
|
||||
|
||||
image_base64 = request.image_base64
|
||||
if image_base64.startswith('data:image'):
|
||||
image_base64 = image_base64.split(',')[1]
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
# 将字节流转换为OpenCV格式的BGR图像
|
||||
image_np = np.frombuffer(image_data, np.uint8)
|
||||
image_cv = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # 这会得到BGR格式的图像
|
||||
|
||||
# 不再需要保存到临时文件,直接使用内存中的图像
|
||||
boxes, scores, class_ids = yolov8ONNX.detect(image_cv)
|
||||
|
||||
|
||||
# 绘制检测结果
|
||||
image_with_boxes = image_cv.copy()
|
||||
|
||||
for box, score, class_id in zip(boxes, scores, class_ids):
|
||||
# 获取边界框坐标
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# 生成随机颜色或使用固定颜色方案
|
||||
color = get_color(class_id) # 绿色,也可以根据class_id设置不同颜色
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 准备标签文本
|
||||
label = f"Class {class_id}: {score:.2f}"
|
||||
# 如果你有类别名称字典,可以这样使用:
|
||||
# label = f"{class_names[class_id]}: {score:.2f}"
|
||||
|
||||
# 计算文本大小以绘制背景
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
|
||||
# 绘制文本背景
|
||||
cv2.rectangle(
|
||||
image_with_boxes,
|
||||
(x1, y1 - text_height - baseline - 5),
|
||||
(x1 + text_width, y1),
|
||||
color,
|
||||
-1
|
||||
)
|
||||
|
||||
# 绘制文本
|
||||
cv2.putText(
|
||||
image_with_boxes,
|
||||
label,
|
||||
(x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0), # 黑色文字
|
||||
1,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 创建保存目录
|
||||
save_dir = "tmp/detect_images"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 生成带时间戳的随机文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
random_id = uuid.uuid4().hex[:8] # 取UUID的前8位
|
||||
filename = f"detect_{timestamp}_{random_id}.jpg"
|
||||
output_path = os.path.join(save_dir, filename)
|
||||
|
||||
# 保存图像
|
||||
cv2.imwrite(output_path, image_with_boxes)
|
||||
print(f"检测结果已保存到: {output_path}")
|
||||
|
||||
return ResponseUtil.success(msg="侵占消防区域目标检测成功,是否有遮挡", data=(len(boxes)==0))
|
||||
|
||||
except Exception as e:
|
||||
return ResponseUtil.error(msg=f"检测是否侵占消防区域失败: {str(e)}", data=None)
|
||||
|
||||
@router.post("/detect_from_base64_1")
|
||||
async def ocr_from_base64(request: ImageBase64Request):
|
||||
""" 从吧色图64图片进行目标检测, 检测是否存在灭火器"""
|
||||
|
||||
# try:
|
||||
# 为不同类别生成不同颜色
|
||||
def get_color(class_id):
|
||||
np.random.seed(class_id)
|
||||
return tuple(map(int, np.random.randint(0, 255, 3)))
|
||||
|
||||
try:
|
||||
|
||||
image_base64 = request.image_base64
|
||||
if image_base64.startswith('data:image'):
|
||||
image_base64 = image_base64.split(',')[1]
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
# 将字节流转换为OpenCV格式的BGR图像
|
||||
image_np = np.frombuffer(image_data, np.uint8)
|
||||
image_cv = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # 这会得到BGR格式的图像
|
||||
|
||||
# 不再需要保存到临时文件,直接使用内存中的图像
|
||||
boxes, scores, class_ids = yolov8ONNX.detect(image_cv)
|
||||
|
||||
|
||||
# 绘制检测结果
|
||||
image_with_boxes = image_cv.copy()
|
||||
|
||||
for box, score, class_id in zip(boxes, scores, class_ids):
|
||||
# 获取边界框坐标
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# 生成随机颜色或使用固定颜色方案
|
||||
color = get_color(class_id) # 绿色,也可以根据class_id设置不同颜色
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 准备标签文本
|
||||
label = f"Class {class_id}: {score:.2f}"
|
||||
# 如果你有类别名称字典,可以这样使用:
|
||||
# label = f"{class_names[class_id]}: {score:.2f}"
|
||||
|
||||
# 计算文本大小以绘制背景
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
|
||||
# 绘制文本背景
|
||||
cv2.rectangle(
|
||||
image_with_boxes,
|
||||
(x1, y1 - text_height - baseline - 5),
|
||||
(x1 + text_width, y1),
|
||||
color,
|
||||
-1
|
||||
)
|
||||
|
||||
# 绘制文本
|
||||
cv2.putText(
|
||||
image_with_boxes,
|
||||
label,
|
||||
(x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0), # 黑色文字
|
||||
1,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 创建保存目录
|
||||
save_dir = "tmp/detect_images"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 生成带时间戳的随机文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
random_id = uuid.uuid4().hex[:8] # 取UUID的前8位
|
||||
filename = f"detect_{timestamp}_{random_id}.jpg"
|
||||
output_path = os.path.join(save_dir, filename)
|
||||
|
||||
# 保存图像
|
||||
cv2.imwrite(output_path, image_with_boxes)
|
||||
print(f"检测结果已保存到: {output_path}")
|
||||
|
||||
return ResponseUtil.success(msg="灭火器目标检测成功,是否存在灭火器", data=(len(class_ids)!=0 and 0 in class_ids))
|
||||
# return ResponseUtil.success(msg="目标检测成功", data=cls)
|
||||
|
||||
except Exception as e:
|
||||
return ResponseUtil.error(msg=f"检测是否存在灭火器失败: {str(e)}", data=None)
|
||||
|
||||
# @router.post("/detect_from_base64_1")
|
||||
# async def ocr_from_base64(request: ImageBase64Request):
|
||||
# """ 从吧色图64图片进行目标检测, 检测是否存在灭火器"""
|
||||
|
||||
# try:
|
||||
# image_base64 = request.image_base64
|
||||
# if image_base64.startswith('data:image'):
|
||||
# image_base64 = image_base64.split(',')[1]
|
||||
# image_data = base64.b64decode(image_base64)
|
||||
|
||||
# image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# temp_dir = "./tmp/detect_images"
|
||||
# os.makedirs(temp_dir, exist_ok=True)
|
||||
# temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
|
||||
# temp_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
# image.save(temp_path)
|
||||
|
||||
# cls, conf, coords = yolov8Obj.detect(temp_path)
|
||||
|
||||
# os.remove(temp_path)
|
||||
|
||||
# return ResponseUtil.success(msg="灭火器目标检测成功,是否存在灭火器", data=(len(cls)!=0 and 0 in cls))
|
||||
# # return ResponseUtil.success(msg="目标检测成功", data=cls)
|
||||
|
||||
# except Exception as e:
|
||||
# return ResponseUtil.error(msg=f"检测是否存在灭火器失败: {str(e)}", data=None)
|
||||
122
app/api/v1/testLogin.py
Normal file
122
app/api/v1/testLogin.py
Normal file
@ -0,0 +1,122 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置参数(应从环境变量读取)
|
||||
SECRET_KEY = "your-secret-key-keep-it-secret!"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
# 模拟用户数据库
|
||||
fake_users_db = {
|
||||
"johndoe": {
|
||||
"username": "johndoe",
|
||||
"full_name": "John Doe",
|
||||
"email": "johndoe@example.com",
|
||||
# 哈希后的密码(明文是 secret)
|
||||
"hashed_password": "$2b$12$EixZaYVK1fsbY1eZIbOnjesN9NwG1s3Z6FDcjyH103a2.dJgD0L4q",
|
||||
"disabled": False,
|
||||
}
|
||||
}
|
||||
|
||||
# Token 模型
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: Optional[str] = None
|
||||
|
||||
# 用户模型
|
||||
class User(BaseModel):
|
||||
username: str
|
||||
email: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
disabled: Optional[bool] = None
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
||||
|
||||
# 密码哈希配置
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# OAuth2 配置
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 密码验证工具
|
||||
def verify_password(plain_password: str, hashed_password: str):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
# 获取用户
|
||||
def get_user(db, username: str):
|
||||
if username in db:
|
||||
user_dict = db[username]
|
||||
return UserInDB(**user_dict)
|
||||
|
||||
# 用户认证
|
||||
def authenticate_user(fake_db, username: str, password: str):
|
||||
user = get_user(fake_db, username)
|
||||
if not user:
|
||||
return False
|
||||
# if not verify_password(password, user.hashed_password):
|
||||
# return False
|
||||
return user
|
||||
|
||||
# 创建访问令牌
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
# 获取当前用户(依赖注入)
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
user = get_user(fake_users_db, username=token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
# 登录路由
|
||||
@app.post("/token", response_model=Token)
|
||||
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
# 受保护路由
|
||||
@app.get("/users/me/", response_model=User)
|
||||
async def read_users_me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
89
app/config/config.py
Normal file
89
app/config/config.py
Normal file
@ -0,0 +1,89 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
|
||||
class DataBaseSettings(BaseSettings):
|
||||
# 数据库配置
|
||||
DB_HOST: str = "10.0.0.17"
|
||||
DB_PORT: int = 3306
|
||||
DB_USER: str = "root"
|
||||
DB_PASSWORD: str = "root"
|
||||
DB_NAME: str = "kangda_robotic_dog"
|
||||
DB_CHARSET: str = "utf8mb4"
|
||||
DB_POOL_SIZE: int = 5
|
||||
DB_MAX_OVERFLOW: int = 10
|
||||
DB_POOL_TIMEOUT: int = 30
|
||||
DB_POOL_RECYCLE: int = 1800
|
||||
|
||||
# class Config:
|
||||
# env_file = ".env"
|
||||
|
||||
class OCRSettings(BaseException):
|
||||
TEXT_DETECTION_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/PP-OCRv5_server_det_infer_20250814/"
|
||||
TEXT_RECONGNITION_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/PP-OCRv5_server_rec_infer_20250815/"
|
||||
|
||||
# TEXT_DETECTION_MODEL_ONNX_DIR: str ='/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx'
|
||||
# TEXT_RECONGNITION_MODEL_ONNX_DIR: str ='/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx'
|
||||
|
||||
TEXT_DETECTION_MODEL_ONNX_DIR="/root/robot_dog_project/kangda_robotic_dog/models/det_mobile_14_shape.onnx"
|
||||
TEXT_RECONGNITION_MODEL_ONNX_DIR="/root/robot_dog_project/kangda_robotic_dog/models/rec_mobile_14_shape.onnx"
|
||||
|
||||
class YoloV8Settings(BaseException):
|
||||
YOLOV8_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/best.pt"
|
||||
YOLOV8_MODEL_ONNX_DIRS="/root/robot_dog_project/kangda_robotic_dog/models/yolov8_20250820.onnx"
|
||||
|
||||
|
||||
class GetSettings:
|
||||
|
||||
def __init__(self):
|
||||
self.parse_cli_args()
|
||||
|
||||
def get_database_settings(self):
|
||||
return DataBaseSettings()
|
||||
|
||||
def get_ocr_settings(self):
|
||||
return OCRSettings()
|
||||
|
||||
def get_yolov8_settings(self):
|
||||
return YoloV8Settings()
|
||||
|
||||
|
||||
def parse_cli_args(self):
|
||||
"""
|
||||
解析命令行参数
|
||||
"""
|
||||
if 'uvicorn' in sys.argv[0]:
|
||||
# 使用uvicorn启动时,命令行参数需要按照uvicorn的文档进行配置,无法自定义参数
|
||||
pass
|
||||
else:
|
||||
# 使用argparse定义命令行参数
|
||||
parser = argparse.ArgumentParser(description='命令行参数')
|
||||
parser.add_argument('--env', type=str, default='dev', help='运行环境')
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
# 设置环境变量,如果未设置命令行参数,默认APP_ENV为dev
|
||||
os.environ['APP_ENV'] = args.env if args.env else 'dev'
|
||||
# 读取运行环境
|
||||
run_env = os.environ.get('APP_ENV', '')
|
||||
# 运行环境未指定时默认加载.env.dev
|
||||
env_file = '.env.dev'
|
||||
# 运行环境不为空时按命令行参数加载对应.env文件
|
||||
if run_env != '':
|
||||
env_file = f'.env.{run_env}'
|
||||
|
||||
|
||||
print(f"加载配置 .env.{run_env}")
|
||||
# 加载配置
|
||||
load_dotenv(env_file)
|
||||
|
||||
get_settings = GetSettings()
|
||||
|
||||
database_settings = get_settings.get_database_settings()
|
||||
|
||||
ocr_settings = get_settings.get_ocr_settings()
|
||||
|
||||
yolov8_settings = get_settings.get_yolov8_settings()
|
||||
|
||||
|
||||
77
app/config/constant.py
Normal file
77
app/config/constant.py
Normal file
@ -0,0 +1,77 @@
|
||||
# from config.env import DataBaseConfig
|
||||
|
||||
|
||||
class CommonConstant:
|
||||
"""
|
||||
常用常量
|
||||
|
||||
WWW: www主域
|
||||
HTTP: http请求
|
||||
HTTPS: https请求
|
||||
LOOKUP_RMI: RMI远程方法调用
|
||||
LOOKUP_LDAP: LDAP远程方法调用
|
||||
LOOKUP_LDAPS: LDAPS远程方法调用
|
||||
YES: 是否为系统默认(是)
|
||||
NO: 是否为系统默认(否)
|
||||
DEPT_NORMAL: 部门正常状态
|
||||
DEPT_DISABLE: 部门停用状态
|
||||
UNIQUE: 校验是否唯一的返回标识(是)
|
||||
NOT_UNIQUE: 校验是否唯一的返回标识(否)
|
||||
"""
|
||||
|
||||
WWW = 'www.'
|
||||
HTTP = 'http://'
|
||||
HTTPS = 'https://'
|
||||
LOOKUP_RMI = 'rmi:'
|
||||
LOOKUP_LDAP = 'ldap:'
|
||||
LOOKUP_LDAPS = 'ldaps:'
|
||||
YES = 'Y'
|
||||
NO = 'N'
|
||||
DEPT_NORMAL = '0'
|
||||
DEPT_DISABLE = '1'
|
||||
UNIQUE = True
|
||||
NOT_UNIQUE = False
|
||||
|
||||
|
||||
class HttpStatusConstant:
|
||||
"""
|
||||
返回状态码
|
||||
|
||||
SUCCESS: 操作成功
|
||||
CREATED: 对象创建成功
|
||||
ACCEPTED: 请求已经被接受
|
||||
NO_CONTENT: 操作已经执行成功,但是没有返回数据
|
||||
MOVED_PERM: 资源已被移除
|
||||
SEE_OTHER: 重定向
|
||||
NOT_MODIFIED: 资源没有被修改
|
||||
BAD_REQUEST: 参数列表错误(缺少,格式不匹配)
|
||||
UNAUTHORIZED: 未授权
|
||||
FORBIDDEN: 访问受限,授权过期
|
||||
NOT_FOUND: 资源,服务未找到
|
||||
BAD_METHOD: 不允许的http方法
|
||||
CONFLICT: 资源冲突,或者资源被锁
|
||||
UNSUPPORTED_TYPE: 不支持的数据,媒体类型
|
||||
ERROR: 系统内部错误
|
||||
NOT_IMPLEMENTED: 接口未实现
|
||||
WARN: 系统警告消息
|
||||
"""
|
||||
|
||||
SUCCESS = 200
|
||||
CREATED = 201
|
||||
ACCEPTED = 202
|
||||
NO_CONTENT = 204
|
||||
MOVED_PERM = 301
|
||||
SEE_OTHER = 303
|
||||
NOT_MODIFIED = 304
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
BAD_METHOD = 405
|
||||
CONFLICT = 409
|
||||
UNSUPPORTED_TYPE = 415
|
||||
ERROR = 500
|
||||
NOT_IMPLEMENTED = 501
|
||||
WARN = 601
|
||||
|
||||
|
||||
32
app/core/database.py
Normal file
32
app/core/database.py
Normal file
@ -0,0 +1,32 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from app.config.config import database_settings
|
||||
|
||||
engine = create_async_engine(
|
||||
f"mysql+aiomysql://{database_settings.DB_USER}:{database_settings.DB_PASSWORD}@{database_settings.DB_HOST}:{database_settings.DB_PORT}/{database_settings.DB_NAME}",
|
||||
pool_size=database_settings.DB_POOL_SIZE, # 连接池常驻连接数
|
||||
max_overflow=database_settings.DB_MAX_OVERFLOW, # 池最大溢出连接数
|
||||
pool_timeout=database_settings.DB_POOL_TIMEOUT, # 获取连接超时时间(秒)
|
||||
pool_recycle=database_settings.DB_POOL_RECYCLE, # 连接回收间隔(秒)
|
||||
echo=False # 关闭SQL日志输出
|
||||
)
|
||||
|
||||
# expire_on_commit=False 禁用提交后过期对象 --> 提交后原对象仍然可以使用.
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
# ORM模型基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# 获取数据库会话
|
||||
async def get_db():
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
# await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
88
app/crud/base.py
Normal file
88
app/crud/base.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete
|
||||
from app.core.database import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
'''
|
||||
Generic[ModelType, CreateSchemaType, UpdateSchemaType] 是 Python 类型提示中泛型编程的关键语法,用于声明一个泛型类。
|
||||
它的作用是让 CRUDBase 类具备类型参数化的能力,允许在继承或实例化时动态绑定具体类型,从而实现代码复用和类型安全
|
||||
声明 CRUDBase 类需要三个类型参数:ModelType、CreateSchemaType、UpdateSchemaType。
|
||||
这些类型参数会在类的内部方法中使用(如 get、create、update),确保类型一致性。
|
||||
|
||||
'''
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD对象与SQLAlchemy模型类一起使用
|
||||
:param model: SQLAlchemy模型类
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
# 根据id获取对象
|
||||
async def get(self, db: AsyncSession, id: Any) -> Optional[ModelType]:
|
||||
"""
|
||||
通过ID获取对象
|
||||
"""
|
||||
query = select(self.model).where(self.model.id == id)
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
# 分页查询
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""
|
||||
获取多个对象
|
||||
"""
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""
|
||||
创建对象
|
||||
"""
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""
|
||||
更新对象
|
||||
"""
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: Any) -> ModelType:
|
||||
"""
|
||||
删除对象
|
||||
"""
|
||||
obj = await self.get(db=db, id=id)
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
130
app/crud/event.py
Normal file
130
app/crud/event.py
Normal file
@ -0,0 +1,130 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import select, and_, or_, join
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.models import Event, Image, Temperature
|
||||
from app.schemas.event import EventUpdate, EventQuery, TestEvent
|
||||
|
||||
class CRUDEvent(CRUDBase[Event, EventUpdate, EventUpdate]):
|
||||
async def get_by_id(self, db: AsyncSession, *, event_id: str) -> Optional[Event]:
|
||||
"""根据ID获取事件"""
|
||||
query = (
|
||||
select(Event)
|
||||
.options(
|
||||
selectinload(Event.images),
|
||||
selectinload(Event.temperatures)
|
||||
)
|
||||
.where(Event.eventId == event_id)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_multi_with_query(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
query: EventQuery
|
||||
) -> List[Event]:
|
||||
"""根据查询条件获取事件列表"""
|
||||
conditions = []
|
||||
|
||||
if query.start_time:
|
||||
conditions.append(Event.insDate >= query.start_time)
|
||||
if query.end_time:
|
||||
conditions.append(Event.insDate <= query.end_time)
|
||||
if query.etypeName:
|
||||
conditions.append(Event.etypeName == query.etypeName)
|
||||
if query.area:
|
||||
conditions.append(Event.area == query.area)
|
||||
|
||||
query_stmt = (
|
||||
select(Event)
|
||||
.options(
|
||||
selectinload(Event.images),
|
||||
selectinload(Event.temperatures)
|
||||
)
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query_stmt = query_stmt.where(and_(*conditions))
|
||||
|
||||
query_stmt = query_stmt.offset(query.skip).limit(query.limit)
|
||||
|
||||
result = await db.execute(query_stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_event(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
event_id: str,
|
||||
obj_in: EventUpdate
|
||||
) -> Optional[Event]:
|
||||
"""更新事件信息"""
|
||||
event = await self.get_by_id(db, event_id=event_id)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
update_data = obj_in.model_dump()
|
||||
for field, value in update_data.items():
|
||||
setattr(event, field, value)
|
||||
|
||||
# db.add(event)
|
||||
await db.commit()
|
||||
await db.refresh(event)
|
||||
return event
|
||||
|
||||
async def delete_event(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
event_id: str
|
||||
) -> Optional[Event]:
|
||||
"""删除事件"""
|
||||
event = await self.get_by_id(db, event_id=event_id)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
await db.delete(event)
|
||||
await db.commit()
|
||||
return event
|
||||
|
||||
|
||||
async def get_test(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
|
||||
|
||||
) -> List[TestEvent]: #响应类型要写对啊
|
||||
|
||||
'''
|
||||
eventId
|
||||
number
|
||||
name
|
||||
imageUrl
|
||||
localPath
|
||||
temperature
|
||||
confidence
|
||||
createTime
|
||||
'''
|
||||
|
||||
query_stmt = (
|
||||
select(Event.eventId, Event.number, Event.name, Image.imageUrl, Image.localPath, Temperature.temperature, Temperature.confidence, Temperature.createTime)
|
||||
.select_from(Event)
|
||||
.outerjoin(Image, Event.eventId == Image.eventId)
|
||||
.outerjoin(Temperature, Image.imageId == Temperature.imageId)
|
||||
# 多个查询条件
|
||||
.where(and_(Event.etypeName=="日常巡检", True))
|
||||
)
|
||||
|
||||
result = await db.execute(query_stmt)
|
||||
|
||||
# 获取字典
|
||||
result = result.mappings().all()
|
||||
|
||||
# print(result)
|
||||
return [TestEvent(**row) for row in result]
|
||||
|
||||
|
||||
event = CRUDEvent(Event)
|
||||
18
app/dao/imageDao.py
Normal file
18
app/dao/imageDao.py
Normal file
@ -0,0 +1,18 @@
|
||||
from typing import List
|
||||
# from app.models.models import image
|
||||
from app.schemas.image import ImageBase
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.models import Image
|
||||
|
||||
class ImageDao:
|
||||
|
||||
@classmethod
|
||||
async def get_image_list(cls, db: AsyncSession, query_model: ImageBase):
|
||||
|
||||
stmt = (
|
||||
select(Image)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
return result.mappings().all()
|
||||
25
app/main.py
Normal file
25
app/main.py
Normal file
@ -0,0 +1,25 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api.v1.api import router
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="测试fastapi",
|
||||
description="测试啊",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "测试成功"}
|
||||
134
app/models/models.py
Normal file
134
app/models/models.py
Normal file
@ -0,0 +1,134 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, Index, BigInteger
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
|
||||
class Image(Base):
|
||||
__tablename__ = "image"
|
||||
|
||||
imageId = Column(BigInteger, primary_key=True, autoincrement=True, comment='图片ID')
|
||||
localPath = Column(String(500), comment='本地存储路径')
|
||||
type = Column(String(2), comment='图片类型, 0 ocr 图片, 1 目标检测图片')
|
||||
result = Column(String(500), comment='检测结果')
|
||||
createTime = Column(DateTime, default=datetime.now, comment='创建时间')
|
||||
updateTime = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||||
|
||||
|
||||
# class Event(Base):
|
||||
# __tablename__ = "event"
|
||||
|
||||
# eventId = Column(String(50), primary_key=True)
|
||||
# tenantInfoId = Column(String(100))
|
||||
# reportEventId = Column(String(100))
|
||||
# number = Column(String(20))
|
||||
# name = Column(String(20))
|
||||
# eclassify = Column(String(5))
|
||||
# operationType = Column(String(5))
|
||||
# etype = Column(String(20))
|
||||
# etypeName = Column(String(20))
|
||||
# enTypeName = Column(String(30))
|
||||
# hkTypeName = Column(String(20))
|
||||
# reportStatus = Column(String(5))
|
||||
# results = Column(String(5))
|
||||
# insDate = Column(DateTime)
|
||||
# insDateShow = Column(DateTime)
|
||||
# updDate = Column(DateTime)
|
||||
# updDateShow = Column(DateTime)
|
||||
# fileType = Column(String(5))
|
||||
# area = Column(String(20))
|
||||
# floor = Column(String(10))
|
||||
# map = Column(String(20))
|
||||
# staffId = Column(String(40))
|
||||
# targetUserId = Column(String(40))
|
||||
# position = Column(String(100))
|
||||
# actualStaffName = Column(String(20))
|
||||
# targetStaffName = Column(String(20))
|
||||
# routeName = Column(String(20))
|
||||
# phoneAddress = Column(String(500))
|
||||
# width = Column(String(10))
|
||||
# height = Column(String(10))
|
||||
# resolution = Column(String(10))
|
||||
# originX = Column(String(20))
|
||||
# originY = Column(String(20))
|
||||
# imgList = Column(String(20))
|
||||
# robotType = Column(String(5))
|
||||
# eventFloor = Column(String(10))
|
||||
# floorName = Column(String(10))
|
||||
# coordId = Column(String(40))
|
||||
# coord = Column(String(40))
|
||||
# coordName = Column(String(30))
|
||||
# positonName = Column(String(20))
|
||||
# processingRemark = Column(String(300))
|
||||
# carId = Column(String(40))
|
||||
# parkingSpaceType = Column(String(10))
|
||||
# parkingSpaceNumber = Column(String(40))
|
||||
# carNumber = Column(String(40))
|
||||
# eno = Column(String(40))
|
||||
# instrument = Column(String(40))
|
||||
# evideo = Column(String(40))
|
||||
# createTime = Column(DateTime, default=datetime.now, comment='本地后台创建时间')
|
||||
# updateTime = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='本地后台更新时间')
|
||||
|
||||
# # 关系
|
||||
# images = relationship("Image", back_populates="event", cascade="all, delete-orphan")
|
||||
# temperatures = relationship("Temperature", back_populates="event", cascade="all, delete-orphan")
|
||||
# process_logs = relationship("ProcessLog", back_populates="event", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
# class Image(Base):
|
||||
# __tablename__ = "image"
|
||||
|
||||
# imageId = Column(BigInteger, primary_key=True, autoincrement=True, comment='图片ID')
|
||||
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
|
||||
# imageUrl = Column(String(500), nullable=False, comment='图片URL')
|
||||
# localPath = Column(String(500), comment='本地存储路径')
|
||||
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||
|
||||
# # 关系
|
||||
# event = relationship("Event", back_populates="images" )
|
||||
# temperatures = relationship("Temperature", back_populates="image")
|
||||
|
||||
# __table_args__ = (
|
||||
# Index('idx_image_event_id', 'eventId'),
|
||||
# )
|
||||
|
||||
|
||||
# class Temperature(Base):
|
||||
# __tablename__ = "temperature"
|
||||
|
||||
# tempId = Column(BigInteger, primary_key=True, autoincrement=True, comment='温度记录ID')
|
||||
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
|
||||
# imageId = Column(BigInteger, ForeignKey('image.imageId'), nullable=False, comment='关联图片ID')
|
||||
# temperature = Column(String(20), nullable=False, comment='温度值')
|
||||
# # status = Column(String(2), comment='温度是否正常')
|
||||
# confidence = Column(String(40), nullable=False, comment='识别置信度')
|
||||
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||
|
||||
# # 关系
|
||||
# event = relationship("Event", back_populates="temperatures")
|
||||
# image = relationship("Image", back_populates="temperatures")
|
||||
|
||||
# __table_args__ = (
|
||||
# Index('idx_temp_event_id', 'eventId'),
|
||||
# Index('idx_temp_create_time', 'createTime'),
|
||||
# )
|
||||
|
||||
|
||||
# class ProcessLog(Base):
|
||||
# __tablename__ = "process_log"
|
||||
|
||||
# logId = Column(BigInteger, primary_key=True, autoincrement=True, comment='日志ID')
|
||||
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
|
||||
# processStatus = Column(Integer, nullable=False, comment='处理状态')
|
||||
# errorMessage = Column(Text, comment='错误信息')
|
||||
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
|
||||
|
||||
# # 关系
|
||||
# event = relationship("Event", back_populates="process_logs")
|
||||
|
||||
# __table_args__ = (
|
||||
# Index('idx_log_event_id', 'eventId'),
|
||||
# Index('idx_log_create_time', 'createTime'),
|
||||
# )
|
||||
0
app/schemas/__init__.py
Normal file
0
app/schemas/__init__.py
Normal file
116
app/schemas/event.py
Normal file
116
app/schemas/event.py
Normal file
@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class ImageBase(BaseModel):
|
||||
imageUrl: str
|
||||
localPath: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class TemperatureBase(BaseModel):
|
||||
temperature: str
|
||||
confidence: str
|
||||
createTime: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class EventBase(BaseModel):
|
||||
eventId: str
|
||||
tenantInfoId: Optional[str] = None
|
||||
reportEventId: Optional[str] = None
|
||||
number: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
etypeName: Optional[str] = None
|
||||
insDate: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class EventList(EventBase):
|
||||
images: List[ImageBase] = []
|
||||
temperatures: List[TemperatureBase] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class EventDetail(EventBase):
|
||||
eclassify: Optional[str] = None
|
||||
operationType: Optional[str] = None
|
||||
etype: Optional[str] = None
|
||||
enTypeName: Optional[str] = None
|
||||
hkTypeName: Optional[str] = None
|
||||
reportStatus: Optional[str] = None
|
||||
results: Optional[str] = None
|
||||
insDateShow: Optional[datetime] = None
|
||||
updDate: Optional[datetime] = None
|
||||
updDateShow: Optional[datetime] = None
|
||||
fileType: Optional[str] = None
|
||||
area: Optional[str] = None
|
||||
floor: Optional[str] = None
|
||||
map: Optional[str] = None
|
||||
staffId: Optional[str] = None
|
||||
targetUserId: Optional[str] = None
|
||||
position: Optional[str] = None
|
||||
actualStaffName: Optional[str] = None
|
||||
targetStaffName: Optional[str] = None
|
||||
routeName: Optional[str] = None
|
||||
phoneAddress: Optional[str] = None
|
||||
width: Optional[str] = None
|
||||
height: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
originX: Optional[str] = None
|
||||
originY: Optional[str] = None
|
||||
robotType: Optional[str] = None
|
||||
eventFloor: Optional[str] = None
|
||||
floorName: Optional[str] = None
|
||||
coordId: Optional[str] = None
|
||||
coord: Optional[str] = None
|
||||
coordName: Optional[str] = None
|
||||
positonName: Optional[str] = None
|
||||
processingRemark: Optional[str] = None
|
||||
carId: Optional[str] = None
|
||||
parkingSpaceType: Optional[str] = None
|
||||
parkingSpaceNumber: Optional[str] = None
|
||||
carNumber: Optional[str] = None
|
||||
eno: Optional[str] = None
|
||||
instrument: Optional[str] = None
|
||||
evideo: Optional[str] = None
|
||||
createTime: Optional[datetime] = None
|
||||
updateTime: Optional[datetime] = None
|
||||
images: List[ImageBase] = []
|
||||
temperatures: List[TemperatureBase] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class EventUpdate(BaseModel):
|
||||
number: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
etypeName: Optional[str] = None
|
||||
area: Optional[str] = None
|
||||
position: Optional[str] = None
|
||||
processingRemark: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class EventQuery(BaseModel):
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
etypeName: Optional[str] = None
|
||||
area: Optional[str] = None
|
||||
skip: int = 0
|
||||
limit: int = 100
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TestEvent(BaseModel):
|
||||
eventId: str
|
||||
number: Optional[str] = None
|
||||
name : Optional[str] = None
|
||||
imageUrl : Optional[str] = None
|
||||
localPath: Optional[str] = None
|
||||
temperature: Optional[str] = None
|
||||
confidence : Optional[str] = None
|
||||
createTime: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
15
app/schemas/image.py
Normal file
15
app/schemas/image.py
Normal file
@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class ImageBase(BaseModel):
|
||||
|
||||
imageId: Optional[int] = None
|
||||
localPath: Optional[str] = None
|
||||
type : Optional[str] = None
|
||||
result : Optional[str] = None
|
||||
createTime : Optional[datetime] = None
|
||||
updateTime : Optional[datetime] = None
|
||||
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
11
app/schemas/ocr.py
Normal file
11
app/schemas/ocr.py
Normal file
@ -0,0 +1,11 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class ImageBase64Request(BaseModel):
|
||||
image_base64: str
|
||||
image_type: Optional[str] = "jpg" # 图片类型,如 jpg, png 等
|
||||
|
||||
# class OCRResponse(BaseModel):
|
||||
# success: bool
|
||||
# message: str
|
||||
# data: Optional[list] = None
|
||||
15
app/services/imageServices.py
Normal file
15
app/services/imageServices.py
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
from app.schemas.image import ImageBase
|
||||
from app.dao.imageDao import ImageDao
|
||||
|
||||
|
||||
class ImageService:
|
||||
|
||||
@classmethod
|
||||
async def get_image_list(cls, db: AsyncSession, query_model: ImageBase):
|
||||
|
||||
result = await ImageDao.get_image_list(db, query_model)
|
||||
|
||||
return result
|
||||
535
app/util/baiduOCR.py
Normal file
535
app/util/baiduOCR.py
Normal file
@ -0,0 +1,535 @@
|
||||
from paddleocr import PaddleOCR
|
||||
import os
|
||||
import cv2
|
||||
import yaml
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
# import math
|
||||
|
||||
|
||||
from app.config.config import OCRSettings
|
||||
|
||||
|
||||
|
||||
|
||||
class BaiduOCR:
|
||||
|
||||
def __init__(self):
|
||||
self.model = PaddleOCR(
|
||||
# 文本检测模型地址
|
||||
# text_detection_model_dir = "/home/admin-root/haotian/康达瑞贝斯机器狗/ocr_model/PP-OCRv5_server_det",
|
||||
text_detection_model_dir=OCRSettings.TEXT_DETECTION_MODEL_DIR,
|
||||
# 文本识别模型地址
|
||||
# text_recognition_model_dir = "/home/admin-root/haotian/康达瑞贝斯机器狗/ocr_model/PP-OCRv5_server_rec",
|
||||
text_recognition_model_dir=OCRSettings.TEXT_RECONGNITION_MODEL_DIR,
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
use_textline_orientation=False,
|
||||
# 多gpu有问题
|
||||
device='gpu:2'
|
||||
)
|
||||
def ocr(self, image_path:str):
|
||||
try:
|
||||
result = self.model.predict(image_path)
|
||||
except IndexError:
|
||||
return []
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
self.draw_ocr_result(image_path, result)
|
||||
|
||||
result = self.parse_result(result)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
return result[0]
|
||||
|
||||
def parse_result(self, result):
|
||||
text_list = []
|
||||
for item in result:
|
||||
rec_texts = item.get('rec_texts')
|
||||
rec_scores = item.get('rec_scores')
|
||||
|
||||
if rec_texts is None or rec_scores is None:
|
||||
continue
|
||||
|
||||
text_list.append([rec_texts, rec_scores])
|
||||
return text_list
|
||||
|
||||
def draw_ocr_result(self, image_path: str, result):
|
||||
"""
|
||||
在原图上绘制 OCR 识别结果
|
||||
"""
|
||||
# 读取原图
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
return
|
||||
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 转换为 PIL Image 以支持中文显示
|
||||
pil_img = Image.fromarray(image)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 尝试加载中文字体(根据系统调整路径)
|
||||
try:
|
||||
# font = ImageFont.truetype("simhei.ttf", 20) # Windows
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", 20) # Linux
|
||||
# font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20) # macOS
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# 绘制每个检测框和文本
|
||||
for idx, item in enumerate(result):
|
||||
box = item['rec_boxes'] # shape=(1, 4)
|
||||
text = item['rec_texts']
|
||||
score = item['rec_scores']
|
||||
|
||||
# 处理可能的列表类型
|
||||
if isinstance(text, list):
|
||||
text = text[0] if len(text) > 0 else ""
|
||||
if isinstance(score, list):
|
||||
score = score[0] if len(score) > 0 else 0.0
|
||||
|
||||
# 提取坐标
|
||||
if box is None:
|
||||
continue
|
||||
|
||||
box = box.flatten()
|
||||
if len(box) < 4:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
||||
|
||||
# 绘制矩形框
|
||||
draw.rectangle([(x1, y1), (x2, y2)], outline=(255, 0, 0), width=2)
|
||||
|
||||
# 绘制文本和置信度
|
||||
text_position = (x1, max(0, y1 - 25))
|
||||
text_with_score = f"{text} ({float(score):.2f})"
|
||||
draw.text(text_position, text_with_score, fill=(0, 255, 0), font=font)
|
||||
|
||||
# 转换回 OpenCV 格式
|
||||
result_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 获取文件夹路径并创建
|
||||
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
||||
|
||||
cv2.imwrite(image_path, result_img)
|
||||
|
||||
# return result_img
|
||||
|
||||
|
||||
|
||||
|
||||
class BaiduOCRONNX:
|
||||
def __init__(self, det_model_path=OCRSettings.TEXT_DETECTION_MODEL_ONNX_DIR, rec_model_path=OCRSettings.TEXT_RECONGNITION_MODEL_ONNX_DIR):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
det_model_path: 检测模型路径 (det.onnx)
|
||||
rec_model_path: 识别模型路径 (rec.onnx)
|
||||
"""
|
||||
# 初始化检测模型
|
||||
self.det_session = ort.InferenceSession(det_model_path)
|
||||
self.det_input_name = self.det_session.get_inputs()[0].name
|
||||
|
||||
# 初始化识别模型
|
||||
self.rec_session = ort.InferenceSession(rec_model_path)
|
||||
self.rec_input_name = self.rec_session.get_inputs()[0].name
|
||||
|
||||
# 字符集(根据您的模型调整)
|
||||
# self.character = ['blank', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+',
|
||||
# ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8',
|
||||
# '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E',
|
||||
# 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
|
||||
# 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
|
||||
# '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
|
||||
# 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y',
|
||||
# 'z', '{', '|', '}', '~'] + [chr(i) for i in range(19968, 40870)] # 中文字符
|
||||
self.character = self.get_dict()
|
||||
|
||||
if self.character is None:
|
||||
raise ValueError('请检查字典文件是否存在!')
|
||||
|
||||
def get_dict(self, dict_path='./dict.yaml'):
|
||||
"""
|
||||
加载字典
|
||||
"""
|
||||
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||
dict_rec = yaml.safe_load(f)
|
||||
return dict_rec.get('character_dict', [])
|
||||
|
||||
def resize_norm_img_det(self, img, input_shape=(640, 640)):
|
||||
"""
|
||||
检测模型的图像预处理 - 固定输入形状 [1, 3, 640, 640]
|
||||
"""
|
||||
h, w, _ = img.shape
|
||||
target_h, target_w = input_shape
|
||||
|
||||
# 计算缩放比例 - 保持宽高比
|
||||
ratio_h = target_h / h
|
||||
ratio_w = target_w / w
|
||||
ratio = min(ratio_h, ratio_w)
|
||||
|
||||
# 计算缩放后的尺寸
|
||||
new_h = int(h * ratio)
|
||||
new_w = int(w * ratio)
|
||||
|
||||
# 调整图像大小
|
||||
resized_img = cv2.resize(img, (new_w, new_h))
|
||||
|
||||
# 创建目标尺寸的图像,用灰色填充
|
||||
padded_img = np.ones((target_h, target_w, 3), dtype=np.float32) * 114.0 # 直接用float32
|
||||
|
||||
# 计算居中位置
|
||||
top = (target_h - new_h) // 2
|
||||
left = (target_w - new_w) // 2
|
||||
|
||||
# 将缩放后的图像放到居中位置
|
||||
padded_img[top:top+new_h, left:left+new_w] = resized_img.astype(np.float32)
|
||||
|
||||
# 归一化
|
||||
img = (padded_img / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
img = img.transpose(2, 0, 1).astype(np.float32)
|
||||
img = np.expand_dims(img, axis=0).astype(np.float32)
|
||||
|
||||
return img, ratio, (top, left)
|
||||
|
||||
def post_process_det(self, dt_boxes, ratio, padding_info, ori_shape):
|
||||
"""
|
||||
检测结果后处理 - 适配固定输入形状
|
||||
"""
|
||||
if dt_boxes is None:
|
||||
return None
|
||||
|
||||
ori_h, ori_w = ori_shape
|
||||
top, left = padding_info
|
||||
|
||||
# 将坐标从模型输出空间转换回原图空间
|
||||
dt_boxes[:, :, 0] = (dt_boxes[:, :, 0] - left) / ratio
|
||||
dt_boxes[:, :, 1] = (dt_boxes[:, :, 1] - top) / ratio
|
||||
|
||||
# 裁剪到原图范围内
|
||||
dt_boxes[:, :, 0] = np.clip(dt_boxes[:, :, 0], 0, ori_w)
|
||||
dt_boxes[:, :, 1] = np.clip(dt_boxes[:, :, 1], 0, ori_h)
|
||||
|
||||
return dt_boxes
|
||||
|
||||
def boxes_from_bitmap(self, pred, bitmap, dest_width, dest_height, max_candidates=1000, box_thresh=0.6):
|
||||
"""
|
||||
从位图中提取文本框
|
||||
"""
|
||||
bitmap = bitmap.astype(np.uint8)
|
||||
height, width = bitmap.shape
|
||||
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
num_contours = min(len(contours), max_candidates)
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
for i in range(num_contours):
|
||||
contour = contours[i]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < 5:
|
||||
continue
|
||||
|
||||
points = np.array(points)
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if box_thresh > score:
|
||||
continue
|
||||
|
||||
# 扩展box
|
||||
box = self.unclip(points, 1.5).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < 5 + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(box[:, 0] / width * dest_width, 0, dest_width)
|
||||
box[:, 1] = np.clip(box[:, 1] / height * dest_height, 0, dest_height)
|
||||
|
||||
boxes.append(box.astype(np.int16))
|
||||
scores.append(score)
|
||||
|
||||
return np.array(boxes), scores
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
"""获取最小外接矩形"""
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
"""快速计算box得分"""
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
"""扩展文本框"""
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = offset.Execute(distance)
|
||||
|
||||
if len(expanded) == 0:
|
||||
return box
|
||||
else:
|
||||
return np.array(expanded[0])
|
||||
|
||||
def resize_norm_img_rec(self, img, input_shape=(320, 48)):
|
||||
"""
|
||||
识别模型的图像预处理 - 固定输入形状 [1, 3, 48, 320]
|
||||
"""
|
||||
target_w, target_h = input_shape # 注意:宽度在前
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
# 计算缩放比例,保持宽高比
|
||||
ratio_h = target_h / h
|
||||
ratio_w = target_w / w
|
||||
ratio = min(ratio_h, ratio_w)
|
||||
|
||||
# 计算缩放后的尺寸
|
||||
new_h = int(h * ratio)
|
||||
new_w = int(w * ratio)
|
||||
|
||||
# 调整图像大小
|
||||
resized_image = cv2.resize(img, (new_w, new_h))
|
||||
|
||||
# 创建目标尺寸的图像,用黑色填充
|
||||
padded_image = np.zeros((target_h, target_w, 3), dtype=np.float32) # 直接用float32
|
||||
|
||||
# 将缩放后的图像放到左上角(识别模型通常左对齐)
|
||||
padded_image[:new_h, :new_w] = resized_image.astype(np.float32)
|
||||
|
||||
# 归一化
|
||||
# padded_image = (padded_image / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
|
||||
# 不缩放反而会将识别结果再移后一个??
|
||||
padded_image = (padded_image / 255.0).astype(np.float32)
|
||||
padded_image = padded_image.transpose((2, 0, 1)).astype(np.float32)
|
||||
|
||||
return np.expand_dims(padded_image, axis=0).astype(np.float32)
|
||||
|
||||
def decode_rec_result(self, preds_prob):
|
||||
"""
|
||||
解码识别结果
|
||||
"""
|
||||
preds_idx = np.argmax(preds_prob, axis=1)
|
||||
preds_prob = np.max(preds_prob, axis=1)
|
||||
|
||||
# CTC解码
|
||||
last_idx = 0
|
||||
preds_text = []
|
||||
preds_conf = []
|
||||
|
||||
for i, idx in enumerate(preds_idx):
|
||||
if idx != last_idx and idx != 0: # 0是blank
|
||||
if idx < len(self.character):
|
||||
preds_text.append(self.character[idx])
|
||||
preds_conf.append(preds_prob[i])
|
||||
last_idx = idx
|
||||
|
||||
text = ''.join(preds_text)
|
||||
conf = np.mean(preds_conf) if preds_conf else 0.0
|
||||
|
||||
return text, conf
|
||||
|
||||
def detect_text(self, image):
|
||||
"""
|
||||
文本检测 - 适配固定输入形状 [1, 3, 640, 640]
|
||||
"""
|
||||
ori_h, ori_w = image.shape[:2]
|
||||
|
||||
# 预处理
|
||||
det_img, ratio, padding_info = self.resize_norm_img_det(image)
|
||||
|
||||
# 推理
|
||||
det_output = self.det_session.run(None, {self.det_input_name: det_img})[0]
|
||||
|
||||
# 后处理
|
||||
mask = det_output[0, 0, :, :]
|
||||
threshold = 0.3
|
||||
bitmap = (mask > threshold).astype(np.uint8) * 255
|
||||
|
||||
# 从位图中提取文本框(坐标是在640x640空间中的)
|
||||
boxes, scores = self.boxes_from_bitmap(mask, bitmap, 640, 640)
|
||||
|
||||
# 将坐标转换回原图空间
|
||||
if len(boxes) > 0:
|
||||
boxes = self.post_process_det(boxes, ratio, padding_info, (ori_h, ori_w))
|
||||
|
||||
return boxes, scores
|
||||
|
||||
def recognize_text(self, image):
|
||||
"""
|
||||
文本识别
|
||||
"""
|
||||
# 预处理
|
||||
rec_img = self.resize_norm_img_rec(image)
|
||||
|
||||
# 推理
|
||||
rec_output = self.rec_session.run(None, {self.rec_input_name: rec_img})[0]
|
||||
|
||||
# 解码
|
||||
text, conf = self.decode_rec_result(rec_output[0])
|
||||
|
||||
return text, conf
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
"""
|
||||
根据四个点坐标裁剪并矫正图像
|
||||
"""
|
||||
img_crop_width = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
return dst_img
|
||||
|
||||
def ocr(self, image_path):
|
||||
"""
|
||||
完整的OCR流程
|
||||
"""
|
||||
# 读取图像
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
return []
|
||||
|
||||
# 1. 文本检测
|
||||
dt_boxes, scores = self.detect_text(image)
|
||||
|
||||
if dt_boxes is None or len(dt_boxes) == 0:
|
||||
return []
|
||||
|
||||
# 2. 文本识别
|
||||
ocr_results = []
|
||||
|
||||
text_list = []
|
||||
confidence_list = []
|
||||
for i, box in enumerate(dt_boxes):
|
||||
# 裁剪文本区域
|
||||
box_points = box.astype(np.float32)
|
||||
crop_img = self.get_rotate_crop_image(image, box_points)
|
||||
if crop_img.size == 0:
|
||||
continue
|
||||
|
||||
# 识别文本
|
||||
text, conf = self.recognize_text(crop_img)
|
||||
|
||||
if conf > 0.5: # 置信度过滤
|
||||
ocr_results.append({
|
||||
'text': text,
|
||||
'confidence': conf,
|
||||
'box': box.tolist(),
|
||||
'score': scores[i] if i < len(scores) else 0.0
|
||||
})
|
||||
|
||||
text_list.append(text)
|
||||
confidence_list.append(round(conf.item(), 2))
|
||||
|
||||
# return ocr_results
|
||||
return [text_list, confidence_list]
|
||||
|
||||
|
||||
# 使用示例
|
||||
def main():
|
||||
# 初始化OCR
|
||||
# ocr = PaddleOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape.onnx')
|
||||
|
||||
ocr = BaiduOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx')
|
||||
|
||||
# 执行OCR
|
||||
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/3aee64cc1f90d93a5a45979f7b17cb4b_frame_001460.jpg'
|
||||
results = ocr.ocr(image_path)
|
||||
|
||||
# 打印结果
|
||||
for result in results:
|
||||
print(f"文本: {result['text']}")
|
||||
print(f"置信度: {result['confidence']:.3f}")
|
||||
print(f"检测得分: {result['score']:.3f}")
|
||||
print(f"坐标: {result['box']}")
|
||||
print("-" * 50)
|
||||
|
||||
# 可视化结果
|
||||
visualize_results(image_path, results)
|
||||
|
||||
def visualize_results(image_path, results):
|
||||
"""
|
||||
可视化OCR结果
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
for result in results:
|
||||
box = np.array(result['box'], dtype=np.int32)
|
||||
cv2.polylines(image, [box], True, (0, 255, 0), 2)
|
||||
|
||||
# 在框上方显示文本
|
||||
cv2.putText(image, result['text'],
|
||||
(box[0][0], box[0][1] - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
|
||||
|
||||
cv2.imwrite('result_shape_20250815.jpg', image)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ocr = BaiduOCR()
|
||||
print(ocr.ocr(""))
|
||||
266
app/util/responseHttp.py
Normal file
266
app/util/responseHttp.py
Normal file
@ -0,0 +1,266 @@
|
||||
from datetime import datetime
|
||||
from fastapi import status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
from app.config.constant import HttpStatusConstant
|
||||
|
||||
|
||||
class ResponseUtil:
|
||||
"""
|
||||
响应工具类
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def success(
|
||||
cls,
|
||||
msg: str = '操作成功',
|
||||
data: Optional[Any] = None,
|
||||
rows: Optional[Any] = None,
|
||||
dict_content: Optional[Dict] = None,
|
||||
model_content: Optional[BaseModel] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
media_type: Optional[str] = None,
|
||||
background: Optional[BackgroundTask] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
成功响应方法
|
||||
|
||||
:param msg: 可选,自定义成功响应信息
|
||||
:param data: 可选,成功响应结果中属性为data的值
|
||||
:param rows: 可选,成功响应结果中属性为rows的值
|
||||
:param dict_content: 可选,dict类型,成功响应结果中自定义属性的值
|
||||
:param model_content: 可选,BaseModel类型,成功响应结果中自定义属性的值
|
||||
:param headers: 可选,响应头信息
|
||||
:param media_type: 可选,响应结果媒体类型
|
||||
:param background: 可选,响应返回后执行的后台任务
|
||||
:return: 成功响应结果
|
||||
"""
|
||||
result = {'code': HttpStatusConstant.SUCCESS, 'msg': msg}
|
||||
|
||||
if data is not None:
|
||||
result['data'] = data
|
||||
if rows is not None:
|
||||
result['rows'] = rows
|
||||
if dict_content is not None:
|
||||
result.update(dict_content)
|
||||
if model_content is not None:
|
||||
result.update(model_content.model_dump(by_alias=True))
|
||||
|
||||
result.update({'success': True, 'time': datetime.now()})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder(result),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(
|
||||
cls,
|
||||
msg: str = '操作失败',
|
||||
data: Optional[Any] = None,
|
||||
rows: Optional[Any] = None,
|
||||
dict_content: Optional[Dict] = None,
|
||||
model_content: Optional[BaseModel] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
media_type: Optional[str] = None,
|
||||
background: Optional[BackgroundTask] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
失败响应方法
|
||||
|
||||
:param msg: 可选,自定义失败响应信息
|
||||
:param data: 可选,失败响应结果中属性为data的值
|
||||
:param rows: 可选,失败响应结果中属性为rows的值
|
||||
:param dict_content: 可选,dict类型,失败响应结果中自定义属性的值
|
||||
:param model_content: 可选,BaseModel类型,失败响应结果中自定义属性的值
|
||||
:param headers: 可选,响应头信息
|
||||
:param media_type: 可选,响应结果媒体类型
|
||||
:param background: 可选,响应返回后执行的后台任务
|
||||
:return: 失败响应结果
|
||||
"""
|
||||
result = {'code': HttpStatusConstant.WARN, 'msg': msg}
|
||||
|
||||
if data is not None:
|
||||
result['data'] = data
|
||||
if rows is not None:
|
||||
result['rows'] = rows
|
||||
if dict_content is not None:
|
||||
result.update(dict_content)
|
||||
if model_content is not None:
|
||||
result.update(model_content.model_dump(by_alias=True))
|
||||
|
||||
result.update({'success': False, 'time': datetime.now()})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder(result),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unauthorized(
|
||||
cls,
|
||||
msg: str = '登录信息已过期,访问系统资源失败',
|
||||
data: Optional[Any] = None,
|
||||
rows: Optional[Any] = None,
|
||||
dict_content: Optional[Dict] = None,
|
||||
model_content: Optional[BaseModel] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
media_type: Optional[str] = None,
|
||||
background: Optional[BackgroundTask] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
未认证响应方法
|
||||
|
||||
:param msg: 可选,自定义未认证响应信息
|
||||
:param data: 可选,未认证响应结果中属性为data的值
|
||||
:param rows: 可选,未认证响应结果中属性为rows的值
|
||||
:param dict_content: 可选,dict类型,未认证响应结果中自定义属性的值
|
||||
:param model_content: 可选,BaseModel类型,未认证响应结果中自定义属性的值
|
||||
:param headers: 可选,响应头信息
|
||||
:param media_type: 可选,响应结果媒体类型
|
||||
:param background: 可选,响应返回后执行的后台任务
|
||||
:return: 未认证响应结果
|
||||
"""
|
||||
result = {'code': HttpStatusConstant.UNAUTHORIZED, 'msg': msg}
|
||||
|
||||
if data is not None:
|
||||
result['data'] = data
|
||||
if rows is not None:
|
||||
result['rows'] = rows
|
||||
if dict_content is not None:
|
||||
result.update(dict_content)
|
||||
if model_content is not None:
|
||||
result.update(model_content.model_dump(by_alias=True))
|
||||
|
||||
result.update({'success': False, 'time': datetime.now()})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder(result),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def forbidden(
|
||||
cls,
|
||||
msg: str = '该用户无此接口权限',
|
||||
data: Optional[Any] = None,
|
||||
rows: Optional[Any] = None,
|
||||
dict_content: Optional[Dict] = None,
|
||||
model_content: Optional[BaseModel] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
media_type: Optional[str] = None,
|
||||
background: Optional[BackgroundTask] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
未授权响应方法
|
||||
|
||||
:param msg: 可选,自定义未授权响应信息
|
||||
:param data: 可选,未授权响应结果中属性为data的值
|
||||
:param rows: 可选,未授权响应结果中属性为rows的值
|
||||
:param dict_content: 可选,dict类型,未授权响应结果中自定义属性的值
|
||||
:param model_content: 可选,BaseModel类型,未授权响应结果中自定义属性的值
|
||||
:param headers: 可选,响应头信息
|
||||
:param media_type: 可选,响应结果媒体类型
|
||||
:param background: 可选,响应返回后执行的后台任务
|
||||
:return: 未授权响应结果
|
||||
"""
|
||||
result = {'code': HttpStatusConstant.FORBIDDEN, 'msg': msg}
|
||||
|
||||
if data is not None:
|
||||
result['data'] = data
|
||||
if rows is not None:
|
||||
result['rows'] = rows
|
||||
if dict_content is not None:
|
||||
result.update(dict_content)
|
||||
if model_content is not None:
|
||||
result.update(model_content.model_dump(by_alias=True))
|
||||
|
||||
result.update({'success': False, 'time': datetime.now()})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder(result),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def error(
|
||||
cls,
|
||||
msg: str = '接口异常',
|
||||
data: Optional[Any] = None,
|
||||
rows: Optional[Any] = None,
|
||||
dict_content: Optional[Dict] = None,
|
||||
model_content: Optional[BaseModel] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
media_type: Optional[str] = None,
|
||||
background: Optional[BackgroundTask] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
错误响应方法
|
||||
|
||||
:param msg: 可选,自定义错误响应信息
|
||||
:param data: 可选,错误响应结果中属性为data的值
|
||||
:param rows: 可选,错误响应结果中属性为rows的值
|
||||
:param dict_content: 可选,dict类型,错误响应结果中自定义属性的值
|
||||
:param model_content: 可选,BaseModel类型,错误响应结果中自定义属性的值
|
||||
:param headers: 可选,响应头信息
|
||||
:param media_type: 可选,响应结果媒体类型
|
||||
:param background: 可选,响应返回后执行的后台任务
|
||||
:return: 错误响应结果
|
||||
"""
|
||||
result = {'code': HttpStatusConstant.ERROR, 'msg': msg}
|
||||
|
||||
if data is not None:
|
||||
result['data'] = data
|
||||
if rows is not None:
|
||||
result['rows'] = rows
|
||||
if dict_content is not None:
|
||||
result.update(dict_content)
|
||||
if model_content is not None:
|
||||
result.update(model_content.model_dump(by_alias=True))
|
||||
|
||||
result.update({'success': False, 'time': datetime.now()})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder(result),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def streaming(
|
||||
cls,
|
||||
*,
|
||||
data: Any = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
media_type: Optional[str] = None,
|
||||
background: Optional[BackgroundTask] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
流式响应方法
|
||||
|
||||
:param data: 流式传输的内容
|
||||
:param headers: 可选,响应头信息
|
||||
:param media_type: 可选,响应结果媒体类型
|
||||
:param background: 可选,响应返回后执行的后台任务
|
||||
:return: 流式响应结果
|
||||
"""
|
||||
return StreamingResponse(
|
||||
status_code=status.HTTP_200_OK, content=data, headers=headers, media_type=media_type, background=background
|
||||
)
|
||||
470
app/util/yolov8Obj.py
Normal file
470
app/util/yolov8Obj.py
Normal file
@ -0,0 +1,470 @@
|
||||
from ultralytics import YOLO
|
||||
# from rknn.api import RKNN
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import time
|
||||
|
||||
from app.config.config import yolov8_settings
|
||||
|
||||
|
||||
|
||||
|
||||
class Yolov8Obj:
|
||||
|
||||
def __init__(self):
|
||||
self.model = YOLO(yolov8_settings.YOLOV8_MODEL_DIR)
|
||||
|
||||
def detect(self, image_path):
|
||||
result = self.model.predict(image_path)
|
||||
boxes = result[0].boxes
|
||||
|
||||
cls = boxes.cls.tolist()
|
||||
conf = boxes.conf.tolist()
|
||||
coords = boxes.xyxy.tolist()
|
||||
|
||||
return cls, conf, coords
|
||||
|
||||
|
||||
class YOLOv8ONNX:
|
||||
def __init__(self, model_path=yolov8_settings.YOLOV8_MODEL_ONNX_DIRS, conf_threshold=0.5, iou_threshold=0.4):
|
||||
"""
|
||||
初始化YOLOv8 ONNX模型
|
||||
|
||||
Args:
|
||||
model_path: ONNX模型文件路径
|
||||
conf_threshold: 置信度阈值
|
||||
iou_threshold: NMS IoU阈值
|
||||
"""
|
||||
self.conf_threshold = conf_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
|
||||
# 获取模型输入输出信息
|
||||
self.input_name = self.session.get_inputs()[0].name
|
||||
self.output_name = self.session.get_outputs()[0].name
|
||||
|
||||
# 获取输入尺寸
|
||||
input_shape = self.session.get_inputs()[0].shape
|
||||
self.input_height = input_shape[2]
|
||||
self.input_width = input_shape[3]
|
||||
|
||||
def preprocess(self, image):
|
||||
"""
|
||||
预处理图像
|
||||
|
||||
Args:
|
||||
image: 输入图像 (BGR格式)
|
||||
|
||||
Returns:
|
||||
preprocessed_image: 预处理后的图像
|
||||
scale_ratio: 缩放比例
|
||||
pad_info: 填充信息 (pad_x, pad_y)
|
||||
"""
|
||||
# 获取原图尺寸
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(self.input_height / h, self.input_width / w)
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
|
||||
# 等比例缩放
|
||||
resized_image = cv2.resize(image, (new_w, new_h))
|
||||
|
||||
# 计算填充
|
||||
pad_x = (self.input_width - new_w) // 2
|
||||
pad_y = (self.input_height - new_h) // 2
|
||||
|
||||
# 创建填充后的图像
|
||||
padded_image = np.full((self.input_height, self.input_width, 3), 114, dtype=np.uint8)
|
||||
padded_image[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized_image
|
||||
|
||||
# 转换为模型输入格式: BGR -> RGB, HWC -> CHW, 归一化
|
||||
input_image = padded_image[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||
input_image = np.expand_dims(input_image, axis=0) # 添加batch维度
|
||||
|
||||
return input_image, scale, (pad_x, pad_y)
|
||||
|
||||
def postprocess(self, outputs, scale, pad_info, original_shape):
|
||||
"""
|
||||
后处理模型输出 - 针对YOLOv8格式优化
|
||||
|
||||
Args:
|
||||
outputs: 模型原始输出
|
||||
scale: 图像缩放比例
|
||||
pad_info: 填充信息
|
||||
original_shape: 原图尺寸
|
||||
|
||||
Returns:
|
||||
boxes: 检测框 [[x1, y1, x2, y2], ...]
|
||||
scores: 置信度分数
|
||||
class_ids: 类别ID
|
||||
"""
|
||||
predictions = outputs[0] # 形状通常是: [1, 6, 8400] 或 [1, num_classes+4, num_boxes]
|
||||
|
||||
# YOLOv8输出格式: [batch, 4+num_classes, num_boxes]
|
||||
# 需要转置为 [batch, num_boxes, 4+num_classes]
|
||||
if len(predictions.shape) == 3:
|
||||
predictions = predictions.transpose(0, 2, 1) # [1, num_boxes, 4+num_classes]
|
||||
|
||||
predictions = predictions[0] # 移除batch维度: [num_boxes, 4+num_classes]
|
||||
|
||||
# 打印调试信息
|
||||
# print(f"预测输出形状: {predictions.shape}")
|
||||
# print(f"前几个预测值: {predictions[:5]}")
|
||||
|
||||
# 分离坐标和分类信息
|
||||
boxes = predictions[:, :4] # [x_center, y_center, width, height]
|
||||
scores = predictions[:, 4:] # 类别置信度 [num_boxes, num_classes]
|
||||
|
||||
# print(f"检测框形状: {boxes.shape}")
|
||||
# print(f"分数形状: {scores.shape}")
|
||||
|
||||
# 获取最高置信度和对应类别
|
||||
class_ids = np.argmax(scores, axis=1)
|
||||
confidences = np.max(scores, axis=1)
|
||||
|
||||
# print(f"置信度范围: {confidences.min():.4f} - {confidences.max():.4f}")
|
||||
# print(f"检测到的类别: {np.unique(class_ids)}")
|
||||
|
||||
# 过滤低置信度检测
|
||||
valid_indices = confidences > self.conf_threshold
|
||||
valid_boxes = boxes[valid_indices]
|
||||
valid_confidences = confidences[valid_indices]
|
||||
valid_class_ids = class_ids[valid_indices]
|
||||
|
||||
# print(f"过滤后检测数量: {len(valid_boxes)}")
|
||||
|
||||
if len(valid_boxes) == 0:
|
||||
return [], [], []
|
||||
|
||||
# 转换为 [x1, y1, x2, y2] 格式
|
||||
x_center, y_center, width, height = valid_boxes[:, 0], valid_boxes[:, 1], valid_boxes[:, 2], valid_boxes[:, 3]
|
||||
x1 = x_center - width / 2
|
||||
y1 = y_center - height / 2
|
||||
x2 = x_center + width / 2
|
||||
y2 = y_center + height / 2
|
||||
|
||||
converted_boxes = np.stack([x1, y1, x2, y2], axis=1)
|
||||
|
||||
# 坐标反变换到原图
|
||||
pad_x, pad_y = pad_info
|
||||
converted_boxes[:, [0, 2]] = (converted_boxes[:, [0, 2]] - pad_x) / scale
|
||||
converted_boxes[:, [1, 3]] = (converted_boxes[:, [1, 3]] - pad_y) / scale
|
||||
|
||||
# 限制坐标范围
|
||||
h, w = original_shape[:2]
|
||||
converted_boxes[:, [0, 2]] = np.clip(converted_boxes[:, [0, 2]], 0, w)
|
||||
converted_boxes[:, [1, 3]] = np.clip(converted_boxes[:, [1, 3]], 0, h)
|
||||
|
||||
# 非极大值抑制 (NMS)
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
converted_boxes.tolist(),
|
||||
valid_confidences.tolist(),
|
||||
self.conf_threshold,
|
||||
self.iou_threshold
|
||||
)
|
||||
|
||||
if len(indices) > 0:
|
||||
indices = indices.flatten()
|
||||
return converted_boxes[indices], valid_confidences[indices], valid_class_ids[indices]
|
||||
|
||||
return [], [], []
|
||||
|
||||
def detect(self, image):
|
||||
"""
|
||||
对图像进行目标检测
|
||||
|
||||
Args:
|
||||
image: 输入图像 (BGR格式)
|
||||
|
||||
Returns:
|
||||
boxes: 检测框列表
|
||||
scores: 置信度分数列表
|
||||
class_ids: 类别ID列表
|
||||
"""
|
||||
# 预处理
|
||||
input_image, scale, pad_info = self.preprocess(image)
|
||||
|
||||
# 推理
|
||||
outputs = self.session.run([self.output_name], {self.input_name: input_image})
|
||||
|
||||
# 后处理
|
||||
boxes, scores, class_ids = self.postprocess(outputs, scale, pad_info, image.shape)
|
||||
|
||||
return boxes, scores, class_ids
|
||||
|
||||
|
||||
class YOLOv8RKNN:
|
||||
def __init__(self, model_path, input_size=(640, 640)):
|
||||
self.model_path = model_path
|
||||
self.input_size = input_size
|
||||
self.rknn = RKNN()
|
||||
|
||||
# 类别名称,根据你的2个类别修改
|
||||
self.class_names = ['class1', 'class2'] # 请替换为你实际的类别名称
|
||||
|
||||
# 初始化模型
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""加载RKNN模型"""
|
||||
print("Loading RKNN model...")
|
||||
ret = self.rknn.load_rknn(self.model_path)
|
||||
if ret != 0:
|
||||
print("Load RKNN model failed!")
|
||||
return False
|
||||
|
||||
# 初始化运行时环境(在RK3588设备上运行)
|
||||
print("Init RKNN runtime...")
|
||||
ret = self.rknn.init_runtime(target='rk3588', device_id=None, perf_debug=False, eval_mem=False)
|
||||
if ret != 0:
|
||||
print("Init RKNN runtime failed!")
|
||||
return False
|
||||
|
||||
print("RKNN model loaded successfully!")
|
||||
return True
|
||||
|
||||
def preprocess(self, image):
|
||||
"""图像预处理"""
|
||||
# 获取原始图像尺寸
|
||||
self.orig_height, self.orig_width = image.shape[:2]
|
||||
|
||||
# Resize到模型输入尺寸,保持宽高比
|
||||
scale = min(self.input_size[0]/self.orig_width, self.input_size[1]/self.orig_height)
|
||||
new_width = int(self.orig_width * scale)
|
||||
new_height = int(self.orig_height * scale)
|
||||
|
||||
# 缩放图像
|
||||
resized = cv2.resize(image, (new_width, new_height))
|
||||
|
||||
# 创建输入图像(填充到目标尺寸)
|
||||
input_image = np.full((self.input_size[1], self.input_size[0], 3), 114, dtype=np.uint8)
|
||||
|
||||
# 计算填充位置(居中)
|
||||
y_offset = (self.input_size[1] - new_height) // 2
|
||||
x_offset = (self.input_size[0] - new_width) // 2
|
||||
|
||||
# 将缩放后的图像放到中心位置
|
||||
input_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width] = resized
|
||||
|
||||
# 保存缩放参数用于后处理
|
||||
self.scale = scale
|
||||
self.x_offset = x_offset
|
||||
self.y_offset = y_offset
|
||||
|
||||
return input_image
|
||||
|
||||
def postprocess(self, outputs, conf_threshold=0.5, nms_threshold=0.4):
|
||||
"""后处理:解析YOLO输出并进行NMS"""
|
||||
# YOLOv8输出格式: [batch, 84, 8400] (2个类别: 4+2+80=84,但实际只有6维)
|
||||
# 对于2类别: [x, y, w, h, conf_class1, conf_class2]
|
||||
predictions = outputs[0][0] # 移除batch维度
|
||||
|
||||
# 转置为 [8400, 6] 格式
|
||||
predictions = predictions.transpose()
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
class_ids = []
|
||||
|
||||
for detection in predictions:
|
||||
# 提取坐标和类别置信度
|
||||
x, y, w, h = detection[:4]
|
||||
class_confs = detection[4:6] # 2个类别的置信度
|
||||
|
||||
# 找到最大置信度的类别
|
||||
class_id = np.argmax(class_confs)
|
||||
max_conf = class_confs[class_id]
|
||||
|
||||
if max_conf >= conf_threshold:
|
||||
# 转换坐标格式 (中心点 -> 左上角)
|
||||
x1 = x - w/2
|
||||
y1 = y - h/2
|
||||
x2 = x + w/2
|
||||
y2 = y + h/2
|
||||
|
||||
# 将坐标映射回原图尺寸
|
||||
x1 = (x1 - self.x_offset) / self.scale
|
||||
y1 = (y1 - self.y_offset) / self.scale
|
||||
x2 = (x2 - self.x_offset) / self.scale
|
||||
y2 = (y2 - self.y_offset) / self.scale
|
||||
|
||||
# 限制在图像边界内
|
||||
x1 = max(0, min(x1, self.orig_width))
|
||||
y1 = max(0, min(y1, self.orig_height))
|
||||
x2 = max(0, min(x2, self.orig_width))
|
||||
y2 = max(0, min(y2, self.orig_height))
|
||||
|
||||
boxes.append([x1, y1, x2, y2])
|
||||
scores.append(max_conf)
|
||||
class_ids.append(class_id)
|
||||
|
||||
# 执行NMS
|
||||
if len(boxes) > 0:
|
||||
boxes = np.array(boxes)
|
||||
scores = np.array(scores)
|
||||
class_ids = np.array(class_ids)
|
||||
|
||||
# OpenCV NMS
|
||||
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, nms_threshold)
|
||||
|
||||
if len(indices) > 0:
|
||||
indices = indices.flatten()
|
||||
return boxes[indices], scores[indices], class_ids[indices]
|
||||
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
def detect(self, image, conf_threshold=0.5, nms_threshold=0.4):
|
||||
"""执行检测"""
|
||||
# 预处理
|
||||
input_image = self.preprocess(image)
|
||||
|
||||
# 推理
|
||||
start_time = time.time()
|
||||
outputs = self.rknn.inference(inputs=[input_image])
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# 后处理
|
||||
boxes, scores, class_ids = self.postprocess(outputs, conf_threshold, nms_threshold)
|
||||
|
||||
return boxes, scores, class_ids, inference_time
|
||||
|
||||
def draw_detections(self, image, boxes, scores, class_ids):
|
||||
"""在图像上绘制检测结果"""
|
||||
for i in range(len(boxes)):
|
||||
x1, y1, x2, y2 = boxes[i].astype(int)
|
||||
score = scores[i]
|
||||
class_id = int(class_ids[i])
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# 绘制标签
|
||||
label = f"{self.class_names[class_id]}: {score:.2f}"
|
||||
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||
cv2.rectangle(image, (x1, y1-label_size[1]-10),
|
||||
(x1+label_size[0], y1), (0, 255, 0), -1)
|
||||
cv2.putText(image, label, (x1, y1-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||
|
||||
return image
|
||||
|
||||
def release(self):
|
||||
"""释放资源"""
|
||||
if self.rknn:
|
||||
self.rknn.release()
|
||||
|
||||
def main():
|
||||
# 初始化检测器
|
||||
model_path = "/home/orangepi/Desktop/康达机器狗/model_rknn/yolov8_20250820.rknn"
|
||||
detector = YOLOv8RKNN(model_path)
|
||||
|
||||
# 测试单张图片
|
||||
def test_image(image_path):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f"Cannot load image: {image_path}")
|
||||
return
|
||||
|
||||
# 执行检测
|
||||
boxes, scores, class_ids, inference_time = detector.detect(image)
|
||||
|
||||
print(f"Inference time: {inference_time*1000:.2f}ms")
|
||||
print(f"Detected {len(boxes)} objects")
|
||||
|
||||
# 绘制结果
|
||||
result_image = detector.draw_detections(image, boxes, scores, class_ids)
|
||||
|
||||
# 显示结果
|
||||
# cv2.imshow("Detection Result", result_image)
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
cv2.imwrite("xxxxxxx.jpg", result_image)
|
||||
|
||||
# 测试摄像头实时检测
|
||||
def test_camera():
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
if not cap.isOpened():
|
||||
print("Cannot open camera")
|
||||
return
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 执行检测
|
||||
boxes, scores, class_ids, inference_time = detector.detect(frame)
|
||||
|
||||
# 绘制结果
|
||||
result_frame = detector.draw_detections(frame, boxes, scores, class_ids)
|
||||
|
||||
# 显示FPS
|
||||
fps = 1.0 / inference_time if inference_time > 0 else 0
|
||||
cv2.putText(result_frame, f"FPS: {fps:.1f}", (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow("Real-time Detection", result_frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 选择测试模式
|
||||
mode = input("选择模式 (1: 图片检测, 2: 摄像头实时检测): ")
|
||||
|
||||
if mode == "1":
|
||||
image_path = input("输入图片路径: ")
|
||||
test_image(image_path)
|
||||
elif mode == "2":
|
||||
test_camera()
|
||||
else:
|
||||
print("无效选择")
|
||||
|
||||
# 释放资源
|
||||
detector.release()
|
||||
|
||||
def draw_detections(image, boxes, scores, class_ids, class_names=None):
|
||||
"""
|
||||
在图像上绘制检测结果
|
||||
|
||||
Args:
|
||||
image: 输入图像
|
||||
boxes: 检测框
|
||||
scores: 置信度分数
|
||||
class_ids: 类别ID
|
||||
class_names: 类别名称列表(可选)
|
||||
|
||||
Returns:
|
||||
绘制了检测结果的图像
|
||||
"""
|
||||
result_image = image.copy()
|
||||
|
||||
for i, (box, score, class_id) in enumerate(zip(boxes, scores, class_ids)):
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(result_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# 准备标签文本
|
||||
if class_names and class_id < len(class_names):
|
||||
label = f"{class_names[class_id]}: {score:.2f}"
|
||||
else:
|
||||
label = f"Class {class_id}: {score:.2f}"
|
||||
|
||||
# 绘制标签背景
|
||||
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||
cv2.rectangle(result_image, (x1, y1 - label_size[1] - 10),
|
||||
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||
|
||||
# 绘制标签文本
|
||||
cv2.putText(result_image, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||
|
||||
return result_image
|
||||
1
base64.txt
Normal file
1
base64.txt
Normal file
File diff suppressed because one or more lines are too long
14
robot_dog.service
Normal file
14
robot_dog.service
Normal file
@ -0,0 +1,14 @@
|
||||
[Unit]
|
||||
Description=Robot Dog Service
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
User=root
|
||||
Group=root
|
||||
WorkingDirectory=/root/robot_dog_project/kangda_robotic_dog/机器狗后台服务
|
||||
ExecStart=/root/robot_dog_project/kangda_robotic_dog/机器狗后台服务/start.sh
|
||||
Restart=always
|
||||
RestartSec=5s
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
14
run.py
Normal file
14
run.py
Normal file
@ -0,0 +1,14 @@
|
||||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
# uvicorn.run(
|
||||
# "app.main:app",
|
||||
# host="10.0.0.202",
|
||||
# port=12342
|
||||
# )
|
||||
|
||||
uvicorn.run(
|
||||
"app.api.main:app",
|
||||
host="0.0.0.0",
|
||||
port=12345
|
||||
)
|
||||
32
start.sh
Normal file
32
start.sh
Normal file
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 激活虚拟环境
|
||||
source /root/robot_dog_project/myvenv/bin/activate
|
||||
|
||||
# 进入项目目录
|
||||
cd /root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1
|
||||
|
||||
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
# 定义日志目录和带时间戳的日志文件名
|
||||
LOG_DIR="/root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1/logs"
|
||||
TIMESTAMP=$(date +'%Y%m%d_%H%M%S')
|
||||
LOG_FILE="${LOG_DIR}/app_${TIMESTAMP}.log"
|
||||
|
||||
# 创建日志目录(如果不存在)
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
# 激活Python虚拟环境
|
||||
source /root/robot_dog_project/myvenv/bin/activate
|
||||
|
||||
# 切换到项目目录
|
||||
cd /root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1
|
||||
|
||||
# 打印启动信息
|
||||
echo "启动应用程序,日志将保存到: $LOG_FILE"
|
||||
echo "启动时间: $(date '+%Y-%m-%d %H:%M:%S')" | tee -a "$LOG_FILE"
|
||||
|
||||
# 运行Python程序并将输出同时显示在控制台和保存到日志文件
|
||||
exec python run.py --env=dev 2>&1 | tee -a "$LOG_FILE"
|
||||
110
test_base64_server.py
Normal file
110
test_base64_server.py
Normal file
@ -0,0 +1,110 @@
|
||||
import base64
|
||||
import requests
|
||||
import json
|
||||
|
||||
def image_to_base64(image_path):
|
||||
"""
|
||||
将图片文件转换为base64编码
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return encoded_string
|
||||
|
||||
def test_ocr_api(image_path, api_url="http://10.0.0.202:12342/api/v1/ocr_from_base64"):
|
||||
"""
|
||||
测试OCR API接口
|
||||
"""
|
||||
# 将图片转换为base64
|
||||
image_base64 = image_to_base64(image_path)
|
||||
|
||||
with open("base64.txt", "w") as f:
|
||||
f.write(repr(image_base64))
|
||||
|
||||
# 准备请求数据
|
||||
payload = {
|
||||
"image_base64": image_base64,
|
||||
"image_type": "jpg" # 根据实际图片类型修改
|
||||
}
|
||||
|
||||
# 发送POST请求
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(api_url, data=json.dumps(payload), headers=headers)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("OCR识别结果:")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
return result
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求异常: {str(e)}")
|
||||
return None
|
||||
|
||||
def test_detect(image_path: str, api_url:str = "http://10.0.0.202:12342/api/v1/detect_from_base64_0"):
|
||||
"""
|
||||
测试yolov8消防区域侵占接口
|
||||
"""
|
||||
image_base64 = image_to_base64(image_path)
|
||||
# 准备请求数据
|
||||
payload = {
|
||||
"image_base64": image_base64,
|
||||
"image_type": "jpg" # 根据实际图片类型修改
|
||||
}
|
||||
|
||||
# 发送POST请求
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(api_url, data=json.dumps(payload), headers=headers)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("detect识别结果:")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
return result
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求异常: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试图片路径,请根据实际情况修改
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/2c7cc83019e7388a7041101da92c9829_frame_000000.jpg"
|
||||
|
||||
#---------------------------------------测试ocr-----------------------------------------
|
||||
test_image_path = "images/AiCheck_20251016114138.jpg"
|
||||
|
||||
# api_url="http://10.0.0.202:12342/api/v1/ocr_onnx_from_base64"
|
||||
# api_url="http://192.168.30.195:12345/api/v1/ocr_onnx_from_base64"
|
||||
# # 调用测试函数
|
||||
# test_ocr_api(test_image_path, api_url)
|
||||
#---------------------------------------测试ocrender-----------------------------------------
|
||||
|
||||
# # -----------------------------------------测试yolov8 侵占消防区域检测-----------------------------------------
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000000.jpg"
|
||||
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000720.jpg"
|
||||
# api_url = "http://10.0.0.202:12342/api/v1/detect_onnx_from_base64_0"
|
||||
# api_url = "http://192.168.30.195:12345/api/v1/detect_onnx_from_base64_0"
|
||||
# test_detect(test_image_path, api_url)
|
||||
# #-----------------------------------------测试yolov8 侵占消防区域检测 end-----------------------------------------
|
||||
|
||||
|
||||
# #-----------------------------------------测试yolov8 灭火器检测-----------------------------------------
|
||||
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/ce81420a27cdaff14fe42f967eaa49a3_frame_001060.jpg"
|
||||
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
|
||||
# api_url = "http://10.0.0.202:12342/api/v1/detect_from_base64_1"
|
||||
api_url = "http://192.168.30.195:12345/api/v1/detect_from_base64_1"
|
||||
test_detect(test_image_path, api_url=api_url)
|
||||
# #-----------------------------------------测试yolov8 灭火器检测 end-----------------------------------------
|
||||
27
test_generic.py
Normal file
27
test_generic.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Generic, TypeVar, Dict
|
||||
from typing import Optional
|
||||
|
||||
K = TypeVar("K") # 键类型
|
||||
V = TypeVar("V") # 值类型
|
||||
|
||||
class GenericCache(Generic[K, V]):
|
||||
def __init__(self):
|
||||
self._store: Dict[K, V] = {}
|
||||
|
||||
def set(self, key: K, value: V) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
def get(self, key: K) -> Optional[V]:
|
||||
return self._store.get(key)
|
||||
|
||||
# 使用示例
|
||||
cache = GenericCache[str, int]() # 键为 str,值为 int
|
||||
cache.set("count", 100)
|
||||
value: Optional[int] = cache.get("count") # 返回 100
|
||||
|
||||
cache.set(112, "123")
|
||||
|
||||
|
||||
for key in cache._store.keys():
|
||||
|
||||
print(type(cache.get(key)))
|
||||
7
test_ocr.py
Normal file
7
test_ocr.py
Normal file
@ -0,0 +1,7 @@
|
||||
from app.util.baiduOCR import BaiduOCR
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ocr = BaiduOCR()
|
||||
result = ocr.ocr('tmp/ocr_images/c723d5d8-697b-41e5-8cf4-d4644ce89a77.jpg')
|
||||
print(result)
|
||||
87
test_requests.py
Normal file
87
test_requests.py
Normal file
@ -0,0 +1,87 @@
|
||||
import requests
|
||||
from requests import Response
|
||||
from typing import List, TypeVar, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T') # 泛型类型声明
|
||||
|
||||
REQUEST_URL = "http://aijinan.biaofun.com.cn/app/query/shandong" # 替换为实际API地址
|
||||
|
||||
def main():
|
||||
cities = ["济南市", "青岛市", "临沂市", "菏泽市", "威海市", "东营市"]
|
||||
|
||||
# 修正:每次请求创建新的参数对象,避免累积
|
||||
for city in cities:
|
||||
params = {"city": city}
|
||||
|
||||
# 发起RPC请求
|
||||
response_list = rpc_get(REQUEST_URL, params, False)
|
||||
|
||||
# 处理响应
|
||||
if not response_list:
|
||||
logger.warning(f"Empty response for city: {city}")
|
||||
continue
|
||||
|
||||
# 设置城市名称并处理后续逻辑 (根据实际Response类调整)
|
||||
response_list[0]['name'] = city # 这里假定response_list是字典列表
|
||||
after(response_list)
|
||||
|
||||
def rpc_get(url: str, params: Dict[str, Any], right_or_wrong: bool) -> Optional[List[T]]:
|
||||
"""
|
||||
RPC请求封装
|
||||
:param url: 请求URL
|
||||
:param params: 请求参数
|
||||
:param right_or_wrong: 未使用的标志(保留原Java参数)
|
||||
:return: 响应对象列表
|
||||
"""
|
||||
logger.info(f"Request URL: {url}")
|
||||
|
||||
try:
|
||||
# 发起POST请求
|
||||
response: Response = requests.post(url, data=params)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Request failed, Status Code: {response.status_code}, URL: {url}")
|
||||
return None
|
||||
|
||||
response_text = response.text
|
||||
logger.debug(f"Response: {response_text}")
|
||||
|
||||
# 解析响应(需要根据实际API响应格式实现)
|
||||
return parse_response(response_text)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"RPC request failed: {str(e)}")
|
||||
return None
|
||||
|
||||
def parse_response(response_text: str) -> List[T]:
|
||||
"""
|
||||
解析API响应 (需要根据实际API响应格式实现)
|
||||
示例实现 - 替换为实际解析逻辑
|
||||
"""
|
||||
# 这里需要根据实际API返回格式实现解析
|
||||
# 示例:解析为JSON对象
|
||||
import json
|
||||
try:
|
||||
data = json.loads(response_text)
|
||||
# 根据实际数据结构返回列表
|
||||
return [data] if isinstance(data, dict) else data
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Invalid JSON response")
|
||||
return []
|
||||
|
||||
def after(response_list: List[Any]):
|
||||
"""
|
||||
后续处理逻辑 (根据实际需求实现)
|
||||
"""
|
||||
# 示例实现
|
||||
print(f"Processing response for: {response_list[0].get('name')}")
|
||||
# 实际业务逻辑...
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user