from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from typing import Optional, Dict, List from contextlib import asynccontextmanager import uvicorn from pathlib import Path import logging import yaml from datetime import datetime import time # 导入API路由 from api.data_api import router as data_router from api.model_api import router as model_router from api.system_api import router as system_router # 设置watchfiles 日志级别为warning logging.getLogger("watchfiles").setLevel(logging.WARNING) # 生命周期管理 @asynccontextmanager async def lifespan(app: FastAPI): """生命周期管理(替代原来的startup/shutdown事件)""" # 启动时的初始化操作 logger.info("Server starting up...") Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True) Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True) Path(".log").mkdir(exist_ok=True) logger.info("Server started successfully") yield # 应用运行期间 # 关闭时的清理操作 logger.info("Server shutting down...") # 创建FastAPI应用 app = FastAPI( title="机器学习平台API", description="提供数据处理、模型训练和系统监控功能的API服务", version="1.0.0", docs_url="/docs", redoc_url="/redoc", lifespan=lifespan # 使用新的生命周期管理方式 ) # 加载配置 def load_config(): try: with open('config/config.yaml', 'r', encoding='utf-8') as f: return yaml.safe_load(f) except Exception as e: logger.error(f"Error loading config: {str(e)}") return {} # 初始化配置 config = load_config() # 设置日志 log_dir = Path('.log') log_dir.mkdir(exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=config.get('cors', {}).get('origins', ["*"]), allow_credentials=True, allow_methods=config.get('cors', {}).get('methods', ["*"]), allow_headers=config.get('cors', {}).get('headers', ["*"]), max_age=config.get('cors', {}).get('max_age', 600), ) # 请求计时中间件 @app.middleware("http") async def add_process_time_header(request: Request, call_next): start_time = time.time() response = await call_next(request) process_time = time.time() - start_time response.headers["X-Process-Time"] = str(process_time) return response # 全局错误处理 @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): """处理请求参数验证错误""" return JSONResponse( status_code=422, content={ "status": "error", "message": "请求参数验证失败", "details": exc.errors() } ) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """处理HTTP异常""" return JSONResponse( status_code=exc.status_code, content={ "status": "error", "message": exc.detail } ) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """处理其他异常""" logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) return JSONResponse( status_code=500, content={ "status": "error", "message": "服务器内部错误", "details": str(exc) if config.get('debug', False) else None } ) # 注册路由 app.include_router( data_router, prefix="/data", tags=["数据处理"], responses={404: {"description": "Not found"}}, ) app.include_router( model_router, prefix="/model", tags=["模型管理"], responses={404: {"description": "Not found"}}, ) app.include_router( system_router, prefix="/system", tags=["系统监控"], responses={404: {"description": "Not found"}}, ) # 健康检查 @app.get("/health") async def health_check(): """系统健康检查""" return { "status": "healthy", "timestamp": datetime.now().isoformat(), "version": app.version, "environment": config.get('environment', 'production') } if __name__ == "__main__": uvicorn.run( "main:app", host=config.get('host', '0.0.0.0'), port=config.get('port', 8992), # reload=True 支持热重载 # reload=config.get('debug', False), workers=config.get('workers', 4), log_level=config.get('log_level', 'warning'), access_log=True )