diff --git a/__pycache__/main.cpython-39.pyc b/__pycache__/main.cpython-39.pyc new file mode 100644 index 0000000..afaf6e2 Binary files /dev/null and b/__pycache__/main.cpython-39.pyc differ diff --git a/api/__pycache__/data_api.cpython-39.pyc b/api/__pycache__/data_api.cpython-39.pyc new file mode 100644 index 0000000..88acd5a Binary files /dev/null and b/api/__pycache__/data_api.cpython-39.pyc differ diff --git a/api/__pycache__/model_api.cpython-39.pyc b/api/__pycache__/model_api.cpython-39.pyc new file mode 100644 index 0000000..5186e8a Binary files /dev/null and b/api/__pycache__/model_api.cpython-39.pyc differ diff --git a/api/__pycache__/system_api.cpython-39.pyc b/api/__pycache__/system_api.cpython-39.pyc new file mode 100644 index 0000000..df15f65 Binary files /dev/null and b/api/__pycache__/system_api.cpython-39.pyc differ diff --git a/api/data_api.py b/api/data_api.py index e69de29..fd334e1 100644 --- a/api/data_api.py +++ b/api/data_api.py @@ -0,0 +1,76 @@ +from fastapi import APIRouter, HTTPException, Query +from typing import Optional, List, Dict +from function.data_manager import DataManager +from pydantic import BaseModel + +router = APIRouter() +data_manager = DataManager() + +# 数据模型 +class ProcessRequest(BaseModel): + input_file: str + output_path: str + preprocessing: List[Dict] + feature_engineering: List[Dict] + split_ratio: Dict[str, float] + +@router.get("/preprocessing/methods") +async def get_preprocessing_methods(): + """获取数据预处理方法列表""" + result = data_manager.get_preprocessing_methods() + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.get("/preprocessing/method/{method_name}") +async def get_preprocessing_method_details(method_name: str): + """获取预处理方法详情""" + result = data_manager.get_preprocessing_method_details(method_name) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.get("/feature/methods") +async def get_feature_methods(): + """获取特征工程方法列表""" + result = data_manager.get_feature_engineering_methods() + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.get("/feature/method/{method_name}") +async def get_feature_method_details(method_name: str): + """获取特征工程方法详情""" + result = data_manager.get_feature_engineering_method_details(method_name) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.post("/process") +async def process_dataset(request: ProcessRequest): + """处理数据集""" + result = data_manager.process_dataset( + input_path=request.input_file, + output_dir=request.output_path, + process_methods=request.preprocessing, + feature_methods=request.feature_engineering, + split_params=request.split_ratio + ) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['message']) + return result + +@router.get("/datasets") +async def get_datasets(): + """获取可用数据集列表""" + try: + datasets = data_manager.get_dataset() + return { + "status": "success", + "datasets": datasets + } + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"获取数据集列表失败: {str(e)}" + ) diff --git a/api/model_api.py b/api/model_api.py index e69de29..1fac347 100644 --- a/api/model_api.py +++ b/api/model_api.py @@ -0,0 +1,114 @@ +from fastapi import APIRouter, HTTPException, Query, Path +from typing import Optional, List, Dict +from function.model_manager import ModelManager +from pydantic import BaseModel + +router = APIRouter() +model_manager = ModelManager() + +# 数据模型 +class TrainRequest(BaseModel): + model: str + dataset: Dict[str, str] + parameters: Dict + metrics: List[str] + +class PredictRequest(BaseModel): + run_id: str + data: str + output_path: str + batch_size: int = 32 + device: str = "cuda" + return_proba: bool = True + metrics: Optional[List[str]] = None + +@router.get("/available") +async def get_available_models(): + """获取可用模型列表""" + result = model_manager.get_models() + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.get("/available/{model_name}") +async def get_model_details(model_name: str): + """获取模型详情""" + result = model_manager.get_model_details(model_name) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.get("/metrics") +async def get_metrics(): + """获取评价指标列表""" + result = model_manager.get_metrics() + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['error']) + return result + +@router.post("/train") +async def train_model(request: TrainRequest): + """模型训练""" + result = model_manager.train_model( + train_data=request.dataset['train'], + val_data=request.dataset.get('val'), + model_config={ + 'model_name': request.model, + 'parameters': request.parameters, + 'metrics': request.metrics + } + ) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['message']) + return result + +@router.get("/experiments") +async def get_experiments( + page: int = Query(1, ge=1), + page_size: int = Query(10, ge=1, le=100) +): + """获取MLFlow中保存的实验""" + result = model_manager.get_experiments(page=page, page_size=page_size) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['message']) + return result + +@router.get("/experiment/{experiment_name}") +async def get_finished_models( + experiment_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(10, ge=1, le=100) +): + """获取已训练完成的模型列表""" + result = model_manager.get_finished_models( + experiment_name=experiment_name, + page=page, + page_size=page_size + ) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['message']) + return result + +@router.delete("/{run_id}") +async def delete_model(run_id: str = Path(..., description="MLflow运行ID")): + """删除指定的训练好的模型""" + result = model_manager.delete_model(run_id) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['message']) + return result + +@router.post("/predict") +async def predict(request: PredictRequest): + """模型预测""" + result = model_manager.predict( + run_id=request.run_id, + data_path=request.data, + output_path=request.output_path, + batch_size=request.batch_size, + device=request.device, + return_proba=request.return_proba, + metrics=request.metrics + ) + if result['status'] == 'error': + raise HTTPException(status_code=500, detail=result['message']) + return result diff --git a/api/system_api.py b/api/system_api.py index 340b0c5..4458f9a 100644 --- a/api/system_api.py +++ b/api/system_api.py @@ -6,7 +6,7 @@ from function.system_monitor import SystemMonitor router = APIRouter() system_monitor = SystemMonitor() -@router.get("/system/resources") +@router.get("/resources") async def get_system_resources(): """获取系统资源使用情况""" result = system_monitor.get_system_resources() @@ -14,7 +14,7 @@ async def get_system_resources(): raise HTTPException(status_code=500, detail=result['message']) return result -@router.get("/system/history") +@router.get("/history") async def get_training_history( page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100), @@ -36,7 +36,7 @@ async def get_training_history( raise HTTPException(status_code=500, detail=result['message']) return result -@router.get("/system/logs") +@router.get("/logs") async def get_system_logs( level: Optional[str] = None, start_time: Optional[str] = None, diff --git a/doc/安装文档.md b/doc/安装文档.md new file mode 100644 index 0000000..6d66e13 --- /dev/null +++ b/doc/安装文档.md @@ -0,0 +1,63 @@ +# 运行说明 + +## 1. 启动mlflow + - mlflow server --host 10.0.0.202 --port 5000 +## 2. 运行 main.py + + +## 3. 测试接口 + +GET http://10.0.0.202:8992/health + +### 3.1 获取数据预处理方法列表 +GET http://10.0.0.202:8992/data/preprocessing/methods + +### 3.2 获取预处理方法详情 + +GET http://10.0.0.202:8992/data/preprocessing/method/{method_name} + +### 3.3 获取特征工程方法列表 +GET http://10.0.0.202:8992/data/feature/methods + +### 3.4 获取特征工程方法详情 + +GET http://10.0.0.202:8992/data/feature/method/{method_name} + +### 3.5 处理数据集 +POST http://10.0.0.202:8992/data/process + +### 3.6 获取可用数据集列表 +GET http://10.0.0.202:8992/data/datasets + +### 3.7 获取可用模型列表 +GET http://10.0.0.202:8992/model/available + +### 3.8 获取模型详情 +GET http://10.0.0.202:8992/model/available/{model_name} + +### 3.9 获取评价指标列表 +GET http://10.0.0.202:8992/model/metrics + +### 3.10 模型训练 +POST + +### 3.11 获取实验列表 +GET http://10.0.0.202:8992/model/experiments + +### 3.12 获取具体实验内容 +GET http://10.0.0.202://8992/model/experiment/{experiment_name} + +### 3.13 删除模型指定实验模型 +DELETE + +### 3.14 模型预测 +POST + +### 3.15 查看系统资源使用情况 +GET http://10.0.0.202:8992/system/resources + +### 3.16 查看训练历史 +GET http://10.0.0.202:8992/system/history + +### 3.17 查看系统训练日志 +GET http://10.0.0.202:8992/system/logs \ No newline at end of file diff --git a/doc/接口文档code.md b/doc/接口文档code.md index 6124707..362a38e 100644 --- a/doc/接口文档code.md +++ b/doc/接口文档code.md @@ -3,7 +3,7 @@ ## 1. 数据处理模块 ### 1.1 获取数据预处理方法列表 ```http -GET /api/data/preprocessing/methods +GET /data/preprocessing/methods Response: { @@ -33,7 +33,7 @@ Response: ### 1.2 获取预处理方法详情 ```http -GET /api/data/preprocessing/methods/{method_name} +GET /data/preprocessing/method/{method_name} Response: { @@ -67,7 +67,7 @@ Response: ### 1.3 获取特征工程方法列表 ```http -GET /api/data/feature/methods +GET /data/feature/methods Response: { @@ -97,7 +97,7 @@ Response: ### 1.4 获取特征工程方法详情 ```http -GET /api/data/feature/methods/{method_name} +GET /data/feature/method/{method_name} Response: { @@ -128,7 +128,7 @@ Response: ### 1.5 处理数据集 ```http -POST /api/data/process +POST /data/process Content-Type: application/json Request: @@ -182,7 +182,7 @@ Response: ### 1.6 查看可用数据集 ```http -GET /api/data/datasets +GET /data/datasets Response: { @@ -275,7 +275,7 @@ Response: ## 2. 模型接口 ### 2.1 获取可用模型列表 ```http -GET /api/models/available +GET /model/available Response: { @@ -299,7 +299,7 @@ Response: ### 2.2 获取模型详情 ```http -GET /api/models/{model_name} +GET /model/available/{model_name} Response: { @@ -329,7 +329,7 @@ Response: ### 2.3 获取评价指标列表 ```http -GET /api/metrics +GET /model/metrics Response: { @@ -369,7 +369,7 @@ Response: ### 2.4 模型训练 ```http -POST /api/train +POST /model/train Content-Type: application/json Request: @@ -396,7 +396,7 @@ Response: ``` ### 2.5 获取MLFlow中保存的实验 ```http -GET /api/experiments +GET /model/experiments Response: { @@ -424,7 +424,7 @@ Response: ### 2.6 获取已经训练好的模型列表 ```http -GET /api/models/finished/{experiment_name} +GET /model/experiment/{experiment_name} Response: { @@ -458,7 +458,7 @@ Response: ``` ### 2.7 删除指定的训练好的模型 ```http -DELETE /api/models/{run_id} +DELETE /model/{run_id} Response: { @@ -478,7 +478,7 @@ Response: ``` ### 2.8 模型预测 ```http -POST /api/model/predict +POST /model/predict Content-Type: application/json Request: @@ -526,10 +526,11 @@ Error Response: ### 2.9 模型优化 -- 未实现 + ## 3. 系统监控 ### 3.1 获取资源使用情况 ```http -GET /api/system/resources +GET /system/resources Response: { @@ -623,7 +624,7 @@ Error Response: ### 3.2 获取训练历史 ```http -GET /api/system/history?page=1&page_size=10&start_time=2025-02-01&end_time=2025-02-19&status=completed&experiment_name=breast_cancer_classification +GET /system/history?page=1&page_size=10&start_time=2025-02-01&end_time=2025-02-19&status=completed&experiment_name=breast_cancer_classification Parameters: - page: 页码 (默认: 1) @@ -709,7 +710,7 @@ Response: ### 3.4 获取系统日志 ```http -GET /api/system/logs?level=error&start_time=2025-02-19T00:00:00&end_time=2025-02-19T23:59:59&module=training&page=1&page_size=20 +GET /system/logs?level=error&start_time=2025-02-19T00:00:00&end_time=2025-02-19T23:59:59&module=training&page=1&page_size=20 Parameters: - level: 日志级别过滤 (可选: debug, info, warning, error, critical) @@ -778,9 +779,8 @@ MLPlatform/ │ ├── model_api.py # 模型相关接口 │ └── system_api.py # 系统监控相关接口 ├── function/ # 功能实现层 -│ ├── data_processor.py # 数据处理类 +│ ├── data_manager.py # 数据处理类 │ ├── model_manager.py # 模型管理类 -│ ├── model_trainer.py # 模型训练类 │ ├── system_monitor.py # 系统监控类 │ └── utils/ # 工具函数 ├── config/ # 配置文件 diff --git a/function/__pycache__/data_manager.cpython-39.pyc b/function/__pycache__/data_manager.cpython-39.pyc index c8343e2..6bbf288 100644 Binary files a/function/__pycache__/data_manager.cpython-39.pyc and b/function/__pycache__/data_manager.cpython-39.pyc differ diff --git a/function/data_manager.py b/function/data_manager.py index 4fd29b9..c858d08 100644 --- a/function/data_manager.py +++ b/function/data_manager.py @@ -597,6 +597,14 @@ class DataManager: "error": str(e) } + + def _clean_json_line(self,line): + # 替换掉不符合JSON标准的特殊浮点数值 + line = line.replace('NaN', 'null') + line = line.replace('Infinity', '1e308') # 或者选择一个合适的替代值 + line = line.replace('-Infinity', '-1e308') # 同上 + return line + def get_dataset(self): back = list() @@ -611,8 +619,15 @@ class DataManager: json_files = list(folder_path.glob('*.json')) for json_file in json_files: + + + # json是不能处理NaN这类的数值,需要单独处理他们 with open(json_file.as_posix(), 'r', encoding='utf-8') as f: - json_data = json.load(f) + cleaned_lines = [self._clean_json_line(line) for line in f] + json_data = json.loads(''.join(cleaned_lines)) + # json_data = json.load(f ,allow_nan=True) back.append(json_data) + + # print("可用数据集", back) return back \ No newline at end of file diff --git a/main.py b/main.py index 753e676..21c8136 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,15 @@ -from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks +from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError from typing import Optional, Dict, List import uvicorn from pathlib import Path import logging import yaml from datetime import datetime +import time # 导入API路由 from api.data_api import router as data_router @@ -17,30 +20,11 @@ from api.system_api import router as system_router app = FastAPI( title="机器学习平台API", description="提供数据处理、模型训练和系统监控功能的API服务", - version="1.0.0" + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc" ) -# 配置CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# 设置日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(f'.log/server_{datetime.now():%Y%m%d_%H%M%S}.log'), - logging.StreamHandler() - ] -) - -logger = logging.getLogger(__name__) - # 加载配置 def load_config(): try: @@ -50,23 +34,137 @@ def load_config(): logger.error(f"Error loading config: {str(e)}") return {} +# 初始化配置 config = load_config() +# 设置日志 +log_dir = Path('.log') +log_dir.mkdir(exist_ok=True) + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'), + logging.StreamHandler() + ] +) + +logger = logging.getLogger(__name__) + +# 配置CORS +app.add_middleware( + CORSMiddleware, + allow_origins=config.get('cors', {}).get('origins', ["*"]), + allow_credentials=True, + allow_methods=config.get('cors', {}).get('methods', ["*"]), + allow_headers=config.get('cors', {}).get('headers', ["*"]), + max_age=config.get('cors', {}).get('max_age', 600), +) + +# 请求计时中间件 +@app.middleware("http") +async def add_process_time_header(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + +# 全局错误处理 +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """处理请求参数验证错误""" + return JSONResponse( + status_code=422, + content={ + "status": "error", + "message": "请求参数验证失败", + "details": exc.errors() + } + ) + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + """处理HTTP异常""" + return JSONResponse( + status_code=exc.status_code, + content={ + "status": "error", + "message": exc.detail + } + ) + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + """处理其他异常""" + logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "status": "error", + "message": "服务器内部错误", + "details": str(exc) if config.get('debug', False) else None + } + ) + # 注册路由 -app.include_router(data_router, prefix="/api", tags=["数据处理"]) -app.include_router(model_router, prefix="/api", tags=["模型管理"]) -app.include_router(system_router, prefix="/api", tags=["系统监控"]) +app.include_router( + data_router, + prefix="/data", + tags=["数据处理"], + responses={404: {"description": "Not found"}}, +) + +app.include_router( + model_router, + prefix="/model", + tags=["模型管理"], + responses={404: {"description": "Not found"}}, +) + +app.include_router( + system_router, + prefix="/system", + tags=["系统监控"], + responses={404: {"description": "Not found"}}, +) # 健康检查 @app.get("/health") async def health_check(): - return {"status": "healthy", "timestamp": datetime.now().isoformat()} + """系统健康检查""" + return { + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "version": app.version, + "environment": config.get('environment', 'production') + } + +# 启动事件 +@app.on_event("startup") +async def startup_event(): + """服务启动时的初始化操作""" + logger.info("Server starting up...") + # 创建必要的目录 + Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True) + Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True) + Path(".log").mkdir(exist_ok=True) + logger.info("Server started successfully") + +# 关闭事件 +@app.on_event("shutdown") +async def shutdown_event(): + """服务关闭时的清理操作""" + logger.info("Server shutting down...") if __name__ == "__main__": uvicorn.run( "main:app", host=config.get('host', '0.0.0.0'), - port=config.get('port', 8000), - reload=True, - workers=config.get('workers', 4) + port=config.get('port', 8992), + reload=config.get('debug', True), + workers=config.get('workers', 4), + log_level=config.get('log_level', 'info'), + access_log=True ) \ No newline at end of file