MLPlatform/main.py

170 lines
4.7 KiB
Python

from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from typing import Optional, Dict, List
import uvicorn
from pathlib import Path
import logging
import yaml
from datetime import datetime
import time
# 导入API路由
from api.data_api import router as data_router
from api.model_api import router as model_router
from api.system_api import router as system_router
# 创建FastAPI应用
app = FastAPI(
title="机器学习平台API",
description="提供数据处理、模型训练和系统监控功能的API服务",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# 加载配置
def load_config():
try:
with open('config/config.yaml', 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except Exception as e:
logger.error(f"Error loading config: {str(e)}")
return {}
# 初始化配置
config = load_config()
# 设置日志
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=config.get('cors', {}).get('origins', ["*"]),
allow_credentials=True,
allow_methods=config.get('cors', {}).get('methods', ["*"]),
allow_headers=config.get('cors', {}).get('headers', ["*"]),
max_age=config.get('cors', {}).get('max_age', 600),
)
# 请求计时中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 全局错误处理
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求参数验证错误"""
return JSONResponse(
status_code=422,
content={
"status": "error",
"message": "请求参数验证失败",
"details": exc.errors()
}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理HTTP异常"""
return JSONResponse(
status_code=exc.status_code,
content={
"status": "error",
"message": exc.detail
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理其他异常"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": "服务器内部错误",
"details": str(exc) if config.get('debug', False) else None
}
)
# 注册路由
app.include_router(
data_router,
prefix="/data",
tags=["数据处理"],
responses={404: {"description": "Not found"}},
)
app.include_router(
model_router,
prefix="/model",
tags=["模型管理"],
responses={404: {"description": "Not found"}},
)
app.include_router(
system_router,
prefix="/system",
tags=["系统监控"],
responses={404: {"description": "Not found"}},
)
# 健康检查
@app.get("/health")
async def health_check():
"""系统健康检查"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"version": app.version,
"environment": config.get('environment', 'production')
}
# 启动事件
@app.on_event("startup")
async def startup_event():
"""服务启动时的初始化操作"""
logger.info("Server starting up...")
# 创建必要的目录
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
Path(".log").mkdir(exist_ok=True)
logger.info("Server started successfully")
# 关闭事件
@app.on_event("shutdown")
async def shutdown_event():
"""服务关闭时的清理操作"""
logger.info("Server shutting down...")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=config.get('host', '0.0.0.0'),
port=config.get('port', 8992),
reload=config.get('debug', True),
workers=config.get('workers', 4),
log_level=config.get('log_level', 'info'),
access_log=True
)