haotian #1

Merged
haotian merged 59 commits from haotian into main 2025-02-26 03:10:13 +00:00
32 changed files with 937 additions and 0 deletions
Showing only changes of commit e2fcde42f7 - Show all commits

2
config/config.yaml Normal file
View File

@ -0,0 +1,2 @@
mlfow:
uri: "http://localhost:5000"

53
example_model_trainer.py Normal file
View File

@ -0,0 +1,53 @@
from function.model_trainer import ModelTrainer
import pandas as pd
import numpy as np
# 创建训练器实例
trainer = ModelTrainer()
# 加载数据
train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250218_094909/train_breast_cancer_20250218_094909.csv')
val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250218_094909/val_breast_cancer_20250218_094909.csv')
# 准备特征和标签
X_train = train_data.drop('target', axis=1)
y_train = train_data['target']
X_val = val_data.drop('target', axis=1)
y_val = val_data['target']
# 模型配置
model_config = {
'algorithm': 'XGBClassifier',
'task_type': 'classification',
'params': {
'n_estimators': 100,
'learning_rate': 0.1,
'max_depth': 6,
'random_state': 42
}
}
# 训练模型
result = trainer.train_model(
{
'features': X_train,
'labels': y_train
},
{
'features': X_val,
'labels': y_val
},
model_config,
'breast_cancer_classification'
)
# 打印结果
print("\n训练结果:")
print(f"状态: {result['status']}")
if result['status'] == 'success':
print(f"\nMLflow运行ID: {result['run_id']}")
print("\n评估指标:")
for metric_name, metric_value in result['metrics'].items():
print(f"{metric_name}: {metric_value:.4f}")
else:
print(f"错误信息: {result['message']}")

182
function/model_trainer.py Normal file
View File

