修改--将模型管理所有的方法都集成到一个文件中

This commit is contained in:
haotian 2025-02-21 11:34:54 +08:00
parent d222cfe968
commit c2b2e20a4a
7 changed files with 547 additions and 516 deletions

View File

@ -1,2 +1,37 @@
mlfow: # 服务器配置
uri: "http://localhost:5000" host: "10.0.0.202"
port: 8992
workers: 4
debug: true
# MLflow配置
mlflow_uri: "http://10.0.0.202:5000"
# 数据处理配置
dataset:
raw_dir: "dataset/dataset_raw"
processed_dir: "dataset/dataset_processed"
# 模型配置
model:
save_dir: "models"
batch_size: 32
num_workers: 4
# 系统监控配置
monitor:
log_dir: ".log"
resource_check_interval: 60 # 秒
cleanup_interval: 86400 # 24小时
max_log_days: 30 # 日志保留天数
# 安全配置
security:
secret_key: "your-secret-key"
token_expire_minutes: 1440 # 24小时
# 性能配置
performance:
max_concurrent_trains: 4 # 最大并发训练数
cache_size: 1024 # MB
timeout: 3600 # 秒

View File

@ -523,6 +523,9 @@ Error Response:
} }
``` ```
### 2.9 模型优化 -- 未实现
## 3. 系统监控 ## 3. 系统监控
### 3.1 获取资源使用情况 ### 3.1 获取资源使用情况
```http ```http
@ -681,7 +684,7 @@ Error Response:
} }
``` ```
### 3.3 获取系统中训练状态 ### 3.3 获取系统中训练状态 ---- 未完成, 等开发系统后台时再实现.
```http ```http
GET /api/train/status/{task_id} GET /api/train/status/{task_id}
@ -764,6 +767,76 @@ Error Response:
} }
``` ```
## 4. 系统后台整体实现
### 4.1 系统架构
```
MLPlatform/
├── api/ # API接口层
│ ├── __init__.py
│ ├── data_api.py # 数据处理相关接口
│ ├── model_api.py # 模型相关接口
│ └── system_api.py # 系统监控相关接口
├── function/ # 功能实现层
│ ├── data_processor.py # 数据处理类
│ ├── model_manager.py # 模型管理类
│ ├── model_trainer.py # 模型训练类
│ ├── system_monitor.py # 系统监控类
│ └── utils/ # 工具函数
├── config/ # 配置文件
│ └── config.yaml # 系统配置
├── dataset/ # 数据集
│ ├── dataset_raw/ # 原始数据
│ └── dataset_processed/ # 处理后数据
├── .log/ # 日志文件
├── doc/ # 文档
└── main.py # 主程序入口
```
### 4.2 技术栈
- FastAPI: Web框架
- MLflow: 模型管理和实验跟踪
- PyTorch/Scikit-learn: 机器学习框架
- Pydantic: 数据验证
- Uvicorn: ASGI服务器
### 4.3 主要功能
1. 异步任务处理
- 支持多个模型同时训练
- 后台任务状态监控
- 任务队列管理
2. 实时监控
- 系统资源监控
- 训练进度监控
- 日志实时查看
3. 错误处理
- 全局异常处理
- 错误日志记录
- 优雅降级策略
4. 安全性
- API认证授权
- 请求限流
- 参数验证
### 4.4 性能优化
1. 数据处理
- 数据流式处理
- 缓存机制
- 批量处理
2. 模型训练
- GPU利用优化
- 分布式训练支持
- 模型检查点
3. 系统监控
- 性能指标采集
- 资源使用预警
- 自动清理机制
## 附录A方法详细说明 ## 附录A方法详细说明
### A1. 数据预处理方法 ### A1. 数据预处理方法

View File

