170 lines
4.7 KiB
Python
170 lines
4.7 KiB
Python
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.exceptions import RequestValidationError
|
|
from typing import Optional, Dict, List
|
|
import uvicorn
|
|
from pathlib import Path
|
|
import logging
|
|
import yaml
|
|
from datetime import datetime
|
|
import time
|
|
|
|
# 导入API路由
|
|
from api.data_api import router as data_router
|
|
from api.model_api import router as model_router
|
|
from api.system_api import router as system_router
|
|
|
|
# 创建FastAPI应用
|
|
app = FastAPI(
|
|
title="机器学习平台API",
|
|
description="提供数据处理、模型训练和系统监控功能的API服务",
|
|
version="1.0.0",
|
|
docs_url="/docs",
|
|
redoc_url="/redoc"
|
|
)
|
|
|
|
# 加载配置
|
|
def load_config():
|
|
try:
|
|
with open('config/config.yaml', 'r', encoding='utf-8') as f:
|
|
return yaml.safe_load(f)
|
|
except Exception as e:
|
|
logger.error(f"Error loading config: {str(e)}")
|
|
return {}
|
|
|
|
# 初始化配置
|
|
config = load_config()
|
|
|
|
# 设置日志
|
|
log_dir = Path('.log')
|
|
log_dir.mkdir(exist_ok=True)
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 配置CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.get('cors', {}).get('origins', ["*"]),
|
|
allow_credentials=True,
|
|
allow_methods=config.get('cors', {}).get('methods', ["*"]),
|
|
allow_headers=config.get('cors', {}).get('headers', ["*"]),
|
|
max_age=config.get('cors', {}).get('max_age', 600),
|
|
)
|
|
|
|
# 请求计时中间件
|
|
@app.middleware("http")
|
|
async def add_process_time_header(request: Request, call_next):
|
|
start_time = time.time()
|
|
response = await call_next(request)
|
|
process_time = time.time() - start_time
|
|
response.headers["X-Process-Time"] = str(process_time)
|
|
return response
|
|
|
|
# 全局错误处理
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
"""处理请求参数验证错误"""
|
|
return JSONResponse(
|
|
status_code=422,
|
|
content={
|
|
"status": "error",
|
|
"message": "请求参数验证失败",
|
|
"details": exc.errors()
|
|
}
|
|
)
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
"""处理HTTP异常"""
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={
|
|
"status": "error",
|
|
"message": exc.detail
|
|
}
|
|
)
|
|
|
|
@app.exception_handler(Exception)
|
|
async def general_exception_handler(request: Request, exc: Exception):
|
|
"""处理其他异常"""
|
|
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"status": "error",
|
|
"message": "服务器内部错误",
|
|
"details": str(exc) if config.get('debug', False) else None
|
|
}
|
|
)
|
|
|
|
# 注册路由
|
|
app.include_router(
|
|
data_router,
|
|
prefix="/data",
|
|
tags=["数据处理"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
app.include_router(
|
|
model_router,
|
|
prefix="/model",
|
|
tags=["模型管理"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
app.include_router(
|
|
system_router,
|
|
prefix="/system",
|
|
tags=["系统监控"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
# 健康检查
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""系统健康检查"""
|
|
return {
|
|
"status": "healthy",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"version": app.version,
|
|
"environment": config.get('environment', 'production')
|
|
}
|
|
|
|
# 启动事件
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""服务启动时的初始化操作"""
|
|
logger.info("Server starting up...")
|
|
# 创建必要的目录
|
|
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
|
|
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
|
|
Path(".log").mkdir(exist_ok=True)
|
|
logger.info("Server started successfully")
|
|
|
|
# 关闭事件
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""服务关闭时的清理操作"""
|
|
logger.info("Server shutting down...")
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.get('host', '0.0.0.0'),
|
|
port=config.get('port', 8992),
|
|
reload=config.get('debug', True),
|
|
workers=config.get('workers', 4),
|
|
log_level=config.get('log_level', 'info'),
|
|
access_log=True
|
|
) |