import numpy as pd import numpy as np from typing import Dict, List, Optional import logging from pathlib import Path import datetime import yaml import mlflow from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score from sklearn.metrics import adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score class ModelTrainer: """模型训练类""" def __init__(self, config: Dict = None): """初始化模型训练器""" self.config = config or {} self.logger = logging.getLogger(__name__) self._setup_logging() self._load_metrics() self._load_parameters() # with open("confg/config.yaml", 'r', encoding='utf-8') as f: # config = yaml.safe_load(f) # 初始化MLflow mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000')) def _setup_logging(self): """设置日志""" log_dir = Path('.log') log_dir.mkdir(exist_ok=True) file_handler = logging.FileHandler( log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log' ) file_handler.setFormatter( logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ) self.logger.addHandler(file_handler) self.logger.setLevel(logging.INFO) def _load_metrics(self): """加载评估指标配置""" try: with open('model/metrics.yaml', 'r', encoding='utf-8') as f: self.metrics_config = yaml.safe_load(f) except Exception as e: self.logger.error(f"Error loading metrics config: {str(e)}") raise def _load_parameters(self): """加载模型参数配置""" try: with open('model/parameter.yaml', 'r', encoding='utf-8') as f: self.parameter_config = yaml.safe_load(f) except Exception as e: self.logger.error(f"Error loading parameter config: {str(e)}") raise def train_model( self, train_data: Dict, val_data: Dict, model_config: Dict, experiment_name: str ) -> Dict: """ 训练模型 Args: train_data: 训练数据,包含特征和标签 val_data: 验证数据,包含特征和标签 model_config: 模型配置,包含算法名称和参数 experiment_name: MLflow实验名称 Returns: 训练结果字典 """ try: # 设置MLflow实验 mlflow.set_experiment(experiment_name) with mlflow.start_run() as run: # 记录参数 mlflow.log_params(model_config['params']) # 获取模型类 model_class = self._get_model_class(model_config['algorithm']) # 创建模型实例 model = model_class(**model_config['params']) # 训练模型 self.logger.info(f"Starting training {model_config['algorithm']}") model.fit(train_data['features'], train_data['labels']) # 在验证集上评估 val_predictions = model.predict(val_data['features']) metrics = self._calculate_metrics( val_data['labels'], val_predictions, model_config['task_type'] ) # 记录指标 for metric_name, metric_value in metrics.items(): mlflow.log_metric(metric_name, metric_value) # 保存模型 mlflow.sklearn.log_model(model, "model") self.logger.info(f"Training completed. Run ID: {run.info.run_id}") return { 'status': 'success', 'run_id': run.info.run_id, 'metrics': metrics } except Exception as e: error_msg = f"Error training model: {str(e)}" self.logger.error(error_msg) return { 'status': 'error', 'message': error_msg } def _get_model_class(self, algorithm_name: str): """获取模型类""" from sklearn.linear_model import LogisticRegression from sklearn.svm import SVC from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier from xgboost import XGBClassifier from lightgbm import LGBMClassifier from catboost import CatBoostClassifier algorithm_map = { 'LogisticRegression': LogisticRegression, 'SVC': SVC, 'DecisionTreeClassifier': DecisionTreeClassifier, 'RandomForestClassifier': RandomForestClassifier, 'XGBClassifier': XGBClassifier, 'LGBMClassifier': LGBMClassifier, 'CatBoostClassifier': CatBoostClassifier } if algorithm_name not in algorithm_map: raise ValueError(f"Unknown algorithm: {algorithm_name}") return algorithm_map[algorithm_name] def _calculate_metrics( self, true_labels: np.ndarray, predictions: np.ndarray, task_type: str ) -> Dict: """计算评估指标""" metrics = {} if task_type == 'classification': metrics['accuracy'] = accuracy_score(true_labels, predictions) metrics['precision'] = precision_score(true_labels, predictions, average='weighted') metrics['recall'] = recall_score(true_labels, predictions, average='weighted') metrics['f1'] = f1_score(true_labels, predictions, average='weighted') elif task_type == 'regression': metrics['mae'] = mean_absolute_error(true_labels, predictions) metrics['mse'] = mean_squared_error(true_labels, predictions) metrics['r2'] = r2_score(true_labels, predictions) metrics['explained_variance'] = explained_variance_score(true_labels, predictions) elif task_type == 'clustering': metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions) metrics['homogeneity'] = homogeneity_score(true_labels, predictions) metrics['completeness'] = completeness_score(true_labels, predictions) metrics['silhouette'] = silhouette_score(true_labels.reshape(-1, 1), predictions) return metrics