Add robot dog OCR service and ignore local artifacts

This commit is contained in:
sladro 2026-03-19 11:32:09 +08:00
commit 19b2aa43d6
29 changed files with 21280 additions and 0 deletions

34
.gitignore vendored Normal file
View File

@ -0,0 +1,34 @@
__pycache__/
*.py[cod]
*$py.class
.Python
.venv/
venv/
env/
ENV/
.env
.env.*
.idea/
.vscode/
.pytest_cache/
.mypy_cache/
.ruff_cache/
.coverage
.coverage.*
htmlcov/
build/
dist/
*.egg-info/
.eggs/
logs/
tmp/
*.log
*.pid

32
app/api/main.py Normal file
View File

@ -0,0 +1,32 @@
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from app.api.v1.api import router
# from app.core.config import settings
# from app.services.websocket_service import websocket_service
app = FastAPI(
title="机器狗后台服务",
description="机器狗后台API接口",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(router)
# @app.websocket("/ws")
# async def websocket_endpoint(websocket: WebSocket):
# """WebSocket端点"""
# await websocket_service.handle_websocket(websocket)
@app.get("/")
async def root():
return {"message": "机器狗后台API服务正在运行"}

394
app/api/v1/api.py Normal file
View File

@ -0,0 +1,394 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
import base64
import io
import cv2
import numpy as np
from PIL import Image
import os
import uuid
from datetime import datetime
from app.core.database import get_db
from app.services.imageServices import ImageService
from app.schemas.image import ImageBase
from app.schemas.ocr import ImageBase64Request
from app.util.responseHttp import ResponseUtil
from app.util.baiduOCR import BaiduOCR, BaiduOCRONNX
from app.util.yolov8Obj import Yolov8Obj, YOLOv8ONNX
# from app.crud.event import event
# from app.schemas.event import EventList, EventDetail, EventUpdate, EventQuery, TestEvent
baiduOCR = BaiduOCR()
baiduOcrOnnx = BaiduOCRONNX()
yolov8Obj = Yolov8Obj()
yolov8ONNX = YOLOv8ONNX()
router = APIRouter(prefix="/api/v1", tags=["ocr"])
@router.get("/hello")
async def get_hello():
# return {"data":"hello"}
return ResponseUtil.error(msg=f"OCR识别失败", data=None)
@router.get("/test_select")
async def test_select(
db: AsyncSession = Depends(get_db),
image_query: ImageBase = None
):
"""
测试查询
"""
result = await ImageService.get_image_list(db,image_query)
return ResponseUtil.success(msg="success", data=result)
@router.get("/test_ocr")
async def test_ocr(
# image_path: str = None
):
"""
测试OCR
"""
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/2c7cc83019e7388a7041101da92c9829_frame_000000.jpg'
result = baiduOCR.ocr(image_path)
# print(result)
return ResponseUtil.success(msg="success", data=result)
# return ResponseUtil.success(msg="success")
@router.post("/ocr_from_base64")
async def ocr_from_base64(request: ImageBase64Request):
"""
从base64图片数据进行OCR识别
"""
try:
# 移除base64数据的前缀如果有
image_base64 = request.image_base64
if image_base64.startswith('data:image'):
# 格式如: data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQ...
image_base64 = image_base64.split(',')[1]
# 解码base64数据
image_data = base64.b64decode(image_base64)
# 将字节数据转换为PIL Image对象
image = Image.open(io.BytesIO(image_data))
# 创建临时文件路径
temp_dir = "tmp/ocr_images"
os.makedirs(temp_dir, exist_ok=True)
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
temp_path = os.path.join(temp_dir, temp_filename)
# 保存图片到临时文件
image.save(temp_path)
# 使用PaddleOCR进行识别
result = baiduOCR.ocr(temp_path)
# 删除临时文件
# os.remove(temp_path)
return ResponseUtil.success(msg="OCR识别成功", data=result)
except Exception as e:
return ResponseUtil.error(msg=f"OCR识别失败: {str(e)}", data=None)
@router.post("/ocr_onnx_from_base64")
async def ocr_from_base64(request: ImageBase64Request):
"""
从base64图片数据进行OCR识别
"""
try:
# 移除base64数据的前缀如果有
image_base64 = request.image_base64
if image_base64.startswith('data:image'):
# 格式如: data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQ...
image_base64 = image_base64.split(',')[1]
# 解码base64数据
image_data = base64.b64decode(image_base64)
# 将字节数据转换为PIL Image对象
image = Image.open(io.BytesIO(image_data))
# 创建临时文件路径
temp_dir = "./tmp/ocr_images"
os.makedirs(temp_dir, exist_ok=True)
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
temp_path = os.path.join(temp_dir, temp_filename)
# 保存图片到临时文件
image.save(temp_path)
# 使用PaddleOCR进行识别
result = baiduOcrOnnx.ocr(temp_path)
print(result)
# 删除临时文件
# os.remove(temp_path)
# return ResponseUtil.success(msg="OCR识别成功", data=[result['text'], result['confidence']])
return ResponseUtil.success(msg="OCR识别成功", data=result)
except Exception as e:
return ResponseUtil.error(msg=f"OCR识别失败: {str(e)}", data=None)
@router.post("/detect_from_base64_0")
async def ocr_from_base64(request: ImageBase64Request):
""" 从base64图片进行目标检测, 检测是否侵占消防区域"""
try:
image_base64 = request.image_base64
if image_base64.startswith('data:image'):
image_base64 = image_base64.split(',')[1]
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
temp_dir = "./tmp/detect_images"
os.makedirs(temp_dir, exist_ok=True)
temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
temp_path = os.path.join(temp_dir, temp_filename)
image.save(temp_path)
cls, conf, coords = yolov8Obj.detect(temp_path)
os.remove(temp_path)
return ResponseUtil.success(msg="侵占消防区域目标检测成功,是否有遮挡", data=(len(cls)==0))
except Exception as e:
return ResponseUtil.error(msg=f"检测是否侵占消防区域失败: {str(e)}", data=None)
@router.post("/detect_onnx_from_base64_0")
async def ocr_from_base64(request: ImageBase64Request):
""" 从base64图片进行目标检测, 检测是否侵占消防区域"""
# 为不同类别生成不同颜色
def get_color(class_id):
np.random.seed(class_id)
return tuple(map(int, np.random.randint(0, 255, 3)))
try:
# image_base64 = request.image_base64
# if image_base64.startswith('data:image'):
# image_base64 = image_base64.split(',')[1]
# image_data = base64.b64decode(image_base64)
# image = Image.open(io.BytesIO(image_data))
# temp_dir = "./tmp/detect_images"
# os.makedirs(temp_dir, exist_ok=True)
# temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
# temp_path = os.path.join(temp_dir, temp_filename)
# image.save(temp_path)
# boxes, scores, class_ids = yolov8ONNX.detect(image_data)
# os.remove(temp_path)
image_base64 = request.image_base64
if image_base64.startswith('data:image'):
image_base64 = image_base64.split(',')[1]
image_data = base64.b64decode(image_base64)
# 将字节流转换为OpenCV格式的BGR图像
image_np = np.frombuffer(image_data, np.uint8)
image_cv = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # 这会得到BGR格式的图像
# 不再需要保存到临时文件,直接使用内存中的图像
boxes, scores, class_ids = yolov8ONNX.detect(image_cv)
# 绘制检测结果
image_with_boxes = image_cv.copy()
for box, score, class_id in zip(boxes, scores, class_ids):
# 获取边界框坐标
x1, y1, x2, y2 = map(int, box)
# 生成随机颜色或使用固定颜色方案
color = get_color(class_id) # 绿色也可以根据class_id设置不同颜色
# 绘制边界框
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
# 准备标签文本
label = f"Class {class_id}: {score:.2f}"
# 如果你有类别名称字典,可以这样使用:
# label = f"{class_names[class_id]}: {score:.2f}"
# 计算文本大小以绘制背景
(text_width, text_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
# 绘制文本背景
cv2.rectangle(
image_with_boxes,
(x1, y1 - text_height - baseline - 5),
(x1 + text_width, y1),
color,
-1
)
# 绘制文本
cv2.putText(
image_with_boxes,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0), # 黑色文字
1,
cv2.LINE_AA
)
# 创建保存目录
save_dir = "tmp/detect_images"
os.makedirs(save_dir, exist_ok=True)
# 生成带时间戳的随机文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_id = uuid.uuid4().hex[:8] # 取UUID的前8位
filename = f"detect_{timestamp}_{random_id}.jpg"
output_path = os.path.join(save_dir, filename)
# 保存图像
cv2.imwrite(output_path, image_with_boxes)
print(f"检测结果已保存到: {output_path}")
return ResponseUtil.success(msg="侵占消防区域目标检测成功,是否有遮挡", data=(len(boxes)==0))
except Exception as e:
return ResponseUtil.error(msg=f"检测是否侵占消防区域失败: {str(e)}", data=None)
@router.post("/detect_from_base64_1")
async def ocr_from_base64(request: ImageBase64Request):
""" 从吧色图64图片进行目标检测, 检测是否存在灭火器"""
# try:
# 为不同类别生成不同颜色
def get_color(class_id):
np.random.seed(class_id)
return tuple(map(int, np.random.randint(0, 255, 3)))
try:
image_base64 = request.image_base64
if image_base64.startswith('data:image'):
image_base64 = image_base64.split(',')[1]
image_data = base64.b64decode(image_base64)
# 将字节流转换为OpenCV格式的BGR图像
image_np = np.frombuffer(image_data, np.uint8)
image_cv = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # 这会得到BGR格式的图像
# 不再需要保存到临时文件,直接使用内存中的图像
boxes, scores, class_ids = yolov8ONNX.detect(image_cv)
# 绘制检测结果
image_with_boxes = image_cv.copy()
for box, score, class_id in zip(boxes, scores, class_ids):
# 获取边界框坐标
x1, y1, x2, y2 = map(int, box)
# 生成随机颜色或使用固定颜色方案
color = get_color(class_id) # 绿色也可以根据class_id设置不同颜色
# 绘制边界框
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
# 准备标签文本
label = f"Class {class_id}: {score:.2f}"
# 如果你有类别名称字典,可以这样使用:
# label = f"{class_names[class_id]}: {score:.2f}"
# 计算文本大小以绘制背景
(text_width, text_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
# 绘制文本背景
cv2.rectangle(
image_with_boxes,
(x1, y1 - text_height - baseline - 5),
(x1 + text_width, y1),
color,
-1
)
# 绘制文本
cv2.putText(
image_with_boxes,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0), # 黑色文字
1,
cv2.LINE_AA
)
# 创建保存目录
save_dir = "tmp/detect_images"
os.makedirs(save_dir, exist_ok=True)
# 生成带时间戳的随机文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_id = uuid.uuid4().hex[:8] # 取UUID的前8位
filename = f"detect_{timestamp}_{random_id}.jpg"
output_path = os.path.join(save_dir, filename)
# 保存图像
cv2.imwrite(output_path, image_with_boxes)
print(f"检测结果已保存到: {output_path}")
return ResponseUtil.success(msg="灭火器目标检测成功,是否存在灭火器", data=(len(class_ids)!=0 and 0 in class_ids))
# return ResponseUtil.success(msg="目标检测成功", data=cls)
except Exception as e:
return ResponseUtil.error(msg=f"检测是否存在灭火器失败: {str(e)}", data=None)
# @router.post("/detect_from_base64_1")
# async def ocr_from_base64(request: ImageBase64Request):
# """ 从吧色图64图片进行目标检测, 检测是否存在灭火器"""
# try:
# image_base64 = request.image_base64
# if image_base64.startswith('data:image'):
# image_base64 = image_base64.split(',')[1]
# image_data = base64.b64decode(image_base64)
# image = Image.open(io.BytesIO(image_data))
# temp_dir = "./tmp/detect_images"
# os.makedirs(temp_dir, exist_ok=True)
# temp_filename = f"{uuid.uuid4()}.{request.image_type or 'jpg'}"
# temp_path = os.path.join(temp_dir, temp_filename)
# image.save(temp_path)
# cls, conf, coords = yolov8Obj.detect(temp_path)
# os.remove(temp_path)
# return ResponseUtil.success(msg="灭火器目标检测成功,是否存在灭火器", data=(len(cls)!=0 and 0 in cls))
# # return ResponseUtil.success(msg="目标检测成功", data=cls)
# except Exception as e:
# return ResponseUtil.error(msg=f"检测是否存在灭火器失败: {str(e)}", data=None)

122
app/api/v1/testLogin.py Normal file
View File

@ -0,0 +1,122 @@
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
# 配置参数(应从环境变量读取)
SECRET_KEY = "your-secret-key-keep-it-secret!"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# 模拟用户数据库
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "johndoe@example.com",
# 哈希后的密码(明文是 secret
"hashed_password": "$2b$12$EixZaYVK1fsbY1eZIbOnjesN9NwG1s3Z6FDcjyH103a2.dJgD0L4q",
"disabled": False,
}
}
# Token 模型
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
# 用户模型
class User(BaseModel):
username: str
email: Optional[str] = None
full_name: Optional[str] = None
disabled: Optional[bool] = None
class UserInDB(User):
hashed_password: str
# 密码哈希配置
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 配置
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
# 密码验证工具
def verify_password(plain_password: str, hashed_password: str):
return pwd_context.verify(plain_password, hashed_password)
# 获取用户
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
# 用户认证
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
# if not verify_password(password, user.hashed_password):
# return False
return user
# 创建访问令牌
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# 获取当前用户(依赖注入)
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
# 登录路由
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
# 受保护路由
@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user

