修改--修改fastapi中的方法
This commit is contained in:
parent
e81b2e96d2
commit
d13ce6c21a
Binary file not shown.
@ -6,127 +6,127 @@ manager = ModelManager()
|
||||
|
||||
|
||||
|
||||
# 获取所有预处理方法
|
||||
print("--------------------------------------------获取预处理方法---------------------------------------------------")
|
||||
methods = manager.get_models()
|
||||
print("模型列表:")
|
||||
print(methods)
|
||||
print("--------------------------------------------获取预处理方法 end ---------------------------------------------------")
|
||||
# # 获取所有预处理方法
|
||||
# print("--------------------------------------------获取预处理方法---------------------------------------------------")
|
||||
# methods = manager.get_models()
|
||||
# print("模型列表:")
|
||||
# print(methods)
|
||||
# print("--------------------------------------------获取预处理方法 end ---------------------------------------------------")
|
||||
|
||||
|
||||
print("--------------------------------------------获取方法详细信息---------------------------------------------------")
|
||||
# 获取特定方法的详细信息
|
||||
method_details = manager.get_model_details('SVC')
|
||||
print("\nSVC方法详情:")
|
||||
method_details = manager.get_model_details('LinearRegression')
|
||||
print("\nLinearRegression方法详情:")
|
||||
print(method_details)
|
||||
print("--------------------------------------------获取方法详细信息 end ---------------------------------------------------")
|
||||
|
||||
|
||||
|
||||
print("--------------------------------------------评价指标 ---------------------------------------------------")
|
||||
# 获取评价指标
|
||||
print(manager.get_metrics())
|
||||
print("--------------------------------------------评价指标 end ---------------------------------------------------")
|
||||
# print("--------------------------------------------评价指标 ---------------------------------------------------")
|
||||
# # 获取评价指标
|
||||
# print(manager.get_metrics())
|
||||
# print("--------------------------------------------评价指标 end ---------------------------------------------------")
|
||||
|
||||
|
||||
print("--------------------------------------------获取所有已训练模型 ---------------------------------------------------")
|
||||
# 获取所有已训练模型
|
||||
result = manager.get_finished_models(
|
||||
page=1,
|
||||
page_size=10,
|
||||
experiment_name='breast_cancer_classification_3'
|
||||
)
|
||||
# print("--------------------------------------------获取所有已训练模型 ---------------------------------------------------")
|
||||
# # 获取所有已训练模型
|
||||
# result = manager.get_finished_models(
|
||||
# page=1,
|
||||
# page_size=10,
|
||||
# experiment_name='breast_cancer_classification_3'
|
||||
# )
|
||||
|
||||
# 打印结果
|
||||
print("\n已训练模型列表:")
|
||||
print(f"状态: {result['status']}")
|
||||
if result['status'] == 'success':
|
||||
print(f"\n总数: {result['total_count']}")
|
||||
print(f"当前页: {result['page']}")
|
||||
print(f"每页数量: {result['page_size']}")
|
||||
print("\n模型列表:")
|
||||
for model in result['models']:
|
||||
'''
|
||||
'run_id': run['run_id'],
|
||||
'experiment_id': run['experiment_id'],
|
||||
'''
|
||||
print(f"run_id", model['run_id'])
|
||||
print(f"experiment_id", model['experiment_id'])
|
||||
print(f"算法: {model['algorithm']}")
|
||||
print(f"任务类型: {model['task_type']}")
|
||||
print(f"数据集: {model['dataset']}")
|
||||
# # 打印结果
|
||||
# print("\n已训练模型列表:")
|
||||
# print(f"状态: {result['status']}")
|
||||
# if result['status'] == 'success':
|
||||
# print(f"\n总数: {result['total_count']}")
|
||||
# print(f"当前页: {result['page']}")
|
||||
# print(f"每页数量: {result['page_size']}")
|
||||
# print("\n模型列表:")
|
||||
# for model in result['models']:
|
||||
# '''
|
||||
# 'run_id': run['run_id'],
|
||||
# 'experiment_id': run['experiment_id'],
|
||||
# '''
|
||||
# print(f"run_id", model['run_id'])
|
||||
# print(f"experiment_id", model['experiment_id'])
|
||||
# print(f"算法: {model['algorithm']}")
|
||||
# print(f"任务类型: {model['task_type']}")
|
||||
# print(f"数据集: {model['dataset']}")
|
||||
|
||||
print(f"训练开始时间: {model['training_start_time']}")
|
||||
print(f"训练结束时间: {model['training_end_time']}")
|
||||
print("模型参数:")
|
||||
for k, v in model['parameters'].items():
|
||||
print(f" {k}: {v}")
|
||||
print("评估指标:")
|
||||
for metric_name, metric_value in model['metrics'].items():
|
||||
print(f" {metric_name}: {metric_value:.4f}")
|
||||
else:
|
||||
print(f"错误信息: {result['message']}")
|
||||
# print(f"训练开始时间: {model['training_start_time']}")
|
||||
# print(f"训练结束时间: {model['training_end_time']}")
|
||||
# print("模型参数:")
|
||||
# for k, v in model['parameters'].items():
|
||||
# print(f" {k}: {v}")
|
||||
# print("评估指标:")
|
||||
# for metric_name, metric_value in model['metrics'].items():
|
||||
# print(f" {metric_name}: {metric_value:.4f}")
|
||||
# else:
|
||||
# print(f"错误信息: {result['message']}")
|
||||
|
||||
print("--------------------------------------------获取所有已训练模型 end ---------------------------------------------------")
|
||||
# print("--------------------------------------------获取所有已训练模型 end ---------------------------------------------------")
|
||||
|
||||
|
||||
print("--------------------------------------------模型训练---------------------------------------------------")
|
||||
# 加载数据
|
||||
train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/train_breast_cancer_20250219_144629.csv')
|
||||
val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/val_breast_cancer_20250219_144629.csv')
|
||||
# print("--------------------------------------------模型训练---------------------------------------------------")
|
||||
# # 加载数据
|
||||
# train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/train_breast_cancer_20250219_144629.csv')
|
||||
# val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/val_breast_cancer_20250219_144629.csv')
|
||||
|
||||
# 准备特征和标签
|
||||
X_train = train_data.drop('target', axis=1)
|
||||
y_train = train_data['target']
|
||||
X_val = val_data.drop('target', axis=1)
|
||||
y_val = val_data['target']
|
||||
# # 准备特征和标签
|
||||
# X_train = train_data.drop('target', axis=1)
|
||||
# y_train = train_data['target']
|
||||
# X_val = val_data.drop('target', axis=1)
|
||||
# y_val = val_data['target']
|
||||
|
||||
# 模型配置
|
||||
model_config = {
|
||||
'algorithm': 'XGBClassifier',
|
||||
'task_type': 'classification',
|
||||
'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629',
|
||||
'params': {
|
||||
'n_estimators': 100,
|
||||
'learning_rate': 0.1,
|
||||
'max_depth': 6,
|
||||
'random_state': 42
|
||||
}
|
||||
}
|
||||
# # 模型配置
|
||||
# model_config = {
|
||||
# 'algorithm': 'XGBClassifier',
|
||||
# 'task_type': 'classification',
|
||||
# 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629',
|
||||
# 'params': {
|
||||
# 'n_estimators': 100,
|
||||
# 'learning_rate': 0.1,
|
||||
# 'max_depth': 6,
|
||||
# 'random_state': 42
|
||||
# }
|
||||
# }
|
||||
|
||||
# 训练模型, 删除训练实验时要删除 mlruns/.trash/ 回收站里的文件
|
||||
# 模型文件 直接在 mlruns/文件夹下
|
||||
for i in range(3, 4):
|
||||
result = manager.train_model(
|
||||
{
|
||||
'features': X_train,
|
||||
'labels': y_train
|
||||
},
|
||||
{
|
||||
'features': X_val,
|
||||
'labels': y_val
|
||||
},
|
||||
model_config,
|
||||
f'breast_cancer_classification_{i}'
|
||||
)
|
||||
# # 训练模型, 删除训练实验时要删除 mlruns/.trash/ 回收站里的文件
|
||||
# # 模型文件 直接在 mlruns/文件夹下
|
||||
# for i in range(3, 4):
|
||||
# result = manager.train_model(
|
||||
# {
|
||||
# 'features': X_train,
|
||||
# 'labels': y_train
|
||||
# },
|
||||
# {
|
||||
# 'features': X_val,
|
||||
# 'labels': y_val
|
||||
# },
|
||||
# model_config,
|
||||
# f'breast_cancer_classification_{i}'
|
||||
# )
|
||||
|
||||
# 打印结果
|
||||
print("\n训练结果:")
|
||||
print(f"状态: {result['status']}")
|
||||
if result['status'] == 'success':
|
||||
print(f"\nMLflow运行ID: {result['run_id']}")
|
||||
print("\n评估指标:")
|
||||
for metric_name, metric_value in result['metrics'].items():
|
||||
print(f"{metric_name}: {metric_value:.4f}")
|
||||
else:
|
||||
print(f"错误信息: {result['message']}")
|
||||
# # 打印结果
|
||||
# print("\n训练结果:")
|
||||
# print(f"状态: {result['status']}")
|
||||
# if result['status'] == 'success':
|
||||
# print(f"\nMLflow运行ID: {result['run_id']}")
|
||||
# print("\n评估指标:")
|
||||
# for metric_name, metric_value in result['metrics'].items():
|
||||
# print(f"{metric_name}: {metric_value:.4f}")
|
||||
# else:
|
||||
# print(f"错误信息: {result['message']}")
|
||||
|
||||
print("-------------------------------------------模型训练 end ---------------------------------------------------")
|
||||
# print("-------------------------------------------模型训练 end ---------------------------------------------------")
|
||||
|
||||
print("--------------------------------------------模型预测 ---------------------------------------------------")
|
||||
print(manager.predict(run_id = "33939ea6d8ce4d43a268f23f7361651e",\
|
||||
data_path="/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_145614/test_breast_cancer_20250219_145614.csv",\
|
||||
output_path="predictions/pred_breast_cancer_20250219_145614.csv" ,\
|
||||
metrics= ["accuracy", "f1", "precision", "recall"] ))
|
||||
# print("--------------------------------------------模型预测 ---------------------------------------------------")
|
||||
# print(manager.predict(run_id = "33939ea6d8ce4d43a268f23f7361651e",\
|
||||
# data_path="/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_145614/test_breast_cancer_20250219_145614.csv",\
|
||||
# output_path="predictions/pred_breast_cancer_20250219_145614.csv" ,\
|
||||
# metrics= ["accuracy", "f1", "precision", "recall"] ))
|
||||
|
||||
print("-------------------------------------------模型预测 end ---------------------------------------------------")
|
||||
# print("-------------------------------------------模型预测 end ---------------------------------------------------")
|
||||
44
main.py
44
main.py
@ -4,6 +4,7 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from typing import Optional, Dict, List
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
from pathlib import Path
|
||||
import logging
|
||||
@ -16,13 +17,35 @@ from api.data_api import router as data_router
|
||||
from api.model_api import router as model_router
|
||||
from api.system_api import router as system_router
|
||||
|
||||
|
||||
# 设置watchfiles 日志级别为warning
|
||||
logging.getLogger("watchfiles").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
# 生命周期管理
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""生命周期管理(替代原来的startup/shutdown事件)"""
|
||||
# 启动时的初始化操作
|
||||
logger.info("Server starting up...")
|
||||
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
|
||||
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
|
||||
Path(".log").mkdir(exist_ok=True)
|
||||
logger.info("Server started successfully")
|
||||
|
||||
yield # 应用运行期间
|
||||
|
||||
# 关闭时的清理操作
|
||||
logger.info("Server shutting down...")
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="机器学习平台API",
|
||||
description="提供数据处理、模型训练和系统监控功能的API服务",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan # 使用新的生命周期管理方式
|
||||
)
|
||||
|
||||
# 加载配置
|
||||
@ -141,30 +164,17 @@ async def health_check():
|
||||
"environment": config.get('environment', 'production')
|
||||
}
|
||||
|
||||
# 启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""服务启动时的初始化操作"""
|
||||
logger.info("Server starting up...")
|
||||
# 创建必要的目录
|
||||
Path("dataset/dataset_raw").mkdir(parents=True, exist_ok=True)
|
||||
Path("dataset/dataset_processed").mkdir(parents=True, exist_ok=True)
|
||||
Path(".log").mkdir(exist_ok=True)
|
||||
logger.info("Server started successfully")
|
||||
|
||||
# 关闭事件
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""服务关闭时的清理操作"""
|
||||
logger.info("Server shutting down...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=config.get('host', '0.0.0.0'),
|
||||
port=config.get('port', 8992),
|
||||
# reload=True 支持热重载
|
||||
reload=config.get('debug', True),
|
||||
workers=config.get('workers', 4),
|
||||
log_level=config.get('log_level', 'info'),
|
||||
log_level=config.get('log_level', 'warning'),
|
||||
access_log=True
|
||||
)
|
||||
@ -159,6 +159,7 @@ classification_algorithms:
|
||||
- "自然语言处理。"
|
||||
|
||||
regression_algorithms:
|
||||
|
||||
LinearRegression:
|
||||
principle: "线性回归通过最小化数据点与回归线之间的误差平方和,来拟合一条最佳的直线。"
|
||||
advantages:
|
||||
@ -310,15 +311,16 @@ clustering_algorithms:
|
||||
- "初始化仍然可能影响最终结果"
|
||||
applicable_scenarios: "适用于K均值聚类方法,并且希望改进初始中心选择的场景。"
|
||||
|
||||
HierarchicalKMeans:
|
||||
principle: "层次化K均值结合了层次聚类和K均值聚类的方法,逐步将样本合并到已有簇中,形成层次化结构。"
|
||||
advantages:
|
||||
- "能够自动确定簇的数量"
|
||||
- "生成的树状图有助于理解数据结构"
|
||||
disadvantages:
|
||||
- "计算量大,尤其是在数据量较大时"
|
||||
- "对噪声和离群点敏感"
|
||||
applicable_scenarios: "适用于不确定簇的数量且数据结构较复杂的情况。"
|
||||
# 没实现这个方法
|
||||
# HierarchicalKMeans:
|
||||
# principle: "层次化K均值结合了层次聚类和K均值聚类的方法,逐步将样本合并到已有簇中,形成层次化结构。"
|
||||
# advantages:
|
||||
# - "能够自动确定簇的数量"
|
||||
# - "生成的树状图有助于理解数据结构"
|
||||
# disadvantages:
|
||||
# - "计算量大,尤其是在数据量较大时"
|
||||
# - "对噪声和离群点敏感"
|
||||
# applicable_scenarios: "适用于不确定簇的数量且数据结构较复杂的情况。"
|
||||
|
||||
FCM:
|
||||
principle: "模糊C均值(FCM)允许每个数据点属于多个簇,基于隶属度来进行聚类。"
|
||||
|
||||
@ -202,23 +202,23 @@ classification_algorithms:
|
||||
|
||||
regression_algorithms:
|
||||
LinearRegression:
|
||||
parameters:
|
||||
- name: "fit_intercept"
|
||||
type: "bool"
|
||||
default: "True"
|
||||
description: "是否计算截距。"
|
||||
- name: "normalize"
|
||||
type: "bool"
|
||||
default: "False"
|
||||
description: "是否对数据进行归一化处理。(已弃用)"
|
||||
- name: "copy_X"
|
||||
type: "bool"
|
||||
default: "True"
|
||||
description: "是否复制输入数据。"
|
||||
- name: "n_jobs"
|
||||
type: "int"
|
||||
default: "None"
|
||||
description: "用于计算的并行作业数。"
|
||||
parameters:
|
||||
- name: "fit_intercept"
|
||||
type: "bool"
|
||||
default: "True"
|
||||
description: "是否计算截距。"
|
||||
- name: "normalize"
|
||||
type: "bool"
|
||||
default: "False"
|
||||
description: "是否对数据进行归一化处理。(已弃用)"
|
||||
- name: "copy_X"
|
||||
type: "bool"
|
||||
default: "True"
|
||||
description: "是否复制输入数据。"
|
||||
- name: "n_jobs"
|
||||
type: "int"
|
||||
default: "None"
|
||||
description: "用于计算的并行作业数。"
|
||||
|
||||
PolynomialRegression:
|
||||
parameters:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user