130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
from fastapi import APIRouter, HTTPException, Query, UploadFile, File
|
|
from typing import Optional, List, Dict
|
|
from function.data_manager import DataManager
|
|
from pydantic import BaseModel
|
|
|
|
router = APIRouter()
|
|
data_manager = DataManager()
|
|
|
|
# 数据模型
|
|
class ProcessRequest(BaseModel):
|
|
input_path: str
|
|
output_dir: str
|
|
process_methods: List[Dict]
|
|
feature_methods: List[Dict]
|
|
split_params: Dict[str, float]
|
|
|
|
class CSVRequest(BaseModel):
|
|
data_path: str
|
|
head: int = 5
|
|
tail: int = 5
|
|
info: bool = True
|
|
describe: bool = True
|
|
|
|
@router.get("/preprocessing/methods")
|
|
async def get_preprocessing_methods():
|
|
"""获取数据预处理方法列表"""
|
|
result = data_manager.get_preprocessing_methods()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/preprocessing/method/{method_name}")
|
|
async def get_preprocessing_method_details(method_name: str):
|
|
"""获取预处理方法详情"""
|
|
result = data_manager.get_preprocessing_method_details(method_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/feature/methods")
|
|
async def get_feature_methods():
|
|
"""获取特征工程方法列表"""
|
|
result = data_manager.get_feature_engineering_methods()
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.get("/feature/method/{method_name}")
|
|
async def get_feature_method_details(method_name: str):
|
|
"""获取特征工程方法详情"""
|
|
result = data_manager.get_feature_engineering_method_details(method_name)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
|
return result
|
|
|
|
@router.post("/process")
|
|
async def process_dataset(request: ProcessRequest):
|
|
"""处理数据集"""
|
|
|
|
print(f"处理数据集: {request.process_methods}")
|
|
result = data_manager.process_dataset(
|
|
input_path=request.input_path,
|
|
output_dir=request.output_dir,
|
|
process_methods=request.process_methods,
|
|
feature_methods=request.feature_methods,
|
|
split_params=request.split_params
|
|
)
|
|
print("result", result)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.get("/datasets")
|
|
async def get_datasets():
|
|
"""获取可用数据集列表"""
|
|
try:
|
|
datasets = data_manager.get_dataset()
|
|
# print("datasets", datasets)
|
|
return {
|
|
"status": "success",
|
|
"datasets": datasets
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"获取数据集列表失败: {str(e)}"
|
|
)
|
|
|
|
@router.post("/csv")
|
|
async def read_csv(request: CSVRequest):
|
|
"""读取CSV文件并展示"""
|
|
result = data_manager.read_csv(
|
|
data_path=request.data_path,
|
|
head=request.head,
|
|
tail=request.tail,
|
|
info=request.info,
|
|
describe=request.describe
|
|
)
|
|
if result['status'] == 'error':
|
|
raise HTTPException(status_code=500, detail=result['message'])
|
|
return result
|
|
|
|
@router.post("/upload")
|
|
async def upload_dataset(file: UploadFile = File(...)):
|
|
"""上传数据集"""
|
|
try:
|
|
result = data_manager.save_dataset(file)
|
|
print("upload result", result)
|
|
return result
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"上传数据集失败: {str(e)}"
|
|
)
|
|
|
|
@router.get("/datasets/raw")
|
|
async def get_raw_datasets():
|
|
"""获取待处理数据集列表"""
|
|
try:
|
|
datasets = data_manager.get_raw_datasets()
|
|
return {
|
|
"status": "success",
|
|
"datasets": datasets
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"获取待处理数据集列表失败: {str(e)}"
|
|
)
|