89
app/config/config.py Normal file
View File

@ -0,0 +1,89 @@
from pydantic_settings import BaseSettings
import os
import sys
import argparse
from dotenv import load_dotenv
class DataBaseSettings(BaseSettings):
# 数据库配置
DB_HOST: str = "10.0.0.17"
DB_PORT: int = 3306
DB_USER: str = "root"
DB_PASSWORD: str = "root"
DB_NAME: str = "kangda_robotic_dog"
DB_CHARSET: str = "utf8mb4"
DB_POOL_SIZE: int = 5
DB_MAX_OVERFLOW: int = 10
DB_POOL_TIMEOUT: int = 30
DB_POOL_RECYCLE: int = 1800
# class Config:
# env_file = ".env"
class OCRSettings(BaseException):
TEXT_DETECTION_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/PP-OCRv5_server_det_infer_20250814/"
TEXT_RECONGNITION_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/PP-OCRv5_server_rec_infer_20250815/"
# TEXT_DETECTION_MODEL_ONNX_DIR: str ='/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx'
# TEXT_RECONGNITION_MODEL_ONNX_DIR: str ='/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx'
TEXT_DETECTION_MODEL_ONNX_DIR="/root/robot_dog_project/kangda_robotic_dog/models/det_mobile_14_shape.onnx"
TEXT_RECONGNITION_MODEL_ONNX_DIR="/root/robot_dog_project/kangda_robotic_dog/models/rec_mobile_14_shape.onnx"
class YoloV8Settings(BaseException):
YOLOV8_MODEL_DIR= "/root/robot_dog_project/kangda_robotic_dog/models/best.pt"
YOLOV8_MODEL_ONNX_DIRS="/root/robot_dog_project/kangda_robotic_dog/models/yolov8_20250820.onnx"
class GetSettings:
def __init__(self):
self.parse_cli_args()
def get_database_settings(self):
return DataBaseSettings()
def get_ocr_settings(self):
return OCRSettings()
def get_yolov8_settings(self):
return YoloV8Settings()
def parse_cli_args(self):
"""
解析命令行参数
"""
if 'uvicorn' in sys.argv[0]:
# 使用uvicorn启动时命令行参数需要按照uvicorn的文档进行配置无法自定义参数
pass
else:
# 使用argparse定义命令行参数
parser = argparse.ArgumentParser(description='命令行参数')
parser.add_argument('--env', type=str, default='dev', help='运行环境')
# 解析命令行参数
args = parser.parse_args()
# 设置环境变量如果未设置命令行参数默认APP_ENV为dev
os.environ['APP_ENV'] = args.env if args.env else 'dev'
# 读取运行环境
run_env = os.environ.get('APP_ENV', '')
# 运行环境未指定时默认加载.env.dev
env_file = '.env.dev'
# 运行环境不为空时按命令行参数加载对应.env文件
if run_env != '':
env_file = f'.env.{run_env}'
print(f"加载配置 .env.{run_env}")
# 加载配置
load_dotenv(env_file)
get_settings = GetSettings()
database_settings = get_settings.get_database_settings()
ocr_settings = get_settings.get_ocr_settings()
yolov8_settings = get_settings.get_yolov8_settings()

77
app/config/constant.py Normal file
View File

@ -0,0 +1,77 @@
# from config.env import DataBaseConfig
class CommonConstant:
"""
常用常量
WWW: www主域
HTTP: http请求
HTTPS: https请求
LOOKUP_RMI: RMI远程方法调用
LOOKUP_LDAP: LDAP远程方法调用
LOOKUP_LDAPS: LDAPS远程方法调用
YES: 是否为系统默认
NO: 是否为系统默认
DEPT_NORMAL: 部门正常状态
DEPT_DISABLE: 部门停用状态
UNIQUE: 校验是否唯一的返回标识
NOT_UNIQUE: 校验是否唯一的返回标识
"""
WWW = 'www.'
HTTP = 'http://'
HTTPS = 'https://'
LOOKUP_RMI = 'rmi:'
LOOKUP_LDAP = 'ldap:'
LOOKUP_LDAPS = 'ldaps:'
YES = 'Y'
NO = 'N'
DEPT_NORMAL = '0'
DEPT_DISABLE = '1'
UNIQUE = True
NOT_UNIQUE = False
class HttpStatusConstant:
"""
返回状态码
SUCCESS: 操作成功
CREATED: 对象创建成功
ACCEPTED: 请求已经被接受
NO_CONTENT: 操作已经执行成功但是没有返回数据
MOVED_PERM: 资源已被移除
SEE_OTHER: 重定向
NOT_MODIFIED: 资源没有被修改
BAD_REQUEST: 参数列表错误缺少格式不匹配
UNAUTHORIZED: 未授权
FORBIDDEN: 访问受限授权过期
NOT_FOUND: 资源服务未找到
BAD_METHOD: 不允许的http方法
CONFLICT: 资源冲突或者资源被锁
UNSUPPORTED_TYPE: 不支持的数据媒体类型
ERROR: 系统内部错误
NOT_IMPLEMENTED: 接口未实现
WARN: 系统警告消息
"""
SUCCESS = 200
CREATED = 201
ACCEPTED = 202
NO_CONTENT = 204
MOVED_PERM = 301
SEE_OTHER = 303
NOT_MODIFIED = 304
BAD_REQUEST = 400
UNAUTHORIZED = 401
FORBIDDEN = 403
NOT_FOUND = 404
BAD_METHOD = 405
CONFLICT = 409
UNSUPPORTED_TYPE = 415
ERROR = 500
NOT_IMPLEMENTED = 501
WARN = 601