@ -1,79 +0,0 @@
import yaml
from typing import Dict, List
import os
import logging
from pathlib import Path
class MethodReader:
"""方法配置读取器"""
def __init__(self):
"""初始化方法读取器"""
self.logger = logging.getLogger(__name__)
self.method_config = self._load_metrics()
def _load_metrics(self) -> Dict:
"""加载方法配置文件"""
try:
config_path = Path('model/metrics.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Method config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded method config")
return config
except Exception as e:
self.logger.error(f"Error loading method config: {str(e)}")
raise
def get_metrics(self) -> Dict:
"""获取预处理方法列表"""
try:
metrics = []
# 分类方法
classification_metrics = self.method_config.get('classification', {})
if classification_metrics:
metrics.append({
"name": "classification_metrics",
"description": "分类方法评价指标",
"metric": classification_metrics
})
# 回归方法
regression_metrics = self.method_config.get('regression', {})
if regression_metrics:
metrics.append({
"name": "regression_metrics",
"description": "回归方法评价指标",
"metric": regression_metrics
})
# 聚类方法
clustering_metrics = self.method_config.get('clustering', {})
if clustering_metrics:
metrics.append({
"name": "clustering_metrics",
"description": "聚类方法评价指标",
"metric": clustering_metrics
})
return {
"status": "success",
"metric": metrics
}
except Exception as e:
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
return {
"status": "error",
"error": str(e)
}

View File

@ -1,135 +0,0 @@
import yaml
from typing import Dict, List
import os
import logging
from pathlib import Path
class MethodReader:
"""方法配置读取器"""
def __init__(self):
"""初始化方法读取器"""
self.logger = logging.getLogger(__name__)
self.method_config = self._load_model_config()
self.parameter_config = self._load_parameter_config()
def _load_model_config(self) -> Dict:
"""加载方法配置文件"""
try:
config_path = Path('model/model.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Method config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded method config")
return config
except Exception as e:
self.logger.error(f"Error loading method config: {str(e)}")
raise
def _load_parameter_config(self) -> Dict:
"""加载参数配置文件"""
try:
config_path = Path('model/parameter.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded parameter config")
return config
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def get_models(self) -> Dict:
"""获取预处理方法列表"""
try:
models = []
# 分类方法
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
if classification_algorithms:
models.append({
"name": "classification_algorithms",
"description": "分类方法",
"method": classification_algorithms
})
# 回归方法
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
if regression_algorithms:
models.append({
"name": "regression_algorithms",
"description": "回归方法",
"method": regression_algorithms
})
# 聚类方法
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
if clustering_algorithms:
models.append({
"name": "clustering_algorithms",
"description": "聚类方法",
"method": clustering_algorithms
})
return {
"status": "success",
"models": models
}
except Exception as e:
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def get_model_details(self, method_name: str) -> Dict:
"""获取指定方法的详细信息"""
try:
# 在各个方法类别中查找方法原理和优缺点
method_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.method_config.get(category, {}):
method_info = self.method_config[category][method_name]
break
if method_info is None:
raise ValueError(f"Method {method_name} not found in method config")
# 查找方法参数信息
parameter_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.parameter_config.get(category, {}):
parameter_info = self.parameter_config[category][method_name]
break
if parameter_info is None:
raise ValueError(f"Method {method_name} not found in parameter config")
# 组合返回信息
return {
"status": "success",
"method": {
"name": method_name,
"description": parameter_info.get('description', ''),
"principle": method_info.get('principle', ''),
"advantages": method_info.get('advantages', []),
"disadvantages": method_info.get('disadvantages', []),
"applicable_scenarios": method_info.get('applicable_scenarios', []),
"parameters": parameter_info.get('parameters', [])
}
}
except Exception as e:
self.logger.error(f"Error getting method details: {str(e)}")
return {
"status": "error",
"error": str(e)
}

View File

@ -18,6 +18,11 @@ from sklearn.metrics import (
import torch import torch
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
'''
模型管理整体集成
'''
class ModelManager: class ModelManager:
"""模型管理类""" """模型管理类"""
@ -27,12 +32,33 @@ class ModelManager:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self._setup_logging() self._setup_logging()
self._metrics_map() self._metrics_map()
self.method_config = self._load_metrics()
self.method_config = self._load_model_config()
self.parameter_config = self._load_parameter_config()
# 初始化MLflow客户端 # 初始化MLflow客户端
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000') self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
mlflow.set_tracking_uri(self.mlflow_uri) mlflow.set_tracking_uri(self.mlflow_uri)
self.client = MlflowClient() self.client = MlflowClient()
def _load_metrics(self) -> Dict:
"""加载方法配置文件"""
try:
config_path = Path('model/metrics.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Method config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded method config")
return config
except Exception as e:
self.logger.error(f"Error loading method config: {str(e)}")
raise
def _setup_logging(self): def _setup_logging(self):
"""设置日志""" """设置日志"""
log_dir = Path('.log') log_dir = Path('.log')
@ -47,6 +73,54 @@ class ModelManager:
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO) self.logger.setLevel(logging.INFO)
def _load_model_config(self) -> Dict:
"""加载方法配置文件"""
try:
config_path = Path('model/model.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Method config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config_model = yaml.safe_load(f)
self.logger.info("Successfully loaded model config")
config_path = Path('model/metrics.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config_metric = yaml.safe_load(f)
self.logger.info("Successfully loaded metrics config")
config = {**config_model, **config_metric}
return config
except Exception as e:
self.logger.error(f"Error loading method or metric config: {str(e)}")
raise
def _load_parameter_config(self) -> Dict:
"""加载参数配置文件"""
try:
config_path = Path('model/parameter.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded parameter config")
return config
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def _metrics_map(self): def _metrics_map(self):
self.metrics_map={ self.metrics_map={
'accuracy' : accuracy_score, 'accuracy' : accuracy_score,
@ -63,7 +137,99 @@ class ModelManager:
'completeness': completeness_score, 'completeness': completeness_score,
'silhouette' : silhouette_score 'silhouette' : silhouette_score
} }
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
"""获取算法信息"""
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if algorithm_name in self.method_config.get(category, {}):
return self.method_config[category][algorithm_name]
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
def _get_model_class(self, algorithm_name: str):
"""获取模型类"""
# 分类算法
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, OneClassSVM
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
RandomForestClassifier, GradientBoostingClassifier,
AdaBoostClassifier, IsolationForest
)
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
# 回归算法
from sklearn.linear_model import (
LinearRegression, Ridge, Lasso,
ElasticNet
)
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import (
RandomForestRegressor, GradientBoostingRegressor,
AdaBoostRegressor
)
from catboost import CatBoostRegressor
from sklearn.neural_network import MLPRegressor
# 聚类算法
from sklearn.cluster import (
KMeans, AgglomerativeClustering,
DBSCAN, SpectralClustering
)
from sklearn.mixture import GaussianMixture
algorithm_map = {
# 分类算法
'LogisticRegression': LogisticRegression,
'SVC': SVC,
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
'DecisionTreeClassifier': DecisionTreeClassifier,
'RandomForestClassifier': RandomForestClassifier,
'XGBClassifier': xgb.XGBClassifier,
'AdaBoostClassifier': AdaBoostClassifier,
'CatBoostClassifier': CatBoostClassifier,
'LGBMClassifier': lgb.LGBMClassifier,
'GaussianNB': GaussianNB,
'KNeighborsClassifier': KNeighborsClassifier,
'MLPClassifier': MLPClassifier,
'GradientBoostingClassifier': GradientBoostingClassifier,
# 回归算法
'LinearRegression': LinearRegression,
'Ridge': Ridge,
'Lasso': Lasso,
'ElasticNet': ElasticNet,
'SVR': SVR,
'DecisionTreeRegressor': DecisionTreeRegressor,
'RandomForestRegressor': RandomForestRegressor,
'XGBRegressor': xgb.XGBRegressor,
'AdaBoostRegressor': AdaBoostRegressor,
'CatBoostRegressor': CatBoostRegressor,
'LGBMRegressor': lgb.LGBMRegressor,
'MLPRegressor': MLPRegressor,
# 聚类算法
'KMeans': KMeans,
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现通过init参数控制
'AgglomerativeClustering': AgglomerativeClustering,
'DBSCAN': DBSCAN,
'GaussianMixture': GaussianMixture,
'SpectralClustering': SpectralClustering
}
if algorithm_name not in algorithm_map:
raise ValueError(f"Unknown algorithm: {algorithm_name}")
return algorithm_map[algorithm_name]
def get_finished_models( def get_finished_models(
self, self,
page: int = 1, page: int = 1,
@ -440,7 +606,96 @@ class ModelManager:
'execution_time': f"{execution_time:.2f}s" 'execution_time': f"{execution_time:.2f}s"
} }
} }
def get_models(self) -> Dict:
"""获取预处理方法列表"""
try:
models = []
# 分类方法
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
if classification_algorithms:
models.append({
"name": "classification_algorithms",
"description": "分类方法",
"method": classification_algorithms
})
# 回归方法
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
if regression_algorithms:
models.append({
"name": "regression_algorithms",
"description": "回归方法",
"method": regression_algorithms
})
# 聚类方法
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
if clustering_algorithms:
models.append({
"name": "clustering_algorithms",
"description": "聚类方法",
"method": clustering_algorithms
})
return {
"status": "success",
"models": models
}
except Exception as e:
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def get_model_details(self, method_name: str) -> Dict:
"""获取指定方法的详细信息"""
try:
# 在各个方法类别中查找方法原理和优缺点
method_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.method_config.get(category, {}):
method_info = self.method_config[category][method_name]
break
if method_info is None:
raise ValueError(f"Method {method_name} not found in method config")
# 查找方法参数信息
parameter_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.parameter_config.get(category, {}):
parameter_info = self.parameter_config[category][method_name]
break
if parameter_info is None:
raise ValueError(f"Method {method_name} not found in parameter config")
# 组合返回信息
return {
"status": "success",
"method": {
"name": method_name,
"description": parameter_info.get('description', ''),
"principle": method_info.get('principle', ''),
"advantages": method_info.get('advantages', []),
"disadvantages": method_info.get('disadvantages', []),
"applicable_scenarios": method_info.get('applicable_scenarios', []),
"parameters": parameter_info.get('parameters', [])
}
}
except Exception as e:
self.logger.error(f"Error getting method details: {str(e)}")
return {
"status": "error",
"error": str(e)
}
# except Exception as e: # except Exception as e:
# error_msg = f"预测过程发生错误: {str(e)}" # error_msg = f"预测过程发生错误: {str(e)}"
# self.logger.error(error_msg) # self.logger.error(error_msg)
@ -451,4 +706,184 @@ class ModelManager:
# 'error_type': type(e).__name__, # 'error_type': type(e).__name__,
# 'error_message': str(e) # 'error_message': str(e)
# } # }
# } # }
def get_metrics(self) -> Dict:
"""获取预处理方法列表"""
try:
metrics = []
# 分类方法
classification_metrics = self.method_config.get('classification', {})
if classification_metrics:
metrics.append({
"name": "classification_metrics",
"description": "分类方法评价指标",
"metric": classification_metrics
})
# 回归方法
regression_metrics = self.method_config.get('regression', {})
if regression_metrics:
metrics.append({
"name": "regression_metrics",
"description": "回归方法评价指标",
"metric": regression_metrics
})
# 聚类方法
clustering_metrics = self.method_config.get('clustering', {})
if clustering_metrics:
metrics.append({
"name": "clustering_metrics",
"description": "聚类方法评价指标",
"metric": clustering_metrics
})
return {
"status": "success",
"metric": metrics
}
except Exception as e:
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def train_model(
self,
train_data: Dict,
val_data: Dict,
model_config: Dict,
experiment_name: str
) -> Dict:
"""
训练模型
Args:
train_data: 训练数据包含特征和标签
val_data: 验证数据包含特征和标签
model_config: 模型配置包含算法名称和参数
experiment_name: MLflow实验名称
Returns:
训练结果字典
"""
try:
# 检查实验是否存在且被删除
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment and experiment.lifecycle_stage == 'deleted':
# 如果实验被删除,则创建一个新的实验名称
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
new_experiment_name = f"{experiment_name}_{timestamp}"
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
experiment_name = new_experiment_name
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
with mlflow.start_run() as run:
# 记录基本信息
mlflow.log_param('algorithm', model_config['algorithm'])
mlflow.log_param('task_type', model_config['task_type'])
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('start_time', timestamp)
# 记录模型参数
for param_name, param_value in model_config['params'].items():
mlflow.log_param(param_name, param_value)
# 记录算法信息
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
mlflow.log_param('principle', algorithm_info['principle'])
mlflow.log_param('advantages', str(algorithm_info['advantages']))
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
# 特殊处理KMeans++
if model_config['algorithm'] == 'KMeansPlusPlus':
model_config['params']['init'] = 'k-means++'
# 获取模型类和信息
model_class = self._get_model_class(model_config['algorithm'])
# 创建模型实例
model = model_class(**model_config['params'])
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(train_data['features'], train_data['labels'])
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('end_time', timestamp)
# 在验证集上评估
val_predictions = model.predict(val_data['features'])
metrics = self._calculate_metrics(
val_data['labels'],
val_predictions,
model_config['task_type']
)
# 记录指标
for metric_name, metric_value in metrics.items():
mlflow.log_metric(metric_name, metric_value)
# 保存模型
mlflow.sklearn.log_model(model, "model")
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
return {
'status': 'success',
'run_id': run.info.run_id,
'metrics': metrics,
'algorithm_info': algorithm_info
}
except Exception as e:
error_msg = f"Error training model: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def _calculate_metrics(
self,
true_labels: np.ndarray,
predictions: np.ndarray,
task_type: str
) -> Dict:
"""计算评估指标"""
metrics = {}
if task_type == 'classification':
metrics['accuracy'] = accuracy_score(true_labels, predictions)
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
if len(np.unique(true_labels)) == 2: # 二分类问题
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
elif task_type == 'regression':
metrics['mae'] = mean_absolute_error(true_labels, predictions)
metrics['mse'] = mean_squared_error(true_labels, predictions)
metrics['rmse'] = np.sqrt(metrics['mse'])
metrics['r2'] = r2_score(true_labels, predictions)
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
elif task_type == 'clustering':
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
metrics['completeness'] = completeness_score(true_labels, predictions)
if len(np.unique(predictions)) > 1: # 确保有多个簇
metrics['silhouette'] = silhouette_score(
true_labels.reshape(-1, 1),
predictions
)
return metrics

View File

@ -1,298 +0,0 @@
import numpy as pd
import numpy as np
from typing import Dict, List, Optional
import logging
from pathlib import Path
import datetime
import yaml
import mlflow
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
)
class ModelTrainer:
"""模型训练类"""
def __init__(self, config: Dict = None):
"""初始化模型训练器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
self._load_metrics()
self._load_parameters()
self._load_model_info()
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
# config = yaml.safe_load(f)
# 初始化MLflow
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def _load_metrics(self):
"""加载评估指标配置"""
try:
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
self.metrics_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading metrics config: {str(e)}")
raise
def _load_parameters(self):
"""加载模型参数配置"""
try:
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
self.parameter_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def _load_model_info(self):
"""加载模型信息配置"""
try:
with open('model/model.yaml', 'r', encoding='utf-8') as f:
self.model_info = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading model info: {str(e)}")
raise
def _get_model_class(self, algorithm_name: str):
"""获取模型类"""
# 分类算法
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, OneClassSVM
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
RandomForestClassifier, GradientBoostingClassifier,
AdaBoostClassifier, IsolationForest
)
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
# 回归算法
from sklearn.linear_model import (
LinearRegression, Ridge, Lasso,
ElasticNet
)
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import (
RandomForestRegressor, GradientBoostingRegressor,
AdaBoostRegressor
)
from catboost import CatBoostRegressor
from sklearn.neural_network import MLPRegressor
# 聚类算法
from sklearn.cluster import (
KMeans, AgglomerativeClustering,
DBSCAN, SpectralClustering
)
from sklearn.mixture import GaussianMixture
algorithm_map = {
# 分类算法
'LogisticRegression': LogisticRegression,
'SVC': SVC,
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
'DecisionTreeClassifier': DecisionTreeClassifier,
'RandomForestClassifier': RandomForestClassifier,
'XGBClassifier': xgb.XGBClassifier,
'AdaBoostClassifier': AdaBoostClassifier,
'CatBoostClassifier': CatBoostClassifier,
'LGBMClassifier': lgb.LGBMClassifier,
'GaussianNB': GaussianNB,
'KNeighborsClassifier': KNeighborsClassifier,
'MLPClassifier': MLPClassifier,
'GradientBoostingClassifier': GradientBoostingClassifier,
# 回归算法
'LinearRegression': LinearRegression,
'Ridge': Ridge,
'Lasso': Lasso,
'ElasticNet': ElasticNet,
'SVR': SVR,
'DecisionTreeRegressor': DecisionTreeRegressor,
'RandomForestRegressor': RandomForestRegressor,
'XGBRegressor': xgb.XGBRegressor,
'AdaBoostRegressor': AdaBoostRegressor,
'CatBoostRegressor': CatBoostRegressor,
'LGBMRegressor': lgb.LGBMRegressor,
'MLPRegressor': MLPRegressor,
# 聚类算法
'KMeans': KMeans,
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现通过init参数控制
'AgglomerativeClustering': AgglomerativeClustering,
'DBSCAN': DBSCAN,
'GaussianMixture': GaussianMixture,
'SpectralClustering': SpectralClustering
}
if algorithm_name not in algorithm_map:
raise ValueError(f"Unknown algorithm: {algorithm_name}")
return algorithm_map[algorithm_name]
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
"""获取算法信息"""
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if algorithm_name in self.model_info.get(category, {}):
return self.model_info[category][algorithm_name]
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
def train_model(
self,
train_data: Dict,
val_data: Dict,
model_config: Dict,
experiment_name: str
) -> Dict:
"""
训练模型
Args:
train_data: 训练数据包含特征和标签
val_data: 验证数据包含特征和标签
model_config: 模型配置包含算法名称和参数
experiment_name: MLflow实验名称
Returns:
训练结果字典
"""
try:
# 检查实验是否存在且被删除
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment and experiment.lifecycle_stage == 'deleted':
# 如果实验被删除,则创建一个新的实验名称
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
new_experiment_name = f"{experiment_name}_{timestamp}"
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
experiment_name = new_experiment_name
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
with mlflow.start_run() as run:
# 记录基本信息
mlflow.log_param('algorithm', model_config['algorithm'])
mlflow.log_param('task_type', model_config['task_type'])
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('start_time', timestamp)
# 记录模型参数
for param_name, param_value in model_config['params'].items():
mlflow.log_param(param_name, param_value)
# 记录算法信息
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
mlflow.log_param('principle', algorithm_info['principle'])
mlflow.log_param('advantages', str(algorithm_info['advantages']))
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
# 特殊处理KMeans++
if model_config['algorithm'] == 'KMeansPlusPlus':
model_config['params']['init'] = 'k-means++'
# 获取模型类和信息
model_class = self._get_model_class(model_config['algorithm'])
# 创建模型实例
model = model_class(**model_config['params'])
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(train_data['features'], train_data['labels'])
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('end_time', timestamp)
# 在验证集上评估
val_predictions = model.predict(val_data['features'])
metrics = self._calculate_metrics(
val_data['labels'],
val_predictions,
model_config['task_type']
)
# 记录指标
for metric_name, metric_value in metrics.items():
mlflow.log_metric(metric_name, metric_value)
# 保存模型
mlflow.sklearn.log_model(model, "model")
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
return {
'status': 'success',
'run_id': run.info.run_id,
'metrics': metrics,
'algorithm_info': algorithm_info
}
except Exception as e:
error_msg = f"Error training model: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def _calculate_metrics(
self,
true_labels: np.ndarray,
predictions: np.ndarray,
task_type: str
) -> Dict:
"""计算评估指标"""
metrics = {}
if task_type == 'classification':
metrics['accuracy'] = accuracy_score(true_labels, predictions)
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
if len(np.unique(true_labels)) == 2: # 二分类问题
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
elif task_type == 'regression':
metrics['mae'] = mean_absolute_error(true_labels, predictions)
metrics['mse'] = mean_squared_error(true_labels, predictions)
metrics['rmse'] = np.sqrt(metrics['mse'])
metrics['r2'] = r2_score(true_labels, predictions)
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
elif task_type == 'clustering':
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
metrics['completeness'] = completeness_score(true_labels, predictions)
if len(np.unique(predictions)) > 1: # 确保有多个簇
metrics['silhouette'] = silhouette_score(
true_labels.reshape(-1, 1),
predictions
)
return metrics