80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
from fastapi import APIRouter, HTTPException, Query
|
|
from typing import Optional, List, Dict
|
|
from function.data_manager import DataManager
|
|
from pydantic import BaseModel
|
|
|
|
router = APIRouter()
|
|
data_manager = DataManager()
|
|
|
|
# 数据模型
|
|
class ProcessRequest(BaseModel):
|
|
input_path: str
|
|
output_dir: str
|
|
process_methods: List[Dict]
|
|
feature_methods: List[Dict]
|
|
split_params: Dict[str, float]
|
|
|
|
@router.get("/preprocessing/methods")
|
|
async def get_preprocessing_methods():
|
|
"""获取数据预处理方法列表"""
|
|
result = data_manager.get_preprocessing_methods()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/preprocessing/method/{method_name}")
|
|
async def get_preprocessing_method_details(method_name: str):
|
|
"""获取预处理方法详情"""
|
|
result = data_manager.get_preprocessing_method_details(method_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/feature/methods")
|
|
async def get_feature_methods():
|
|
"""获取特征工程方法列表"""
|
|
result = data_manager.get_feature_engineering_methods()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/feature/method/{method_name}")
|
|
async def get_feature_method_details(method_name: str):
|
|
"""获取特征工程方法详情"""
|
|
result = data_manager.get_feature_engineering_method_details(method_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.post("/process")
|
|
async def process_dataset(request: ProcessRequest):
|
|
"""处理数据集"""
|
|
|
|
print(f"处理数据集: {request.process_methods}")
|
|
result = data_manager.process_dataset(
|
|
input_path=request.input_path,
|
|
output_dir=request.output_dir,
|
|
process_methods=request.process_methods,
|
|
feature_methods=request.feature_methods,
|
|
split_params=request.split_params
|
|
)
|
|
print("result", result)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.get("/datasets")
|
|
async def get_datasets():
|
|
"""获取可用数据集列表"""
|
|
try:
|
|
datasets = data_manager.get_dataset()
|
|
return {
|
|
"status": "success",
|
|
"datasets": datasets
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"获取数据集列表失败: {str(e)}"
|
|
)
|