MLPlatform/main.py

180 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from typing import Optional, Dict, List
from contextlib import asynccontextmanager
import uvicorn
from pathlib import Path
import logging
import yaml
from datetime import datetime
import time
# 导入API路由
from api.data_api import router as data_router
from api.model_api import router as model_router
from api.system_api import router as system_router
# 设置watchfiles 日志级别为warning
logging.getLogger("watchfiles").setLevel(logging.WARNING)
# 生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
"""生命周期管理替代原来的startup/shutdown事件"""
# 启动时的初始化操作
logger.info("Server starting up...")
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
Path(".log").mkdir(exist_ok=True)
logger.info("Server started successfully")
yield # 应用运行期间
# 关闭时的清理操作
logger.info("Server shutting down...")
# 创建FastAPI应用
app = FastAPI(
title="机器学习平台API",
description="提供数据处理、模型训练和系统监控功能的API服务",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan # 使用新的生命周期管理方式
)
# 加载配置
def load_config():
try:
with open('config/config.yaml', 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except Exception as e:
logger.error(f"Error loading config: {str(e)}")
return {}
# 初始化配置
config = load_config()
# 设置日志
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=config.get('cors', {}).get('origins', ["*"]),
allow_credentials=True,
allow_methods=config.get('cors', {}).get('methods', ["*"]),
allow_headers=config.get('cors', {}).get('headers', ["*"]),
max_age=config.get('cors', {}).get('max_age', 600),
)
# 请求计时中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 全局错误处理
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求参数验证错误"""
return JSONResponse(
status_code=422,
content={
"status": "error",
"message": "请求参数验证失败",
"details": exc.errors()
}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理HTTP异常"""
return JSONResponse(
status_code=exc.status_code,
content={
"status": "error",
"message": exc.detail
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理其他异常"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": "服务器内部错误",
"details": str(exc) if config.get('debug', False) else None
}
)
# 注册路由
app.include_router(
data_router,
prefix="/data",
tags=["数据处理"],
responses={404: {"description": "Not found"}},
)
app.include_router(
model_router,
prefix="/model",
tags=["模型管理"],
responses={404: {"description": "Not found"}},
)
app.include_router(
system_router,
prefix="/system",
tags=["系统监控"],
responses={404: {"description": "Not found"}},
)
# 健康检查
@app.get("/health")
async def health_check():
"""系统健康检查"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"version": app.version,
"environment": config.get('environment', 'production')
}
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=config.get('host', '0.0.0.0'),
port=config.get('port', 8992),
# reload=True 支持热重载
reload=config.get('debug', False),
workers=config.get('workers', 4),
log_level=config.get('log_level', 'warning'),
access_log=True
)