MLPlatform/api/optimize_api.py

161 lines
6.6 KiB
Python

from fastapi import APIRouter, Query, Path, Body, HTTPException, BackgroundTasks
from typing import Dict, List, Optional, Any, Union
from pydantic import BaseModel, Field
from function.optimize_manager import OptimizeManager
import logging
from datetime import datetime
# 创建路由
router = APIRouter()
# 初始化优化管理器
optimize_manager = OptimizeManager()
# 初始化日志
logger = logging.getLogger(__name__)
# 请求体模型定义
class OptimizationRequest(BaseModel):
run_id: str = Field(..., description="已训练模型的运行ID")
method: str = Field(..., description="优化方法名称,如 'GridSearchCV'")
parameters: Dict = Field(..., description="优化方法参数")
data_path: Optional[str] = Field(None, description="数据集路径")
output_dir: Optional[str] = Field(None, description="输出目录路径")
experiment_name: str = Field(..., description="实验名称")
class Config:
json_schema_extra = {
"example": {
"run_id": "bd3697dc238c4d1587e0f4f319d04448",
"method": "GridSearchCV",
"parameters": {
"max_depth": [3, 5, 7],
"n_estimators": [50, 100, 200],
"min_samples_split": [2, 5, 10]
},
"data_path": "dataset/dataset_processed/test_optimize/train.csv",
"output_dir": None,
"experiment_name": "测试模型优化方法"
}
}
# API接口实现
@router.get("/methods", summary="获取优化方法列表")
async def get_optimize_methods():
"""获取所有可用的模型优化方法列表"""
try:
methods = optimize_manager.get_optimize_methods()
return methods
except Exception as e:
logger.error(f"Error getting optimization methods: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取优化方法失败: {str(e)}")
@router.get("/method/{method_name}", summary="获取优化方法详情")
async def get_optimize_method_details(
method_name: str = Path(..., description="优化方法名称")
):
"""获取特定优化方法的详细信息"""
try:
method_details = optimize_manager.get_optimize_method_details(method_name)
return method_details
except Exception as e:
logger.error(f"Error getting method details for {method_name}: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取方法详情失败: {str(e)}")
@router.post("/model", summary="执行模型优化")
async def optimize_model(
request: OptimizationRequest = Body(..., description="优化请求参数")
):
"""启动模型优化任务"""
try:
# 执行优化任务
result = optimize_manager.run_optimization(
run_id=request.run_id,
method=request.method,
parameters=request.parameters,
data_path=request.data_path,
output_dir=request.output_dir,
experiment_name=request.experiment_name
)
return {
"status": "success",
"message": "优化任务已执行",
"optimization": result.get("optimization")
# "run_id": result['optimization'].get("run_id"),
# "optimized_model_id": result['optimization'].get("optimized_model_id")
# "best_params": result.get("best_params"),
# "best_score": result.get("best_score"),
# "optimized_model_id": result.get("optimized_model_id")
}
except Exception as e:
logger.error(f"Error starting optimization task: {str(e)}")
raise HTTPException(status_code=500, detail=f"启动优化任务失败: {str(e)}")
@router.get("/tasks", summary="获取优化任务列表")
async def get_optimization_tasks(
experiment_name: Optional[str] = None,
page: int = Query(1, description="页码", ge=1),
page_size: int = Query(10, description="每页任务数量", ge=1, le=100),
status: Optional[str] = Query(None, description="任务状态过滤")
):
"""获取优化任务列表"""
try:
result = optimize_manager.get_optimization_tasks(experiment_name,page, page_size, status)
return result
except Exception as e:
logger.error(f"Error getting optimization tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取优化任务列表失败: {str(e)}")
@router.get("/task/{task_id}", summary="获取优化任务详情")
async def get_optimization_task(
task_id: str = Path(..., description="任务ID")
):
"""获取特定优化任务的详细信息"""
try:
result = optimize_manager.get_optimization_task(task_id)
if result["status"] == "error":
raise HTTPException(status_code=404, detail=result["message"])
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting optimization task {task_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取优化任务详情失败: {str(e)}")
@router.post("/task/{task_id}/cancel", summary="取消优化任务")
async def cancel_optimization_task(
task_id: str = Path(..., description="任务ID")
):
"""取消正在运行的优化任务"""
try:
result = optimize_manager.cancel_optimization_task(task_id)
if result["status"] == "error":
raise HTTPException(status_code=400, detail=result["message"])
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling optimization task {task_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"取消优化任务失败: {str(e)}")
@router.delete("/task/{task_id}", summary="删除优化任务")
async def delete_optimization_task(
task_id: str = Path(..., description="任务ID")
):
"""删除优化任务及其相关资源"""
try:
result = optimize_manager.delete_optimization_task(task_id)
if result["status"] == "error":
if "任务不存在" in result["message"]:
raise HTTPException(status_code=404, detail=result["message"])
elif "任务正在运行中" in result.get("details", {}).get("reason", ""):
raise HTTPException(status_code=400, detail="无法删除正在运行的任务")
else:
raise HTTPException(status_code=400, detail=result["message"])
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting optimization task {task_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"删除优化任务失败: {str(e)}")