@ -0,0 +1,182 @@
import numpy as pd
import numpy as np
from typing import Dict, List, Optional
import logging
from pathlib import Path
import datetime
import yaml
import mlflow
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score
from sklearn.metrics import adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
class ModelTrainer:
"""模型训练类"""
def __init__(self, config: Dict = None):
"""初始化模型训练器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
self._load_metrics()
self._load_parameters()
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
# config = yaml.safe_load(f)
# 初始化MLflow
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def _load_metrics(self):
"""加载评估指标配置"""
try:
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
self.metrics_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading metrics config: {str(e)}")
raise
def _load_parameters(self):
"""加载模型参数配置"""
try:
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
self.parameter_config = yaml.safe_load(f)
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def train_model(
self,
train_data: Dict,
val_data: Dict,
model_config: Dict,
experiment_name: str
) -> Dict:
"""
训练模型
Args:
train_data: 训练数据包含特征和标签
val_data: 验证数据包含特征和标签
model_config: 模型配置包含算法名称和参数
experiment_name: MLflow实验名称
Returns:
训练结果字典
"""
try:
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
with mlflow.start_run() as run:
# 记录参数
mlflow.log_params(model_config['params'])
# 获取模型类
model_class = self._get_model_class(model_config['algorithm'])
# 创建模型实例
model = model_class(**model_config['params'])
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(train_data['features'], train_data['labels'])
# 在验证集上评估
val_predictions = model.predict(val_data['features'])
metrics = self._calculate_metrics(
val_data['labels'],
val_predictions,
model_config['task_type']
)
# 记录指标
for metric_name, metric_value in metrics.items():
mlflow.log_metric(metric_name, metric_value)
# 保存模型
mlflow.sklearn.log_model(model, "model")
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
return {
'status': 'success',
'run_id': run.info.run_id,
'metrics': metrics
}
except Exception as e:
error_msg = f"Error training model: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def _get_model_class(self, algorithm_name: str):
"""获取模型类"""
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
algorithm_map = {
'LogisticRegression': LogisticRegression,
'SVC': SVC,
'DecisionTreeClassifier': DecisionTreeClassifier,
'RandomForestClassifier': RandomForestClassifier,
'XGBClassifier': XGBClassifier,
'LGBMClassifier': LGBMClassifier,
'CatBoostClassifier': CatBoostClassifier
}
if algorithm_name not in algorithm_map:
raise ValueError(f"Unknown algorithm: {algorithm_name}")
return algorithm_map[algorithm_name]
def _calculate_metrics(
self,
true_labels: np.ndarray,
predictions: np.ndarray,
task_type: str
) -> Dict:
"""计算评估指标"""
metrics = {}
if task_type == 'classification':
metrics['accuracy'] = accuracy_score(true_labels, predictions)
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
elif task_type == 'regression':
metrics['mae'] = mean_absolute_error(true_labels, predictions)
metrics['mse'] = mean_squared_error(true_labels, predictions)
metrics['r2'] = r2_score(true_labels, predictions)
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
elif task_type == 'clustering':
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
metrics['completeness'] = completeness_score(true_labels, predictions)
metrics['silhouette'] = silhouette_score(true_labels.reshape(-1, 1), predictions)
return metrics

View File

@ -0,0 +1,20 @@
artifact_path: model
flavors:
python_function:
env:
conda: conda.yaml
virtualenv: python_env.yaml
loader_module: mlflow.sklearn
model_path: model.pkl
predict_fn: predict
python_version: 3.9.19
sklearn:
code: null
pickled_model: model.pkl
serialization_format: cloudpickle
sklearn_version: 1.5.2
mlflow_version: 2.20.1
model_size_bytes: 755
model_uuid: 9b32eb2aebfd43128286145b0fa5d84d
run_id: ae04d91ac81e4fa9bcb85645b477098b
utc_time_created: '2025-02-18 08:12:44.699771'

View File

@ -0,0 +1,14 @@
channels:
- conda-forge
dependencies:
- python=3.9.19
- pip<=24.0
- pip:
- mlflow==2.20.1
- cloudpickle==3.1.0
- numpy==1.26.4
- pandas==2.2.2
- psutil==6.0.0
- scikit-learn==1.5.2
- scipy==1.13.1
name: mlflow-env

View File

@ -0,0 +1,7 @@
python: 3.9.19
build_dependencies:
- pip==24.0
- setuptools==60.2.0
- wheel==0.43.0
dependencies:
- -r requirements.txt

View File

@ -0,0 +1,7 @@
mlflow==2.20.1
cloudpickle==3.1.0
numpy==1.26.4
pandas==2.2.2
psutil==6.0.0
scikit-learn==1.5.2
scipy==1.13.1

View File

@ -0,0 +1,15 @@
artifact_uri: mlflow-artifacts:/903949115613983655/881e28aafaaa44008cecf2b67fda5969/artifacts
end_time: 1739866364568
entry_point_name: ''
experiment_id: '903949115613983655'
lifecycle_stage: active
run_id: 881e28aafaaa44008cecf2b67fda5969
run_name: fearless-conch-599
run_uuid: 881e28aafaaa44008cecf2b67fda5969
source_name: ''
source_type: 4
source_version: ''
start_time: 1739866364243
status: 4
tags: []
user_id: admin-root

View File

@ -0,0 +1 @@
fearless-conch-599

View File

@ -0,0 +1 @@
0fd77c8c14a553cad3b90999667861a9ead64df0

View File

@ -0,0 +1 @@
/home/admin-root/haotian/MLPlatform/test_model_trainer.py

View File

@ -0,0 +1 @@
admin-root

View File

@ -0,0 +1,15 @@
artifact_uri: mlflow-artifacts:/903949115613983655/ae04d91ac81e4fa9bcb85645b477098b/artifacts
end_time: 1739866367406
entry_point_name: ''
experiment_id: '903949115613983655'
lifecycle_stage: active
run_id: ae04d91ac81e4fa9bcb85645b477098b
run_name: defiant-steed-760
run_uuid: ae04d91ac81e4fa9bcb85645b477098b
source_name: ''
source_type: 4
source_version: ''
start_time: 1739866364662
status: 3
tags: []
user_id: admin-root

View File

@ -0,0 +1 @@
1739866364686 0.3 0

View File

@ -0,0 +1 @@
1739866364695 0.236662749706228 0

View File

@ -0,0 +1 @@
1739866364690 0.23888888888888887 0

View File

@ -0,0 +1 @@
1739866364692 0.3 0

View File

@ -0,0 +1 @@
[{"run_id": "ae04d91ac81e4fa9bcb85645b477098b", "artifact_path": "model", "utc_time_created": "2025-02-18 08:12:44.699771", "model_uuid": "9b32eb2aebfd43128286145b0fa5d84d", "flavors": {"python_function": {"model_path": "model.pkl", "predict_fn": "predict", "loader_module": "mlflow.sklearn", "python_version": "3.9.19", "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}}, "sklearn": {"pickled_model": "model.pkl", "sklearn_version": "1.5.2", "serialization_format": "cloudpickle", "code": null}}}]

View File

@ -0,0 +1 @@
defiant-steed-760

View File

@ -0,0 +1 @@
0fd77c8c14a553cad3b90999667861a9ead64df0

View File

@ -0,0 +1 @@
/home/admin-root/haotian/MLPlatform/test_model_trainer.py

View File

@ -0,0 +1 @@
admin-root

View File

@ -0,0 +1,6 @@
artifact_location: mlflow-artifacts:/903949115613983655
creation_time: 1739866364180
experiment_id: '903949115613983655'
last_update_time: 1739866364180
lifecycle_stage: active
name: test_experiment

Binary file not shown.

373
model/model.yaml Normal file
View File

@ -0,0 +1,373 @@
classification_algorithms:
LogisticRegression:
principle: "逻辑回归是一种线性分类模型,使用 Sigmoid 函数将线性回归的输出映射到 [0,1] 之间,适用于二分类问题。"
advantages:
- "计算效率高,适用于大规模数据。"
- "可解释性强,系数可用于判断特征的重要性。"
- "对线性可分数据表现良好。"
disadvantages:
- "对非线性数据的处理能力较弱。"
- "容易受到异常值的影响。"
applicable_scenarios:
- "医学诊断,如是否患有某种疾病。"
- "信用风险评估,如贷款违约预测。"
- "市场营销中的客户分类。"
SVC:
principle: "支持向量机SVM通过寻找最大化类别间隔的超平面进行分类支持线性和非线性分类通过核函数。"
advantages:
- "适用于高维数据,尤其是样本较少的情况。"
- "通过核函数可以处理非线性分类问题。"
disadvantages:
- "对大规模数据的计算复杂度较高。"
- "对参数和核函数的选择较敏感。"
applicable_scenarios:
- "文本分类,如垃圾邮件检测。"
- "图像识别,如手写数字分类。"
SVDD:
principle: "支持向量数据描述SVDD是一种基于支持向量机的单类分类方法通过寻找最小球面来包围正常数据点从而检测异常值。"
advantages:
- "适用于异常检测和单类分类问题。"
- "能够有效应对非线性分布数据。"
disadvantages:
- "对参数选择较敏感,训练时间较长。"
applicable_scenarios:
- "异常检测,如信用卡欺诈检测。"
- "工业设备故障检测。"
DecisionTreeClassifier:
principle: "决策树基于特征的划分规则构建一棵树,通过树结构进行分类。"
advantages:
- "可解释性强,易于理解和可视化。"
- "对数据的分布和尺度不敏感,无需归一化。"
disadvantages:
- "容易过拟合,泛化能力较弱。"
- "对噪声数据较敏感。"
applicable_scenarios:
- "客户细分,如电商用户分类。"
- "医学诊断,如肿瘤良恶性分类。"
RandomForestClassifier:
principle: "随机森林是一种集成学习方法,通过构建多个决策树并投票或平均进行分类,提高模型的稳定性和准确性。"
advantages:
- "较强的泛化能力,能有效防止过拟合。"
- "可用于高维数据,支持特征重要性评估。"
disadvantages:
- "训练时间较长,预测速度相对较慢。"
applicable_scenarios:
- "信用风险评估。"
- "医疗数据分析。"
XGBClassifier:
principle: "XGBoostExtreme Gradient Boosting是一种基于梯度提升树GBDT的改进算法具有更强的正则化和并行处理能力。"
advantages:
- "计算效率高,支持并行计算。"
- "具有内置的缺失值处理能力。"
disadvantages:
- "参数较多,调优较复杂。"
applicable_scenarios:
- "金融风险预测。"
- "搜索引擎排序。"
AdaBoostClassifier:
principle: "AdaBoost 是一种提升方法,它通过组合多个弱分类器来提高整体分类精度,赋予错误分类样本更高的权重。"
advantages:
- "能有效提升弱分类器的性能。"
- "适用于处理非均衡数据。"
disadvantages:
- "对噪声数据敏感。"
applicable_scenarios:
- "人脸检测。"
- "信用评分模型。"
CatBoostClassifier:
principle: "CatBoost 是一种专为类别特征优化的梯度提升决策树方法,适用于处理高维稀疏数据。"
advantages:
- "对类别型特征处理能力强。"
- "训练速度快,支持 GPU 加速。"
disadvantages:
- "需要较长的训练时间。"
applicable_scenarios:
- "推荐系统。"
- "搜索排序。"
LGBMClassifier:
principle: "LightGBM 是一种基于直方图优化的梯度提升树方法,优化了计算效率,适用于大规模数据。"
advantages:
- "计算速度快,适用于大规模数据集。"
- "对高维特征数据处理效果良好。"
disadvantages:
- "对小数据集容易过拟合。"
applicable_scenarios:
- "广告点击率预测。"
- "金融风控。"
GaussianNB:
principle: "高斯朴素贝叶斯基于贝叶斯定理,假设特征服从高斯分布进行分类。"
advantages:
- "计算效率高,适用于大规模数据。"
- "对小数据集具有良好效果。"
disadvantages:
- "对特征的独立性假设较强,可能不适用于某些数据。"
applicable_scenarios:
- "文本分类。"
- "医疗诊断。"
KNeighborsClassifier:
principle: "K 近邻KNN是一种基于距离度量的分类算法通过找到最近的 K 个邻居来进行分类。"
advantages:
- "易于理解,实现简单。"
- "对异常值不敏感。"
disadvantages:
- "计算复杂度高,预测速度较慢。"
applicable_scenarios:
- "推荐系统。"
- "生物信息学分类。"
MLPClassifier:
principle: "多层感知机MLP是一种前馈神经网络通过多个隐藏层和非线性激活函数实现复杂分类任务。"
advantages:
- "能够学习复杂的非线性关系。"
- "适用于高维数据。"
disadvantages:
- "需要大量数据进行训练,容易过拟合。"
applicable_scenarios:
- "图像分类。"
- "语音识别。"
GradientBoostingClassifier:
principle: "梯度提升决策树GBDT是一种集成学习方法通过迭代训练多个决策树使模型不断优化误差。"
advantages:
- "较强的泛化能力,适用于大规模数据集。"
- "支持特征重要性分析。"
disadvantages:
- "训练时间较长,调优复杂。"
applicable_scenarios:
- "搜索引擎排名。"
- "金融信用评分。"
DNN:
principle: "深度神经网络DNN是一种多层神经网络通过多层非线性变换提取数据的复杂特征。"
advantages:
- "适用于复杂非线性问题。"
- "能够处理大规模数据。"
disadvantages:
- "计算开销大,对硬件要求高。"
applicable_scenarios:
- "自动驾驶。"
- "自然语言处理。"
regression_algorithms:
LinearRegression:
principle: "线性回归通过最小化数据点与回归线之间的误差平方和,来拟合一条最佳的直线。"
advantages:
- "实现简单,易于理解"
- "计算效率高,适用于小规模数据集"
disadvantages:
- "对异常值敏感"
- "只能建模线性关系,无法处理非线性数据"
applicable_scenarios: "适用于线性关系明确,且数据量适中的问题。"
PolynomialRegression:
principle: "多项式回归是线性回归的一种扩展,利用多项式拟合数据,处理非线性关系。"
advantages:
- "能够拟合非线性关系"
- "灵活性强,能够处理复杂数据模式"
disadvantages:
- "容易过拟合,特别是在多项式度数较高时"
- "计算复杂度较高"
applicable_scenarios: "适用于非线性数据拟合,但需避免过拟合。"
Ridge:
principle: "岭回归在最小化误差的同时引入L2正则化约束模型的复杂度。"
advantages:
- "有效防止过拟合"
- "适用于特征多的情况"
disadvantages:
- "对于特征相关性较强的数据效果较差"
- "结果解释性较差"
applicable_scenarios: "适用于特征多且存在共线性的线性回归问题。"
Lasso:
principle: "Lasso回归通过L1正则化来限制模型的复杂度可以实现特征选择。"
advantages:
- "能够进行特征选择,减少不相关特征"
- "减少模型的复杂度,避免过拟合"
disadvantages:
- "可能会导致某些特征完全被剔除,造成信息丢失"
- "对于高相关特征可能不稳定"
applicable_scenarios: "适用于需要特征选择或特征较多的线性回归问题。"
ElasticNet:
principle: "弹性网络回归结合了Lasso回归的L1正则化和岭回归的L2正则化能够处理更多情况。"
advantages:
- "能够处理相关特征结合L1和L2的优点"
- "适用于特征数大于样本数的情形"
disadvantages:
- "计算复杂度较高"
- "可能需要调参来获得最佳效果"
applicable_scenarios: "适用于高维数据,且特征间存在线性相关的情况。"
SVR:
principle: "支持向量回归SVR通过在高维空间中寻找一个最优超平面来拟合数据能够处理非线性回归问题。"
advantages:
- "能够处理非线性数据,效果较好"
- "适用于高维数据"
disadvantages:
- "对超参数敏感,调参困难"
- "计算时间较长,尤其在大数据集上"
applicable_scenarios: "适用于数据具有非线性关系且数据量适中的问题。"
DecisionTreeRegressor:
principle: "决策树回归通过构建树状结构来拟合数据,节点上的划分基于特征的不同值。"
advantages:
- "易于理解和可视化"
- "可以处理非线性关系"
disadvantages:
- "容易过拟合"
- "对噪声数据敏感"
applicable_scenarios: "适用于非线性回归,且对数据的复杂性有较高的容忍度。"
RandomForestRegressor:
principle: "随机森林回归通过集成多棵决策树的结果来提高预测精度,能够处理高维数据。"
advantages:
- "不容易过拟合,适用于复杂问题"
- "能够处理缺失数据"
disadvantages:
- "计算复杂度较高"
- "结果难以解释"
applicable_scenarios: "适用于复杂的回归问题,尤其是特征维度较高的场景。"
XGBRegressor:
principle: "XGBoost回归通过梯度提升算法GBDT优化损失函数通过树的组合来预测目标值。"
advantages:
- "预测精度高,效果好"
- "处理大数据能力强"
disadvantages:
- "需要较长的训练时间"
- "对超参数敏感"
applicable_scenarios: "适用于大规模数据集且对预测精度要求较高的问题。"
AdaBoostRegressor:
principle: "AdaBoost回归通过多次训练弱学习器并加权组合改进模型的预测能力。"
advantages:
- "能够提高弱学习器的预测精度"
- "对噪声较为鲁棒"
disadvantages:
- "容易受到异常值影响"
- "对某些问题的性能较差"
applicable_scenarios: "适用于弱学习器的组合优化问题,特别是样本不平衡的场景。"
CatBoostRegressor:
principle: "CatBoost回归基于梯度提升决策树GBDT特别优化了类别特征的处理。"
advantages:
- "能够处理类别特征,减少数据预处理"
- "高效且精度较高"
disadvantages:
- "训练时间较长"
- "参数调节较为复杂"
applicable_scenarios: "适用于包含大量类别特征的数据集,且数据量较大的问题。"
LGBMRegressor:
principle: "LightGBM回归基于梯度提升决策树使用了直方图优化算法以加速训练过程。"
advantages:
- "训练速度快,能够处理大规模数据"
- "内存占用较少,适合高维数据"
disadvantages:
- "需要调参以获得最佳效果"
- "模型解释性较差"
applicable_scenarios: "适用于大规模数据集,尤其是处理海量数据时效果优越。"
MLPRegressor:
principle: "多层感知机回归是基于神经网络的回归模型,通过多个隐层来拟合复杂的非线性关系。"
advantages:
- "能够处理复杂的非线性关系"
- "在大数据和高维数据中表现良好"
disadvantages:
- "训练时间较长,且对计算资源要求较高"
- "容易过拟合,特别是在数据较少时"
applicable_scenarios: "适用于需要捕捉复杂非线性关系的回归问题,尤其在数据量大时效果较好。"
clustering_algorithms:
KMeans:
principle: "K均值聚类通过最小化样本点到其最近簇中心的距离来将数据分为K个簇。"
advantages:
- "实现简单,计算效率高"
- "适用于大数据集"
disadvantages:
- "需要预先指定簇的数量K"
- "对初始中心点敏感,容易受到噪声影响"
applicable_scenarios: "适用于簇形状较为规则且数据量较大的聚类问题。"
KMeansPlusPlus:
principle: "K-Means++是一种改进的初始化方法通过选择更远离当前簇中心的样本点作为初始中心提升K均值聚类的效果。"
advantages:
- "比传统K均值方法更稳定"
- "可以减少聚类结果的变异性"
disadvantages:
- "依然存在需要指定K的问题"
- "初始化仍然可能影响最终结果"
applicable_scenarios: "适用于K均值聚类方法并且希望改进初始中心选择的场景。"
HierarchicalKMeans:
principle: "层次化K均值结合了层次聚类和K均值聚类的方法逐步将样本合并到已有簇中形成层次化结构。"
advantages:
- "能够自动确定簇的数量"
- "生成的树状图有助于理解数据结构"
disadvantages:
- "计算量大,尤其是在数据量较大时"
- "对噪声和离群点敏感"
applicable_scenarios: "适用于不确定簇的数量且数据结构较复杂的情况。"
FCM:
principle: "模糊C均值FCM允许每个数据点属于多个簇基于隶属度来进行聚类。"
advantages:
- "能够处理数据点属于多个簇的情况"
- "适用于软聚类问题"
disadvantages:
- "对初始簇中心和隶属度设置较为敏感"
- "计算量大,尤其是簇数较多时"
applicable_scenarios: "适用于数据点可以属于多个簇的情况,如图像分割等问题。"
AgglomerativeClustering:
principle: "层次聚类通过将相似度较高的样本逐步合并,最终形成树状结构(树状图)。"
advantages:
- "无需预先指定簇的数量"
- "能够处理任意形状的簇"
disadvantages:
- "计算复杂度较高,数据量大时效率低"
- "容易受到噪声的影响"
applicable_scenarios: "适用于不确定簇的数量且簇的形状较复杂的场景。"
DBSCAN:
principle: "DBSCAN基于密度的聚类方法通过寻找密集区域来划分簇对于稀疏区域则标记为噪声点。"
advantages:
- "能够发现任意形状的簇"
- "不需要预先指定簇的数量,能够自动识别噪声"
disadvantages:
- "对参数设置敏感,尤其是邻域半径和最小样本数"
- "不适用于簇大小差异过大的数据集"
applicable_scenarios: "适用于具有明显密度差异的数据集,尤其适合处理含有噪声的数据。"
GaussianMixture:
principle: "高斯混合模型GMM假设数据是由多个高斯分布的混合体组成通过EM算法估计每个数据点的隶属概率。"
advantages:
- "能够处理重叠的簇"
- "可以自动估计每个簇的分布"
disadvantages:
- "计算复杂度高,容易陷入局部最优解"
- "需要预先指定簇的数量"
applicable_scenarios: "适用于需要拟合混合高斯分布的数据,尤其适合处理连续数据。"
SpectralClustering:
principle: "谱聚类通过构造样本之间的相似度矩阵,并计算其特征值与特征向量来实现聚类。"
advantages:
- "能够处理复杂的簇结构"
- "适用于非凸形状的簇"
disadvantages:
- "计算量大,尤其在大数据集上效率低"
- "需要计算样本间的相似度矩阵,存储和计算成本较高"
applicable_scenarios: "适用于样本之间相似度较强,但簇形状复杂或非凸的聚类任务。"

93
test_data_processor.py Normal file
View File

@ -0,0 +1,93 @@
import unittest
import pandas as pd
import numpy as np
from pathlib import Path
from function.data_processor_date import DataProcessor
class TestDataProcessor(unittest.TestCase):
def setUp(self):
self.processor = DataProcessor()
# 创建测试数据
self.test_data = pd.DataFrame({
'feature1': [1, 2, np.nan, 4, 5],
'feature2': [10, 20, 30, 40, 50],
'target': [0, 1, 0, 1, 0]
})
# 保存测试数据
self.input_path = 'dataset/dataset_raw/test_data.csv'
Path(self.input_path).parent.mkdir(parents=True, exist_ok=True)
self.test_data.to_csv(self.input_path, index=False)
# 设置输出目录
self.output_dir = 'dataset/dataset_processed'
def test_process_dataset(self):
# 定义处理方法
cleaning_methods = [
{
'method_name': 'SimpleImputer',
'params': {'strategy': 'mean'}
}
]
feature_methods = [
{
'method_name': 'StandardScaler',
'params': {}
}
]
split_params = {
'test_size': 0.2,
'val_size': 0.2
}
# 处理数据集
result = self.processor.process_dataset(
self.input_path,
self.output_dir,
cleaning_methods,
feature_methods,
split_params
)
# 验证结果
self.assertEqual(result['status'], 'success')
self.assertIn('process_record', result)
# 验证输出文件
record = result['process_record']
self.assertTrue(Path(record['output_files']['train']).exists())
self.assertTrue(Path(record['output_files']['validation']).exists())
self.assertTrue(Path(record['output_files']['test']).exists())
def test_invalid_method(self):
# 测试无效的方法名
cleaning_methods = [
{
'method_name': 'InvalidMethod',
'params': {}
}
]
result = self.processor.process_dataset(
self.input_path,
self.output_dir,
cleaning_methods,
[],
{'test_size': 0.2, 'val_size': 0.2}
)
self.assertEqual(result['status'], 'error')
def tearDown(self):
# 清理测试文件
try:
Path(self.input_path).unlink()
except:
pass
if __name__ == '__main__':
unittest.main()

49
test_method_reader.py Normal file
View File

@ -0,0 +1,49 @@
import unittest
from function.method_reader_date_process import MethodReader
class TestMethodReader(unittest.TestCase):
def setUp(self):
self.reader = MethodReader()
def test_get_preprocessing_methods(self):
result = self.reader.get_preprocessing_methods()
self.assertEqual(result['status'], 'success')
self.assertIsInstance(result['methods'], list)
# 检查返回的方法列表
methods = result['methods']
self.assertTrue(any(m['name'] == 'data_scaler' for m in methods))
self.assertTrue(any(m['name'] == 'missing_value_handler' for m in methods))
self.assertTrue(any(m['name'] == 'outlier_detector' for m in methods))
def test_get_method_details(self):
# 测试获取StandardScaler的详细信息
result = self.reader.get_method_details('StandardScaler')
self.assertEqual(result['status'], 'success')
self.assertEqual(result['method']['name'], 'StandardScaler')
# 检查返回的详细信息字段
method = result['method']
self.assertIn('description', method)
self.assertIn('principle', method)
self.assertIn('advantages', method)
self.assertIn('disadvantages', method)
self.assertIn('applicable_scenarios', method)
self.assertIn('parameters', method)
# 检查参数信息
parameters = method['parameters']
self.assertIsInstance(parameters, list)
if parameters:
param = parameters[0]
self.assertIn('name', param)
self.assertIn('type', param)
self.assertIn('default', param)
self.assertIn('description', param)
# 测试获取不存在的方法
result = self.reader.get_method_details('NonExistentMethod')
self.assertEqual(result['status'], 'error')
if __name__ == '__main__':
unittest.main()

85
test_model_trainer.py Normal file
View File

@ -0,0 +1,85 @@
import unittest
import numpy as np
from function.model_trainer import ModelTrainer
class TestModelTrainer(unittest.TestCase):
def setUp(self):
self.trainer = ModelTrainer()
# 创建测试数据
np.random.seed(42)
self.X_train = np.random.randn(100, 5)
self.y_train = np.random.randint(0, 2, 100)
self.X_val = np.random.randn(30, 5)
self.y_val = np.random.randint(0, 2, 30)
def test_train_model(self):
# 准备训练数据
train_data = {
'features': self.X_train,
'labels': self.y_train
}
val_data = {
'features': self.X_val,
'labels': self.y_val
}
# 模型配置
model_config = {
'algorithm': 'LogisticRegression',
'task_type': 'classification',
'params': {
'random_state': 42
}
}
# 训练模型
result = self.trainer.train_model(
train_data,
val_data,
model_config,
'test_experiment'
)
# 验证结果
self.assertEqual(result['status'], 'success')
self.assertIn('run_id', result)
self.assertIn('metrics', result)
# 验证指标
metrics = result['metrics']
self.assertIn('accuracy', metrics)
self.assertIn('precision', metrics)
self.assertIn('recall', metrics)
self.assertIn('f1', metrics)
def test_invalid_algorithm(self):
# 测试无效的算法名
train_data = {
'features': self.X_train,
'labels': self.y_train
}
val_data = {
'features': self.X_val,
'labels': self.y_val
}
model_config = {
'algorithm': 'InvalidAlgorithm',
'task_type': 'classification',
'params': {}
}
result = self.trainer.train_model(
train_data,
val_data,
model_config,
'test_experiment'
)
self.assertEqual(result['status'], 'error')
if __name__ == '__main__':
unittest.main()