MLPlatform/function/model_trainer.py

182 lines
6.8 KiB
Python

import numpy as pd
import numpy as np
from typing import Dict, List, Optional
import logging
from pathlib import Path
import datetime
import yaml
import mlflow
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score
from sklearn.metrics import adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
class ModelTrainer:
"""模型训练类"""
def __init__(self, config: Dict = None):
"""初始化模型训练器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
self._load_metrics()
self._load_parameters()
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
# config = yaml.safe_load(f)
# 初始化MLflow
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def _load_metrics(self):
"""加载评估指标配置"""
try:
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
self.metrics_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading metrics config: {str(e)}")
raise
def _load_parameters(self):
"""加载模型参数配置"""
try:
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
self.parameter_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def train_model(
self,
train_data: Dict,
val_data: Dict,
model_config: Dict,
experiment_name: str
) -> Dict:
"""
训练模型
Args:
train_data: 训练数据,包含特征和标签
val_data: 验证数据,包含特征和标签
model_config: 模型配置,包含算法名称和参数
experiment_name: MLflow实验名称
Returns:
训练结果字典
"""
try:
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
with mlflow.start_run() as run:
# 记录参数
mlflow.log_params(model_config['params'])
# 获取模型类
model_class = self._get_model_class(model_config['algorithm'])
# 创建模型实例
model = model_class(**model_config['params'])
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(train_data['features'], train_data['labels'])
# 在验证集上评估
val_predictions = model.predict(val_data['features'])
metrics = self._calculate_metrics(
val_data['labels'],
val_predictions,
model_config['task_type']
)
# 记录指标
for metric_name, metric_value in metrics.items():
mlflow.log_metric(metric_name, metric_value)
# 保存模型
mlflow.sklearn.log_model(model, "model")
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
return {
'status': 'success',
'run_id': run.info.run_id,
'metrics': metrics
}
except Exception as e:
error_msg = f"Error training model: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def _get_model_class(self, algorithm_name: str):
"""获取模型类"""
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
algorithm_map = {
'LogisticRegression': LogisticRegression,
'SVC': SVC,
'DecisionTreeClassifier': DecisionTreeClassifier,
'RandomForestClassifier': RandomForestClassifier,
'XGBClassifier': XGBClassifier,
'LGBMClassifier': LGBMClassifier,
'CatBoostClassifier': CatBoostClassifier
}
if algorithm_name not in algorithm_map:
raise ValueError(f"Unknown algorithm: {algorithm_name}")
return algorithm_map[algorithm_name]
def _calculate_metrics(
self,
true_labels: np.ndarray,
predictions: np.ndarray,
task_type: str
) -> Dict:
"""计算评估指标"""
metrics = {}
if task_type == 'classification':
metrics['accuracy'] = accuracy_score(true_labels, predictions)
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
elif task_type == 'regression':
metrics['mae'] = mean_absolute_error(true_labels, predictions)
metrics['mse'] = mean_squared_error(true_labels, predictions)
metrics['r2'] = r2_score(true_labels, predictions)
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
elif task_type == 'clustering':
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
metrics['completeness'] = completeness_score(true_labels, predictions)
metrics['silhouette'] = silhouette_score(true_labels.reshape(-1, 1), predictions)
return metrics