32
app/core/database.py Normal file
View File

@ -0,0 +1,32 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from app.config.config import database_settings
engine = create_async_engine(
f"mysql+aiomysql://{database_settings.DB_USER}:{database_settings.DB_PASSWORD}@{database_settings.DB_HOST}:{database_settings.DB_PORT}/{database_settings.DB_NAME}",
pool_size=database_settings.DB_POOL_SIZE, # 连接池常驻连接数
max_overflow=database_settings.DB_MAX_OVERFLOW, # 池最大溢出连接数
pool_timeout=database_settings.DB_POOL_TIMEOUT, # 获取连接超时时间(秒)
pool_recycle=database_settings.DB_POOL_RECYCLE, # 连接回收间隔(秒)
echo=False # 关闭SQL日志输出
)
# expire_on_commit=False 禁用提交后过期对象 --> 提交后原对象仍然可以使用.
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
# ORM模型基类
Base = declarative_base()
# 获取数据库会话
async def get_db():
async with async_session() as session:
try:
yield session
# await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()

88
app/crud/base.py Normal file
View File

@ -0,0 +1,88 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from app.core.database import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
'''
Generic[ModelType, CreateSchemaType, UpdateSchemaType] Python 类型提示中泛型编程的关键语法用于声明一个泛型类
它的作用是让 CRUDBase 类具备类型参数化的能力允许在继承或实例化时动态绑定具体类型从而实现代码复用和类型安全
声明 CRUDBase 类需要三个类型参数ModelTypeCreateSchemaTypeUpdateSchemaType
这些类型参数会在类的内部方法中使用 getcreateupdate确保类型一致性
'''
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD对象与SQLAlchemy模型类一起使用
:param model: SQLAlchemy模型类
"""
self.model = model
# 根据id获取对象
async def get(self, db: AsyncSession, id: Any) -> Optional[ModelType]:
"""
通过ID获取对象
"""
query = select(self.model).where(self.model.id == id)
result = await db.execute(query)
return result.scalar_one_or_none()
# 分页查询
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
"""
获取多个对象
"""
query = select(self.model).offset(skip).limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""
创建对象
"""
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""
更新对象
"""
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def remove(self, db: AsyncSession, *, id: Any) -> ModelType:
"""
删除对象
"""
obj = await self.get(db=db, id=id)
await db.delete(obj)
await db.commit()
return obj

130
app/crud/event.py Normal file
View File

@ -0,0 +1,130 @@
from typing import List, Optional, Dict, Any
from sqlalchemy import select, and_, or_, join
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.crud.base import CRUDBase
from app.models.models import Event, Image, Temperature
from app.schemas.event import EventUpdate, EventQuery, TestEvent
class CRUDEvent(CRUDBase[Event, EventUpdate, EventUpdate]):
async def get_by_id(self, db: AsyncSession, *, event_id: str) -> Optional[Event]:
"""根据ID获取事件"""
query = (
select(Event)
.options(
selectinload(Event.images),
selectinload(Event.temperatures)
)
.where(Event.eventId == event_id)
)
result = await db.execute(query)
return result.scalar_one_or_none()
async def get_multi_with_query(
self,
db: AsyncSession,
*,
query: EventQuery
) -> List[Event]:
"""根据查询条件获取事件列表"""
conditions = []
if query.start_time:
conditions.append(Event.insDate >= query.start_time)
if query.end_time:
conditions.append(Event.insDate <= query.end_time)
if query.etypeName:
conditions.append(Event.etypeName == query.etypeName)
if query.area:
conditions.append(Event.area == query.area)
query_stmt = (
select(Event)
.options(
selectinload(Event.images),
selectinload(Event.temperatures)
)
)
if conditions:
query_stmt = query_stmt.where(and_(*conditions))
query_stmt = query_stmt.offset(query.skip).limit(query.limit)
result = await db.execute(query_stmt)
return result.scalars().all()
async def update_event(
self,
db: AsyncSession,
*,
event_id: str,
obj_in: EventUpdate
) -> Optional[Event]:
"""更新事件信息"""
event = await self.get_by_id(db, event_id=event_id)
if not event:
return None
update_data = obj_in.model_dump()
for field, value in update_data.items():
setattr(event, field, value)
# db.add(event)
await db.commit()
await db.refresh(event)
return event
async def delete_event(
self,
db: AsyncSession,
*,
event_id: str
) -> Optional[Event]:
"""删除事件"""
event = await self.get_by_id(db, event_id=event_id)
if not event:
return None
await db.delete(event)
await db.commit()
return event
async def get_test(
self,
db: AsyncSession,
) -> List[TestEvent]: #响应类型要写对啊
'''
eventId
number
name
imageUrl
localPath
temperature
confidence
createTime
'''
query_stmt = (
select(Event.eventId, Event.number, Event.name, Image.imageUrl, Image.localPath, Temperature.temperature, Temperature.confidence, Temperature.createTime)
.select_from(Event)
.outerjoin(Image, Event.eventId == Image.eventId)
.outerjoin(Temperature, Image.imageId == Temperature.imageId)
# 多个查询条件
.where(and_(Event.etypeName=="日常巡检", True))
)
result = await db.execute(query_stmt)
# 获取字典
result = result.mappings().all()
# print(result)
return [TestEvent(**row) for row in result]
event = CRUDEvent(Event)

18
app/dao/imageDao.py Normal file
View File

@ -0,0 +1,18 @@
from typing import List
# from app.models.models import image
from app.schemas.image import ImageBase
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.models import Image
class ImageDao:
@classmethod
async def get_image_list(cls, db: AsyncSession, query_model: ImageBase):
stmt = (
select(Image)
)
result = await db.execute(stmt)
return result.mappings().all()

25
app/main.py Normal file
View File

@ -0,0 +1,25 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.v1.api import router
app = FastAPI(
title="测试fastapi",
description="测试啊",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router, prefix="/api/v1")
@app.get("/")
async def root():
return {"message": "测试成功"}

134
app/models/models.py Normal file
View File

@ -0,0 +1,134 @@
from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, Index, BigInteger
from sqlalchemy.orm import relationship
from app.core.database import Base
class Image(Base):
__tablename__ = "image"
imageId = Column(BigInteger, primary_key=True, autoincrement=True, comment='图片ID')
localPath = Column(String(500), comment='本地存储路径')
type = Column(String(2), comment='图片类型, 0 ocr 图片, 1 目标检测图片')
result = Column(String(500), comment='检测结果')
createTime = Column(DateTime, default=datetime.now, comment='创建时间')
updateTime = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# class Event(Base):
# __tablename__ = "event"
# eventId = Column(String(50), primary_key=True)
# tenantInfoId = Column(String(100))
# reportEventId = Column(String(100))
# number = Column(String(20))
# name = Column(String(20))
# eclassify = Column(String(5))
# operationType = Column(String(5))
# etype = Column(String(20))
# etypeName = Column(String(20))
# enTypeName = Column(String(30))
# hkTypeName = Column(String(20))
# reportStatus = Column(String(5))
# results = Column(String(5))
# insDate = Column(DateTime)
# insDateShow = Column(DateTime)
# updDate = Column(DateTime)
# updDateShow = Column(DateTime)
# fileType = Column(String(5))
# area = Column(String(20))
# floor = Column(String(10))
# map = Column(String(20))
# staffId = Column(String(40))
# targetUserId = Column(String(40))
# position = Column(String(100))
# actualStaffName = Column(String(20))
# targetStaffName = Column(String(20))
# routeName = Column(String(20))
# phoneAddress = Column(String(500))
# width = Column(String(10))
# height = Column(String(10))
# resolution = Column(String(10))
# originX = Column(String(20))
# originY = Column(String(20))
# imgList = Column(String(20))
# robotType = Column(String(5))
# eventFloor = Column(String(10))
# floorName = Column(String(10))
# coordId = Column(String(40))
# coord = Column(String(40))
# coordName = Column(String(30))
# positonName = Column(String(20))
# processingRemark = Column(String(300))
# carId = Column(String(40))
# parkingSpaceType = Column(String(10))
# parkingSpaceNumber = Column(String(40))
# carNumber = Column(String(40))
# eno = Column(String(40))
# instrument = Column(String(40))
# evideo = Column(String(40))
# createTime = Column(DateTime, default=datetime.now, comment='本地后台创建时间')
# updateTime = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='本地后台更新时间')
# # 关系
# images = relationship("Image", back_populates="event", cascade="all, delete-orphan")
# temperatures = relationship("Temperature", back_populates="event", cascade="all, delete-orphan")
# process_logs = relationship("ProcessLog", back_populates="event", cascade="all, delete-orphan")
# class Image(Base):
# __tablename__ = "image"
# imageId = Column(BigInteger, primary_key=True, autoincrement=True, comment='图片ID')
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
# imageUrl = Column(String(500), nullable=False, comment='图片URL')
# localPath = Column(String(500), comment='本地存储路径')
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
# # 关系
# event = relationship("Event", back_populates="images" )
# temperatures = relationship("Temperature", back_populates="image")
# __table_args__ = (
# Index('idx_image_event_id', 'eventId'),
# )
# class Temperature(Base):
# __tablename__ = "temperature"
# tempId = Column(BigInteger, primary_key=True, autoincrement=True, comment='温度记录ID')
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
# imageId = Column(BigInteger, ForeignKey('image.imageId'), nullable=False, comment='关联图片ID')
# temperature = Column(String(20), nullable=False, comment='温度值')
# # status = Column(String(2), comment='温度是否正常')
# confidence = Column(String(40), nullable=False, comment='识别置信度')
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
# # 关系
# event = relationship("Event", back_populates="temperatures")
# image = relationship("Image", back_populates="temperatures")
# __table_args__ = (
# Index('idx_temp_event_id', 'eventId'),
# Index('idx_temp_create_time', 'createTime'),
# )
# class ProcessLog(Base):
# __tablename__ = "process_log"
# logId = Column(BigInteger, primary_key=True, autoincrement=True, comment='日志ID')
# eventId = Column(String(50), ForeignKey('event.eventId'), nullable=False, comment='关联事件ID')
# processStatus = Column(Integer, nullable=False, comment='处理状态')
# errorMessage = Column(Text, comment='错误信息')
# createTime = Column(DateTime, default=datetime.now, nullable=False, comment='创建时间')
# # 关系
# event = relationship("Event", back_populates="process_logs")
# __table_args__ = (
# Index('idx_log_event_id', 'eventId'),
# Index('idx_log_create_time', 'createTime'),
# )

0
app/schemas/__init__.py Normal file
View File

116
app/schemas/event.py Normal file
View File

@ -0,0 +1,116 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, ConfigDict
class ImageBase(BaseModel):
imageUrl: str
localPath: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class TemperatureBase(BaseModel):
temperature: str
confidence: str
createTime: datetime
model_config = ConfigDict(from_attributes=True)
class EventBase(BaseModel):
eventId: str
tenantInfoId: Optional[str] = None
reportEventId: Optional[str] = None
number: Optional[str] = None
name: Optional[str] = None
etypeName: Optional[str] = None
insDate: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class EventList(EventBase):
images: List[ImageBase] = []
temperatures: List[TemperatureBase] = []
model_config = ConfigDict(from_attributes=True)
class EventDetail(EventBase):
eclassify: Optional[str] = None
operationType: Optional[str] = None
etype: Optional[str] = None
enTypeName: Optional[str] = None
hkTypeName: Optional[str] = None
reportStatus: Optional[str] = None
results: Optional[str] = None
insDateShow: Optional[datetime] = None
updDate: Optional[datetime] = None
updDateShow: Optional[datetime] = None
fileType: Optional[str] = None
area: Optional[str] = None
floor: Optional[str] = None
map: Optional[str] = None
staffId: Optional[str] = None
targetUserId: Optional[str] = None
position: Optional[str] = None
actualStaffName: Optional[str] = None
targetStaffName: Optional[str] = None
routeName: Optional[str] = None
phoneAddress: Optional[str] = None
width: Optional[str] = None
height: Optional[str] = None
resolution: Optional[str] = None
originX: Optional[str] = None
originY: Optional[str] = None
robotType: Optional[str] = None
eventFloor: Optional[str] = None
floorName: Optional[str] = None
coordId: Optional[str] = None
coord: Optional[str] = None
coordName: Optional[str] = None
positonName: Optional[str] = None
processingRemark: Optional[str] = None
carId: Optional[str] = None
parkingSpaceType: Optional[str] = None
parkingSpaceNumber: Optional[str] = None
carNumber: Optional[str] = None
eno: Optional[str] = None
instrument: Optional[str] = None
evideo: Optional[str] = None
createTime: Optional[datetime] = None
updateTime: Optional[datetime] = None
images: List[ImageBase] = []
temperatures: List[TemperatureBase] = []
model_config = ConfigDict(from_attributes=True)
class EventUpdate(BaseModel):
number: Optional[str] = None
name: Optional[str] = None
etypeName: Optional[str] = None
area: Optional[str] = None
position: Optional[str] = None
processingRemark: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class EventQuery(BaseModel):
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
etypeName: Optional[str] = None
area: Optional[str] = None
skip: int = 0
limit: int = 100
model_config = ConfigDict(from_attributes=True)
class TestEvent(BaseModel):
eventId: str
number: Optional[str] = None
name : Optional[str] = None
imageUrl : Optional[str] = None
localPath: Optional[str] = None
temperature: Optional[str] = None
confidence : Optional[str] = None
createTime: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)

15
app/schemas/image.py Normal file
View File

@ -0,0 +1,15 @@
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime
class ImageBase(BaseModel):
imageId: Optional[int] = None
localPath: Optional[str] = None
type : Optional[str] = None
result : Optional[str] = None
createTime : Optional[datetime] = None
updateTime : Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)

11
app/schemas/ocr.py Normal file
View File

@ -0,0 +1,11 @@
from pydantic import BaseModel
from typing import Optional
class ImageBase64Request(BaseModel):
image_base64: str
image_type: Optional[str] = "jpg" # 图片类型,如 jpg, png 等
# class OCRResponse(BaseModel):
# success: bool
# message: str
# data: Optional[list] = None

View File

@ -0,0 +1,15 @@
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from app.schemas.image import ImageBase
from app.dao.imageDao import ImageDao
class ImageService:
@classmethod
async def get_image_list(cls, db: AsyncSession, query_model: ImageBase):
result = await ImageDao.get_image_list(db, query_model)
return result

535
app/util/baiduOCR.py Normal file
View File

@ -0,0 +1,535 @@
from paddleocr import PaddleOCR
import os
import cv2
import yaml
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
# import math
from app.config.config import OCRSettings
class BaiduOCR:
def __init__(self):
self.model = PaddleOCR(
# 文本检测模型地址
# text_detection_model_dir = "/home/admin-root/haotian/康达瑞贝斯机器狗/ocr_model/PP-OCRv5_server_det",
text_detection_model_dir=OCRSettings.TEXT_DETECTION_MODEL_DIR,
# 文本识别模型地址
# text_recognition_model_dir = "/home/admin-root/haotian/康达瑞贝斯机器狗/ocr_model/PP-OCRv5_server_rec",
text_recognition_model_dir=OCRSettings.TEXT_RECONGNITION_MODEL_DIR,
use_doc_orientation_classify=False,
use_doc_unwarping=False,
use_textline_orientation=False,
# 多gpu有问题
device='gpu:2'
)
def ocr(self, image_path:str):
try:
result = self.model.predict(image_path)
except IndexError:
return []
except Exception:
raise
if not result:
return []
self.draw_ocr_result(image_path, result)
result = self.parse_result(result)
if not result:
return []
return result[0]
def parse_result(self, result):
text_list = []
for item in result:
rec_texts = item.get('rec_texts')
rec_scores = item.get('rec_scores')
if rec_texts is None or rec_scores is None:
continue
text_list.append([rec_texts, rec_scores])
return text_list
def draw_ocr_result(self, image_path: str, result):
"""
在原图上绘制 OCR 识别结果
"""
# 读取原图
image = cv2.imread(image_path)
if image is None:
return
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 转换为 PIL Image 以支持中文显示
pil_img = Image.fromarray(image)
draw = ImageDraw.Draw(pil_img)
# 尝试加载中文字体(根据系统调整路径)
try:
# font = ImageFont.truetype("simhei.ttf", 20) # Windows
font = ImageFont.truetype("/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", 20) # Linux
# font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20) # macOS
except:
font = ImageFont.load_default()
# 绘制每个检测框和文本
for idx, item in enumerate(result):
box = item['rec_boxes'] # shape=(1, 4)
text = item['rec_texts']
score = item['rec_scores']
# 处理可能的列表类型
if isinstance(text, list):
text = text[0] if len(text) > 0 else ""
if isinstance(score, list):
score = score[0] if len(score) > 0 else 0.0
# 提取坐标
if box is None:
continue
box = box.flatten()
if len(box) < 4:
continue
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
# 绘制矩形框
draw.rectangle([(x1, y1), (x2, y2)], outline=(255, 0, 0), width=2)
# 绘制文本和置信度
text_position = (x1, max(0, y1 - 25))
text_with_score = f"{text} ({float(score):.2f})"
draw.text(text_position, text_with_score, fill=(0, 255, 0), font=font)
# 转换回 OpenCV 格式
result_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
# 获取文件夹路径并创建
os.makedirs(os.path.dirname(image_path), exist_ok=True)
cv2.imwrite(image_path, result_img)
# return result_img
class BaiduOCRONNX:
def __init__(self, det_model_path=OCRSettings.TEXT_DETECTION_MODEL_ONNX_DIR, rec_model_path=OCRSettings.TEXT_RECONGNITION_MODEL_ONNX_DIR):
"""
初始化ONNX推理器
Args:
det_model_path: 检测模型路径 (det.onnx)
rec_model_path: 识别模型路径 (rec.onnx)
"""
# 初始化检测模型
self.det_session = ort.InferenceSession(det_model_path)
self.det_input_name = self.det_session.get_inputs()[0].name
# 初始化识别模型
self.rec_session = ort.InferenceSession(rec_model_path)
self.rec_input_name = self.rec_session.get_inputs()[0].name
# 字符集(根据您的模型调整)
# self.character = ['blank', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+',
# ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8',
# '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E',
# 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
# 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
# '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
# 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y',
# 'z', '{', '|', '}', '~'] + [chr(i) for i in range(19968, 40870)] # 中文字符
self.character = self.get_dict()
if self.character is None:
raise ValueError('请检查字典文件是否存在!')
def get_dict(self, dict_path='./dict.yaml'):
"""
加载字典
"""
with open(dict_path, 'r', encoding='utf-8') as f:
dict_rec = yaml.safe_load(f)
return dict_rec.get('character_dict', [])
def resize_norm_img_det(self, img, input_shape=(640, 640)):
"""
检测模型的图像预处理 - 固定输入形状 [1, 3, 640, 640]
"""
h, w, _ = img.shape
target_h, target_w = input_shape
# 计算缩放比例 - 保持宽高比
ratio_h = target_h / h
ratio_w = target_w / w
ratio = min(ratio_h, ratio_w)
# 计算缩放后的尺寸
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整图像大小
resized_img = cv2.resize(img, (new_w, new_h))
# 创建目标尺寸的图像,用灰色填充
padded_img = np.ones((target_h, target_w, 3), dtype=np.float32) * 114.0 # 直接用float32
# 计算居中位置
top = (target_h - new_h) // 2
left = (target_w - new_w) // 2
# 将缩放后的图像放到居中位置
padded_img[top:top+new_h, left:left+new_w] = resized_img.astype(np.float32)
# 归一化
img = (padded_img / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
img = img.transpose(2, 0, 1).astype(np.float32)
img = np.expand_dims(img, axis=0).astype(np.float32)
return img, ratio, (top, left)
def post_process_det(self, dt_boxes, ratio, padding_info, ori_shape):
"""
检测结果后处理 - 适配固定输入形状
"""
if dt_boxes is None:
return None
ori_h, ori_w = ori_shape
top, left = padding_info
# 将坐标从模型输出空间转换回原图空间
dt_boxes[:, :, 0] = (dt_boxes[:, :, 0] - left) / ratio
dt_boxes[:, :, 1] = (dt_boxes[:, :, 1] - top) / ratio
# 裁剪到原图范围内
dt_boxes[:, :, 0] = np.clip(dt_boxes[:, :, 0], 0, ori_w)
dt_boxes[:, :, 1] = np.clip(dt_boxes[:, :, 1], 0, ori_h)
return dt_boxes
def boxes_from_bitmap(self, pred, bitmap, dest_width, dest_height, max_candidates=1000, box_thresh=0.6):
"""
从位图中提取文本框
"""
bitmap = bitmap.astype(np.uint8)
height, width = bitmap.shape
# 查找轮廓
contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
num_contours = min(len(contours), max_candidates)
boxes = []
scores = []
for i in range(num_contours):
contour = contours[i]
points, sside = self.get_mini_boxes(contour)
if sside < 5:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if box_thresh > score:
continue
# 扩展box
box = self.unclip(points, 1.5).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < 5 + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(box[:, 0] / width * dest_width, 0, dest_width)
box[:, 1] = np.clip(box[:, 1] / height * dest_height, 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes), scores
def get_mini_boxes(self, contour):
"""获取最小外接矩形"""
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
"""快速计算box得分"""
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def unclip(self, box, unclip_ratio):
"""扩展文本框"""
from shapely.geometry import Polygon
import pyclipper
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = offset.Execute(distance)
if len(expanded) == 0:
return box
else:
return np.array(expanded[0])
def resize_norm_img_rec(self, img, input_shape=(320, 48)):
"""
识别模型的图像预处理 - 固定输入形状 [1, 3, 48, 320]
"""
target_w, target_h = input_shape # 注意:宽度在前
h, w = img.shape[:2]
# 计算缩放比例,保持宽高比
ratio_h = target_h / h
ratio_w = target_w / w
ratio = min(ratio_h, ratio_w)
# 计算缩放后的尺寸
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整图像大小
resized_image = cv2.resize(img, (new_w, new_h))
# 创建目标尺寸的图像,用黑色填充
padded_image = np.zeros((target_h, target_w, 3), dtype=np.float32) # 直接用float32
# 将缩放后的图像放到左上角(识别模型通常左对齐)
padded_image[:new_h, :new_w] = resized_image.astype(np.float32)
# 归一化
# padded_image = (padded_image / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
# 不缩放反而会将识别结果再移后一个??
padded_image = (padded_image / 255.0).astype(np.float32)
padded_image = padded_image.transpose((2, 0, 1)).astype(np.float32)
return np.expand_dims(padded_image, axis=0).astype(np.float32)
def decode_rec_result(self, preds_prob):
"""
解码识别结果
"""
preds_idx = np.argmax(preds_prob, axis=1)
preds_prob = np.max(preds_prob, axis=1)
# CTC解码
last_idx = 0
preds_text = []
preds_conf = []
for i, idx in enumerate(preds_idx):
if idx != last_idx and idx != 0: # 0是blank
if idx < len(self.character):
preds_text.append(self.character[idx])
preds_conf.append(preds_prob[i])
last_idx = idx
text = ''.join(preds_text)
conf = np.mean(preds_conf) if preds_conf else 0.0
return text, conf
def detect_text(self, image):
"""
文本检测 - 适配固定输入形状 [1, 3, 640, 640]
"""
ori_h, ori_w = image.shape[:2]
# 预处理
det_img, ratio, padding_info = self.resize_norm_img_det(image)
# 推理
det_output = self.det_session.run(None, {self.det_input_name: det_img})[0]
# 后处理
mask = det_output[0, 0, :, :]
threshold = 0.3
bitmap = (mask > threshold).astype(np.uint8) * 255
# 从位图中提取文本框坐标是在640x640空间中的
boxes, scores = self.boxes_from_bitmap(mask, bitmap, 640, 640)
# 将坐标转换回原图空间
if len(boxes) > 0:
boxes = self.post_process_det(boxes, ratio, padding_info, (ori_h, ori_w))
return boxes, scores
def recognize_text(self, image):
"""
文本识别
"""
# 预处理
rec_img = self.resize_norm_img_rec(image)
# 推理
rec_output = self.rec_session.run(None, {self.rec_input_name: rec_img})[0]
# 解码
text, conf = self.decode_rec_result(rec_output[0])
return text, conf
def get_rotate_crop_image(self, img, points):
"""
根据四个点坐标裁剪并矫正图像
"""
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def ocr(self, image_path):
"""
完整的OCR流程
"""
# 读取图像
image = cv2.imread(image_path)
if image is None:
return []
# 1. 文本检测
dt_boxes, scores = self.detect_text(image)
if dt_boxes is None or len(dt_boxes) == 0:
return []
# 2. 文本识别
ocr_results = []
text_list = []
confidence_list = []
for i, box in enumerate(dt_boxes):
# 裁剪文本区域
box_points = box.astype(np.float32)
crop_img = self.get_rotate_crop_image(image, box_points)
if crop_img.size == 0:
continue
# 识别文本
text, conf = self.recognize_text(crop_img)
if conf > 0.5: # 置信度过滤
ocr_results.append({
'text': text,
'confidence': conf,
'box': box.tolist(),
'score': scores[i] if i < len(scores) else 0.0
})
text_list.append(text)
confidence_list.append(round(conf.item(), 2))
# return ocr_results
return [text_list, confidence_list]
# 使用示例
def main():
# 初始化OCR
# ocr = PaddleOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape.onnx')
ocr = BaiduOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx')
# 执行OCR
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/3aee64cc1f90d93a5a45979f7b17cb4b_frame_001460.jpg'
results = ocr.ocr(image_path)
# 打印结果
for result in results:
print(f"文本: {result['text']}")
print(f"置信度: {result['confidence']:.3f}")
print(f"检测得分: {result['score']:.3f}")
print(f"坐标: {result['box']}")
print("-" * 50)
# 可视化结果
visualize_results(image_path, results)
def visualize_results(image_path, results):
"""
可视化OCR结果
"""
image = cv2.imread(image_path)
for result in results:
box = np.array(result['box'], dtype=np.int32)
cv2.polylines(image, [box], True, (0, 255, 0), 2)
# 在框上方显示文本
cv2.putText(image, result['text'],
(box[0][0], box[0][1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
cv2.imwrite('result_shape_20250815.jpg', image)
if __name__ == '__main__':
ocr = BaiduOCR()
print(ocr.ocr(""))

266
app/util/responseHttp.py Normal file
View File

@ -0,0 +1,266 @@
from datetime import datetime
from fastapi import status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, Response, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from typing import Any, Dict, Mapping, Optional
from app.config.constant import HttpStatusConstant
class ResponseUtil:
"""
响应工具类
"""
@classmethod
def success(
cls,
msg: str = '操作成功',
data: Optional[Any] = None,
rows: Optional[Any] = None,
dict_content: Optional[Dict] = None,
model_content: Optional[BaseModel] = None,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
) -> Response:
"""
成功响应方法
:param msg: 可选自定义成功响应信息
:param data: 可选成功响应结果中属性为data的值
:param rows: 可选成功响应结果中属性为rows的值
:param dict_content: 可选dict类型成功响应结果中自定义属性的值
:param model_content: 可选BaseModel类型成功响应结果中自定义属性的值
:param headers: 可选响应头信息
:param media_type: 可选响应结果媒体类型
:param background: 可选响应返回后执行的后台任务
:return: 成功响应结果
"""
result = {'code': HttpStatusConstant.SUCCESS, 'msg': msg}
if data is not None:
result['data'] = data
if rows is not None:
result['rows'] = rows
if dict_content is not None:
result.update(dict_content)
if model_content is not None:
result.update(model_content.model_dump(by_alias=True))
result.update({'success': True, 'time': datetime.now()})
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder(result),
headers=headers,
media_type=media_type,
background=background,
)
@classmethod
def failure(
cls,
msg: str = '操作失败',
data: Optional[Any] = None,
rows: Optional[Any] = None,
dict_content: Optional[Dict] = None,
model_content: Optional[BaseModel] = None,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
) -> Response:
"""
失败响应方法
:param msg: 可选自定义失败响应信息
:param data: 可选失败响应结果中属性为data的值
:param rows: 可选失败响应结果中属性为rows的值
:param dict_content: 可选dict类型失败响应结果中自定义属性的值
:param model_content: 可选BaseModel类型失败响应结果中自定义属性的值
:param headers: 可选响应头信息
:param media_type: 可选响应结果媒体类型
:param background: 可选响应返回后执行的后台任务
:return: 失败响应结果
"""
result = {'code': HttpStatusConstant.WARN, 'msg': msg}
if data is not None:
result['data'] = data
if rows is not None:
result['rows'] = rows
if dict_content is not None:
result.update(dict_content)
if model_content is not None:
result.update(model_content.model_dump(by_alias=True))
result.update({'success': False, 'time': datetime.now()})
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder(result),
headers=headers,
media_type=media_type,
background=background,
)
@classmethod
def unauthorized(
cls,
msg: str = '登录信息已过期,访问系统资源失败',
data: Optional[Any] = None,
rows: Optional[Any] = None,
dict_content: Optional[Dict] = None,
model_content: Optional[BaseModel] = None,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
) -> Response:
"""
未认证响应方法
:param msg: 可选自定义未认证响应信息
:param data: 可选未认证响应结果中属性为data的值
:param rows: 可选未认证响应结果中属性为rows的值
:param dict_content: 可选dict类型未认证响应结果中自定义属性的值
:param model_content: 可选BaseModel类型未认证响应结果中自定义属性的值
:param headers: 可选响应头信息
:param media_type: 可选响应结果媒体类型
:param background: 可选响应返回后执行的后台任务
:return: 未认证响应结果
"""
result = {'code': HttpStatusConstant.UNAUTHORIZED, 'msg': msg}
if data is not None:
result['data'] = data
if rows is not None:
result['rows'] = rows
if dict_content is not None:
result.update(dict_content)
if model_content is not None:
result.update(model_content.model_dump(by_alias=True))
result.update({'success': False, 'time': datetime.now()})
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder(result),
headers=headers,
media_type=media_type,
background=background,
)
@classmethod
def forbidden(
cls,
msg: str = '该用户无此接口权限',
data: Optional[Any] = None,
rows: Optional[Any] = None,
dict_content: Optional[Dict] = None,
model_content: Optional[BaseModel] = None,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
) -> Response:
"""
未授权响应方法
:param msg: 可选自定义未授权响应信息
:param data: 可选未授权响应结果中属性为data的值
:param rows: 可选未授权响应结果中属性为rows的值
:param dict_content: 可选dict类型未授权响应结果中自定义属性的值
:param model_content: 可选BaseModel类型未授权响应结果中自定义属性的值
:param headers: 可选响应头信息
:param media_type: 可选响应结果媒体类型
:param background: 可选响应返回后执行的后台任务
:return: 未授权响应结果
"""
result = {'code': HttpStatusConstant.FORBIDDEN, 'msg': msg}
if data is not None:
result['data'] = data
if rows is not None:
result['rows'] = rows
if dict_content is not None:
result.update(dict_content)
if model_content is not None:
result.update(model_content.model_dump(by_alias=True))
result.update({'success': False, 'time': datetime.now()})
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder(result),
headers=headers,
media_type=media_type,
background=background,
)
@classmethod
def error(
cls,
msg: str = '接口异常',
data: Optional[Any] = None,
rows: Optional[Any] = None,
dict_content: Optional[Dict] = None,
model_content: Optional[BaseModel] = None,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
) -> Response:
"""
错误响应方法
:param msg: 可选自定义错误响应信息
:param data: 可选错误响应结果中属性为data的值
:param rows: 可选错误响应结果中属性为rows的值
:param dict_content: 可选dict类型错误响应结果中自定义属性的值
:param model_content: 可选BaseModel类型错误响应结果中自定义属性的值
:param headers: 可选响应头信息
:param media_type: 可选响应结果媒体类型
:param background: 可选响应返回后执行的后台任务
:return: 错误响应结果
"""
result = {'code': HttpStatusConstant.ERROR, 'msg': msg}
if data is not None:
result['data'] = data
if rows is not None:
result['rows'] = rows
if dict_content is not None:
result.update(dict_content)
if model_content is not None:
result.update(model_content.model_dump(by_alias=True))
result.update({'success': False, 'time': datetime.now()})
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder(result),
headers=headers,
media_type=media_type,
background=background,
)
@classmethod
def streaming(
cls,
*,
data: Any = None,
headers: Optional[Mapping[str, str]] = None,
media_type: Optional[str] = None,
background: Optional[BackgroundTask] = None,
) -> Response:
"""
流式响应方法
:param data: 流式传输的内容
:param headers: 可选响应头信息
:param media_type: 可选响应结果媒体类型
:param background: 可选响应返回后执行的后台任务
:return: 流式响应结果
"""
return StreamingResponse(
status_code=status.HTTP_200_OK, content=data, headers=headers, media_type=media_type, background=background
)

470
app/util/yolov8Obj.py Normal file
View File

@ -0,0 +1,470 @@
from ultralytics import YOLO
# from rknn.api import RKNN
import cv2
import numpy as np
import onnxruntime as ort
import time
from app.config.config import yolov8_settings
class Yolov8Obj:
def __init__(self):
self.model = YOLO(yolov8_settings.YOLOV8_MODEL_DIR)
def detect(self, image_path):
result = self.model.predict(image_path)
boxes = result[0].boxes
cls = boxes.cls.tolist()
conf = boxes.conf.tolist()
coords = boxes.xyxy.tolist()
return cls, conf, coords
class YOLOv8ONNX:
def __init__(self, model_path=yolov8_settings.YOLOV8_MODEL_ONNX_DIRS, conf_threshold=0.5, iou_threshold=0.4):
"""
初始化YOLOv8 ONNX模型
Args:
model_path: ONNX模型文件路径
conf_threshold: 置信度阈值
iou_threshold: NMS IoU阈值
"""
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold
# 创建ONNX Runtime会话
self.session = ort.InferenceSession(model_path)
# 获取模型输入输出信息
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
# 获取输入尺寸
input_shape = self.session.get_inputs()[0].shape
self.input_height = input_shape[2]
self.input_width = input_shape[3]
def preprocess(self, image):
"""
预处理图像
Args:
image: 输入图像 (BGR格式)
Returns:
preprocessed_image: 预处理后的图像
scale_ratio: 缩放比例
pad_info: 填充信息 (pad_x, pad_y)
"""
# 获取原图尺寸
h, w = image.shape[:2]
# 计算缩放比例
scale = min(self.input_height / h, self.input_width / w)
new_h, new_w = int(h * scale), int(w * scale)
# 等比例缩放
resized_image = cv2.resize(image, (new_w, new_h))
# 计算填充
pad_x = (self.input_width - new_w) // 2
pad_y = (self.input_height - new_h) // 2
# 创建填充后的图像
padded_image = np.full((self.input_height, self.input_width, 3), 114, dtype=np.uint8)
padded_image[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized_image
# 转换为模型输入格式: BGR -> RGB, HWC -> CHW, 归一化
input_image = padded_image[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
input_image = np.expand_dims(input_image, axis=0) # 添加batch维度
return input_image, scale, (pad_x, pad_y)
def postprocess(self, outputs, scale, pad_info, original_shape):
"""
后处理模型输出 - 针对YOLOv8格式优化
Args:
outputs: 模型原始输出
scale: 图像缩放比例
pad_info: 填充信息
original_shape: 原图尺寸
Returns:
boxes: 检测框 [[x1, y1, x2, y2], ...]
scores: 置信度分数
class_ids: 类别ID
"""
predictions = outputs[0] # 形状通常是: [1, 6, 8400] 或 [1, num_classes+4, num_boxes]
# YOLOv8输出格式: [batch, 4+num_classes, num_boxes]
# 需要转置为 [batch, num_boxes, 4+num_classes]
if len(predictions.shape) == 3:
predictions = predictions.transpose(0, 2, 1) # [1, num_boxes, 4+num_classes]
predictions = predictions[0] # 移除batch维度: [num_boxes, 4+num_classes]
# 打印调试信息
# print(f"预测输出形状: {predictions.shape}")
# print(f"前几个预测值: {predictions[:5]}")
# 分离坐标和分类信息
boxes = predictions[:, :4] # [x_center, y_center, width, height]
scores = predictions[:, 4:] # 类别置信度 [num_boxes, num_classes]
# print(f"检测框形状: {boxes.shape}")
# print(f"分数形状: {scores.shape}")
# 获取最高置信度和对应类别
class_ids = np.argmax(scores, axis=1)
confidences = np.max(scores, axis=1)
# print(f"置信度范围: {confidences.min():.4f} - {confidences.max():.4f}")
# print(f"检测到的类别: {np.unique(class_ids)}")
# 过滤低置信度检测
valid_indices = confidences > self.conf_threshold
valid_boxes = boxes[valid_indices]
valid_confidences = confidences[valid_indices]
valid_class_ids = class_ids[valid_indices]
# print(f"过滤后检测数量: {len(valid_boxes)}")
if len(valid_boxes) == 0:
return [], [], []
# 转换为 [x1, y1, x2, y2] 格式
x_center, y_center, width, height = valid_boxes[:, 0], valid_boxes[:, 1], valid_boxes[:, 2], valid_boxes[:, 3]
x1 = x_center - width / 2
y1 = y_center - height / 2
x2 = x_center + width / 2
y2 = y_center + height / 2
converted_boxes = np.stack([x1, y1, x2, y2], axis=1)
# 坐标反变换到原图
pad_x, pad_y = pad_info
converted_boxes[:, [0, 2]] = (converted_boxes[:, [0, 2]] - pad_x) / scale
converted_boxes[:, [1, 3]] = (converted_boxes[:, [1, 3]] - pad_y) / scale
# 限制坐标范围
h, w = original_shape[:2]
converted_boxes[:, [0, 2]] = np.clip(converted_boxes[:, [0, 2]], 0, w)
converted_boxes[:, [1, 3]] = np.clip(converted_boxes[:, [1, 3]], 0, h)
# 非极大值抑制 (NMS)
indices = cv2.dnn.NMSBoxes(
converted_boxes.tolist(),
valid_confidences.tolist(),
self.conf_threshold,
self.iou_threshold
)
if len(indices) > 0:
indices = indices.flatten()
return converted_boxes[indices], valid_confidences[indices], valid_class_ids[indices]
return [], [], []
def detect(self, image):
"""
对图像进行目标检测
Args:
image: 输入图像 (BGR格式)
Returns:
boxes: 检测框列表
scores: 置信度分数列表
class_ids: 类别ID列表
"""
# 预处理
input_image, scale, pad_info = self.preprocess(image)
# 推理
outputs = self.session.run([self.output_name], {self.input_name: input_image})
# 后处理
boxes, scores, class_ids = self.postprocess(outputs, scale, pad_info, image.shape)
return boxes, scores, class_ids
class YOLOv8RKNN:
def __init__(self, model_path, input_size=(640, 640)):
self.model_path = model_path
self.input_size = input_size
self.rknn = RKNN()
# 类别名称根据你的2个类别修改
self.class_names = ['class1', 'class2'] # 请替换为你实际的类别名称
# 初始化模型
self.load_model()
def load_model(self):
"""加载RKNN模型"""
print("Loading RKNN model...")
ret = self.rknn.load_rknn(self.model_path)
if ret != 0:
print("Load RKNN model failed!")
return False
# 初始化运行时环境在RK3588设备上运行
print("Init RKNN runtime...")
ret = self.rknn.init_runtime(target='rk3588', device_id=None, perf_debug=False, eval_mem=False)
if ret != 0:
print("Init RKNN runtime failed!")
return False
print("RKNN model loaded successfully!")
return True
def preprocess(self, image):
"""图像预处理"""
# 获取原始图像尺寸
self.orig_height, self.orig_width = image.shape[:2]
# Resize到模型输入尺寸保持宽高比
scale = min(self.input_size[0]/self.orig_width, self.input_size[1]/self.orig_height)
new_width = int(self.orig_width * scale)
new_height = int(self.orig_height * scale)
# 缩放图像
resized = cv2.resize(image, (new_width, new_height))
# 创建输入图像(填充到目标尺寸)
input_image = np.full((self.input_size[1], self.input_size[0], 3), 114, dtype=np.uint8)
# 计算填充位置(居中)
y_offset = (self.input_size[1] - new_height) // 2
x_offset = (self.input_size[0] - new_width) // 2
# 将缩放后的图像放到中心位置
input_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width] = resized
# 保存缩放参数用于后处理
self.scale = scale
self.x_offset = x_offset
self.y_offset = y_offset
return input_image
def postprocess(self, outputs, conf_threshold=0.5, nms_threshold=0.4):
"""后处理解析YOLO输出并进行NMS"""
# YOLOv8输出格式: [batch, 84, 8400] (2个类别: 4+2+80=84但实际只有6维)
# 对于2类别: [x, y, w, h, conf_class1, conf_class2]
predictions = outputs[0][0] # 移除batch维度
# 转置为 [8400, 6] 格式
predictions = predictions.transpose()
boxes = []
scores = []
class_ids = []
for detection in predictions:
# 提取坐标和类别置信度
x, y, w, h = detection[:4]
class_confs = detection[4:6] # 2个类别的置信度
# 找到最大置信度的类别
class_id = np.argmax(class_confs)
max_conf = class_confs[class_id]
if max_conf >= conf_threshold:
# 转换坐标格式 (中心点 -> 左上角)
x1 = x - w/2
y1 = y - h/2
x2 = x + w/2
y2 = y + h/2
# 将坐标映射回原图尺寸
x1 = (x1 - self.x_offset) / self.scale
y1 = (y1 - self.y_offset) / self.scale
x2 = (x2 - self.x_offset) / self.scale
y2 = (y2 - self.y_offset) / self.scale
# 限制在图像边界内
x1 = max(0, min(x1, self.orig_width))
y1 = max(0, min(y1, self.orig_height))
x2 = max(0, min(x2, self.orig_width))
y2 = max(0, min(y2, self.orig_height))
boxes.append([x1, y1, x2, y2])
scores.append(max_conf)
class_ids.append(class_id)
# 执行NMS
if len(boxes) > 0:
boxes = np.array(boxes)
scores = np.array(scores)
class_ids = np.array(class_ids)
# OpenCV NMS
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, nms_threshold)
if len(indices) > 0:
indices = indices.flatten()
return boxes[indices], scores[indices], class_ids[indices]
return np.array([]), np.array([]), np.array([])
def detect(self, image, conf_threshold=0.5, nms_threshold=0.4):
"""执行检测"""
# 预处理
input_image = self.preprocess(image)
# 推理
start_time = time.time()
outputs = self.rknn.inference(inputs=[input_image])
inference_time = time.time() - start_time
# 后处理
boxes, scores, class_ids = self.postprocess(outputs, conf_threshold, nms_threshold)
return boxes, scores, class_ids, inference_time
def draw_detections(self, image, boxes, scores, class_ids):
"""在图像上绘制检测结果"""
for i in range(len(boxes)):
x1, y1, x2, y2 = boxes[i].astype(int)
score = scores[i]
class_id = int(class_ids[i])
# 绘制边界框
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 绘制标签
label = f"{self.class_names[class_id]}: {score:.2f}"
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(image, (x1, y1-label_size[1]-10),
(x1+label_size[0], y1), (0, 255, 0), -1)
cv2.putText(image, label, (x1, y1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
return image
def release(self):
"""释放资源"""
if self.rknn:
self.rknn.release()
def main():
# 初始化检测器
model_path = "/home/orangepi/Desktop/康达机器狗/model_rknn/yolov8_20250820.rknn"
detector = YOLOv8RKNN(model_path)
# 测试单张图片
def test_image(image_path):
image = cv2.imread(image_path)
if image is None:
print(f"Cannot load image: {image_path}")
return
# 执行检测
boxes, scores, class_ids, inference_time = detector.detect(image)
print(f"Inference time: {inference_time*1000:.2f}ms")
print(f"Detected {len(boxes)} objects")
# 绘制结果
result_image = detector.draw_detections(image, boxes, scores, class_ids)
# 显示结果
# cv2.imshow("Detection Result", result_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
cv2.imwrite("xxxxxxx.jpg", result_image)
# 测试摄像头实时检测
def test_camera():
cap = cv2.VideoCapture(0) # 使用默认摄像头
if not cap.isOpened():
print("Cannot open camera")
return
while True:
ret, frame = cap.read()
if not ret:
break
# 执行检测
boxes, scores, class_ids, inference_time = detector.detect(frame)
# 绘制结果
result_frame = detector.draw_detections(frame, boxes, scores, class_ids)
# 显示FPS
fps = 1.0 / inference_time if inference_time > 0 else 0
cv2.putText(result_frame, f"FPS: {fps:.1f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("Real-time Detection", result_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
# 选择测试模式
mode = input("选择模式 (1: 图片检测, 2: 摄像头实时检测): ")
if mode == "1":
image_path = input("输入图片路径: ")
test_image(image_path)
elif mode == "2":
test_camera()
else:
print("无效选择")
# 释放资源
detector.release()
def draw_detections(image, boxes, scores, class_ids, class_names=None):
"""
在图像上绘制检测结果
Args:
image: 输入图像
boxes: 检测框
scores: 置信度分数
class_ids: 类别ID
class_names: 类别名称列表可选
Returns:
绘制了检测结果的图像
"""
result_image = image.copy()
for i, (box, score, class_id) in enumerate(zip(boxes, scores, class_ids)):
x1, y1, x2, y2 = map(int, box)
# 绘制边界框
cv2.rectangle(result_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 准备标签文本
if class_names and class_id < len(class_names):
label = f"{class_names[class_id]}: {score:.2f}"
else:
label = f"Class {class_id}: {score:.2f}"
# 绘制标签背景
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(result_image, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), (0, 255, 0), -1)
# 绘制标签文本
cv2.putText(result_image, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
return result_image

1
base64.txt Normal file

File diff suppressed because one or more lines are too long

18385
dict.yaml Normal file

File diff suppressed because it is too large Load Diff

14
robot_dog.service Normal file
View File

@ -0,0 +1,14 @@
[Unit]
Description=Robot Dog Service
After=network.target
[Service]
User=root
Group=root
WorkingDirectory=/root/robot_dog_project/kangda_robotic_dog/机器狗后台服务
ExecStart=/root/robot_dog_project/kangda_robotic_dog/机器狗后台服务/start.sh
Restart=always
RestartSec=5s
[Install]
WantedBy=multi-user.target

14
run.py Normal file
View File

@ -0,0 +1,14 @@
import uvicorn
if __name__ == "__main__":
# uvicorn.run(
# "app.main:app",
# host="10.0.0.202",
# port=12342
# )
uvicorn.run(
"app.api.main:app",
host="0.0.0.0",
port=12345
)

32
start.sh Normal file
View File

@ -0,0 +1,32 @@
#!/bin/bash
# 激活虚拟环境
source /root/robot_dog_project/myvenv/bin/activate
# 进入项目目录
cd /root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1
#!/bin/bash
# 定义日志目录和带时间戳的日志文件名
LOG_DIR="/root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1/logs"
TIMESTAMP=$(date +'%Y%m%d_%H%M%S')
LOG_FILE="${LOG_DIR}/app_${TIMESTAMP}.log"
# 创建日志目录(如果不存在)
mkdir -p "$LOG_DIR"
# 激活Python虚拟环境
source /root/robot_dog_project/myvenv/bin/activate
# 切换到项目目录
cd /root/robot_dog_project/kangda_robotic_dog/#U673a#U5668#U72d7#U540e#U53f0#U670d#U52a1
# 打印启动信息
echo "启动应用程序,日志将保存到: $LOG_FILE"
echo "启动时间: $(date '+%Y-%m-%d %H:%M:%S')" | tee -a "$LOG_FILE"
# 运行Python程序并将输出同时显示在控制台和保存到日志文件
exec python run.py --env=dev 2>&1 | tee -a "$LOG_FILE"

110
test_base64_server.py Normal file
View File

@ -0,0 +1,110 @@
import base64
import requests
import json
def image_to_base64(image_path):
"""
将图片文件转换为base64编码
"""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
def test_ocr_api(image_path, api_url="http://10.0.0.202:12342/api/v1/ocr_from_base64"):
"""
测试OCR API接口
"""
# 将图片转换为base64
image_base64 = image_to_base64(image_path)
with open("base64.txt", "w") as f:
f.write(repr(image_base64))
# 准备请求数据
payload = {
"image_base64": image_base64,
"image_type": "jpg" # 根据实际图片类型修改
}
# 发送POST请求
headers = {
"Content-Type": "application/json"
}
try:
response = requests.post(api_url, data=json.dumps(payload), headers=headers)
if response.status_code == 200:
result = response.json()
print("OCR识别结果:")
print(json.dumps(result, ensure_ascii=False, indent=2))
return result
else:
print(f"请求失败,状态码: {response.status_code}")
print(response.text)
return None
except Exception as e:
print(f"请求异常: {str(e)}")
return None
def test_detect(image_path: str, api_url:str = "http://10.0.0.202:12342/api/v1/detect_from_base64_0"):
"""
测试yolov8消防区域侵占接口
"""
image_base64 = image_to_base64(image_path)
# 准备请求数据
payload = {
"image_base64": image_base64,
"image_type": "jpg" # 根据实际图片类型修改
}
# 发送POST请求
headers = {
"Content-Type": "application/json"
}
try:
response = requests.post(api_url, data=json.dumps(payload), headers=headers)
if response.status_code == 200:
result = response.json()
print("detect识别结果:")
print(json.dumps(result, ensure_ascii=False, indent=2))
return result
else:
print(f"请求失败,状态码: {response.status_code}")
print(response.text)
return None
except Exception as e:
print(f"请求异常: {str(e)}")
return None
if __name__ == "__main__":
# 测试图片路径,请根据实际情况修改
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/2c7cc83019e7388a7041101da92c9829_frame_000000.jpg"
#---------------------------------------测试ocr-----------------------------------------
test_image_path = "images/AiCheck_20251016114138.jpg"
# api_url="http://10.0.0.202:12342/api/v1/ocr_onnx_from_base64"
# api_url="http://192.168.30.195:12345/api/v1/ocr_onnx_from_base64"
# # 调用测试函数
# test_ocr_api(test_image_path, api_url)
#---------------------------------------测试ocrender-----------------------------------------
# # -----------------------------------------测试yolov8 侵占消防区域检测-----------------------------------------
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000000.jpg"
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000720.jpg"
# api_url = "http://10.0.0.202:12342/api/v1/detect_onnx_from_base64_0"
# api_url = "http://192.168.30.195:12345/api/v1/detect_onnx_from_base64_0"
# test_detect(test_image_path, api_url)
# #-----------------------------------------测试yolov8 侵占消防区域检测 end-----------------------------------------
# #-----------------------------------------测试yolov8 灭火器检测-----------------------------------------
# test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/ce81420a27cdaff14fe42f967eaa49a3_frame_001060.jpg"
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
# # test_image_path = "/home/admin-root/haotian/康达瑞贝斯机器狗/YoloV8Obj/dataset_20250819/train/images/1e4c75b76e531606e2adc491a8f09ae8_frame_000120.jpg"
# api_url = "http://10.0.0.202:12342/api/v1/detect_from_base64_1"
api_url = "http://192.168.30.195:12345/api/v1/detect_from_base64_1"
test_detect(test_image_path, api_url=api_url)
# #-----------------------------------------测试yolov8 灭火器检测 end-----------------------------------------

27
test_generic.py Normal file
View File

@ -0,0 +1,27 @@
from typing import Generic, TypeVar, Dict
from typing import Optional
K = TypeVar("K") # 键类型
V = TypeVar("V") # 值类型
class GenericCache(Generic[K, V]):
def __init__(self):
self._store: Dict[K, V] = {}
def set(self, key: K, value: V) -> None:
self._store[key] = value
def get(self, key: K) -> Optional[V]:
return self._store.get(key)
# 使用示例
cache = GenericCache[str, int]() # 键为 str值为 int
cache.set("count", 100)
value: Optional[int] = cache.get("count") # 返回 100
cache.set(112, "123")
for key in cache._store.keys():
print(type(cache.get(key)))

7
test_ocr.py Normal file
View File

@ -0,0 +1,7 @@
from app.util.baiduOCR import BaiduOCR
if __name__ == '__main__':
ocr = BaiduOCR()
result = ocr.ocr('tmp/ocr_images/c723d5d8-697b-41e5-8cf4-d4644ce89a77.jpg')
print(result)

87
test_requests.py Normal file
View File

@ -0,0 +1,87 @@
import requests
from requests import Response
from typing import List, TypeVar, Dict, Any, Optional
import logging
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
T = TypeVar('T') # 泛型类型声明
REQUEST_URL = "http://aijinan.biaofun.com.cn/app/query/shandong" # 替换为实际API地址
def main():
cities = ["济南市", "青岛市", "临沂市", "菏泽市", "威海市", "东营市"]
# 修正:每次请求创建新的参数对象,避免累积
for city in cities:
params = {"city": city}
# 发起RPC请求
response_list = rpc_get(REQUEST_URL, params, False)
# 处理响应
if not response_list:
logger.warning(f"Empty response for city: {city}")
continue
# 设置城市名称并处理后续逻辑 (根据实际Response类调整)
response_list[0]['name'] = city # 这里假定response_list是字典列表
after(response_list)
def rpc_get(url: str, params: Dict[str, Any], right_or_wrong: bool) -> Optional[List[T]]:
"""
RPC请求封装
:param url: 请求URL
:param params: 请求参数
:param right_or_wrong: 未使用的标志保留原Java参数
:return: 响应对象列表
"""
logger.info(f"Request URL: {url}")
try:
# 发起POST请求
response: Response = requests.post(url, data=params)
# 检查响应状态
if response.status_code != 200:
logger.error(f"Request failed, Status Code: {response.status_code}, URL: {url}")
return None
response_text = response.text
logger.debug(f"Response: {response_text}")
# 解析响应需要根据实际API响应格式实现
return parse_response(response_text)
except Exception as e:
logger.exception(f"RPC request failed: {str(e)}")
return None
def parse_response(response_text: str) -> List[T]:
"""
解析API响应 (需要根据实际API响应格式实现)
示例实现 - 替换为实际解析逻辑
"""
# 这里需要根据实际API返回格式实现解析
# 示例解析为JSON对象
import json
try:
data = json.loads(response_text)
# 根据实际数据结构返回列表
return [data] if isinstance(data, dict) else data
except json.JSONDecodeError:
logger.error("Invalid JSON response")
return []
def after(response_list: List[Any]):
"""
后续处理逻辑 (根据实际需求实现)
"""
# 示例实现
print(f"Processing response for: {response_list[0].get('name')}")
# 实际业务逻辑...
if __name__ == "__main__":
main()