MLPlatform/function/model_trainer.py

298 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as pd
import numpy as np
from typing import Dict, List, Optional
import logging
from pathlib import Path
import datetime
import yaml
import mlflow
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
)
class ModelTrainer:
"""模型训练类"""
def __init__(self, config: Dict = None):
"""初始化模型训练器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
self._load_metrics()
self._load_parameters()
self._load_model_info()
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
# config = yaml.safe_load(f)
# 初始化MLflow
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def _load_metrics(self):
"""加载评估指标配置"""
try:
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
self.metrics_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading metrics config: {str(e)}")
raise
def _load_parameters(self):
"""加载模型参数配置"""
try:
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
self.parameter_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def _load_model_info(self):
"""加载模型信息配置"""
try:
with open('model/model.yaml', 'r', encoding='utf-8') as f:
self.model_info = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading model info: {str(e)}")
raise
def _get_model_class(self, algorithm_name: str):
"""获取模型类"""
# 分类算法
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, OneClassSVM
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
RandomForestClassifier, GradientBoostingClassifier,
AdaBoostClassifier, IsolationForest
)
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
# 回归算法
from sklearn.linear_model import (
LinearRegression, Ridge, Lasso,
ElasticNet
)
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import (
RandomForestRegressor, GradientBoostingRegressor,
AdaBoostRegressor
)
from catboost import CatBoostRegressor
from sklearn.neural_network import MLPRegressor
# 聚类算法
from sklearn.cluster import (
KMeans, AgglomerativeClustering,
DBSCAN, SpectralClustering
)
from sklearn.mixture import GaussianMixture
algorithm_map = {
# 分类算法
'LogisticRegression': LogisticRegression,
'SVC': SVC,
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
'DecisionTreeClassifier': DecisionTreeClassifier,
'RandomForestClassifier': RandomForestClassifier,
'XGBClassifier': xgb.XGBClassifier,
'AdaBoostClassifier': AdaBoostClassifier,
'CatBoostClassifier': CatBoostClassifier,
'LGBMClassifier': lgb.LGBMClassifier,
'GaussianNB': GaussianNB,
'KNeighborsClassifier': KNeighborsClassifier,
'MLPClassifier': MLPClassifier,
'GradientBoostingClassifier': GradientBoostingClassifier,
# 回归算法
'LinearRegression': LinearRegression,
'Ridge': Ridge,
'Lasso': Lasso,
'ElasticNet': ElasticNet,
'SVR': SVR,
'DecisionTreeRegressor': DecisionTreeRegressor,
'RandomForestRegressor': RandomForestRegressor,
'XGBRegressor': xgb.XGBRegressor,
'AdaBoostRegressor': AdaBoostRegressor,
'CatBoostRegressor': CatBoostRegressor,
'LGBMRegressor': lgb.LGBMRegressor,
'MLPRegressor': MLPRegressor,
# 聚类算法
'KMeans': KMeans,
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现通过init参数控制
'AgglomerativeClustering': AgglomerativeClustering,
'DBSCAN': DBSCAN,
'GaussianMixture': GaussianMixture,
'SpectralClustering': SpectralClustering
}
if algorithm_name not in algorithm_map:
raise ValueError(f"Unknown algorithm: {algorithm_name}")
return algorithm_map[algorithm_name]
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
"""获取算法信息"""
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if algorithm_name in self.model_info.get(category, {}):
return self.model_info[category][algorithm_name]
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
def train_model(
self,
train_data: Dict,
val_data: Dict,
model_config: Dict,
experiment_name: str
) -> Dict:
"""
训练模型
Args:
train_data: 训练数据,包含特征和标签
val_data: 验证数据,包含特征和标签
model_config: 模型配置,包含算法名称和参数
experiment_name: MLflow实验名称
Returns:
训练结果字典
"""
try:
# 检查实验是否存在且被删除
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment and experiment.lifecycle_stage == 'deleted':
# 如果实验被删除,则创建一个新的实验名称
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
new_experiment_name = f"{experiment_name}_{timestamp}"
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
experiment_name = new_experiment_name
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
with mlflow.start_run() as run:
# 记录基本信息
mlflow.log_param('algorithm', model_config['algorithm'])
mlflow.log_param('task_type', model_config['task_type'])
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('start_time', timestamp)
# 记录模型参数
for param_name, param_value in model_config['params'].items():
mlflow.log_param(param_name, param_value)
# 记录算法信息
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
mlflow.log_param('principle', algorithm_info['principle'])
mlflow.log_param('advantages', str(algorithm_info['advantages']))
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
# 特殊处理KMeans++
if model_config['algorithm'] == 'KMeansPlusPlus':
model_config['params']['init'] = 'k-means++'
# 获取模型类和信息
model_class = self._get_model_class(model_config['algorithm'])
# 创建模型实例
model = model_class(**model_config['params'])
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(train_data['features'], train_data['labels'])
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('end_time', timestamp)
# 在验证集上评估
val_predictions = model.predict(val_data['features'])
metrics = self._calculate_metrics(
val_data['labels'],
val_predictions,
model_config['task_type']
)
# 记录指标
for metric_name, metric_value in metrics.items():
mlflow.log_metric(metric_name, metric_value)
# 保存模型
mlflow.sklearn.log_model(model, "model")
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
return {
'status': 'success',
'run_id': run.info.run_id,
'metrics': metrics,
'algorithm_info': algorithm_info
}
except Exception as e:
error_msg = f"Error training model: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def _calculate_metrics(
self,
true_labels: np.ndarray,
predictions: np.ndarray,
task_type: str
) -> Dict:
"""计算评估指标"""
metrics = {}
if task_type == 'classification':
metrics['accuracy'] = accuracy_score(true_labels, predictions)
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
if len(np.unique(true_labels)) == 2: # 二分类问题
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
elif task_type == 'regression':
metrics['mae'] = mean_absolute_error(true_labels, predictions)
metrics['mse'] = mean_squared_error(true_labels, predictions)
metrics['rmse'] = np.sqrt(metrics['mse'])
metrics['r2'] = r2_score(true_labels, predictions)
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
elif task_type == 'clustering':
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
metrics['completeness'] = completeness_score(true_labels, predictions)
if len(np.unique(predictions)) > 1: # 确保有多个簇
metrics['silhouette'] = silhouette_score(
true_labels.reshape(-1, 1),
predictions
)
return metrics