完成--fastapi框架初步完成

This commit is contained in:
haotian 2025-02-24 11:55:35 +08:00
parent c4f09c9028
commit e81b2e96d2
12 changed files with 419 additions and 53 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,76 @@
from fastapi import APIRouter, HTTPException, Query
from typing import Optional, List, Dict
from function.data_manager import DataManager
from pydantic import BaseModel
router = APIRouter()
data_manager = DataManager()
# 数据模型
class ProcessRequest(BaseModel):
input_file: str
output_path: str
preprocessing: List[Dict]
feature_engineering: List[Dict]
split_ratio: Dict[str, float]
@router.get("/preprocessing/methods")
async def get_preprocessing_methods():
"""获取数据预处理方法列表"""
result = data_manager.get_preprocessing_methods()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/preprocessing/method/{method_name}")
async def get_preprocessing_method_details(method_name: str):
"""获取预处理方法详情"""
result = data_manager.get_preprocessing_method_details(method_name)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/feature/methods")
async def get_feature_methods():
"""获取特征工程方法列表"""
result = data_manager.get_feature_engineering_methods()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/feature/method/{method_name}")
async def get_feature_method_details(method_name: str):
"""获取特征工程方法详情"""
result = data_manager.get_feature_engineering_method_details(method_name)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.post("/process")
async def process_dataset(request: ProcessRequest):
"""处理数据集"""
result = data_manager.process_dataset(
input_path=request.input_file,
output_dir=request.output_path,
process_methods=request.preprocessing,
feature_methods=request.feature_engineering,
split_params=request.split_ratio
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/datasets")
async def get_datasets():
"""获取可用数据集列表"""
try:
datasets = data_manager.get_dataset()
return {
"status": "success",
"datasets": datasets
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取数据集列表失败: {str(e)}"
)

View File

@ -0,0 +1,114 @@
from fastapi import APIRouter, HTTPException, Query, Path
from typing import Optional, List, Dict
from function.model_manager import ModelManager
from pydantic import BaseModel
router = APIRouter()
model_manager = ModelManager()
# 数据模型
class TrainRequest(BaseModel):
model: str
dataset: Dict[str, str]
parameters: Dict
metrics: List[str]
class PredictRequest(BaseModel):
run_id: str
data: str
output_path: str
batch_size: int = 32
device: str = "cuda"
return_proba: bool = True
metrics: Optional[List[str]] = None
@router.get("/available")
async def get_available_models():
"""获取可用模型列表"""
result = model_manager.get_models()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/available/{model_name}")
async def get_model_details(model_name: str):
"""获取模型详情"""
result = model_manager.get_model_details(model_name)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.get("/metrics")
async def get_metrics():
"""获取评价指标列表"""
result = model_manager.get_metrics()
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
@router.post("/train")
async def train_model(request: TrainRequest):
"""模型训练"""
result = model_manager.train_model(
train_data=request.dataset['train'],
val_data=request.dataset.get('val'),
model_config={
'model_name': request.model,
'parameters': request.parameters,
'metrics': request.metrics
}
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/experiments")
async def get_experiments(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
"""获取MLFlow中保存的实验"""
result = model_manager.get_experiments(page=page, page_size=page_size)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/experiment/{experiment_name}")
async def get_finished_models(
experiment_name: str,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
"""获取已训练完成的模型列表"""
result = model_manager.get_finished_models(
experiment_name=experiment_name,
page=page,
page_size=page_size
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.delete("/{run_id}")
async def delete_model(run_id: str = Path(..., description="MLflow运行ID")):
"""删除指定的训练好的模型"""
result = model_manager.delete_model(run_id)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.post("/predict")
async def predict(request: PredictRequest):
"""模型预测"""
result = model_manager.predict(
run_id=request.run_id,
data_path=request.data,
output_path=request.output_path,
batch_size=request.batch_size,
device=request.device,
return_proba=request.return_proba,
metrics=request.metrics
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])
return result

View File

@ -6,7 +6,7 @@ from function.system_monitor import SystemMonitor
router = APIRouter()
system_monitor = SystemMonitor()
@router.get("/system/resources")
@router.get("/resources")
async def get_system_resources():
"""获取系统资源使用情况"""
result = system_monitor.get_system_resources()
@ -14,7 +14,7 @@ async def get_system_resources():
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/system/history")
@router.get("/history")
async def get_training_history(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
@ -36,7 +36,7 @@ async def get_training_history(
raise HTTPException(status_code=500, detail=result['message'])
return result
@router.get("/system/logs")
@router.get("/logs")
async def get_system_logs(
level: Optional[str] = None,
start_time: Optional[str] = None,

63
doc/安装文档.md Normal file
View File

@ -0,0 +1,63 @@
# 运行说明
## 1. 启动mlflow
- mlflow server --host 10.0.0.202 --port 5000
## 2. 运行 main.py
## 3. 测试接口
GET http://10.0.0.202:8992/health
### 3.1 获取数据预处理方法列表
GET http://10.0.0.202:8992/data/preprocessing/methods
### 3.2 获取预处理方法详情
GET http://10.0.0.202:8992/data/preprocessing/method/{method_name}
### 3.3 获取特征工程方法列表
GET http://10.0.0.202:8992/data/feature/methods
### 3.4 获取特征工程方法详情
GET http://10.0.0.202:8992/data/feature/method/{method_name}
### 3.5 处理数据集
POST http://10.0.0.202:8992/data/process
### 3.6 获取可用数据集列表
GET http://10.0.0.202:8992/data/datasets
### 3.7 获取可用模型列表
GET http://10.0.0.202:8992/model/available
### 3.8 获取模型详情
GET http://10.0.0.202:8992/model/available/{model_name}
### 3.9 获取评价指标列表
GET http://10.0.0.202:8992/model/metrics
### 3.10 模型训练
POST
### 3.11 获取实验列表
GET http://10.0.0.202:8992/model/experiments
### 3.12 获取具体实验内容
GET http://10.0.0.202://8992/model/experiment/{experiment_name}
### 3.13 删除模型指定实验模型
DELETE
### 3.14 模型预测
POST
### 3.15 查看系统资源使用情况
GET http://10.0.0.202:8992/system/resources
### 3.16 查看训练历史
GET http://10.0.0.202:8992/system/history
### 3.17 查看系统训练日志
GET http://10.0.0.202:8992/system/logs

View File

@ -3,7 +3,7 @@
## 1. 数据处理模块
### 1.1 获取数据预处理方法列表
```http
GET /api/data/preprocessing/methods
GET /data/preprocessing/methods
Response:
{
@ -33,7 +33,7 @@ Response:
### 1.2 获取预处理方法详情
```http
GET /api/data/preprocessing/methods/{method_name}
GET /data/preprocessing/method/{method_name}
Response:
{
@ -67,7 +67,7 @@ Response:
### 1.3 获取特征工程方法列表
```http
GET /api/data/feature/methods
GET /data/feature/methods
Response:
{
@ -97,7 +97,7 @@ Response:
### 1.4 获取特征工程方法详情
```http
GET /api/data/feature/methods/{method_name}
GET /data/feature/method/{method_name}
Response:
{
@ -128,7 +128,7 @@ Response:
### 1.5 处理数据集
```http
POST /api/data/process
POST /data/process
Content-Type: application/json
Request:
@ -182,7 +182,7 @@ Response:
### 1.6 查看可用数据集
```http
GET /api/data/datasets
GET /data/datasets
Response:
{
@ -275,7 +275,7 @@ Response:
## 2. 模型接口
### 2.1 获取可用模型列表
```http
GET /api/models/available
GET /model/available
Response:
{
@ -299,7 +299,7 @@ Response:
### 2.2 获取模型详情
```http
GET /api/models/{model_name}
GET /model/available/{model_name}
Response:
{
@ -329,7 +329,7 @@ Response:
### 2.3 获取评价指标列表
```http
GET /api/metrics
GET /model/metrics
Response:
{
@ -369,7 +369,7 @@ Response:
### 2.4 模型训练
```http
POST /api/train
POST /model/train
Content-Type: application/json
Request:
@ -396,7 +396,7 @@ Response:
```
### 2.5 获取MLFlow中保存的实验
```http
GET /api/experiments
GET /model/experiments
Response:
{
@ -424,7 +424,7 @@ Response:
### 2.6 获取已经训练好的模型列表
```http
GET /api/models/finished/{experiment_name}
GET /model/experiment/{experiment_name}
Response:
{
@ -458,7 +458,7 @@ Response:
```
### 2.7 删除指定的训练好的模型
```http
DELETE /api/models/{run_id}
DELETE /model/{run_id}
Response:
{
@ -478,7 +478,7 @@ Response:
```
### 2.8 模型预测
```http
POST /api/model/predict
POST /model/predict
Content-Type: application/json
Request:
@ -526,10 +526,11 @@ Error Response:
### 2.9 模型优化 -- 未实现
## 3. 系统监控
### 3.1 获取资源使用情况
```http
GET /api/system/resources
GET /system/resources
Response:
{
@ -623,7 +624,7 @@ Error Response:
### 3.2 获取训练历史
```http
GET /api/system/history?page=1&page_size=10&start_time=2025-02-01&end_time=2025-02-19&status=completed&experiment_name=breast_cancer_classification
GET /system/history?page=1&page_size=10&start_time=2025-02-01&end_time=2025-02-19&status=completed&experiment_name=breast_cancer_classification
Parameters:
- page: 页码 (默认: 1)
@ -709,7 +710,7 @@ Response:
### 3.4 获取系统日志
```http
GET /api/system/logs?level=error&start_time=2025-02-19T00:00:00&end_time=2025-02-19T23:59:59&module=training&page=1&page_size=20
GET /system/logs?level=error&start_time=2025-02-19T00:00:00&end_time=2025-02-19T23:59:59&module=training&page=1&page_size=20
Parameters:
- level: 日志级别过滤 (可选: debug, info, warning, error, critical)
@ -778,9 +779,8 @@ MLPlatform/
│ ├── model_api.py # 模型相关接口
│ └── system_api.py # 系统监控相关接口
├── function/ # 功能实现层
│ ├── data_processor.py # 数据处理类
│ ├── data_manager.py # 数据处理类
│ ├── model_manager.py # 模型管理类
│ ├── model_trainer.py # 模型训练类
│ ├── system_monitor.py # 系统监控类
│ └── utils/ # 工具函数
├── config/ # 配置文件

View File

@ -597,6 +597,14 @@ class DataManager:
"error": str(e)
}
def _clean_json_line(self,line):
# 替换掉不符合JSON标准的特殊浮点数值
line = line.replace('NaN', 'null')
line = line.replace('Infinity', '1e308') # 或者选择一个合适的替代值
line = line.replace('-Infinity', '-1e308') # 同上
return line
def get_dataset(self):
back = list()
@ -611,8 +619,15 @@ class DataManager:
json_files = list(folder_path.glob('*.json'))
for json_file in json_files:
# json是不能处理NaN这类的数值,需要单独处理他们
with open(json_file.as_posix(), 'r', encoding='utf-8') as f:
json_data = json.load(f)
cleaned_lines = [self._clean_json_line(line) for line in f]
json_data = json.loads(''.join(cleaned_lines))
# json_data = json.load(f ,allow_nan=True)
back.append(json_data)
# print("可用数据集", back)
return back

158
main.py
View File

@ -1,12 +1,15 @@
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from typing import Optional, Dict, List
import uvicorn
from pathlib import Path
import logging
import yaml
from datetime import datetime
import time
# 导入API路由
from api.data_api import router as data_router
@ -17,30 +20,11 @@ from api.system_api import router as system_router
app = FastAPI(
title="机器学习平台API",
description="提供数据处理、模型训练和系统监控功能的API服务",
version="1.0.0"
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'.log/server_{datetime.now():%Y%m%d_%H%M%S}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 加载配置
def load_config():
try:
@ -50,23 +34,137 @@ def load_config():
logger.error(f"Error loading config: {str(e)}")
return {}
# 初始化配置
config = load_config()
# 设置日志
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / f'server_{datetime.now():%Y%m%d_%H%M%S}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=config.get('cors', {}).get('origins', ["*"]),
allow_credentials=True,
allow_methods=config.get('cors', {}).get('methods', ["*"]),
allow_headers=config.get('cors', {}).get('headers', ["*"]),
max_age=config.get('cors', {}).get('max_age', 600),
)
# 请求计时中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 全局错误处理
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求参数验证错误"""
return JSONResponse(
status_code=422,
content={
"status": "error",
"message": "请求参数验证失败",
"details": exc.errors()
}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理HTTP异常"""
return JSONResponse(
status_code=exc.status_code,
content={
"status": "error",
"message": exc.detail
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理其他异常"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": "服务器内部错误",
"details": str(exc) if config.get('debug', False) else None
}
)
# 注册路由
app.include_router(data_router, prefix="/api", tags=["数据处理"])
app.include_router(model_router, prefix="/api", tags=["模型管理"])
app.include_router(system_router, prefix="/api", tags=["系统监控"])
app.include_router(
data_router,
prefix="/data",
tags=["数据处理"],
responses={404: {"description": "Not found"}},
)
app.include_router(
model_router,
prefix="/model",
tags=["模型管理"],
responses={404: {"description": "Not found"}},
)
app.include_router(
system_router,
prefix="/system",
tags=["系统监控"],
responses={404: {"description": "Not found"}},
)
# 健康检查
@app.get("/health")
async def health_check():
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
"""系统健康检查"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"version": app.version,
"environment": config.get('environment', 'production')
}
# 启动事件
@app.on_event("startup")
async def startup_event():
"""服务启动时的初始化操作"""
logger.info("Server starting up...")
# 创建必要的目录
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
Path(".log").mkdir(exist_ok=True)
logger.info("Server started successfully")
# 关闭事件
@app.on_event("shutdown")
async def shutdown_event():
"""服务关闭时的清理操作"""
logger.info("Server shutting down...")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=config.get('host', '0.0.0.0'),
port=config.get('port', 8000),
reload=True,
workers=config.get('workers', 4)
port=config.get('port', 8992),
reload=config.get('debug', True),
workers=config.get('workers', 4),
log_level=config.get('log_level', 'info'),
access_log=True
)