182 lines
6.8 KiB
Python
182 lines
6.8 KiB
Python
import numpy as pd
|
|
import numpy as np
|
|
from typing import Dict, List, Optional
|
|
import logging
|
|
from pathlib import Path
|
|
import datetime
|
|
import yaml
|
|
import mlflow
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score
|
|
from sklearn.metrics import adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
|
|
|
|
class ModelTrainer:
|
|
"""模型训练类"""
|
|
|
|
def __init__(self, config: Dict = None):
|
|
"""初始化模型训练器"""
|
|
self.config = config or {}
|
|
self.logger = logging.getLogger(__name__)
|
|
self._setup_logging()
|
|
self._load_metrics()
|
|
self._load_parameters()
|
|
|
|
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
|
|
# config = yaml.safe_load(f)
|
|
|
|
# 初始化MLflow
|
|
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
|
|
|
|
def _setup_logging(self):
|
|
"""设置日志"""
|
|
log_dir = Path('.log')
|
|
log_dir.mkdir(exist_ok=True)
|
|
|
|
file_handler = logging.FileHandler(
|
|
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
|
|
)
|
|
file_handler.setFormatter(
|
|
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
)
|
|
self.logger.addHandler(file_handler)
|
|
self.logger.setLevel(logging.INFO)
|
|
|
|
def _load_metrics(self):
|
|
"""加载评估指标配置"""
|
|
try:
|
|
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
|
|
self.metrics_config = yaml.safe_load(f)
|
|
except Exception as e:
|
|
self.logger.error(f"Error loading metrics config: {str(e)}")
|
|
raise
|
|
|
|
def _load_parameters(self):
|
|
"""加载模型参数配置"""
|
|
try:
|
|
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
|
|
self.parameter_config = yaml.safe_load(f)
|
|
except Exception as e:
|
|
self.logger.error(f"Error loading parameter config: {str(e)}")
|
|
raise
|
|
|
|
def train_model(
|
|
self,
|
|
train_data: Dict,
|
|
val_data: Dict,
|
|
model_config: Dict,
|
|
experiment_name: str
|
|
) -> Dict:
|
|
"""
|
|
训练模型
|
|
|
|
Args:
|
|
train_data: 训练数据,包含特征和标签
|
|
val_data: 验证数据,包含特征和标签
|
|
model_config: 模型配置,包含算法名称和参数
|
|
experiment_name: MLflow实验名称
|
|
|
|
Returns:
|
|
训练结果字典
|
|
"""
|
|
try:
|
|
# 设置MLflow实验
|
|
mlflow.set_experiment(experiment_name)
|
|
|
|
with mlflow.start_run() as run:
|
|
# 记录参数
|
|
mlflow.log_params(model_config['params'])
|
|
|
|
# 获取模型类
|
|
model_class = self._get_model_class(model_config['algorithm'])
|
|
|
|
# 创建模型实例
|
|
model = model_class(**model_config['params'])
|
|
|
|
# 训练模型
|
|
self.logger.info(f"Starting training {model_config['algorithm']}")
|
|
model.fit(train_data['features'], train_data['labels'])
|
|
|
|
# 在验证集上评估
|
|
val_predictions = model.predict(val_data['features'])
|
|
metrics = self._calculate_metrics(
|
|
val_data['labels'],
|
|
val_predictions,
|
|
model_config['task_type']
|
|
)
|
|
|
|
# 记录指标
|
|
for metric_name, metric_value in metrics.items():
|
|
mlflow.log_metric(metric_name, metric_value)
|
|
|
|
# 保存模型
|
|
mlflow.sklearn.log_model(model, "model")
|
|
|
|
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
|
|
|
return {
|
|
'status': 'success',
|
|
'run_id': run.info.run_id,
|
|
'metrics': metrics
|
|
}
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error training model: {str(e)}"
|
|
self.logger.error(error_msg)
|
|
return {
|
|
'status': 'error',
|
|
'message': error_msg
|
|
}
|
|
|
|
def _get_model_class(self, algorithm_name: str):
|
|
"""获取模型类"""
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.svm import SVC
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from xgboost import XGBClassifier
|
|
from lightgbm import LGBMClassifier
|
|
from catboost import CatBoostClassifier
|
|
|
|
algorithm_map = {
|
|
'LogisticRegression': LogisticRegression,
|
|
'SVC': SVC,
|
|
'DecisionTreeClassifier': DecisionTreeClassifier,
|
|
'RandomForestClassifier': RandomForestClassifier,
|
|
'XGBClassifier': XGBClassifier,
|
|
'LGBMClassifier': LGBMClassifier,
|
|
'CatBoostClassifier': CatBoostClassifier
|
|
}
|
|
|
|
if algorithm_name not in algorithm_map:
|
|
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
|
|
|
return algorithm_map[algorithm_name]
|
|
|
|
def _calculate_metrics(
|
|
self,
|
|
true_labels: np.ndarray,
|
|
predictions: np.ndarray,
|
|
task_type: str
|
|
) -> Dict:
|
|
"""计算评估指标"""
|
|
metrics = {}
|
|
|
|
if task_type == 'classification':
|
|
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
|
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
|
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
|
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
|
|
|
elif task_type == 'regression':
|
|
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
|
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
|
metrics['r2'] = r2_score(true_labels, predictions)
|
|
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
|
|
|
elif task_type == 'clustering':
|
|
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
|
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
|
metrics['completeness'] = completeness_score(true_labels, predictions)
|
|
metrics['silhouette'] = silhouette_score(true_labels.reshape(-1, 1), predictions)
|
|
|
|
return metrics |