from fastapi import APIRouter, HTTPException, Query, Path from typing import Optional, List, Dict from function.model_manager import ModelManager from pydantic import BaseModel router = APIRouter() model_manager = ModelManager() # 数据模型 class TrainRequest(BaseModel): model: str dataset: Dict[str, str] parameters: Dict metrics: List[str] class PredictRequest(BaseModel): run_id: str data_path: str output_path: str batch_size: int = 32 device: str = "cuda" return_proba: bool = True metrics: Optional[List[str]] = None @router.get("/available") async def get_available_models(): """获取可用模型列表""" result = model_manager.get_models() if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['error']) return result @router.get("/available/{model_name}") async def get_model_details(model_name: str): """获取模型详情""" result = model_manager.get_model_details(model_name) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['error']) return result @router.get("/metrics") async def get_metrics(): """获取评价指标列表""" result = model_manager.get_metrics() if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['error']) return result @router.post("/train") async def train_model(request: TrainRequest): """模型训练""" result = model_manager.train_model( train_data=request.dataset['train'], val_data=request.dataset.get('val'), model_config={ 'model_name': request.model, 'parameters': request.parameters, 'metrics': request.metrics } ) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) return result @router.get("/experiments") async def get_experiments( page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100) ): """获取MLFlow中保存的实验""" result = model_manager.get_experiments(page=page, page_size=page_size) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) return result @router.get("/experiment/{experiment_name}") async def get_finished_models( experiment_name: str, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100) ): """获取已训练完成的模型列表""" result = model_manager.get_finished_models( experiment_name=experiment_name, page=page, page_size=page_size ) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) return result @router.delete("/{run_id}") async def delete_model(run_id: str = Path(..., description="MLflow运行ID")): """删除指定的训练好的模型""" result = model_manager.delete_model(run_id) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) return result @router.post("/predict") async def predict(request: PredictRequest): """模型预测""" result = model_manager.predict( run_id=request.run_id, data_path=request.data, output_path=request.output_path, batch_size=request.batch_size, device=request.device, return_proba=request.return_proba, metrics=request.metrics ) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) return result