haotian #1
2
config/config.yaml
Normal file
2
config/config.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
mlfow:
|
||||
uri: "http://localhost:5000"
|
||||
53
example_model_trainer.py
Normal file
53
example_model_trainer.py
Normal file
@ -0,0 +1,53 @@
|
||||
from function.model_trainer import ModelTrainer
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# 创建训练器实例
|
||||
trainer = ModelTrainer()
|
||||
|
||||
# 加载数据
|
||||
train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250218_094909/train_breast_cancer_20250218_094909.csv')
|
||||
val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250218_094909/val_breast_cancer_20250218_094909.csv')
|
||||
|
||||
# 准备特征和标签
|
||||
X_train = train_data.drop('target', axis=1)
|
||||
y_train = train_data['target']
|
||||
X_val = val_data.drop('target', axis=1)
|
||||
y_val = val_data['target']
|
||||
|
||||
# 模型配置
|
||||
model_config = {
|
||||
'algorithm': 'XGBClassifier',
|
||||
'task_type': 'classification',
|
||||
'params': {
|
||||
'n_estimators': 100,
|
||||
'learning_rate': 0.1,
|
||||
'max_depth': 6,
|
||||
'random_state': 42
|
||||
}
|
||||
}
|
||||
|
||||
# 训练模型
|
||||
result = trainer.train_model(
|
||||
{
|
||||
'features': X_train,
|
||||
'labels': y_train
|
||||
},
|
||||
{
|
||||
'features': X_val,
|
||||
'labels': y_val
|
||||
},
|
||||
model_config,
|
||||
'breast_cancer_classification'
|
||||
)
|
||||
|
||||
# 打印结果
|
||||
print("\n训练结果:")
|
||||
print(f"状态: {result['status']}")
|
||||
if result['status'] == 'success':
|
||||
print(f"\nMLflow运行ID: {result['run_id']}")
|
||||
print("\n评估指标:")
|
||||
for metric_name, metric_value in result['metrics'].items():
|
||||
print(f"{metric_name}: {metric_value:.4f}")
|
||||
else:
|
||||
print(f"错误信息: {result['message']}")
|
||||
182
function/model_trainer.py
Normal file
182
function/model_trainer.py
Normal file
@ -0,0 +1,182 @@
|
||||
import numpy as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
import yaml
|
||||
import mlflow
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score
|
||||
from sklearn.metrics import adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
|
||||
|
||||
class ModelTrainer:
|
||||
"""模型训练类"""
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
"""初始化模型训练器"""
|
||||
self.config = config or {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._setup_logging()
|
||||
self._load_metrics()
|
||||
self._load_parameters()
|
||||
|
||||
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
|
||||
# config = yaml.safe_load(f)
|
||||
|
||||
# 初始化MLflow
|
||||
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
|
||||
|
||||
def _setup_logging(self):
|
||||
"""设置日志"""
|
||||
log_dir = Path('.log')
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
|
||||
)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
)
|
||||
self.logger.addHandler(file_handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def _load_metrics(self):
|
||||
"""加载评估指标配置"""
|
||||
try:
|
||||
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
|
||||
self.metrics_config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading metrics config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_parameters(self):
|
||||
"""加载模型参数配置"""
|
||||
try:
|
||||
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
|
||||
self.parameter_config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||||
raise
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
train_data: Dict,
|
||||
val_data: Dict,
|
||||
model_config: Dict,
|
||||
experiment_name: str
|
||||
) -> Dict:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据,包含特征和标签
|
||||
val_data: 验证数据,包含特征和标签
|
||||
model_config: 模型配置,包含算法名称和参数
|
||||
experiment_name: MLflow实验名称
|
||||
|
||||
Returns:
|
||||
训练结果字典
|
||||
"""
|
||||
try:
|
||||
# 设置MLflow实验
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
with mlflow.start_run() as run:
|
||||
# 记录参数
|
||||
mlflow.log_params(model_config['params'])
|
||||
|
||||
# 获取模型类
|
||||
model_class = self._get_model_class(model_config['algorithm'])
|
||||
|
||||
# 创建模型实例
|
||||
model = model_class(**model_config['params'])
|
||||
|
||||
# 训练模型
|
||||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||||
model.fit(train_data['features'], train_data['labels'])
|
||||
|
||||
# 在验证集上评估
|
||||
val_predictions = model.predict(val_data['features'])
|
||||
metrics = self._calculate_metrics(
|
||||
val_data['labels'],
|
||||
val_predictions,
|
||||
model_config['task_type']
|
||||
)
|
||||
|
||||
# 记录指标
|
||||
for metric_name, metric_value in metrics.items():
|
||||
mlflow.log_metric(metric_name, metric_value)
|
||||
|
||||
# 保存模型
|
||||
mlflow.sklearn.log_model(model, "model")
|
||||
|
||||
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'run_id': run.info.run_id,
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error training model: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': error_msg
|
||||
}
|
||||
|
||||
def _get_model_class(self, algorithm_name: str):
|
||||
"""获取模型类"""
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from xgboost import XGBClassifier
|
||||
from lightgbm import LGBMClassifier
|
||||
from catboost import CatBoostClassifier
|
||||
|
||||
algorithm_map = {
|
||||
'LogisticRegression': LogisticRegression,
|
||||
'SVC': SVC,
|
||||
'DecisionTreeClassifier': DecisionTreeClassifier,
|
||||
'RandomForestClassifier': RandomForestClassifier,
|
||||
'XGBClassifier': XGBClassifier,
|
||||
'LGBMClassifier': LGBMClassifier,
|
||||
'CatBoostClassifier': CatBoostClassifier
|
||||
}
|
||||
|
||||
if algorithm_name not in algorithm_map:
|
||||
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
||||
|
||||
return algorithm_map[algorithm_name]
|
||||
|
||||
def _calculate_metrics(
|
||||
self,
|
||||
true_labels: np.ndarray,
|
||||
predictions: np.ndarray,
|
||||
task_type: str
|
||||
) -> Dict:
|
||||
"""计算评估指标"""
|
||||
metrics = {}
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
||||
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
||||
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
||||
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
||||
|
||||
elif task_type == 'regression':
|
||||
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
||||
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
||||
metrics['r2'] = r2_score(true_labels, predictions)
|
||||
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
||||
|
||||
elif task_type == 'clustering':
|
||||
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
||||
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
||||
metrics['completeness'] = completeness_score(true_labels, predictions)
|
||||
metrics['silhouette'] = silhouette_score(true_labels.reshape(-1, 1), predictions)
|
||||
|
||||
return metrics
|
||||
@ -0,0 +1,20 @@
|
||||
artifact_path: model
|
||||
flavors:
|
||||
python_function:
|
||||
env:
|
||||
conda: conda.yaml
|
||||
virtualenv: python_env.yaml
|
||||
loader_module: mlflow.sklearn
|
||||
model_path: model.pkl
|
||||
predict_fn: predict
|
||||
python_version: 3.9.19
|
||||
sklearn:
|
||||
code: null
|
||||
pickled_model: model.pkl
|
||||
serialization_format: cloudpickle
|
||||
sklearn_version: 1.5.2
|
||||
mlflow_version: 2.20.1
|
||||
model_size_bytes: 755
|
||||
model_uuid: 9b32eb2aebfd43128286145b0fa5d84d
|
||||
run_id: ae04d91ac81e4fa9bcb85645b477098b
|
||||
utc_time_created: '2025-02-18 08:12:44.699771'
|
||||
@ -0,0 +1,14 @@
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.9.19
|
||||
- pip<=24.0
|
||||
- pip:
|
||||
- mlflow==2.20.1
|
||||
- cloudpickle==3.1.0
|
||||
- numpy==1.26.4
|
||||
- pandas==2.2.2
|
||||
- psutil==6.0.0
|
||||
- scikit-learn==1.5.2
|
||||
- scipy==1.13.1
|
||||
name: mlflow-env
|
||||
Binary file not shown.
@ -0,0 +1,7 @@
|
||||
python: 3.9.19
|
||||
build_dependencies:
|
||||
- pip==24.0
|
||||
- setuptools==60.2.0
|
||||
- wheel==0.43.0
|
||||
dependencies:
|
||||
- -r requirements.txt
|
||||
@ -0,0 +1,7 @@
|
||||
mlflow==2.20.1
|
||||
cloudpickle==3.1.0
|
||||
numpy==1.26.4
|
||||
pandas==2.2.2
|
||||
psutil==6.0.0
|
||||
scikit-learn==1.5.2
|
||||
scipy==1.13.1
|
||||
@ -0,0 +1,15 @@
|
||||
artifact_uri: mlflow-artifacts:/903949115613983655/881e28aafaaa44008cecf2b67fda5969/artifacts
|
||||
end_time: 1739866364568
|
||||
entry_point_name: ''
|
||||
experiment_id: '903949115613983655'
|
||||
lifecycle_stage: active
|
||||
run_id: 881e28aafaaa44008cecf2b67fda5969
|
||||
run_name: fearless-conch-599
|
||||
run_uuid: 881e28aafaaa44008cecf2b67fda5969
|
||||
source_name: ''
|
||||
source_type: 4
|
||||
source_version: ''
|
||||
start_time: 1739866364243
|
||||
status: 4
|
||||
tags: []
|
||||
user_id: admin-root
|
||||
@ -0,0 +1 @@
|
||||
fearless-conch-599
|
||||
@ -0,0 +1 @@
|
||||
0fd77c8c14a553cad3b90999667861a9ead64df0
|
||||
@ -0,0 +1 @@
|
||||
/home/admin-root/haotian/MLPlatform/test_model_trainer.py
|
||||
@ -0,0 +1 @@
|
||||
LOCAL
|
||||
@ -0,0 +1 @@
|
||||
admin-root
|
||||
@ -0,0 +1,15 @@
|
||||
artifact_uri: mlflow-artifacts:/903949115613983655/ae04d91ac81e4fa9bcb85645b477098b/artifacts
|
||||
end_time: 1739866367406
|
||||
entry_point_name: ''
|
||||
experiment_id: '903949115613983655'
|
||||
lifecycle_stage: active
|
||||
run_id: ae04d91ac81e4fa9bcb85645b477098b
|
||||
run_name: defiant-steed-760
|
||||
run_uuid: ae04d91ac81e4fa9bcb85645b477098b
|
||||
source_name: ''
|
||||
source_type: 4
|
||||
source_version: ''
|
||||
start_time: 1739866364662
|
||||
status: 3
|
||||
tags: []
|
||||
user_id: admin-root
|
||||
@ -0,0 +1 @@
|
||||
1739866364686 0.3 0
|
||||
@ -0,0 +1 @@
|
||||
1739866364695 0.236662749706228 0
|
||||
@ -0,0 +1 @@
|
||||
1739866364690 0.23888888888888887 0
|
||||
@ -0,0 +1 @@
|
||||
1739866364692 0.3 0
|
||||
@ -0,0 +1 @@
|
||||
42
|
||||
@ -0,0 +1 @@
|
||||
[{"run_id": "ae04d91ac81e4fa9bcb85645b477098b", "artifact_path": "model", "utc_time_created": "2025-02-18 08:12:44.699771", "model_uuid": "9b32eb2aebfd43128286145b0fa5d84d", "flavors": {"python_function": {"model_path": "model.pkl", "predict_fn": "predict", "loader_module": "mlflow.sklearn", "python_version": "3.9.19", "env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"}}, "sklearn": {"pickled_model": "model.pkl", "sklearn_version": "1.5.2", "serialization_format": "cloudpickle", "code": null}}}]
|
||||
@ -0,0 +1 @@
|
||||
defiant-steed-760
|
||||
@ -0,0 +1 @@
|
||||
0fd77c8c14a553cad3b90999667861a9ead64df0
|
||||
@ -0,0 +1 @@
|
||||
/home/admin-root/haotian/MLPlatform/test_model_trainer.py
|
||||
@ -0,0 +1 @@
|
||||
LOCAL
|
||||
@ -0,0 +1 @@
|
||||
admin-root
|
||||
6
mlruns/903949115613983655/meta.yaml
Normal file
6
mlruns/903949115613983655/meta.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
artifact_location: mlflow-artifacts:/903949115613983655
|
||||
creation_time: 1739866364180
|
||||
experiment_id: '903949115613983655'
|
||||
last_update_time: 1739866364180
|
||||
lifecycle_stage: active
|
||||
name: test_experiment
|
||||
BIN
model/__pycache__/model_trainer.cpython-39.pyc
Normal file
BIN
model/__pycache__/model_trainer.cpython-39.pyc
Normal file
Binary file not shown.
373
model/model.yaml
Normal file
373
model/model.yaml
Normal file
@ -0,0 +1,373 @@
|
||||
classification_algorithms:
|
||||
LogisticRegression:
|
||||
principle: "逻辑回归是一种线性分类模型,使用 Sigmoid 函数将线性回归的输出映射到 [0,1] 之间,适用于二分类问题。"
|
||||
advantages:
|
||||
- "计算效率高,适用于大规模数据。"
|
||||
- "可解释性强,系数可用于判断特征的重要性。"
|
||||
- "对线性可分数据表现良好。"
|
||||
disadvantages:
|
||||
- "对非线性数据的处理能力较弱。"
|
||||
- "容易受到异常值的影响。"
|
||||
applicable_scenarios:
|
||||
- "医学诊断,如是否患有某种疾病。"
|
||||
- "信用风险评估,如贷款违约预测。"
|
||||
- "市场营销中的客户分类。"
|
||||
|
||||
SVC:
|
||||
principle: "支持向量机(SVM)通过寻找最大化类别间隔的超平面进行分类,支持线性和非线性分类(通过核函数)。"
|
||||
advantages:
|
||||
- "适用于高维数据,尤其是样本较少的情况。"
|
||||
- "通过核函数可以处理非线性分类问题。"
|
||||
disadvantages:
|
||||
- "对大规模数据的计算复杂度较高。"
|
||||
- "对参数和核函数的选择较敏感。"
|
||||
applicable_scenarios:
|
||||
- "文本分类,如垃圾邮件检测。"
|
||||
- "图像识别,如手写数字分类。"
|
||||
|
||||
SVDD:
|
||||
principle: "支持向量数据描述(SVDD)是一种基于支持向量机的单类分类方法,通过寻找最小球面来包围正常数据点,从而检测异常值。"
|
||||
advantages:
|
||||
- "适用于异常检测和单类分类问题。"
|
||||
- "能够有效应对非线性分布数据。"
|
||||
disadvantages:
|
||||
- "对参数选择较敏感,训练时间较长。"
|
||||
applicable_scenarios:
|
||||
- "异常检测,如信用卡欺诈检测。"
|
||||
- "工业设备故障检测。"
|
||||
|
||||
DecisionTreeClassifier:
|
||||
principle: "决策树基于特征的划分规则构建一棵树,通过树结构进行分类。"
|
||||
advantages:
|
||||
- "可解释性强,易于理解和可视化。"
|
||||
- "对数据的分布和尺度不敏感,无需归一化。"
|
||||
disadvantages:
|
||||
- "容易过拟合,泛化能力较弱。"
|
||||
- "对噪声数据较敏感。"
|
||||
applicable_scenarios:
|
||||
- "客户细分,如电商用户分类。"
|
||||
- "医学诊断,如肿瘤良恶性分类。"
|
||||
|
||||
RandomForestClassifier:
|
||||
principle: "随机森林是一种集成学习方法,通过构建多个决策树并投票或平均进行分类,提高模型的稳定性和准确性。"
|
||||
advantages:
|
||||
- "较强的泛化能力,能有效防止过拟合。"
|
||||
- "可用于高维数据,支持特征重要性评估。"
|
||||
disadvantages:
|
||||
- "训练时间较长,预测速度相对较慢。"
|
||||
applicable_scenarios:
|
||||
- "信用风险评估。"
|
||||
- "医疗数据分析。"
|
||||
|
||||
XGBClassifier:
|
||||
principle: "XGBoost(Extreme Gradient Boosting)是一种基于梯度提升树(GBDT)的改进算法,具有更强的正则化和并行处理能力。"
|
||||
advantages:
|
||||
- "计算效率高,支持并行计算。"
|
||||
- "具有内置的缺失值处理能力。"
|
||||
disadvantages:
|
||||
- "参数较多,调优较复杂。"
|
||||
applicable_scenarios:
|
||||
- "金融风险预测。"
|
||||
- "搜索引擎排序。"
|
||||
|
||||
AdaBoostClassifier:
|
||||
principle: "AdaBoost 是一种提升方法,它通过组合多个弱分类器来提高整体分类精度,赋予错误分类样本更高的权重。"
|
||||
advantages:
|
||||
- "能有效提升弱分类器的性能。"
|
||||
- "适用于处理非均衡数据。"
|
||||
disadvantages:
|
||||
- "对噪声数据敏感。"
|
||||
applicable_scenarios:
|
||||
- "人脸检测。"
|
||||
- "信用评分模型。"
|
||||
|
||||
CatBoostClassifier:
|
||||
principle: "CatBoost 是一种专为类别特征优化的梯度提升决策树方法,适用于处理高维稀疏数据。"
|
||||
advantages:
|
||||
- "对类别型特征处理能力强。"
|
||||
- "训练速度快,支持 GPU 加速。"
|
||||
disadvantages:
|
||||
- "需要较长的训练时间。"
|
||||
applicable_scenarios:
|
||||
- "推荐系统。"
|
||||
- "搜索排序。"
|
||||
|
||||
LGBMClassifier:
|
||||
principle: "LightGBM 是一种基于直方图优化的梯度提升树方法,优化了计算效率,适用于大规模数据。"
|
||||
advantages:
|
||||
- "计算速度快,适用于大规模数据集。"
|
||||
- "对高维特征数据处理效果良好。"
|
||||
disadvantages:
|
||||
- "对小数据集容易过拟合。"
|
||||
applicable_scenarios:
|
||||
- "广告点击率预测。"
|
||||
- "金融风控。"
|
||||
|
||||
GaussianNB:
|
||||
principle: "高斯朴素贝叶斯基于贝叶斯定理,假设特征服从高斯分布进行分类。"
|
||||
advantages:
|
||||
- "计算效率高,适用于大规模数据。"
|
||||
- "对小数据集具有良好效果。"
|
||||
disadvantages:
|
||||
- "对特征的独立性假设较强,可能不适用于某些数据。"
|
||||
applicable_scenarios:
|
||||
- "文本分类。"
|
||||
- "医疗诊断。"
|
||||
|
||||
KNeighborsClassifier:
|
||||
principle: "K 近邻(KNN)是一种基于距离度量的分类算法,通过找到最近的 K 个邻居来进行分类。"
|
||||
advantages:
|
||||
- "易于理解,实现简单。"
|
||||
- "对异常值不敏感。"
|
||||
disadvantages:
|
||||
- "计算复杂度高,预测速度较慢。"
|
||||
applicable_scenarios:
|
||||
- "推荐系统。"
|
||||
- "生物信息学分类。"
|
||||
|
||||
MLPClassifier:
|
||||
principle: "多层感知机(MLP)是一种前馈神经网络,通过多个隐藏层和非线性激活函数实现复杂分类任务。"
|
||||
advantages:
|
||||
- "能够学习复杂的非线性关系。"
|
||||
- "适用于高维数据。"
|
||||
disadvantages:
|
||||
- "需要大量数据进行训练,容易过拟合。"
|
||||
applicable_scenarios:
|
||||
- "图像分类。"
|
||||
- "语音识别。"
|
||||
|
||||
GradientBoostingClassifier:
|
||||
principle: "梯度提升决策树(GBDT)是一种集成学习方法,通过迭代训练多个决策树,使模型不断优化误差。"
|
||||
advantages:
|
||||
- "较强的泛化能力,适用于大规模数据集。"
|
||||
- "支持特征重要性分析。"
|
||||
disadvantages:
|
||||
- "训练时间较长,调优复杂。"
|
||||
applicable_scenarios:
|
||||
- "搜索引擎排名。"
|
||||
- "金融信用评分。"
|
||||
|
||||
DNN:
|
||||
principle: "深度神经网络(DNN)是一种多层神经网络,通过多层非线性变换提取数据的复杂特征。"
|
||||
advantages:
|
||||
- "适用于复杂非线性问题。"
|
||||
- "能够处理大规模数据。"
|
||||
disadvantages:
|
||||
- "计算开销大,对硬件要求高。"
|
||||
applicable_scenarios:
|
||||
- "自动驾驶。"
|
||||
- "自然语言处理。"
|
||||
|
||||
regression_algorithms:
|
||||
LinearRegression:
|
||||
principle: "线性回归通过最小化数据点与回归线之间的误差平方和,来拟合一条最佳的直线。"
|
||||
advantages:
|
||||
- "实现简单,易于理解"
|
||||
- "计算效率高,适用于小规模数据集"
|
||||
disadvantages:
|
||||
- "对异常值敏感"
|
||||
- "只能建模线性关系,无法处理非线性数据"
|
||||
applicable_scenarios: "适用于线性关系明确,且数据量适中的问题。"
|
||||
|
||||
PolynomialRegression:
|
||||
principle: "多项式回归是线性回归的一种扩展,利用多项式拟合数据,处理非线性关系。"
|
||||
advantages:
|
||||
- "能够拟合非线性关系"
|
||||
- "灵活性强,能够处理复杂数据模式"
|
||||
disadvantages:
|
||||
- "容易过拟合,特别是在多项式度数较高时"
|
||||
- "计算复杂度较高"
|
||||
applicable_scenarios: "适用于非线性数据拟合,但需避免过拟合。"
|
||||
|
||||
Ridge:
|
||||
principle: "岭回归在最小化误差的同时,引入L2正则化,约束模型的复杂度。"
|
||||
advantages:
|
||||
- "有效防止过拟合"
|
||||
- "适用于特征多的情况"
|
||||
disadvantages:
|
||||
- "对于特征相关性较强的数据效果较差"
|
||||
- "结果解释性较差"
|
||||
applicable_scenarios: "适用于特征多且存在共线性的线性回归问题。"
|
||||
|
||||
Lasso:
|
||||
principle: "Lasso回归通过L1正则化来限制模型的复杂度,可以实现特征选择。"
|
||||
advantages:
|
||||
- "能够进行特征选择,减少不相关特征"
|
||||
- "减少模型的复杂度,避免过拟合"
|
||||
disadvantages:
|
||||
- "可能会导致某些特征完全被剔除,造成信息丢失"
|
||||
- "对于高相关特征可能不稳定"
|
||||
applicable_scenarios: "适用于需要特征选择或特征较多的线性回归问题。"
|
||||
|
||||
ElasticNet:
|
||||
principle: "弹性网络回归结合了Lasso回归的L1正则化和岭回归的L2正则化,能够处理更多情况。"
|
||||
advantages:
|
||||
- "能够处理相关特征,结合L1和L2的优点"
|
||||
- "适用于特征数大于样本数的情形"
|
||||
disadvantages:
|
||||
- "计算复杂度较高"
|
||||
- "可能需要调参来获得最佳效果"
|
||||
applicable_scenarios: "适用于高维数据,且特征间存在线性相关的情况。"
|
||||
|
||||
SVR:
|
||||
principle: "支持向量回归(SVR)通过在高维空间中寻找一个最优超平面来拟合数据,能够处理非线性回归问题。"
|
||||
advantages:
|
||||
- "能够处理非线性数据,效果较好"
|
||||
- "适用于高维数据"
|
||||
disadvantages:
|
||||
- "对超参数敏感,调参困难"
|
||||
- "计算时间较长,尤其在大数据集上"
|
||||
applicable_scenarios: "适用于数据具有非线性关系且数据量适中的问题。"
|
||||
|
||||
DecisionTreeRegressor:
|
||||
principle: "决策树回归通过构建树状结构来拟合数据,节点上的划分基于特征的不同值。"
|
||||
advantages:
|
||||
- "易于理解和可视化"
|
||||
- "可以处理非线性关系"
|
||||
disadvantages:
|
||||
- "容易过拟合"
|
||||
- "对噪声数据敏感"
|
||||
applicable_scenarios: "适用于非线性回归,且对数据的复杂性有较高的容忍度。"
|
||||
|
||||
RandomForestRegressor:
|
||||
principle: "随机森林回归通过集成多棵决策树的结果来提高预测精度,能够处理高维数据。"
|
||||
advantages:
|
||||
- "不容易过拟合,适用于复杂问题"
|
||||
- "能够处理缺失数据"
|
||||
disadvantages:
|
||||
- "计算复杂度较高"
|
||||
- "结果难以解释"
|
||||
applicable_scenarios: "适用于复杂的回归问题,尤其是特征维度较高的场景。"
|
||||
|
||||
XGBRegressor:
|
||||
principle: "XGBoost回归通过梯度提升算法(GBDT)优化损失函数,通过树的组合来预测目标值。"
|
||||
advantages:
|
||||
- "预测精度高,效果好"
|
||||
- "处理大数据能力强"
|
||||
disadvantages:
|
||||
- "需要较长的训练时间"
|
||||
- "对超参数敏感"
|
||||
applicable_scenarios: "适用于大规模数据集且对预测精度要求较高的问题。"
|
||||
|
||||
AdaBoostRegressor:
|
||||
principle: "AdaBoost回归通过多次训练弱学习器并加权组合,改进模型的预测能力。"
|
||||
advantages:
|
||||
- "能够提高弱学习器的预测精度"
|
||||
- "对噪声较为鲁棒"
|
||||
disadvantages:
|
||||
- "容易受到异常值影响"
|
||||
- "对某些问题的性能较差"
|
||||
applicable_scenarios: "适用于弱学习器的组合优化问题,特别是样本不平衡的场景。"
|
||||
|
||||
CatBoostRegressor:
|
||||
principle: "CatBoost回归基于梯度提升决策树(GBDT),特别优化了类别特征的处理。"
|
||||
advantages:
|
||||
- "能够处理类别特征,减少数据预处理"
|
||||
- "高效且精度较高"
|
||||
disadvantages:
|
||||
- "训练时间较长"
|
||||
- "参数调节较为复杂"
|
||||
applicable_scenarios: "适用于包含大量类别特征的数据集,且数据量较大的问题。"
|
||||
|
||||
LGBMRegressor:
|
||||
principle: "LightGBM回归基于梯度提升决策树,使用了直方图优化算法以加速训练过程。"
|
||||
advantages:
|
||||
- "训练速度快,能够处理大规模数据"
|
||||
- "内存占用较少,适合高维数据"
|
||||
disadvantages:
|
||||
- "需要调参以获得最佳效果"
|
||||
- "模型解释性较差"
|
||||
applicable_scenarios: "适用于大规模数据集,尤其是处理海量数据时效果优越。"
|
||||
|
||||
MLPRegressor:
|
||||
principle: "多层感知机回归是基于神经网络的回归模型,通过多个隐层来拟合复杂的非线性关系。"
|
||||
advantages:
|
||||
- "能够处理复杂的非线性关系"
|
||||
- "在大数据和高维数据中表现良好"
|
||||
disadvantages:
|
||||
- "训练时间较长,且对计算资源要求较高"
|
||||
- "容易过拟合,特别是在数据较少时"
|
||||
applicable_scenarios: "适用于需要捕捉复杂非线性关系的回归问题,尤其在数据量大时效果较好。"
|
||||
|
||||
clustering_algorithms:
|
||||
KMeans:
|
||||
principle: "K均值聚类通过最小化样本点到其最近簇中心的距离来将数据分为K个簇。"
|
||||
advantages:
|
||||
- "实现简单,计算效率高"
|
||||
- "适用于大数据集"
|
||||
disadvantages:
|
||||
- "需要预先指定簇的数量K"
|
||||
- "对初始中心点敏感,容易受到噪声影响"
|
||||
applicable_scenarios: "适用于簇形状较为规则且数据量较大的聚类问题。"
|
||||
|
||||
KMeansPlusPlus:
|
||||
principle: "K-Means++是一种改进的初始化方法,通过选择更远离当前簇中心的样本点作为初始中心,提升K均值聚类的效果。"
|
||||
advantages:
|
||||
- "比传统K均值方法更稳定"
|
||||
- "可以减少聚类结果的变异性"
|
||||
disadvantages:
|
||||
- "依然存在需要指定K的问题"
|
||||
- "初始化仍然可能影响最终结果"
|
||||
applicable_scenarios: "适用于K均值聚类方法,并且希望改进初始中心选择的场景。"
|
||||
|
||||
HierarchicalKMeans:
|
||||
principle: "层次化K均值结合了层次聚类和K均值聚类的方法,逐步将样本合并到已有簇中,形成层次化结构。"
|
||||
advantages:
|
||||
- "能够自动确定簇的数量"
|
||||
- "生成的树状图有助于理解数据结构"
|
||||
disadvantages:
|
||||
- "计算量大,尤其是在数据量较大时"
|
||||
- "对噪声和离群点敏感"
|
||||
applicable_scenarios: "适用于不确定簇的数量且数据结构较复杂的情况。"
|
||||
|
||||
FCM:
|
||||
principle: "模糊C均值(FCM)允许每个数据点属于多个簇,基于隶属度来进行聚类。"
|
||||
advantages:
|
||||
- "能够处理数据点属于多个簇的情况"
|
||||
- "适用于软聚类问题"
|
||||
disadvantages:
|
||||
- "对初始簇中心和隶属度设置较为敏感"
|
||||
- "计算量大,尤其是簇数较多时"
|
||||
applicable_scenarios: "适用于数据点可以属于多个簇的情况,如图像分割等问题。"
|
||||
|
||||
AgglomerativeClustering:
|
||||
principle: "层次聚类通过将相似度较高的样本逐步合并,最终形成树状结构(树状图)。"
|
||||
advantages:
|
||||
- "无需预先指定簇的数量"
|
||||
- "能够处理任意形状的簇"
|
||||
disadvantages:
|
||||
- "计算复杂度较高,数据量大时效率低"
|
||||
- "容易受到噪声的影响"
|
||||
applicable_scenarios: "适用于不确定簇的数量且簇的形状较复杂的场景。"
|
||||
|
||||
DBSCAN:
|
||||
principle: "DBSCAN基于密度的聚类方法,通过寻找密集区域来划分簇,对于稀疏区域则标记为噪声点。"
|
||||
advantages:
|
||||
- "能够发现任意形状的簇"
|
||||
- "不需要预先指定簇的数量,能够自动识别噪声"
|
||||
disadvantages:
|
||||
- "对参数设置敏感,尤其是邻域半径和最小样本数"
|
||||
- "不适用于簇大小差异过大的数据集"
|
||||
applicable_scenarios: "适用于具有明显密度差异的数据集,尤其适合处理含有噪声的数据。"
|
||||
|
||||
GaussianMixture:
|
||||
principle: "高斯混合模型(GMM)假设数据是由多个高斯分布的混合体组成,通过EM算法估计每个数据点的隶属概率。"
|
||||
advantages:
|
||||
- "能够处理重叠的簇"
|
||||
- "可以自动估计每个簇的分布"
|
||||
disadvantages:
|
||||
- "计算复杂度高,容易陷入局部最优解"
|
||||
- "需要预先指定簇的数量"
|
||||
applicable_scenarios: "适用于需要拟合混合高斯分布的数据,尤其适合处理连续数据。"
|
||||
|
||||
SpectralClustering:
|
||||
principle: "谱聚类通过构造样本之间的相似度矩阵,并计算其特征值与特征向量来实现聚类。"
|
||||
advantages:
|
||||
- "能够处理复杂的簇结构"
|
||||
- "适用于非凸形状的簇"
|
||||
disadvantages:
|
||||
- "计算量大,尤其在大数据集上效率低"
|
||||
- "需要计算样本间的相似度矩阵,存储和计算成本较高"
|
||||
applicable_scenarios: "适用于样本之间相似度较强,但簇形状复杂或非凸的聚类任务。"
|
||||
|
||||
|
||||
93
test_data_processor.py
Normal file
93
test_data_processor.py
Normal file
@ -0,0 +1,93 @@
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from function.data_processor_date import DataProcessor
|
||||
|
||||
class TestDataProcessor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = DataProcessor()
|
||||
|
||||
# 创建测试数据
|
||||
self.test_data = pd.DataFrame({
|
||||
'feature1': [1, 2, np.nan, 4, 5],
|
||||
'feature2': [10, 20, 30, 40, 50],
|
||||
'target': [0, 1, 0, 1, 0]
|
||||
})
|
||||
|
||||
# 保存测试数据
|
||||
self.input_path = 'dataset/dataset_raw/test_data.csv'
|
||||
Path(self.input_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self.test_data.to_csv(self.input_path, index=False)
|
||||
|
||||
# 设置输出目录
|
||||
self.output_dir = 'dataset/dataset_processed'
|
||||
|
||||
def test_process_dataset(self):
|
||||
# 定义处理方法
|
||||
cleaning_methods = [
|
||||
{
|
||||
'method_name': 'SimpleImputer',
|
||||
'params': {'strategy': 'mean'}
|
||||
}
|
||||
]
|
||||
|
||||
feature_methods = [
|
||||
{
|
||||
'method_name': 'StandardScaler',
|
||||
'params': {}
|
||||
}
|
||||
]
|
||||
|
||||
split_params = {
|
||||
'test_size': 0.2,
|
||||
'val_size': 0.2
|
||||
}
|
||||
|
||||
# 处理数据集
|
||||
result = self.processor.process_dataset(
|
||||
self.input_path,
|
||||
self.output_dir,
|
||||
cleaning_methods,
|
||||
feature_methods,
|
||||
split_params
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertIn('process_record', result)
|
||||
|
||||
# 验证输出文件
|
||||
record = result['process_record']
|
||||
self.assertTrue(Path(record['output_files']['train']).exists())
|
||||
self.assertTrue(Path(record['output_files']['validation']).exists())
|
||||
self.assertTrue(Path(record['output_files']['test']).exists())
|
||||
|
||||
def test_invalid_method(self):
|
||||
# 测试无效的方法名
|
||||
cleaning_methods = [
|
||||
{
|
||||
'method_name': 'InvalidMethod',
|
||||
'params': {}
|
||||
}
|
||||
]
|
||||
|
||||
result = self.processor.process_dataset(
|
||||
self.input_path,
|
||||
self.output_dir,
|
||||
cleaning_methods,
|
||||
[],
|
||||
{'test_size': 0.2, 'val_size': 0.2}
|
||||
)
|
||||
|
||||
self.assertEqual(result['status'], 'error')
|
||||
|
||||
def tearDown(self):
|
||||
# 清理测试文件
|
||||
try:
|
||||
Path(self.input_path).unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
49
test_method_reader.py
Normal file
49
test_method_reader.py
Normal file
@ -0,0 +1,49 @@
|
||||
import unittest
|
||||
from function.method_reader_date_process import MethodReader
|
||||
|
||||
class TestMethodReader(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reader = MethodReader()
|
||||
|
||||
def test_get_preprocessing_methods(self):
|
||||
result = self.reader.get_preprocessing_methods()
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertIsInstance(result['methods'], list)
|
||||
|
||||
# 检查返回的方法列表
|
||||
methods = result['methods']
|
||||
self.assertTrue(any(m['name'] == 'data_scaler' for m in methods))
|
||||
self.assertTrue(any(m['name'] == 'missing_value_handler' for m in methods))
|
||||
self.assertTrue(any(m['name'] == 'outlier_detector' for m in methods))
|
||||
|
||||
def test_get_method_details(self):
|
||||
# 测试获取StandardScaler的详细信息
|
||||
result = self.reader.get_method_details('StandardScaler')
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertEqual(result['method']['name'], 'StandardScaler')
|
||||
|
||||
# 检查返回的详细信息字段
|
||||
method = result['method']
|
||||
self.assertIn('description', method)
|
||||
self.assertIn('principle', method)
|
||||
self.assertIn('advantages', method)
|
||||
self.assertIn('disadvantages', method)
|
||||
self.assertIn('applicable_scenarios', method)
|
||||
self.assertIn('parameters', method)
|
||||
|
||||
# 检查参数信息
|
||||
parameters = method['parameters']
|
||||
self.assertIsInstance(parameters, list)
|
||||
if parameters:
|
||||
param = parameters[0]
|
||||
self.assertIn('name', param)
|
||||
self.assertIn('type', param)
|
||||
self.assertIn('default', param)
|
||||
self.assertIn('description', param)
|
||||
|
||||
# 测试获取不存在的方法
|
||||
result = self.reader.get_method_details('NonExistentMethod')
|
||||
self.assertEqual(result['status'], 'error')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
85
test_model_trainer.py
Normal file
85
test_model_trainer.py
Normal file
@ -0,0 +1,85 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from function.model_trainer import ModelTrainer
|
||||
|
||||
class TestModelTrainer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.trainer = ModelTrainer()
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
self.X_train = np.random.randn(100, 5)
|
||||
self.y_train = np.random.randint(0, 2, 100)
|
||||
self.X_val = np.random.randn(30, 5)
|
||||
self.y_val = np.random.randint(0, 2, 30)
|
||||
|
||||
def test_train_model(self):
|
||||
# 准备训练数据
|
||||
train_data = {
|
||||
'features': self.X_train,
|
||||
'labels': self.y_train
|
||||
}
|
||||
|
||||
val_data = {
|
||||
'features': self.X_val,
|
||||
'labels': self.y_val
|
||||
}
|
||||
|
||||
# 模型配置
|
||||
model_config = {
|
||||
'algorithm': 'LogisticRegression',
|
||||
'task_type': 'classification',
|
||||
'params': {
|
||||
'random_state': 42
|
||||
}
|
||||
}
|
||||
|
||||
# 训练模型
|
||||
result = self.trainer.train_model(
|
||||
train_data,
|
||||
val_data,
|
||||
model_config,
|
||||
'test_experiment'
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertIn('run_id', result)
|
||||
self.assertIn('metrics', result)
|
||||
|
||||
# 验证指标
|
||||
metrics = result['metrics']
|
||||
self.assertIn('accuracy', metrics)
|
||||
self.assertIn('precision', metrics)
|
||||
self.assertIn('recall', metrics)
|
||||
self.assertIn('f1', metrics)
|
||||
|
||||
def test_invalid_algorithm(self):
|
||||
# 测试无效的算法名
|
||||
train_data = {
|
||||
'features': self.X_train,
|
||||
'labels': self.y_train
|
||||
}
|
||||
|
||||
val_data = {
|
||||
'features': self.X_val,
|
||||
'labels': self.y_val
|
||||
}
|
||||
|
||||
model_config = {
|
||||
'algorithm': 'InvalidAlgorithm',
|
||||
'task_type': 'classification',
|
||||
'params': {}
|
||||
}
|
||||
|
||||
result = self.trainer.train_model(
|
||||
train_data,
|
||||
val_data,
|
||||
model_config,
|
||||
'test_experiment'
|
||||
)
|
||||
|
||||
self.assertEqual(result['status'], 'error')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user