MLPlatform/api/model_api.py

121 lines
3.7 KiB
Python

from fastapi import APIRouter, HTTPException, Query, Path
from typing import Optional, List, Dict
from function.model_manager import ModelManager
from pydantic import BaseModel
router = APIRouter()
model_manager = ModelManager()
# 数据模型
class TrainRequest(BaseModel):
train_path: str
val_path: str
algorithm: str
task_type : str
parameters: Dict
# metrics: List[str]
experiment_name: str
class PredictRequest(BaseModel):
run_id: str
data_path: str
output_path: str
batch_size: int = 32
device: str = "cuda"
return_proba: bool = True
metrics: Optional[List[str]] = None
@router.get("/available")
async def get_available_models():
"""获取可用模型列表"""
result = model_manager.get_models()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/available/{model_name}")
async def get_model_details(model_name: str):
"""获取模型详情"""
result = model_manager.get_model_details(model_name)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/metrics")
async def get_metrics():
"""获取评价指标列表"""
result = model_manager.get_metrics()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.post("/train")
async def train_model(request: TrainRequest):
"""模型训练"""
result = model_manager.train_model(
train_path=request.train_path,
val_path=request.val_path,
model_config={
'algorithm': request.algorithm,
'task_type': request.task_type,
'params': request.parameters,
# 'metrics': request.metrics
},
experiment_name=request.experiment_name
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/experiments")
async def get_experiments(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
"""获取MLFlow中保存的实验"""
result = model_manager.get_experiments(page=page, page_size=page_size)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/experiment/{experiment_name}")
async def get_finished_models(
experiment_name: str,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
"""获取已训练完成的模型列表"""
result = model_manager.get_finished_models(
experiment_name=experiment_name,
page=page,
page_size=page_size
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.delete("/{run_id}")
async def delete_model(run_id: str = Path(..., description="MLflow运行ID")):
"""删除指定的训练好的模型"""
result = model_manager.delete_model(run_id)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.post("/predict")
async def predict(request: PredictRequest):
"""模型预测"""
result = model_manager.predict(
run_id=request.run_id,
data_path=request.data_path,
output_path=request.output_path,
batch_size=request.batch_size,
device=request.device,
return_proba=request.return_proba,
metrics=request.metrics
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result