161 lines
6.6 KiB
Python
161 lines
6.6 KiB
Python
from fastapi import APIRouter, Query, Path, Body, HTTPException, BackgroundTasks
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from pydantic import BaseModel, Field
|
|
from function.optimize_manager import OptimizeManager
|
|
import logging
|
|
from datetime import datetime
|
|
|
|
# 创建路由
|
|
router = APIRouter()
|
|
|
|
# 初始化优化管理器
|
|
optimize_manager = OptimizeManager()
|
|
|
|
# 初始化日志
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 请求体模型定义
|
|
class OptimizationRequest(BaseModel):
|
|
run_id: str = Field(..., description="已训练模型的运行ID")
|
|
method: str = Field(..., description="优化方法名称,如 'GridSearchCV'")
|
|
parameters: Dict = Field(..., description="优化方法参数")
|
|
data_path: Optional[str] = Field(None, description="数据集路径")
|
|
output_dir: Optional[str] = Field(None, description="输出目录路径")
|
|
experiment_name: str = Field(..., description="实验名称")
|
|
|
|
class Config:
|
|
json_schema_extra = {
|
|
"example": {
|
|
"run_id": "bd3697dc238c4d1587e0f4f319d04448",
|
|
"method": "GridSearchCV",
|
|
"parameters": {
|
|
"max_depth": [3, 5, 7],
|
|
"n_estimators": [50, 100, 200],
|
|
"min_samples_split": [2, 5, 10]
|
|
},
|
|
"data_path": "dataset/dataset_processed/test_optimize/train.csv",
|
|
"output_dir": None,
|
|
"experiment_name": "测试模型优化方法"
|
|
}
|
|
}
|
|
|
|
# API接口实现
|
|
@router.get("/methods", summary="获取优化方法列表")
|
|
async def get_optimize_methods():
|
|
"""获取所有可用的模型优化方法列表"""
|
|
try:
|
|
methods = optimize_manager.get_optimize_methods()
|
|
return methods
|
|
except Exception as e:
|
|
logger.error(f"Error getting optimization methods: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"获取优化方法失败: {str(e)}")
|
|
|
|
@router.get("/method/{method_name}", summary="获取优化方法详情")
|
|
async def get_optimize_method_details(
|
|
method_name: str = Path(..., description="优化方法名称")
|
|
):
|
|
"""获取特定优化方法的详细信息"""
|
|
try:
|
|
method_details = optimize_manager.get_optimize_method_details(method_name)
|
|
return method_details
|
|
except Exception as e:
|
|
logger.error(f"Error getting method details for {method_name}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"获取方法详情失败: {str(e)}")
|
|
|
|
@router.post("/model", summary="执行模型优化")
|
|
async def optimize_model(
|
|
request: OptimizationRequest = Body(..., description="优化请求参数")
|
|
):
|
|
"""启动模型优化任务"""
|
|
try:
|
|
# 执行优化任务
|
|
result = optimize_manager.run_optimization(
|
|
run_id=request.run_id,
|
|
method=request.method,
|
|
parameters=request.parameters,
|
|
data_path=request.data_path,
|
|
output_dir=request.output_dir,
|
|
experiment_name=request.experiment_name
|
|
)
|
|
|
|
return {
|
|
"status": "success",
|
|
"message": "优化任务已执行",
|
|
"optimization": result.get("optimization")
|
|
# "run_id": result['optimization'].get("run_id"),
|
|
# "optimized_model_id": result['optimization'].get("optimized_model_id")
|
|
# "best_params": result.get("best_params"),
|
|
# "best_score": result.get("best_score"),
|
|
# "optimized_model_id": result.get("optimized_model_id")
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error starting optimization task: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"启动优化任务失败: {str(e)}")
|
|
|
|
@router.get("/tasks", summary="获取优化任务列表")
|
|
async def get_optimization_tasks(
|
|
experiment_name: Optional[str] = None,
|
|
page: int = Query(1, description="页码", ge=1),
|
|
page_size: int = Query(10, description="每页任务数量", ge=1, le=100),
|
|
status: Optional[str] = Query(None, description="任务状态过滤")
|
|
):
|
|
"""获取优化任务列表"""
|
|
try:
|
|
result = optimize_manager.get_optimization_tasks(experiment_name,page, page_size, status)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting optimization tasks: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"获取优化任务列表失败: {str(e)}")
|
|
|
|
@router.get("/task/{task_id}", summary="获取优化任务详情")
|
|
async def get_optimization_task(
|
|
task_id: str = Path(..., description="任务ID")
|
|
):
|
|
"""获取特定优化任务的详细信息"""
|
|
try:
|
|
result = optimize_manager.get_optimization_task(task_id)
|
|
if result["status"] == "error":
|
|
raise HTTPException(status_code=404, detail=result["message"])
|
|
return result
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting optimization task {task_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"获取优化任务详情失败: {str(e)}")
|
|
|
|
@router.post("/task/{task_id}/cancel", summary="取消优化任务")
|
|
async def cancel_optimization_task(
|
|
task_id: str = Path(..., description="任务ID")
|
|
):
|
|
"""取消正在运行的优化任务"""
|
|
try:
|
|
result = optimize_manager.cancel_optimization_task(task_id)
|
|
if result["status"] == "error":
|
|
raise HTTPException(status_code=400, detail=result["message"])
|
|
return result
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling optimization task {task_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"取消优化任务失败: {str(e)}")
|
|
|
|
@router.delete("/task/{task_id}", summary="删除优化任务")
|
|
async def delete_optimization_task(
|
|
task_id: str = Path(..., description="任务ID")
|
|
):
|
|
"""删除优化任务及其相关资源"""
|
|
try:
|
|
result = optimize_manager.delete_optimization_task(task_id)
|
|
if result["status"] == "error":
|
|
if "任务不存在" in result["message"]:
|
|
raise HTTPException(status_code=404, detail=result["message"])
|
|
elif "任务正在运行中" in result.get("details", {}).get("reason", ""):
|
|
raise HTTPException(status_code=400, detail="无法删除正在运行的任务")
|
|
else:
|
|
raise HTTPException(status_code=400, detail=result["message"])
|
|
return result
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting optimization task {task_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"删除优化任务失败: {str(e)}") |