115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
from fastapi import APIRouter, HTTPException, Query, Path
|
|
from typing import Optional, List, Dict
|
|
from function.model_manager import ModelManager
|
|
from pydantic import BaseModel
|
|
|
|
router = APIRouter()
|
|
model_manager = ModelManager()
|
|
|
|
# 数据模型
|
|
class TrainRequest(BaseModel):
|
|
model: str
|
|
dataset: Dict[str, str]
|
|
parameters: Dict
|
|
metrics: List[str]
|
|
|
|
class PredictRequest(BaseModel):
|
|
run_id: str
|
|
data_path: str
|
|
output_path: str
|
|
batch_size: int = 32
|
|
device: str = "cuda"
|
|
return_proba: bool = True
|
|
metrics: Optional[List[str]] = None
|
|
|
|
@router.get("/available")
|
|
async def get_available_models():
|
|
"""获取可用模型列表"""
|
|
result = model_manager.get_models()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/available/{model_name}")
|
|
async def get_model_details(model_name: str):
|
|
"""获取模型详情"""
|
|
result = model_manager.get_model_details(model_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/metrics")
|
|
async def get_metrics():
|
|
"""获取评价指标列表"""
|
|
result = model_manager.get_metrics()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.post("/train")
|
|
async def train_model(request: TrainRequest):
|
|
"""模型训练"""
|
|
result = model_manager.train_model(
|
|
train_data=request.dataset['train'],
|
|
val_data=request.dataset['val'],
|
|
model_config={
|
|
'model_name': request.model,
|
|
'parameters': request.parameters,
|
|
'metrics': request.metrics
|
|
}
|
|
)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.get("/experiments")
|
|
async def get_experiments(
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(10, ge=1, le=100)
|
|
):
|
|
"""获取MLFlow中保存的实验"""
|
|
result = model_manager.get_experiments(page=page, page_size=page_size)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.get("/experiment/{experiment_name}")
|
|
async def get_finished_models(
|
|
experiment_name: str,
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(10, ge=1, le=100)
|
|
):
|
|
"""获取已训练完成的模型列表"""
|
|
result = model_manager.get_finished_models(
|
|
experiment_name=experiment_name,
|
|
page=page,
|
|
page_size=page_size
|
|
)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.delete("/{run_id}")
|
|
async def delete_model(run_id: str = Path(..., description="MLflow运行ID")):
|
|
"""删除指定的训练好的模型"""
|
|
result = model_manager.delete_model(run_id)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.post("/predict")
|
|
async def predict(request: PredictRequest):
|
|
"""模型预测"""
|
|
result = model_manager.predict(
|
|
run_id=request.run_id,
|
|
data_path=request.data_path,
|
|
output_path=request.output_path,
|
|
batch_size=request.batch_size,
|
|
device=request.device,
|
|
return_proba=request.return_proba,
|
|
metrics=request.metrics
|
|
)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|