298 lines
12 KiB
Python
298 lines
12 KiB
Python
import numpy as pd
|
||
import numpy as np
|
||
from typing import Dict, List, Optional
|
||
import logging
|
||
from pathlib import Path
|
||
import datetime
|
||
import yaml
|
||
import mlflow
|
||
from sklearn.metrics import (
|
||
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
|
||
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
|
||
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
|
||
)
|
||
|
||
class ModelTrainer:
|
||
"""模型训练类"""
|
||
|
||
def __init__(self, config: Dict = None):
|
||
"""初始化模型训练器"""
|
||
self.config = config or {}
|
||
self.logger = logging.getLogger(__name__)
|
||
self._setup_logging()
|
||
self._load_metrics()
|
||
self._load_parameters()
|
||
self._load_model_info()
|
||
|
||
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
|
||
# config = yaml.safe_load(f)
|
||
|
||
# 初始化MLflow
|
||
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
|
||
|
||
def _setup_logging(self):
|
||
"""设置日志"""
|
||
log_dir = Path('.log')
|
||
log_dir.mkdir(exist_ok=True)
|
||
|
||
file_handler = logging.FileHandler(
|
||
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
|
||
)
|
||
file_handler.setFormatter(
|
||
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
)
|
||
self.logger.addHandler(file_handler)
|
||
self.logger.setLevel(logging.INFO)
|
||
|
||
def _load_metrics(self):
|
||
"""加载评估指标配置"""
|
||
try:
|
||
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
|
||
self.metrics_config = yaml.safe_load(f)
|
||
except Exception as e:
|
||
self.logger.error(f"Error loading metrics config: {str(e)}")
|
||
raise
|
||
|
||
def _load_parameters(self):
|
||
"""加载模型参数配置"""
|
||
try:
|
||
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
|
||
self.parameter_config = yaml.safe_load(f)
|
||
except Exception as e:
|
||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||
raise
|
||
|
||
def _load_model_info(self):
|
||
"""加载模型信息配置"""
|
||
try:
|
||
with open('model/model.yaml', 'r', encoding='utf-8') as f:
|
||
self.model_info = yaml.safe_load(f)
|
||
except Exception as e:
|
||
self.logger.error(f"Error loading model info: {str(e)}")
|
||
raise
|
||
|
||
def _get_model_class(self, algorithm_name: str):
|
||
"""获取模型类"""
|
||
# 分类算法
|
||
from sklearn.linear_model import LogisticRegression
|
||
from sklearn.svm import SVC, OneClassSVM
|
||
from sklearn.tree import DecisionTreeClassifier
|
||
from sklearn.ensemble import (
|
||
RandomForestClassifier, GradientBoostingClassifier,
|
||
AdaBoostClassifier, IsolationForest
|
||
)
|
||
from sklearn.naive_bayes import GaussianNB
|
||
from sklearn.neighbors import KNeighborsClassifier
|
||
from sklearn.neural_network import MLPClassifier
|
||
import xgboost as xgb
|
||
import lightgbm as lgb
|
||
from catboost import CatBoostClassifier
|
||
|
||
# 回归算法
|
||
from sklearn.linear_model import (
|
||
LinearRegression, Ridge, Lasso,
|
||
ElasticNet
|
||
)
|
||
from sklearn.svm import SVR
|
||
from sklearn.tree import DecisionTreeRegressor
|
||
from sklearn.ensemble import (
|
||
RandomForestRegressor, GradientBoostingRegressor,
|
||
AdaBoostRegressor
|
||
)
|
||
from catboost import CatBoostRegressor
|
||
from sklearn.neural_network import MLPRegressor
|
||
|
||
# 聚类算法
|
||
from sklearn.cluster import (
|
||
KMeans, AgglomerativeClustering,
|
||
DBSCAN, SpectralClustering
|
||
)
|
||
from sklearn.mixture import GaussianMixture
|
||
|
||
algorithm_map = {
|
||
# 分类算法
|
||
'LogisticRegression': LogisticRegression,
|
||
'SVC': SVC,
|
||
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
|
||
'DecisionTreeClassifier': DecisionTreeClassifier,
|
||
'RandomForestClassifier': RandomForestClassifier,
|
||
'XGBClassifier': xgb.XGBClassifier,
|
||
'AdaBoostClassifier': AdaBoostClassifier,
|
||
'CatBoostClassifier': CatBoostClassifier,
|
||
'LGBMClassifier': lgb.LGBMClassifier,
|
||
'GaussianNB': GaussianNB,
|
||
'KNeighborsClassifier': KNeighborsClassifier,
|
||
'MLPClassifier': MLPClassifier,
|
||
'GradientBoostingClassifier': GradientBoostingClassifier,
|
||
|
||
# 回归算法
|
||
'LinearRegression': LinearRegression,
|
||
'Ridge': Ridge,
|
||
'Lasso': Lasso,
|
||
'ElasticNet': ElasticNet,
|
||
'SVR': SVR,
|
||
'DecisionTreeRegressor': DecisionTreeRegressor,
|
||
'RandomForestRegressor': RandomForestRegressor,
|
||
'XGBRegressor': xgb.XGBRegressor,
|
||
'AdaBoostRegressor': AdaBoostRegressor,
|
||
'CatBoostRegressor': CatBoostRegressor,
|
||
'LGBMRegressor': lgb.LGBMRegressor,
|
||
'MLPRegressor': MLPRegressor,
|
||
|
||
# 聚类算法
|
||
'KMeans': KMeans,
|
||
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制
|
||
'AgglomerativeClustering': AgglomerativeClustering,
|
||
'DBSCAN': DBSCAN,
|
||
'GaussianMixture': GaussianMixture,
|
||
'SpectralClustering': SpectralClustering
|
||
}
|
||
|
||
if algorithm_name not in algorithm_map:
|
||
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
||
|
||
return algorithm_map[algorithm_name]
|
||
|
||
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
|
||
"""获取算法信息"""
|
||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||
if algorithm_name in self.model_info.get(category, {}):
|
||
return self.model_info[category][algorithm_name]
|
||
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
|
||
|
||
def train_model(
|
||
self,
|
||
train_data: Dict,
|
||
val_data: Dict,
|
||
model_config: Dict,
|
||
experiment_name: str
|
||
) -> Dict:
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
train_data: 训练数据,包含特征和标签
|
||
val_data: 验证数据,包含特征和标签
|
||
model_config: 模型配置,包含算法名称和参数
|
||
experiment_name: MLflow实验名称
|
||
|
||
Returns:
|
||
训练结果字典
|
||
"""
|
||
try:
|
||
# 检查实验是否存在且被删除
|
||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||
if experiment and experiment.lifecycle_stage == 'deleted':
|
||
# 如果实验被删除,则创建一个新的实验名称
|
||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
new_experiment_name = f"{experiment_name}_{timestamp}"
|
||
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
|
||
experiment_name = new_experiment_name
|
||
|
||
# 设置MLflow实验
|
||
mlflow.set_experiment(experiment_name)
|
||
|
||
with mlflow.start_run() as run:
|
||
# 记录基本信息
|
||
mlflow.log_param('algorithm', model_config['algorithm'])
|
||
mlflow.log_param('task_type', model_config['task_type'])
|
||
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
||
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
||
|
||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
# mlflow.log_param('start_time', timestamp)
|
||
|
||
# 记录模型参数
|
||
for param_name, param_value in model_config['params'].items():
|
||
mlflow.log_param(param_name, param_value)
|
||
|
||
# 记录算法信息
|
||
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
|
||
mlflow.log_param('principle', algorithm_info['principle'])
|
||
mlflow.log_param('advantages', str(algorithm_info['advantages']))
|
||
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
|
||
|
||
# 特殊处理KMeans++
|
||
if model_config['algorithm'] == 'KMeansPlusPlus':
|
||
model_config['params']['init'] = 'k-means++'
|
||
|
||
# 获取模型类和信息
|
||
model_class = self._get_model_class(model_config['algorithm'])
|
||
|
||
# 创建模型实例
|
||
model = model_class(**model_config['params'])
|
||
|
||
# 训练模型
|
||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||
model.fit(train_data['features'], train_data['labels'])
|
||
|
||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
# mlflow.log_param('end_time', timestamp)
|
||
|
||
# 在验证集上评估
|
||
val_predictions = model.predict(val_data['features'])
|
||
metrics = self._calculate_metrics(
|
||
val_data['labels'],
|
||
val_predictions,
|
||
model_config['task_type']
|
||
)
|
||
|
||
# 记录指标
|
||
for metric_name, metric_value in metrics.items():
|
||
mlflow.log_metric(metric_name, metric_value)
|
||
|
||
# 保存模型
|
||
mlflow.sklearn.log_model(model, "model")
|
||
|
||
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
||
|
||
return {
|
||
'status': 'success',
|
||
'run_id': run.info.run_id,
|
||
'metrics': metrics,
|
||
'algorithm_info': algorithm_info
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"Error training model: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
return {
|
||
'status': 'error',
|
||
'message': error_msg
|
||
}
|
||
|
||
def _calculate_metrics(
|
||
self,
|
||
true_labels: np.ndarray,
|
||
predictions: np.ndarray,
|
||
task_type: str
|
||
) -> Dict:
|
||
"""计算评估指标"""
|
||
metrics = {}
|
||
|
||
if task_type == 'classification':
|
||
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
||
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
||
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
||
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
||
if len(np.unique(true_labels)) == 2: # 二分类问题
|
||
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
|
||
|
||
elif task_type == 'regression':
|
||
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
||
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
||
metrics['rmse'] = np.sqrt(metrics['mse'])
|
||
metrics['r2'] = r2_score(true_labels, predictions)
|
||
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
||
|
||
elif task_type == 'clustering':
|
||
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
||
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
||
metrics['completeness'] = completeness_score(true_labels, predictions)
|
||
if len(np.unique(predictions)) > 1: # 确保有多个簇
|
||
metrics['silhouette'] = silhouette_score(
|
||
true_labels.reshape(-1, 1),
|
||
predictions
|
||
)
|
||
|
||
return metrics |