MLPlatform/api/data_api.py
2025-02-26 10:38:11 +08:00

130 lines
4.0 KiB
Python

from fastapi import APIRouter, HTTPException, Query, UploadFile, File
from typing import Optional, List, Dict
from function.data_manager import DataManager
from pydantic import BaseModel
router = APIRouter()
data_manager = DataManager()
# 数据模型
class ProcessRequest(BaseModel):
input_path: str
output_dir: str
process_methods: List[Dict]
feature_methods: List[Dict]
split_params: Dict[str, float]
class CSVRequest(BaseModel):
data_path: str
head: int = 5
tail: int = 5
info: bool = True
describe: bool = True
@router.get("/preprocessing/methods")
async def get_preprocessing_methods():
"""获取数据预处理方法列表"""
result = data_manager.get_preprocessing_methods()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/preprocessing/method/{method_name}")
async def get_preprocessing_method_details(method_name: str):
"""获取预处理方法详情"""
result = data_manager.get_preprocessing_method_details(method_name)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/feature/methods")
async def get_feature_methods():
"""获取特征工程方法列表"""
result = data_manager.get_feature_engineering_methods()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/feature/method/{method_name}")
async def get_feature_method_details(method_name: str):
"""获取特征工程方法详情"""
result = data_manager.get_feature_engineering_method_details(method_name)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.post("/process")
async def process_dataset(request: ProcessRequest):
"""处理数据集"""
print(f"处理数据集: {request.process_methods}")
result = data_manager.process_dataset(
input_path=request.input_path,
output_dir=request.output_dir,
process_methods=request.process_methods,
feature_methods=request.feature_methods,
split_params=request.split_params
)
print("result", result)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/datasets")
async def get_datasets():
"""获取可用数据集列表"""
try:
datasets = data_manager.get_dataset()
# print("datasets", datasets)
return {
"status": "success",
"datasets": datasets
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取数据集列表失败: {str(e)}"
)
@router.post("/csv")
async def read_csv(request: CSVRequest):
"""读取CSV文件并展示"""
result = data_manager.read_csv(
data_path=request.data_path,
head=request.head,
tail=request.tail,
info=request.info,
describe=request.describe
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.post("/upload")
async def upload_dataset(file: UploadFile = File(...)):
"""上传数据集"""
try:
result = data_manager.save_dataset(file)
print("upload result", result)
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"上传数据集失败: {str(e)}"
)
@router.get("/datasets/raw")
async def get_raw_datasets():
"""获取待处理数据集列表"""
try:
datasets = data_manager.get_raw_datasets()
return {
"status": "success",
"datasets": datasets
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取待处理数据集列表失败: {str(e)}"
)