907 lines
34 KiB
Python
907 lines
34 KiB
Python
import mlflow
|
||
from mlflow.tracking import MlflowClient
|
||
import pandas as pd
|
||
from typing import Dict, List, Optional
|
||
import logging
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
import yaml
|
||
import json
|
||
import time
|
||
import os
|
||
import numpy as np
|
||
from sklearn.metrics import (
|
||
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
|
||
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
|
||
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
|
||
)
|
||
import torch
|
||
from torch.utils.data import DataLoader, TensorDataset
|
||
|
||
|
||
'''
|
||
模型管理整体集成
|
||
'''
|
||
|
||
class ModelManager:
|
||
"""模型管理类"""
|
||
|
||
def __init__(self, config: Dict = None):
|
||
"""初始化模型管理器"""
|
||
self.config = config or {}
|
||
self.logger = logging.getLogger(__name__)
|
||
self._setup_logging()
|
||
self._metrics_map()
|
||
# self.method_config = self._load_metrics()
|
||
|
||
self.method_config = self._load_model_config()
|
||
self.parameter_config = self._load_parameter_config()
|
||
|
||
# 初始化MLflow客户端
|
||
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
|
||
mlflow.set_tracking_uri(self.mlflow_uri)
|
||
self.client = MlflowClient()
|
||
|
||
def _load_metrics(self) -> Dict:
|
||
"""加载模型评价指标配置文件"""
|
||
try:
|
||
config_path = Path('model/metrics.yaml')
|
||
if not config_path.exists():
|
||
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
self.logger.info("Successfully loaded metrics config")
|
||
return config
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Error loading metrics config: {str(e)}")
|
||
raise
|
||
|
||
def _setup_logging(self):
|
||
"""设置日志"""
|
||
log_dir = Path('.log')
|
||
log_dir.mkdir(exist_ok=True)
|
||
|
||
file_handler = logging.FileHandler(
|
||
log_dir / f'model_manager_{datetime.now():%Y%m%d_%H%M%S}.log'
|
||
)
|
||
file_handler.setFormatter(
|
||
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
)
|
||
self.logger.addHandler(file_handler)
|
||
self.logger.setLevel(logging.INFO)
|
||
|
||
def _load_model_config(self) -> Dict:
|
||
"""加载方法配置文件"""
|
||
try:
|
||
config_path = Path('model/model.yaml')
|
||
if not config_path.exists():
|
||
raise FileNotFoundError(f"Model config file not found at {config_path}")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config_model = yaml.safe_load(f)
|
||
|
||
self.logger.info("Successfully loaded model config")
|
||
|
||
|
||
config_path = Path('model/metrics.yaml')
|
||
if not config_path.exists():
|
||
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config_metric = yaml.safe_load(f)
|
||
|
||
self.logger.info("Successfully loaded metrics config")
|
||
|
||
|
||
config = {**config_model, **config_metric}
|
||
return config
|
||
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Error loading method or metric config: {str(e)}")
|
||
raise
|
||
|
||
def _load_parameter_config(self) -> Dict:
|
||
"""加载参数配置文件"""
|
||
try:
|
||
config_path = Path('model/parameter.yaml')
|
||
if not config_path.exists():
|
||
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
self.logger.info("Successfully loaded parameter config")
|
||
return config
|
||
except Exception as e:
|
||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||
raise
|
||
|
||
|
||
def _metrics_map(self):
|
||
self.metrics_map={
|
||
'accuracy' : accuracy_score,
|
||
'precision' : precision_score,
|
||
'recall' : recall_score,
|
||
'f1' : f1_score,
|
||
'mae' : mean_absolute_error,
|
||
'mse' : mean_squared_error,
|
||
# 'rmse' : np.sqrt(mean_absolute_error), # 这里要特殊处理一下
|
||
'r2': r2_score,
|
||
'explained_variance' : explained_variance_score,
|
||
'adjusted_rand' : adjusted_rand_score,
|
||
'homogeneity' : homogeneity_score,
|
||
'completeness': completeness_score,
|
||
'silhouette' : silhouette_score
|
||
}
|
||
|
||
|
||
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
|
||
"""获取算法信息"""
|
||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||
if algorithm_name in self.method_config.get(category, {}):
|
||
return self.method_config[category][algorithm_name]
|
||
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
|
||
|
||
|
||
def _get_model_class(self, algorithm_name: str):
|
||
"""获取模型类"""
|
||
# 分类算法
|
||
from sklearn.linear_model import LogisticRegression
|
||
from sklearn.svm import SVC, OneClassSVM
|
||
from sklearn.tree import DecisionTreeClassifier
|
||
from sklearn.ensemble import (
|
||
RandomForestClassifier, GradientBoostingClassifier,
|
||
AdaBoostClassifier, IsolationForest
|
||
)
|
||
from sklearn.naive_bayes import GaussianNB
|
||
from sklearn.neighbors import KNeighborsClassifier
|
||
from sklearn.neural_network import MLPClassifier
|
||
import xgboost as xgb
|
||
import lightgbm as lgb
|
||
from catboost import CatBoostClassifier
|
||
|
||
# 回归算法
|
||
from sklearn.linear_model import (
|
||
LinearRegression, Ridge, Lasso,
|
||
ElasticNet
|
||
)
|
||
from sklearn.svm import SVR
|
||
from sklearn.tree import DecisionTreeRegressor
|
||
from sklearn.ensemble import (
|
||
RandomForestRegressor, GradientBoostingRegressor,
|
||
AdaBoostRegressor
|
||
)
|
||
from catboost import CatBoostRegressor
|
||
from sklearn.neural_network import MLPRegressor
|
||
|
||
# 聚类算法
|
||
from sklearn.cluster import (
|
||
KMeans, AgglomerativeClustering,
|
||
DBSCAN, SpectralClustering
|
||
)
|
||
from sklearn.mixture import GaussianMixture
|
||
|
||
algorithm_map = {
|
||
# 分类算法
|
||
'LogisticRegression': LogisticRegression,
|
||
'SVC': SVC,
|
||
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
|
||
'DecisionTreeClassifier': DecisionTreeClassifier,
|
||
'RandomForestClassifier': RandomForestClassifier,
|
||
'XGBClassifier': xgb.XGBClassifier,
|
||
'AdaBoostClassifier': AdaBoostClassifier,
|
||
'CatBoostClassifier': CatBoostClassifier,
|
||
'LGBMClassifier': lgb.LGBMClassifier,
|
||
'GaussianNB': GaussianNB,
|
||
'KNeighborsClassifier': KNeighborsClassifier,
|
||
'MLPClassifier': MLPClassifier,
|
||
'GradientBoostingClassifier': GradientBoostingClassifier,
|
||
|
||
# 回归算法
|
||
'LinearRegression': LinearRegression,
|
||
'Ridge': Ridge,
|
||
'Lasso': Lasso,
|
||
'ElasticNet': ElasticNet,
|
||
'SVR': SVR,
|
||
'DecisionTreeRegressor': DecisionTreeRegressor,
|
||
'RandomForestRegressor': RandomForestRegressor,
|
||
'XGBRegressor': xgb.XGBRegressor,
|
||
'AdaBoostRegressor': AdaBoostRegressor,
|
||
'CatBoostRegressor': CatBoostRegressor,
|
||
'LGBMRegressor': lgb.LGBMRegressor,
|
||
'MLPRegressor': MLPRegressor,
|
||
|
||
# 聚类算法
|
||
'KMeans': KMeans,
|
||
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制
|
||
'AgglomerativeClustering': AgglomerativeClustering,
|
||
'DBSCAN': DBSCAN,
|
||
'GaussianMixture': GaussianMixture,
|
||
'SpectralClustering': SpectralClustering
|
||
}
|
||
|
||
if algorithm_name not in algorithm_map:
|
||
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
||
|
||
return algorithm_map[algorithm_name]
|
||
|
||
|
||
def get_finished_models(
|
||
self,
|
||
page: int = 1,
|
||
page_size: int = 10,
|
||
experiment_name: Optional[str] = None
|
||
) -> Dict:
|
||
"""
|
||
获取已训练完成的模型列表
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页数量
|
||
experiment_name: 实验名称过滤
|
||
|
||
Returns:
|
||
模型列表信息
|
||
"""
|
||
try:
|
||
# 获取所有实验
|
||
if experiment_name:
|
||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||
if experiment is None:
|
||
return {
|
||
'status': 'error',
|
||
'message': f'Experiment {experiment_name} not found'
|
||
}
|
||
experiments = [experiment]
|
||
else:
|
||
experiments = mlflow.search_experiments()
|
||
|
||
# 获取所有运行记录
|
||
all_runs = []
|
||
for exp in experiments:
|
||
runs = mlflow.search_runs(
|
||
experiment_ids=[exp.experiment_id],
|
||
filter_string="status = 'FINISHED'"
|
||
)
|
||
all_runs.extend(runs.to_dict('records'))
|
||
|
||
# 计算分页
|
||
total_count = len(all_runs)
|
||
start_idx = (page - 1) * page_size
|
||
end_idx = start_idx + page_size
|
||
page_runs = all_runs[start_idx:end_idx]
|
||
|
||
# 格式化模型信息
|
||
models = []
|
||
for run in page_runs:
|
||
print(run.keys())
|
||
# 收集所有参数
|
||
params = {}
|
||
metrics = {}
|
||
for key, value in run.items():
|
||
if key.startswith('params.'):
|
||
params[key.replace('params.', '')] = value
|
||
elif key.startswith('metrics.'):
|
||
metrics[key.replace('metrics.', '')] = value
|
||
|
||
# 转换时间为本地时间
|
||
start_time = pd.to_datetime(run['start_time']).tz_convert('Asia/Shanghai')
|
||
end_time = pd.to_datetime(run['end_time']).tz_convert('Asia/Shanghai')
|
||
|
||
# 构建模型信息
|
||
model_info = {
|
||
'run_id': run['run_id'],
|
||
'experiment_id': run['experiment_id'],
|
||
'algorithm': params['algorithm'], # 从配置或其他地方获取
|
||
'task_type': params['task_type'], # 从配置或其他地方获取
|
||
'dataset': params['dataset'], # 从配置或其他地方获取
|
||
'training_start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'), # 格式化为本地时间字符串
|
||
'training_end_time': end_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'metrics': metrics,
|
||
'parameters': {
|
||
k: v for k, v in params.items()
|
||
if k not in ['principle', 'advantages', 'disadvantages']
|
||
},
|
||
'algorithm_info': {
|
||
'principle': params.get('principle', ''),
|
||
'advantages': params.get('advantages', ''),
|
||
'disadvantages': params.get('disadvantages', '')
|
||
},
|
||
'mlflow_run_id': run['run_id'],
|
||
'model_path': f"models/{run['run_id']}"
|
||
}
|
||
models.append(model_info)
|
||
|
||
self.logger.info("获取已训练完成的模型列表")
|
||
|
||
return {
|
||
'status': 'success',
|
||
'models': models,
|
||
'total_count': total_count,
|
||
'page': page,
|
||
'page_size': page_size
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"Error getting finished models: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
return {
|
||
'status': 'error',
|
||
'message': error_msg
|
||
}
|
||
|
||
def get_experiments(
|
||
self,
|
||
page: int = 1,
|
||
page_size: int = 10,
|
||
include_deleted: bool = False
|
||
) -> Dict:
|
||
"""
|
||
获取MLflow中保存的实验列表
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页数量
|
||
include_deleted: 是否包含已删除的实验
|
||
|
||
Returns:
|
||
实验列表信息
|
||
"""
|
||
try:
|
||
# 获取所有实验
|
||
experiments = mlflow.search_experiments()
|
||
|
||
# 过滤已删除的实验
|
||
if not include_deleted:
|
||
experiments = [exp for exp in experiments if exp.lifecycle_stage == 'active']
|
||
|
||
# 计算分页
|
||
total_count = len(experiments)
|
||
start_idx = (page - 1) * page_size
|
||
end_idx = start_idx + page_size
|
||
page_experiments = experiments[start_idx:end_idx]
|
||
|
||
# 格式化实验信息
|
||
experiment_list = []
|
||
for exp in page_experiments:
|
||
# 获取实验的运行次数
|
||
runs = mlflow.search_runs(
|
||
experiment_ids=[exp.experiment_id],
|
||
filter_string="status = 'FINISHED'"
|
||
)
|
||
|
||
# 获取最后更新时间并转换为本地时间
|
||
if len(runs) > 0:
|
||
last_update_time = pd.to_datetime(runs['end_time'].max())
|
||
else:
|
||
last_update_time = pd.to_datetime(exp.creation_time)
|
||
|
||
# 创建时间转换为本地时间
|
||
creation_time = pd.to_datetime(exp.creation_time)
|
||
|
||
experiment_info = {
|
||
'experiment_id': exp.experiment_id,
|
||
'name': exp.name,
|
||
'artifact_location': exp.artifact_location,
|
||
'lifecycle_stage': exp.lifecycle_stage,
|
||
'creation_time': creation_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'last_update_time': last_update_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'tags': exp.tags,
|
||
'runs_count': len(runs)
|
||
}
|
||
experiment_list.append(experiment_info)
|
||
|
||
self.logger.info("获取保存的实验列表")
|
||
return {
|
||
'status': 'success',
|
||
'experiments': experiment_list,
|
||
'total_count': total_count,
|
||
'page': page,
|
||
'page_size': page_size
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"Error getting experiments: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
return {
|
||
'status': 'error',
|
||
'message': error_msg
|
||
}
|
||
|
||
def delete_model(self, run_id: str) -> Dict:
|
||
"""
|
||
删除指定的训练好的模型
|
||
|
||
Args:
|
||
run_id: MLflow运行ID
|
||
|
||
Returns:
|
||
删除操作的结果信息
|
||
"""
|
||
try:
|
||
# 获取运行信息
|
||
run = self.client.get_run(run_id)
|
||
if not run:
|
||
return {
|
||
'status': 'error',
|
||
'message': f'未找到运行ID为 {run_id} 的模型'
|
||
}
|
||
|
||
# 获取实验ID
|
||
experiment_id = run.info.experiment_id
|
||
|
||
# 获取模型信息
|
||
run_data = mlflow.get_run(run_id)
|
||
model_name = run_data.data.params.get('algorithm', 'Unknown')
|
||
|
||
# 获取工件列表
|
||
artifacts = []
|
||
for artifact in self.client.list_artifacts(run_id):
|
||
artifacts.append(artifact.path)
|
||
|
||
# 删除运行记录
|
||
self.client.delete_run(run_id)
|
||
|
||
# 记录日志
|
||
self.logger.info(f"已删除模型 - Run ID: {run_id}, 实验ID: {experiment_id}, 模型: {model_name}")
|
||
|
||
return {
|
||
'status': 'success',
|
||
'message': '模型删除成功',
|
||
'details': {
|
||
'run_id': run_id,
|
||
'experiment_id': experiment_id,
|
||
'model_name': model_name,
|
||
'deleted_artifacts': artifacts
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"删除模型时发生错误: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
return {
|
||
'status': 'error',
|
||
'message': error_msg
|
||
}
|
||
|
||
def predict(
|
||
self,
|
||
run_id: str,
|
||
data_path: str,
|
||
output_path: str,
|
||
batch_size: int = 32,
|
||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
|
||
return_proba: bool = True,
|
||
metrics: List[str] = None
|
||
) -> Dict:
|
||
"""
|
||
使用指定的模型进行预测
|
||
|
||
Args:
|
||
run_id: MLflow运行ID
|
||
data_path: 输入数据路径
|
||
output_path: 预测结果保存路径
|
||
batch_size: 批处理大小
|
||
device: 计算设备 ('cuda' or 'cpu')
|
||
return_proba: 是否返回概率预测
|
||
metrics: 评估指标列表
|
||
|
||
Returns:
|
||
预测结果信息
|
||
"""
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 获取模型信息
|
||
run = self.client.get_run(run_id)
|
||
if not run:
|
||
return {
|
||
'status': 'error',
|
||
'message': f'未找到运行ID为 {run_id} 的模型'
|
||
}
|
||
|
||
# 加载模型
|
||
model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
|
||
model_name = run.data.params.get('algorithm', 'Unknown')
|
||
|
||
# 加载数据
|
||
try:
|
||
data = pd.read_csv(data_path)
|
||
if 'label' in data.columns:
|
||
y_true = data.pop('label').values
|
||
has_labels = True
|
||
elif 'target' in data.columns:
|
||
y_true = data.pop('target').values
|
||
has_labels = True
|
||
else:
|
||
has_labels = False
|
||
X = data.values
|
||
except Exception as e:
|
||
return {
|
||
'status': 'error',
|
||
'message': '数据加载失败',
|
||
'details': {
|
||
'error_type': type(e).__name__,
|
||
'error_message': str(e)
|
||
}
|
||
}
|
||
|
||
# 创建预测ID
|
||
pred_id = f"pred_{datetime.now():%Y%m%d_%H%M%S}"
|
||
|
||
# 进行预测
|
||
if isinstance(model, torch.nn.Module):
|
||
# PyTorch模型预测
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
dataset = TensorDataset(torch.FloatTensor(X))
|
||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||
|
||
predictions = []
|
||
probas = []
|
||
|
||
with torch.no_grad():
|
||
for batch in dataloader:
|
||
batch = batch[0].to(device)
|
||
outputs = model(batch)
|
||
|
||
if return_proba:
|
||
proba = torch.softmax(outputs, dim=1)
|
||
probas.append(proba.cpu().numpy())
|
||
|
||
preds = outputs.argmax(dim=1)
|
||
predictions.append(preds.cpu().numpy())
|
||
|
||
predictions = np.concatenate(predictions)
|
||
if return_proba:
|
||
probas = np.concatenate(probas)
|
||
else:
|
||
# 其他模型预测
|
||
predictions = model.predict(X)
|
||
if return_proba and hasattr(model, 'predict_proba'):
|
||
probas = model.predict_proba(X)
|
||
else:
|
||
probas = []
|
||
|
||
# 计算评估指标
|
||
metrics_results = {}
|
||
if has_labels and metrics:
|
||
for metric in metrics:
|
||
if metric in self.metrics_map.keys():
|
||
metrics_results[metric] = float(self.metrics_map[metric](y_true, predictions))
|
||
|
||
# 保存预测结果
|
||
results_df = pd.DataFrame({
|
||
'prediction': predictions
|
||
})
|
||
if return_proba and len(probas) > 0:
|
||
for i in range(probas.shape[1]):
|
||
results_df[f'probability_{i}'] = probas[:, i]
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
results_df.to_csv(output_path, index=False)
|
||
|
||
# 计算执行时间
|
||
execution_time = time.time() - start_time
|
||
|
||
# 记录日志
|
||
self.logger.info(
|
||
f"预测完成 - Run ID: {run_id}, 模型: {model_name}, "
|
||
f"样本数: {len(predictions)}, 耗时: {execution_time:.2f}s"
|
||
)
|
||
|
||
return {
|
||
'status': 'success',
|
||
'prediction': {
|
||
'id': pred_id,
|
||
'run_id': run_id,
|
||
'model_name': model_name,
|
||
'output_file': output_path,
|
||
'prediction_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'samples_count': len(predictions),
|
||
'metrics': metrics_results,
|
||
'execution_time': f"{execution_time:.2f}s"
|
||
}
|
||
}
|
||
except Exception as e:
|
||
self.logger.error(f"Error model predict: {str(e)}")
|
||
return {
|
||
"status": "error",
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
|
||
def get_models(self) -> Dict:
|
||
"""获取机器/深度学习方法列表"""
|
||
try:
|
||
models = []
|
||
|
||
# 分类方法
|
||
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
|
||
if classification_algorithms:
|
||
models.append({
|
||
"name": "classification_algorithms",
|
||
"description": "分类方法",
|
||
"method": classification_algorithms
|
||
})
|
||
|
||
# 回归方法
|
||
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
|
||
if regression_algorithms:
|
||
models.append({
|
||
"name": "regression_algorithms",
|
||
"description": "回归方法",
|
||
"method": regression_algorithms
|
||
})
|
||
|
||
# 聚类方法
|
||
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
|
||
if clustering_algorithms:
|
||
models.append({
|
||
"name": "clustering_algorithms",
|
||
"description": "聚类方法",
|
||
"method": clustering_algorithms
|
||
})
|
||
|
||
|
||
self.logger.info("获取机器/深度学习方法列表")
|
||
|
||
return {
|
||
"status": "success",
|
||
"models": models
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Error get models: {str(e)}")
|
||
return {
|
||
"status": "error",
|
||
"error": str(e)
|
||
}
|
||
|
||
def get_model_details(self, method_name: str) -> Dict:
|
||
"""获取指定算法的详细信息"""
|
||
try:
|
||
# 在各个方法类别中查找方法原理和优缺点
|
||
method_info = None
|
||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||
if method_name in self.method_config.get(category, {}):
|
||
method_info = self.method_config[category][method_name]
|
||
break
|
||
|
||
if method_info is None:
|
||
raise ValueError(f"Method {method_name} not found in method config")
|
||
|
||
# 查找方法参数信息
|
||
parameter_info = None
|
||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||
if method_name in self.parameter_config.get(category, {}):
|
||
parameter_info = self.parameter_config[category][method_name]
|
||
break
|
||
|
||
if parameter_info is None:
|
||
raise ValueError(f"Method {method_name} not found in parameter config")
|
||
|
||
self.logger.info(f"获取{method_name}算法的详细信息")
|
||
|
||
# 组合返回信息
|
||
return {
|
||
"status": "success",
|
||
"method": {
|
||
"name": method_name,
|
||
"description": parameter_info.get('description', ''),
|
||
"principle": method_info.get('principle', ''),
|
||
"advantages": method_info.get('advantages', []),
|
||
"disadvantages": method_info.get('disadvantages', []),
|
||
"applicable_scenarios": method_info.get('applicable_scenarios', []),
|
||
"parameters": parameter_info.get('parameters', [])
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Error getting method details: {str(e)}")
|
||
return {
|
||
"status": "error",
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
# except Exception as e:
|
||
# error_msg = f"预测过程发生错误: {str(e)}"
|
||
# self.logger.error(error_msg)
|
||
# return {
|
||
# 'status': 'error',
|
||
# 'message': '模型预测失败',
|
||
# 'details': {
|
||
# 'error_type': type(e).__name__,
|
||
# 'error_message': str(e)
|
||
# }
|
||
# }
|
||
|
||
def get_metrics(self) -> Dict:
|
||
"""获取评价指标列表"""
|
||
try:
|
||
metrics = []
|
||
|
||
# 分类方法
|
||
classification_metrics = self.method_config.get('classification', {})
|
||
if classification_metrics:
|
||
metrics.append({
|
||
"name": "classification_metrics",
|
||
"description": "分类方法评价指标",
|
||
"metric": classification_metrics
|
||
})
|
||
|
||
# 回归方法
|
||
regression_metrics = self.method_config.get('regression', {})
|
||
if regression_metrics:
|
||
metrics.append({
|
||
"name": "regression_metrics",
|
||
"description": "回归方法评价指标",
|
||
"metric": regression_metrics
|
||
})
|
||
|
||
# 聚类方法
|
||
clustering_metrics = self.method_config.get('clustering', {})
|
||
if clustering_metrics:
|
||
metrics.append({
|
||
"name": "clustering_metrics",
|
||
"description": "聚类方法评价指标",
|
||
"metric": clustering_metrics
|
||
})
|
||
|
||
self.logger.info("获取评价指标列表")
|
||
|
||
return {
|
||
"status": "success",
|
||
"metric": metrics
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Error get metrics: {str(e)}")
|
||
return {
|
||
"status": "error",
|
||
"error": str(e)
|
||
}
|
||
def train_model(
|
||
self,
|
||
train_data: Dict,
|
||
val_data: Dict,
|
||
model_config: Dict,
|
||
experiment_name: str
|
||
) -> Dict:
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
train_data: 训练数据,包含特征和标签
|
||
val_data: 验证数据,包含特征和标签
|
||
model_config: 模型配置,包含算法名称和参数
|
||
experiment_name: MLflow实验名称
|
||
|
||
Returns:
|
||
训练结果字典
|
||
"""
|
||
try:
|
||
# 检查实验是否存在且被删除
|
||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||
if experiment and experiment.lifecycle_stage == 'deleted':
|
||
# 如果实验被删除,则创建一个新的实验名称
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
new_experiment_name = f"{experiment_name}_{timestamp}"
|
||
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
|
||
experiment_name = new_experiment_name
|
||
|
||
# 设置MLflow实验
|
||
mlflow.set_experiment(experiment_name)
|
||
|
||
with mlflow.start_run() as run:
|
||
# 记录基本信息
|
||
mlflow.log_param('algorithm', model_config['algorithm'])
|
||
mlflow.log_param('task_type', model_config['task_type'])
|
||
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
||
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
||
|
||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
# mlflow.log_param('start_time', timestamp)
|
||
|
||
# 记录模型参数
|
||
for param_name, param_value in model_config['params'].items():
|
||
mlflow.log_param(param_name, param_value)
|
||
|
||
# 记录算法信息
|
||
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
|
||
mlflow.log_param('principle', algorithm_info['principle'])
|
||
mlflow.log_param('advantages', str(algorithm_info['advantages']))
|
||
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
|
||
|
||
# 特殊处理KMeans++
|
||
if model_config['algorithm'] == 'KMeansPlusPlus':
|
||
model_config['params']['init'] = 'k-means++'
|
||
|
||
# 获取模型类和信息
|
||
model_class = self._get_model_class(model_config['algorithm'])
|
||
|
||
# 创建模型实例
|
||
model = model_class(**model_config['params'])
|
||
|
||
# 训练模型
|
||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||
model.fit(train_data['features'], train_data['labels'])
|
||
|
||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
# mlflow.log_param('end_time', timestamp)
|
||
|
||
# 在验证集上评估
|
||
val_predictions = model.predict(val_data['features'])
|
||
metrics = self._calculate_metrics(
|
||
val_data['labels'],
|
||
val_predictions,
|
||
model_config['task_type']
|
||
)
|
||
|
||
# 记录指标
|
||
for metric_name, metric_value in metrics.items():
|
||
mlflow.log_metric(metric_name, metric_value)
|
||
|
||
# 保存模型
|
||
mlflow.sklearn.log_model(model, "model")
|
||
|
||
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
||
|
||
return {
|
||
'status': 'success',
|
||
'run_id': run.info.run_id,
|
||
'metrics': metrics,
|
||
'algorithm_info': algorithm_info
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"Error training model: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
return {
|
||
'status': 'error',
|
||
'message': error_msg
|
||
}
|
||
|
||
def _calculate_metrics(
|
||
self,
|
||
true_labels: np.ndarray,
|
||
predictions: np.ndarray,
|
||
task_type: str
|
||
) -> Dict:
|
||
"""计算评估指标"""
|
||
metrics = {}
|
||
|
||
if task_type == 'classification':
|
||
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
||
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
||
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
||
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
||
if len(np.unique(true_labels)) == 2: # 二分类问题
|
||
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
|
||
|
||
elif task_type == 'regression':
|
||
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
||
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
||
metrics['rmse'] = np.sqrt(metrics['mse'])
|
||
metrics['r2'] = r2_score(true_labels, predictions)
|
||
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
||
|
||
elif task_type == 'clustering':
|
||
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
||
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
||
metrics['completeness'] = completeness_score(true_labels, predictions)
|
||
if len(np.unique(predictions)) > 1: # 确保有多个簇
|
||
metrics['silhouette'] = silhouette_score(
|
||
true_labels.reshape(-1, 1),
|
||
predictions
|
||
)
|
||
|
||
return metrics |