77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
from fastapi import APIRouter, HTTPException, Query
|
|
from typing import Optional, List, Dict
|
|
from function.data_manager import DataManager
|
|
from pydantic import BaseModel
|
|
|
|
router = APIRouter()
|
|
data_manager = DataManager()
|
|
|
|
# 数据模型
|
|
class ProcessRequest(BaseModel):
|
|
input_file: str
|
|
output_path: str
|
|
preprocessing: List[Dict]
|
|
feature_engineering: List[Dict]
|
|
split_ratio: Dict[str, float]
|
|
|
|
@router.get("/preprocessing/methods")
|
|
async def get_preprocessing_methods():
|
|
"""获取数据预处理方法列表"""
|
|
result = data_manager.get_preprocessing_methods()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/preprocessing/method/{method_name}")
|
|
async def get_preprocessing_method_details(method_name: str):
|
|
"""获取预处理方法详情"""
|
|
result = data_manager.get_preprocessing_method_details(method_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/feature/methods")
|
|
async def get_feature_methods():
|
|
"""获取特征工程方法列表"""
|
|
result = data_manager.get_feature_engineering_methods()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/feature/method/{method_name}")
|
|
async def get_feature_method_details(method_name: str):
|
|
"""获取特征工程方法详情"""
|
|
result = data_manager.get_feature_engineering_method_details(method_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.post("/process")
|
|
async def process_dataset(request: ProcessRequest):
|
|
"""处理数据集"""
|
|
result = data_manager.process_dataset(
|
|
input_path=request.input_file,
|
|
output_dir=request.output_path,
|
|
process_methods=request.preprocessing,
|
|
feature_methods=request.feature_engineering,
|
|
split_params=request.split_ratio
|
|
)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.get("/datasets")
|
|
async def get_datasets():
|
|
"""获取可用数据集列表"""
|
|
try:
|
|
datasets = data_manager.get_dataset()
|
|
return {
|
|
"status": "success",
|
|
"datasets": datasets
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"获取数据集列表失败: {str(e)}"
|
|
)
|