180 lines
5.0 KiB
Python
180 lines
5.0 KiB
Python
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.exceptions import RequestValidationError
|
||
from typing import Optional, Dict, List
|
||
from contextlib import asynccontextmanager
|
||
import uvicorn
|
||
from pathlib import Path
|
||
import logging
|
||
import yaml
|
||
from datetime import datetime
|
||
import time
|
||
|
||
# 导入API路由
|
||
from api.data_api import router as data_router
|
||
from api.model_api import router as model_router
|
||
from api.system_api import router as system_router
|
||
|
||
|
||
# 设置watchfiles 日志级别为warning
|
||
logging.getLogger("watchfiles").setLevel(logging.WARNING)
|
||
|
||
|
||
# 生命周期管理
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""生命周期管理(替代原来的startup/shutdown事件)"""
|
||
# 启动时的初始化操作
|
||
logger.info("Server starting up...")
|
||
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
|
||
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
|
||
Path(".log").mkdir(exist_ok=True)
|
||
logger.info("Server started successfully")
|
||
|
||
yield # 应用运行期间
|
||
|
||
# 关闭时的清理操作
|
||
logger.info("Server shutting down...")
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title="机器学习平台API",
|
||
description="提供数据处理、模型训练和系统监控功能的API服务",
|
||
version="1.0.0",
|
||
docs_url="/docs",
|
||
redoc_url="/redoc",
|
||
lifespan=lifespan # 使用新的生命周期管理方式
|
||
)
|
||
|
||
# 加载配置
|
||
def load_config():
|
||
try:
|
||
with open('config/config.yaml', 'r', encoding='utf-8') as f:
|
||
return yaml.safe_load(f)
|
||
except Exception as e:
|
||
logger.error(f"Error loading config: {str(e)}")
|
||
return {}
|
||
|
||
# 初始化配置
|
||
config = load_config()
|
||
|
||
# 设置日志
|
||
log_dir = Path('.log')
|
||
log_dir.mkdir(exist_ok=True)
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 配置CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=config.get('cors', {}).get('origins', ["*"]),
|
||
allow_credentials=True,
|
||
allow_methods=config.get('cors', {}).get('methods', ["*"]),
|
||
allow_headers=config.get('cors', {}).get('headers', ["*"]),
|
||
max_age=config.get('cors', {}).get('max_age', 600),
|
||
)
|
||
|
||
# 请求计时中间件
|
||
@app.middleware("http")
|
||
async def add_process_time_header(request: Request, call_next):
|
||
start_time = time.time()
|
||
response = await call_next(request)
|
||
process_time = time.time() - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
return response
|
||
|
||
# 全局错误处理
|
||
@app.exception_handler(RequestValidationError)
|
||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||
"""处理请求参数验证错误"""
|
||
return JSONResponse(
|
||
status_code=422,
|
||
content={
|
||
"status": "error",
|
||
"message": "请求参数验证失败",
|
||
"details": exc.errors()
|
||
}
|
||
)
|
||
|
||
@app.exception_handler(HTTPException)
|
||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||
"""处理HTTP异常"""
|
||
return JSONResponse(
|
||
status_code=exc.status_code,
|
||
content={
|
||
"status": "error",
|
||
"message": exc.detail
|
||
}
|
||
)
|
||
|
||
@app.exception_handler(Exception)
|
||
async def general_exception_handler(request: Request, exc: Exception):
|
||
"""处理其他异常"""
|
||
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={
|
||
"status": "error",
|
||
"message": "服务器内部错误",
|
||
"details": str(exc) if config.get('debug', False) else None
|
||
}
|
||
)
|
||
|
||
# 注册路由
|
||
app.include_router(
|
||
data_router,
|
||
prefix="/data",
|
||
tags=["数据处理"],
|
||
responses={404: {"description": "Not found"}},
|
||
)
|
||
|
||
app.include_router(
|
||
model_router,
|
||
prefix="/model",
|
||
tags=["模型管理"],
|
||
responses={404: {"description": "Not found"}},
|
||
)
|
||
|
||
app.include_router(
|
||
system_router,
|
||
prefix="/system",
|
||
tags=["系统监控"],
|
||
responses={404: {"description": "Not found"}},
|
||
)
|
||
|
||
# 健康检查
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""系统健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"version": app.version,
|
||
"environment": config.get('environment', 'production')
|
||
}
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
"main:app",
|
||
host=config.get('host', '0.0.0.0'),
|
||
port=config.get('port', 8992),
|
||
# reload=True 支持热重载
|
||
# reload=config.get('debug', False),
|
||
workers=config.get('workers', 4),
|
||
log_level=config.get('log_level', 'warning'),
|
||
access_log=True
|
||
) |