MLPlatform/function/model_manager.py

936 lines
35 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
from typing import Dict, List, Optional
import logging
from pathlib import Path
from datetime import datetime
import yaml
import json
import time
import os
import numpy as np
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
)
import torch
from torch.utils.data import DataLoader, TensorDataset
'''
模型管理整体集成
'''
class ModelManager:
"""模型管理类"""
def __init__(self, config: Dict = None):
"""初始化模型管理器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
self._metrics_map()
# self.method_config = self._load_metrics()
self.method_config = self._load_model_config()
self.parameter_config = self._load_parameter_config()
# 初始化MLflow客户端
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
mlflow.set_tracking_uri(self.mlflow_uri)
self.client = MlflowClient()
def _load_metrics(self) -> Dict:
"""加载模型评价指标配置文件"""
try:
config_path = Path('model/metrics.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded metrics config")
return config
except Exception as e:
self.logger.error(f"Error loading metrics config: {str(e)}")
raise
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_manager_{datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def _load_model_config(self) -> Dict:
"""加载方法配置文件"""
try:
config_path = Path('model/model.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Model config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config_model = yaml.safe_load(f)
self.logger.info("Successfully loaded model config")
config_path = Path('model/metrics.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config_metric = yaml.safe_load(f)
self.logger.info("Successfully loaded metrics config")
config = {**config_model, **config_metric}
return config
except Exception as e:
self.logger.error(f"Error loading method or metric config: {str(e)}")
raise
def _load_parameter_config(self) -> Dict:
"""加载参数配置文件"""
try:
config_path = Path('model/parameter.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded parameter config")
return config
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def _metrics_map(self):
self.metrics_map={
'accuracy' : accuracy_score,
'precision' : precision_score,
'recall' : recall_score,
'f1' : f1_score,
'mae' : mean_absolute_error,
'mse' : mean_squared_error,
# 'rmse' : np.sqrt(mean_absolute_error), # 这里要特殊处理一下
'r2': r2_score,
'explained_variance' : explained_variance_score,
'adjusted_rand' : adjusted_rand_score,
'homogeneity' : homogeneity_score,
'completeness': completeness_score,
'silhouette' : silhouette_score
}
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
"""获取算法信息"""
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if algorithm_name in self.method_config.get(category, {}):
return self.method_config[category][algorithm_name]
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
def _get_model_class(self, algorithm_name: str):
"""获取模型类"""
# 分类算法
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, OneClassSVM
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
RandomForestClassifier, GradientBoostingClassifier,
AdaBoostClassifier, IsolationForest
)
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
# 回归算法
from sklearn.linear_model import (
LinearRegression, Ridge, Lasso,
ElasticNet
)
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import (
RandomForestRegressor, GradientBoostingRegressor,
AdaBoostRegressor
)
from catboost import CatBoostRegressor
from sklearn.neural_network import MLPRegressor
# 聚类算法
from sklearn.cluster import (
KMeans, AgglomerativeClustering,
DBSCAN, SpectralClustering
)
from sklearn.mixture import GaussianMixture
algorithm_map = {
# 分类算法
'LogisticRegression': LogisticRegression,
'SVC': SVC,
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
'DecisionTreeClassifier': DecisionTreeClassifier,
'RandomForestClassifier': RandomForestClassifier,
'XGBClassifier': xgb.XGBClassifier,
'AdaBoostClassifier': AdaBoostClassifier,
'CatBoostClassifier': CatBoostClassifier,
'LGBMClassifier': lgb.LGBMClassifier,
'GaussianNB': GaussianNB,
'KNeighborsClassifier': KNeighborsClassifier,
'MLPClassifier': MLPClassifier,
'GradientBoostingClassifier': GradientBoostingClassifier,
# 回归算法
'LinearRegression': LinearRegression,
'Ridge': Ridge,
'Lasso': Lasso,
'ElasticNet': ElasticNet,
'SVR': SVR,
'DecisionTreeRegressor': DecisionTreeRegressor,
'RandomForestRegressor': RandomForestRegressor,
'XGBRegressor': xgb.XGBRegressor,
'AdaBoostRegressor': AdaBoostRegressor,
'CatBoostRegressor': CatBoostRegressor,
'LGBMRegressor': lgb.LGBMRegressor,
'MLPRegressor': MLPRegressor,
# 聚类算法
'KMeans': KMeans,
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现通过init参数控制
'AgglomerativeClustering': AgglomerativeClustering,
'DBSCAN': DBSCAN,
'GaussianMixture': GaussianMixture,
'SpectralClustering': SpectralClustering
}
if algorithm_name not in algorithm_map:
raise ValueError(f"Unknown algorithm: {algorithm_name}")
return algorithm_map[algorithm_name]
def get_finished_models(
self,
page: int = 1,
page_size: int = 10,
experiment_name: Optional[str] = None
) -> Dict:
"""
获取已训练完成的模型列表
Args:
page: 页码
page_size: 每页数量
experiment_name: 实验名称过滤
Returns:
模型列表信息
"""
try:
# 获取所有实验
if experiment_name:
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
return {
'status': 'error',
'message': f'Experiment {experiment_name} not found'
}
experiments = [experiment]
else:
experiments = mlflow.search_experiments()
# 获取所有运行记录
all_runs = []
for exp in experiments:
runs = mlflow.search_runs(
experiment_ids=[exp.experiment_id],
filter_string="status = 'FINISHED'"
)
all_runs.extend(runs.to_dict('records'))
# 计算分页
total_count = len(all_runs)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_runs = all_runs[start_idx:end_idx]
# 格式化模型信息
models = []
for run in page_runs:
print(run.keys())
# 收集所有参数
params = {}
metrics = {}
for key, value in run.items():
if key.startswith('params.'):
params[key.replace('params.', '')] = value
elif key.startswith('metrics.'):
metrics[key.replace('metrics.', '')] = value
# 转换时间为本地时间
start_time = pd.to_datetime(run['start_time']).tz_convert('Asia/Shanghai')
end_time = pd.to_datetime(run['end_time']).tz_convert('Asia/Shanghai')
# 构建模型信息
model_info = {
'run_id': run['run_id'],
'experiment_id': run['experiment_id'],
'algorithm': params['algorithm'], # 从配置或其他地方获取
'task_type': params['task_type'], # 从配置或其他地方获取
'dataset': params['dataset'], # 从配置或其他地方获取
'training_start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'), # 格式化为本地时间字符串
'training_end_time': end_time.strftime('%Y-%m-%d %H:%M:%S'),
'metrics': metrics,
'parameters': {
k: v for k, v in params.items()
if k not in ['principle', 'advantages', 'disadvantages']
},
'algorithm_info': {
'principle': params.get('principle', ''),
'advantages': params.get('advantages', ''),
'disadvantages': params.get('disadvantages', '')
},
'mlflow_run_id': run['run_id'],
'model_path': f"models/{run['run_id']}"
}
models.append(model_info)
self.logger.info("获取已训练完成的模型列表")
return {
'status': 'success',
'models': models,
'total_count': total_count,
'page': page,
'page_size': page_size
}
except Exception as e:
error_msg = f"Error getting finished models: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def get_experiments(
self,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False
) -> Dict:
"""
获取MLflow中保存的实验列表
Args:
page: 页码
page_size: 每页数量
include_deleted: 是否包含已删除的实验
Returns:
实验列表信息
"""
try:
# 获取所有实验
experiments = mlflow.search_experiments()
# 过滤已删除的实验
if not include_deleted:
experiments = [exp for exp in experiments if exp.lifecycle_stage == 'active']
# 计算分页
total_count = len(experiments)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_experiments = experiments[start_idx:end_idx]
# 格式化实验信息
experiment_list = []
for exp in page_experiments:
# 获取实验的运行次数
runs = mlflow.search_runs(
experiment_ids=[exp.experiment_id],
filter_string="status = 'FINISHED'"
)
# 获取最后更新时间并转换为本地时间
if len(runs) > 0:
last_update_time = pd.to_datetime(runs['end_time'].max())
else:
last_update_time = pd.to_datetime(exp.creation_time)
# 创建时间转换为本地时间
creation_time = pd.to_datetime(exp.creation_time)
experiment_info = {
'experiment_id': exp.experiment_id,
'name': exp.name,
'artifact_location': exp.artifact_location,
'lifecycle_stage': exp.lifecycle_stage,
'creation_time': creation_time.strftime('%Y-%m-%d %H:%M:%S'),
'last_update_time': last_update_time.strftime('%Y-%m-%d %H:%M:%S'),
'tags': exp.tags,
'runs_count': len(runs)
}
experiment_list.append(experiment_info)
self.logger.info("获取保存的实验列表")
return {
'status': 'success',
'experiments': experiment_list,
'total_count': total_count,
'page': page,
'page_size': page_size
}
except Exception as e:
error_msg = f"Error getting experiments: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def delete_model(self, run_id: str) -> Dict:
"""
删除指定的训练好的模型
Args:
run_id: MLflow运行ID
Returns:
删除操作的结果信息
"""
try:
# 获取运行信息
run = self.client.get_run(run_id)
if not run:
return {
'status': 'error',
'message': f'未找到运行ID为 {run_id} 的模型'
}
# 获取实验ID
experiment_id = run.info.experiment_id
# 获取模型信息
run_data = mlflow.get_run(run_id)
model_name = run_data.data.params.get('algorithm', 'Unknown')
# 获取工件列表
artifacts = []
for artifact in self.client.list_artifacts(run_id):
artifacts.append(artifact.path)
# 删除运行记录
self.client.delete_run(run_id)
# 记录日志
self.logger.info(f"已删除模型 - Run ID: {run_id}, 实验ID: {experiment_id}, 模型: {model_name}")
return {
'status': 'success',
'message': '模型删除成功',
'details': {
'run_id': run_id,
'experiment_id': experiment_id,
'model_name': model_name,
'deleted_artifacts': artifacts
}
}
except Exception as e:
error_msg = f"删除模型时发生错误: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def predict(
self,
run_id: str,
data_path: str,
output_path: str,
batch_size: int = 32,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
return_proba: bool = True,
metrics: List[str] = None
) -> Dict:
"""
使用指定的模型进行预测
Args:
run_id: MLflow运行ID
data_path: 输入数据路径
output_path: 预测结果保存路径
batch_size: 批处理大小
device: 计算设备 ('cuda' or 'cpu')
return_proba: 是否返回概率预测
metrics: 评估指标列表
Returns:
预测结果信息
"""
try:
start_time = time.time()
# 获取模型信息
run = self.client.get_run(run_id)
if not run:
return {
'status': 'error',
'message': f'未找到运行ID为 {run_id} 的模型'
}
# 加载模型
model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
model_name = run.data.params.get('algorithm', 'Unknown')
# 加载数据
try:
data = pd.read_csv(data_path)
if 'label' in data.columns:
y_true = data.pop('label').values
has_labels = True
elif 'target' in data.columns:
y_true = data.pop('target').values
has_labels = True
else:
has_labels = False
X = data.values
except Exception as e:
return {
'status': 'error',
'message': '数据加载失败',
'details': {
'error_type': type(e).__name__,
'error_message': str(e)
}
}
# 创建预测ID
pred_id = f"pred_{datetime.now():%Y%m%d_%H%M%S}"
# 进行预测
if isinstance(model, torch.nn.Module):
# PyTorch模型预测
model.to(device)
model.eval()
dataset = TensorDataset(torch.FloatTensor(X))
dataloader = DataLoader(dataset, batch_size=batch_size)
predictions = []
probas = []
with torch.no_grad():
for batch in dataloader:
batch = batch[0].to(device)
outputs = model(batch)
if return_proba:
proba = torch.softmax(outputs, dim=1)
probas.append(proba.cpu().numpy())
preds = outputs.argmax(dim=1)
predictions.append(preds.cpu().numpy())
predictions = np.concatenate(predictions)
if return_proba:
probas = np.concatenate(probas)
else:
# 其他模型预测
predictions = model.predict(X)
if return_proba and hasattr(model, 'predict_proba'):
probas = model.predict_proba(X)
else:
probas = []
# 计算评估指标
metrics_results = {}
if has_labels and metrics:
for metric in metrics:
if metric in self.metrics_map.keys():
metrics_results[metric] = float(self.metrics_map[metric](y_true, predictions))
# 保存预测结果
results_df = pd.DataFrame({
'prediction': predictions
})
if return_proba and len(probas) > 0:
for i in range(probas.shape[1]):
results_df[f'probability_{i}'] = probas[:, i]
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
results_df.to_csv(output_path, index=False)
# 计算执行时间
execution_time = time.time() - start_time
# 记录日志
self.logger.info(
f"预测完成 - Run ID: {run_id}, 模型: {model_name}, "
f"样本数: {len(predictions)}, 耗时: {execution_time:.2f}s"
)
return {
'status': 'success',
'prediction': {
'id': pred_id,
'run_id': run_id,
'model_name': model_name,
'output_file': output_path,
'prediction_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'samples_count': len(predictions),
'metrics': metrics_results,
'execution_time': f"{execution_time:.2f}s"
}
}
except Exception as e:
self.logger.error(f"Error model predict: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def get_models(self) -> Dict:
"""获取机器/深度学习方法列表"""
try:
models = []
# 分类方法
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
if classification_algorithms:
models.append({
"name": "classification_algorithms",
"description": "分类方法",
"method": classification_algorithms
})
# 回归方法
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
if regression_algorithms:
models.append({
"name": "regression_algorithms",
"description": "回归方法",
"method": regression_algorithms
})
# 聚类方法
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
if clustering_algorithms:
models.append({
"name": "clustering_algorithms",
"description": "聚类方法",
"method": clustering_algorithms
})
self.logger.info("获取机器/深度学习方法列表")
return {
"status": "success",
"models": models
}
except Exception as e:
self.logger.error(f"Error get models: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def get_model_details(self, method_name: str) -> Dict:
"""获取指定算法的详细信息"""
try:
# 在各个方法类别中查找方法原理和优缺点
method_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.method_config.get(category, {}):
method_info = self.method_config[category][method_name]
break
if method_info is None:
raise ValueError(f"Method {method_name} not found in method config")
# 查找方法参数信息
parameter_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.parameter_config.get(category, {}):
parameter_info = self.parameter_config[category][method_name]
break
if parameter_info is None:
raise ValueError(f"Method {method_name} not found in parameter config")
self.logger.info(f"获取{method_name}算法的详细信息")
# 组合返回信息
return {
"status": "success",
"method": {
"name": method_name,
"description": parameter_info.get('description', ''),
"principle": method_info.get('principle', ''),
"advantages": method_info.get('advantages', []),
"disadvantages": method_info.get('disadvantages', []),
"applicable_scenarios": method_info.get('applicable_scenarios', []),
"parameters": parameter_info.get('parameters', [])
}
}
except Exception as e:
self.logger.error(f"Error getting method details: {str(e)}")
return {
"status": "error",
"error": str(e)
}
# except Exception as e:
# error_msg = f"预测过程发生错误: {str(e)}"
# self.logger.error(error_msg)
# return {
# 'status': 'error',
# 'message': '模型预测失败',
# 'details': {
# 'error_type': type(e).__name__,
# 'error_message': str(e)
# }
# }
def get_metrics(self) -> Dict:
"""获取评价指标列表"""
try:
metrics = []
# 分类方法
classification_metrics = self.method_config.get('classification', {})
if classification_metrics:
metrics.append({
"name": "classification_metrics",
"description": "分类方法评价指标",
"metric": classification_metrics
})
# 回归方法
regression_metrics = self.method_config.get('regression', {})
if regression_metrics:
metrics.append({
"name": "regression_metrics",
"description": "回归方法评价指标",
"metric": regression_metrics
})
# 聚类方法
clustering_metrics = self.method_config.get('clustering', {})
if clustering_metrics:
metrics.append({
"name": "clustering_metrics",
"description": "聚类方法评价指标",
"metric": clustering_metrics
})
self.logger.info("获取评价指标列表")
return {
"status": "success",
"metric": metrics
}
except Exception as e:
self.logger.error(f"Error get metrics: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def train_model(
self,
train_path: str,
val_path: str,
model_config: Dict,
experiment_name: str
) -> Dict:
"""
训练模型
Args:
train_data: 训练数据,包含特征和标签
val_data: 验证数据,包含特征和标签
model_config: 模型配置,包含算法名称和参数
experiment_name: MLflow实验名称
Returns:
训练结果字典
"""
try:
# 检查实验是否存在且被删除
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment and experiment.lifecycle_stage == 'deleted':
# 如果实验被删除,则创建一个新的实验名称
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
new_experiment_name = f"{experiment_name}_{timestamp}"
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
experiment_name = new_experiment_name
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
if os.path.exists(train_path):
# 加载数据
train_data = pd.read_csv(train_path)
else:
return {
'status': 'error',
'message': '找不到训练集路径'
}
if os.path.exists(val_path):
val_data = pd.read_csv(val_path)
else:
return{
'status': 'error',
'message': '找不到验证集路径'
}
# 准备特征和标签
X_train = train_data.drop('target', axis=1)
y_train = train_data['target']
X_val = val_data.drop('target', axis=1)
y_val = val_data['target']
with mlflow.start_run() as run:
# 记录基本信息
mlflow.log_param('algorithm', model_config['algorithm'])
mlflow.log_param('task_type', model_config['task_type'])
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
mlflow.log_param('dataset_train', train_path) # 直接写数据集路径
mlflow.log_param('dataset_val', val_path)
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('start_time', timestamp)
# 记录模型参数
for param_name, param_value in model_config['params'].items():
mlflow.log_param(param_name, param_value)
# 记录算法信息
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
mlflow.log_param('principle', algorithm_info['principle'])
mlflow.log_param('advantages', str(algorithm_info['advantages']))
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
# 特殊处理KMeans++
if model_config['algorithm'] == 'KMeansPlusPlus':
model_config['params']['init'] = 'k-means++'
# 获取模型类和信息
model_class = self._get_model_class(model_config['algorithm'])
# 创建模型实例
model = model_class(**model_config['params'])
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(X_train, y_train)
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('end_time', timestamp)
# 在验证集上评估
val_predictions = model.predict(X_val)
metrics = self._calculate_metrics(
y_val,
val_predictions,
model_config['task_type']
)
# 记录指标
for metric_name, metric_value in metrics.items():
mlflow.log_metric(metric_name, metric_value)
# 保存模型
mlflow.sklearn.log_model(model, "model")
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
return {
'status': 'success',
'run_id': run.info.run_id,
'metrics': metrics,
'algorithm_info': algorithm_info
}
except Exception as e:
error_msg = f"Error training model: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def _calculate_metrics(
self,
true_labels: np.ndarray,
predictions: np.ndarray,
task_type: str
) -> Dict:
"""计算评估指标"""
metrics = {}
if task_type == 'classification':
metrics['accuracy'] = accuracy_score(true_labels, predictions)
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
if len(np.unique(true_labels)) == 2: # 二分类问题
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
elif task_type == 'regression':
metrics['mae'] = mean_absolute_error(true_labels, predictions)
metrics['mse'] = mean_squared_error(true_labels, predictions)
metrics['rmse'] = np.sqrt(metrics['mse'])
metrics['r2'] = r2_score(true_labels, predictions)
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
elif task_type == 'clustering':
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
metrics['completeness'] = completeness_score(true_labels, predictions)
if len(np.unique(predictions)) > 1: # 确保有多个簇
metrics['silhouette'] = silhouette_score(
true_labels.reshape(-1, 1),
predictions
)
return metrics