diff --git a/config/config.yaml b/config/config.yaml index dca3239..9d11c9b 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,2 +1,37 @@ -mlfow: - uri: "http://localhost:5000" \ No newline at end of file +# 服务器配置 +host: "10.0.0.202" +port: 8992 +workers: 4 +debug: true + +# MLflow配置 +mlflow_uri: "http://10.0.0.202:5000" + +# 数据处理配置 +dataset: + raw_dir: "dataset/dataset_raw" + processed_dir: "dataset/dataset_processed" + +# 模型配置 +model: + save_dir: "models" + batch_size: 32 + num_workers: 4 + +# 系统监控配置 +monitor: + log_dir: ".log" + resource_check_interval: 60 # 秒 + cleanup_interval: 86400 # 24小时 + max_log_days: 30 # 日志保留天数 + +# 安全配置 +security: + secret_key: "your-secret-key" + token_expire_minutes: 1440 # 24小时 + +# 性能配置 +performance: + max_concurrent_trains: 4 # 最大并发训练数 + cache_size: 1024 # MB + timeout: 3600 # 秒 \ No newline at end of file diff --git a/doc/接口文档code.md b/doc/接口文档code.md index e874777..6124707 100644 --- a/doc/接口文档code.md +++ b/doc/接口文档code.md @@ -523,6 +523,9 @@ Error Response: } ``` +### 2.9 模型优化 -- 未实现 + + ## 3. 系统监控 ### 3.1 获取资源使用情况 ```http @@ -681,7 +684,7 @@ Error Response: } ``` -### 3.3 获取系统中训练状态 +### 3.3 获取系统中训练状态 ---- 未完成, 等开发系统后台时再实现. ```http GET /api/train/status/{task_id} @@ -764,6 +767,76 @@ Error Response: } ``` +## 4. 系统后台整体实现 + +### 4.1 系统架构 +``` +MLPlatform/ +├── api/ # API接口层 +│ ├── __init__.py +│ ├── data_api.py # 数据处理相关接口 +│ ├── model_api.py # 模型相关接口 +│ └── system_api.py # 系统监控相关接口 +├── function/ # 功能实现层 +│ ├── data_processor.py # 数据处理类 +│ ├── model_manager.py # 模型管理类 +│ ├── model_trainer.py # 模型训练类 +│ ├── system_monitor.py # 系统监控类 +│ └── utils/ # 工具函数 +├── config/ # 配置文件 +│ └── config.yaml # 系统配置 +├── dataset/ # 数据集 +│ ├── dataset_raw/ # 原始数据 +│ └── dataset_processed/ # 处理后数据 +├── .log/ # 日志文件 +├── doc/ # 文档 +└── main.py # 主程序入口 +``` + +### 4.2 技术栈 +- FastAPI: Web框架 +- MLflow: 模型管理和实验跟踪 +- PyTorch/Scikit-learn: 机器学习框架 +- Pydantic: 数据验证 +- Uvicorn: ASGI服务器 + +### 4.3 主要功能 +1. 异步任务处理 + - 支持多个模型同时训练 + - 后台任务状态监控 + - 任务队列管理 + +2. 实时监控 + - 系统资源监控 + - 训练进度监控 + - 日志实时查看 + +3. 错误处理 + - 全局异常处理 + - 错误日志记录 + - 优雅降级策略 + +4. 安全性 + - API认证授权 + - 请求限流 + - 参数验证 + +### 4.4 性能优化 +1. 数据处理 + - 数据流式处理 + - 缓存机制 + - 批量处理 + +2. 模型训练 + - GPU利用优化 + - 分布式训练支持 + - 模型检查点 + +3. 系统监控 + - 性能指标采集 + - 资源使用预警 + - 自动清理机制 + ## 附录A:方法详细说明 ### A1. 数据预处理方法 diff --git a/function/__pycache__/model_manager.cpython-39.pyc b/function/__pycache__/model_manager.cpython-39.pyc index 23d59e7..30dc1d6 100644 Binary files a/function/__pycache__/model_manager.cpython-39.pyc and b/function/__pycache__/model_manager.cpython-39.pyc differ diff --git a/function/method_reader_metric.py b/function/method_reader_metric.py deleted file mode 100644 index bf43fea..0000000 --- a/function/method_reader_metric.py +++ /dev/null @@ -1,79 +0,0 @@ -import yaml -from typing import Dict, List -import os -import logging -from pathlib import Path - -class MethodReader: - """方法配置读取器""" - - def __init__(self): - """初始化方法读取器""" - self.logger = logging.getLogger(__name__) - self.method_config = self._load_metrics() - - - def _load_metrics(self) -> Dict: - """加载方法配置文件""" - try: - config_path = Path('model/metrics.yaml') - if not config_path.exists(): - raise FileNotFoundError(f"Method config file not found at {config_path}") - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - self.logger.info("Successfully loaded method config") - return config - - except Exception as e: - self.logger.error(f"Error loading method config: {str(e)}") - raise - - - - def get_metrics(self) -> Dict: - """获取预处理方法列表""" - try: - metrics = [] - - # 分类方法 - classification_metrics = self.method_config.get('classification', {}) - if classification_metrics: - metrics.append({ - "name": "classification_metrics", - "description": "分类方法评价指标", - "metric": classification_metrics - }) - - # 回归方法 - regression_metrics = self.method_config.get('regression', {}) - if regression_metrics: - metrics.append({ - "name": "regression_metrics", - "description": "回归方法评价指标", - "metric": regression_metrics - }) - - # 聚类方法 - clustering_metrics = self.method_config.get('clustering', {}) - if clustering_metrics: - metrics.append({ - "name": "clustering_metrics", - "description": "聚类方法评价指标", - "metric": clustering_metrics - }) - - return { - "status": "success", - "metric": metrics - } - - except Exception as e: - self.logger.error(f"Error getting preprocessing methods: {str(e)}") - return { - "status": "error", - "error": str(e) - } - - \ No newline at end of file diff --git a/function/method_reader_model.py b/function/method_reader_model.py deleted file mode 100644 index 2233e87..0000000 --- a/function/method_reader_model.py +++ /dev/null @@ -1,135 +0,0 @@ -import yaml -from typing import Dict, List -import os -import logging -from pathlib import Path - -class MethodReader: - """方法配置读取器""" - - def __init__(self): - """初始化方法读取器""" - self.logger = logging.getLogger(__name__) - self.method_config = self._load_model_config() - self.parameter_config = self._load_parameter_config() - - def _load_model_config(self) -> Dict: - """加载方法配置文件""" - try: - config_path = Path('model/model.yaml') - if not config_path.exists(): - raise FileNotFoundError(f"Method config file not found at {config_path}") - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - self.logger.info("Successfully loaded method config") - return config - - except Exception as e: - self.logger.error(f"Error loading method config: {str(e)}") - raise - - def _load_parameter_config(self) -> Dict: - """加载参数配置文件""" - try: - config_path = Path('model/parameter.yaml') - if not config_path.exists(): - raise FileNotFoundError(f"Parameter config file not found at {config_path}") - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - self.logger.info("Successfully loaded parameter config") - return config - except Exception as e: - self.logger.error(f"Error loading parameter config: {str(e)}") - raise - - def get_models(self) -> Dict: - """获取预处理方法列表""" - try: - models = [] - - # 分类方法 - classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys()) - if classification_algorithms: - models.append({ - "name": "classification_algorithms", - "description": "分类方法", - "method": classification_algorithms - }) - - # 回归方法 - regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys()) - if regression_algorithms: - models.append({ - "name": "regression_algorithms", - "description": "回归方法", - "method": regression_algorithms - }) - - # 聚类方法 - clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys()) - if clustering_algorithms: - models.append({ - "name": "clustering_algorithms", - "description": "聚类方法", - "method": clustering_algorithms - }) - - return { - "status": "success", - "models": models - } - - except Exception as e: - self.logger.error(f"Error getting preprocessing methods: {str(e)}") - return { - "status": "error", - "error": str(e) - } - - def get_model_details(self, method_name: str) -> Dict: - """获取指定方法的详细信息""" - try: - # 在各个方法类别中查找方法原理和优缺点 - method_info = None - for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: - if method_name in self.method_config.get(category, {}): - method_info = self.method_config[category][method_name] - break - - if method_info is None: - raise ValueError(f"Method {method_name} not found in method config") - - # 查找方法参数信息 - parameter_info = None - for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: - if method_name in self.parameter_config.get(category, {}): - parameter_info = self.parameter_config[category][method_name] - break - - if parameter_info is None: - raise ValueError(f"Method {method_name} not found in parameter config") - - # 组合返回信息 - return { - "status": "success", - "method": { - "name": method_name, - "description": parameter_info.get('description', ''), - "principle": method_info.get('principle', ''), - "advantages": method_info.get('advantages', []), - "disadvantages": method_info.get('disadvantages', []), - "applicable_scenarios": method_info.get('applicable_scenarios', []), - "parameters": parameter_info.get('parameters', []) - } - } - - except Exception as e: - self.logger.error(f"Error getting method details: {str(e)}") - return { - "status": "error", - "error": str(e) - } \ No newline at end of file diff --git a/function/model_manager.py b/function/model_manager.py index b8c36c5..9137c98 100644 --- a/function/model_manager.py +++ b/function/model_manager.py @@ -18,6 +18,11 @@ from sklearn.metrics import ( import torch from torch.utils.data import DataLoader, TensorDataset + +''' + 模型管理整体集成 +''' + class ModelManager: """模型管理类""" @@ -27,12 +32,33 @@ class ModelManager: self.logger = logging.getLogger(__name__) self._setup_logging() self._metrics_map() + self.method_config = self._load_metrics() + + self.method_config = self._load_model_config() + self.parameter_config = self._load_parameter_config() # 初始化MLflow客户端 self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000') mlflow.set_tracking_uri(self.mlflow_uri) self.client = MlflowClient() + def _load_metrics(self) -> Dict: + """加载方法配置文件""" + try: + config_path = Path('model/metrics.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Method config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + self.logger.info("Successfully loaded method config") + return config + + except Exception as e: + self.logger.error(f"Error loading method config: {str(e)}") + raise + def _setup_logging(self): """设置日志""" log_dir = Path('.log') @@ -47,6 +73,54 @@ class ModelManager: self.logger.addHandler(file_handler) self.logger.setLevel(logging.INFO) + def _load_model_config(self) -> Dict: + """加载方法配置文件""" + try: + config_path = Path('model/model.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Method config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config_model = yaml.safe_load(f) + + self.logger.info("Successfully loaded model config") + + + config_path = Path('model/metrics.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Metrics config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config_metric = yaml.safe_load(f) + + self.logger.info("Successfully loaded metrics config") + + + config = {**config_model, **config_metric} + return config + + + except Exception as e: + self.logger.error(f"Error loading method or metric config: {str(e)}") + raise + + def _load_parameter_config(self) -> Dict: + """加载参数配置文件""" + try: + config_path = Path('model/parameter.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Parameter config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + self.logger.info("Successfully loaded parameter config") + return config + except Exception as e: + self.logger.error(f"Error loading parameter config: {str(e)}") + raise + + def _metrics_map(self): self.metrics_map={ 'accuracy' : accuracy_score, @@ -63,7 +137,99 @@ class ModelManager: 'completeness': completeness_score, 'silhouette' : silhouette_score } + + + def _get_algorithm_info(self, algorithm_name: str) -> Dict: + """获取算法信息""" + for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: + if algorithm_name in self.method_config.get(category, {}): + return self.method_config[category][algorithm_name] + raise ValueError(f"Algorithm {algorithm_name} not found in model info") + + + def _get_model_class(self, algorithm_name: str): + """获取模型类""" + # 分类算法 + from sklearn.linear_model import LogisticRegression + from sklearn.svm import SVC, OneClassSVM + from sklearn.tree import DecisionTreeClassifier + from sklearn.ensemble import ( + RandomForestClassifier, GradientBoostingClassifier, + AdaBoostClassifier, IsolationForest + ) + from sklearn.naive_bayes import GaussianNB + from sklearn.neighbors import KNeighborsClassifier + from sklearn.neural_network import MLPClassifier + import xgboost as xgb + import lightgbm as lgb + from catboost import CatBoostClassifier + # 回归算法 + from sklearn.linear_model import ( + LinearRegression, Ridge, Lasso, + ElasticNet + ) + from sklearn.svm import SVR + from sklearn.tree import DecisionTreeRegressor + from sklearn.ensemble import ( + RandomForestRegressor, GradientBoostingRegressor, + AdaBoostRegressor + ) + from catboost import CatBoostRegressor + from sklearn.neural_network import MLPRegressor + + # 聚类算法 + from sklearn.cluster import ( + KMeans, AgglomerativeClustering, + DBSCAN, SpectralClustering + ) + from sklearn.mixture import GaussianMixture + + algorithm_map = { + # 分类算法 + 'LogisticRegression': LogisticRegression, + 'SVC': SVC, + 'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现 + 'DecisionTreeClassifier': DecisionTreeClassifier, + 'RandomForestClassifier': RandomForestClassifier, + 'XGBClassifier': xgb.XGBClassifier, + 'AdaBoostClassifier': AdaBoostClassifier, + 'CatBoostClassifier': CatBoostClassifier, + 'LGBMClassifier': lgb.LGBMClassifier, + 'GaussianNB': GaussianNB, + 'KNeighborsClassifier': KNeighborsClassifier, + 'MLPClassifier': MLPClassifier, + 'GradientBoostingClassifier': GradientBoostingClassifier, + + # 回归算法 + 'LinearRegression': LinearRegression, + 'Ridge': Ridge, + 'Lasso': Lasso, + 'ElasticNet': ElasticNet, + 'SVR': SVR, + 'DecisionTreeRegressor': DecisionTreeRegressor, + 'RandomForestRegressor': RandomForestRegressor, + 'XGBRegressor': xgb.XGBRegressor, + 'AdaBoostRegressor': AdaBoostRegressor, + 'CatBoostRegressor': CatBoostRegressor, + 'LGBMRegressor': lgb.LGBMRegressor, + 'MLPRegressor': MLPRegressor, + + # 聚类算法 + 'KMeans': KMeans, + 'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制 + 'AgglomerativeClustering': AgglomerativeClustering, + 'DBSCAN': DBSCAN, + 'GaussianMixture': GaussianMixture, + 'SpectralClustering': SpectralClustering + } + + if algorithm_name not in algorithm_map: + raise ValueError(f"Unknown algorithm: {algorithm_name}") + + return algorithm_map[algorithm_name] + + def get_finished_models( self, page: int = 1, @@ -440,7 +606,96 @@ class ModelManager: 'execution_time': f"{execution_time:.2f}s" } } + + def get_models(self) -> Dict: + """获取预处理方法列表""" + try: + models = [] + # 分类方法 + classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys()) + if classification_algorithms: + models.append({ + "name": "classification_algorithms", + "description": "分类方法", + "method": classification_algorithms + }) + + # 回归方法 + regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys()) + if regression_algorithms: + models.append({ + "name": "regression_algorithms", + "description": "回归方法", + "method": regression_algorithms + }) + + # 聚类方法 + clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys()) + if clustering_algorithms: + models.append({ + "name": "clustering_algorithms", + "description": "聚类方法", + "method": clustering_algorithms + }) + + return { + "status": "success", + "models": models + } + + except Exception as e: + self.logger.error(f"Error getting preprocessing methods: {str(e)}") + return { + "status": "error", + "error": str(e) + } + + def get_model_details(self, method_name: str) -> Dict: + """获取指定方法的详细信息""" + try: + # 在各个方法类别中查找方法原理和优缺点 + method_info = None + for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: + if method_name in self.method_config.get(category, {}): + method_info = self.method_config[category][method_name] + break + + if method_info is None: + raise ValueError(f"Method {method_name} not found in method config") + + # 查找方法参数信息 + parameter_info = None + for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: + if method_name in self.parameter_config.get(category, {}): + parameter_info = self.parameter_config[category][method_name] + break + + if parameter_info is None: + raise ValueError(f"Method {method_name} not found in parameter config") + + # 组合返回信息 + return { + "status": "success", + "method": { + "name": method_name, + "description": parameter_info.get('description', ''), + "principle": method_info.get('principle', ''), + "advantages": method_info.get('advantages', []), + "disadvantages": method_info.get('disadvantages', []), + "applicable_scenarios": method_info.get('applicable_scenarios', []), + "parameters": parameter_info.get('parameters', []) + } + } + + except Exception as e: + self.logger.error(f"Error getting method details: {str(e)}") + return { + "status": "error", + "error": str(e) + } + + # except Exception as e: # error_msg = f"预测过程发生错误: {str(e)}" # self.logger.error(error_msg) @@ -451,4 +706,184 @@ class ModelManager: # 'error_type': type(e).__name__, # 'error_message': str(e) # } - # } \ No newline at end of file + # } + + def get_metrics(self) -> Dict: + """获取预处理方法列表""" + try: + metrics = [] + + # 分类方法 + classification_metrics = self.method_config.get('classification', {}) + if classification_metrics: + metrics.append({ + "name": "classification_metrics", + "description": "分类方法评价指标", + "metric": classification_metrics + }) + + # 回归方法 + regression_metrics = self.method_config.get('regression', {}) + if regression_metrics: + metrics.append({ + "name": "regression_metrics", + "description": "回归方法评价指标", + "metric": regression_metrics + }) + + # 聚类方法 + clustering_metrics = self.method_config.get('clustering', {}) + if clustering_metrics: + metrics.append({ + "name": "clustering_metrics", + "description": "聚类方法评价指标", + "metric": clustering_metrics + }) + + return { + "status": "success", + "metric": metrics + } + + except Exception as e: + self.logger.error(f"Error getting preprocessing methods: {str(e)}") + return { + "status": "error", + "error": str(e) + } + def train_model( + self, + train_data: Dict, + val_data: Dict, + model_config: Dict, + experiment_name: str + ) -> Dict: + """ + 训练模型 + + Args: + train_data: 训练数据,包含特征和标签 + val_data: 验证数据,包含特征和标签 + model_config: 模型配置,包含算法名称和参数 + experiment_name: MLflow实验名称 + + Returns: + 训练结果字典 + """ + try: + # 检查实验是否存在且被删除 + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment and experiment.lifecycle_stage == 'deleted': + # 如果实验被删除,则创建一个新的实验名称 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + new_experiment_name = f"{experiment_name}_{timestamp}" + self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}") + experiment_name = new_experiment_name + + # 设置MLflow实验 + mlflow.set_experiment(experiment_name) + + with mlflow.start_run() as run: + # 记录基本信息 + mlflow.log_param('algorithm', model_config['algorithm']) + mlflow.log_param('task_type', model_config['task_type']) + # mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称 + mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径 + + # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + # mlflow.log_param('start_time', timestamp) + + # 记录模型参数 + for param_name, param_value in model_config['params'].items(): + mlflow.log_param(param_name, param_value) + + # 记录算法信息 + algorithm_info = self._get_algorithm_info(model_config['algorithm']) + mlflow.log_param('principle', algorithm_info['principle']) + mlflow.log_param('advantages', str(algorithm_info['advantages'])) + mlflow.log_param('disadvantages', str(algorithm_info['disadvantages'])) + + # 特殊处理KMeans++ + if model_config['algorithm'] == 'KMeansPlusPlus': + model_config['params']['init'] = 'k-means++' + + # 获取模型类和信息 + model_class = self._get_model_class(model_config['algorithm']) + + # 创建模型实例 + model = model_class(**model_config['params']) + + # 训练模型 + self.logger.info(f"Starting training {model_config['algorithm']}") + model.fit(train_data['features'], train_data['labels']) + + # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + # mlflow.log_param('end_time', timestamp) + + # 在验证集上评估 + val_predictions = model.predict(val_data['features']) + metrics = self._calculate_metrics( + val_data['labels'], + val_predictions, + model_config['task_type'] + ) + + # 记录指标 + for metric_name, metric_value in metrics.items(): + mlflow.log_metric(metric_name, metric_value) + + # 保存模型 + mlflow.sklearn.log_model(model, "model") + + self.logger.info(f"Training completed. Run ID: {run.info.run_id}") + + return { + 'status': 'success', + 'run_id': run.info.run_id, + 'metrics': metrics, + 'algorithm_info': algorithm_info + } + + except Exception as e: + error_msg = f"Error training model: {str(e)}" + self.logger.error(error_msg) + return { + 'status': 'error', + 'message': error_msg + } + + def _calculate_metrics( + self, + true_labels: np.ndarray, + predictions: np.ndarray, + task_type: str + ) -> Dict: + """计算评估指标""" + metrics = {} + + if task_type == 'classification': + metrics['accuracy'] = accuracy_score(true_labels, predictions) + metrics['precision'] = precision_score(true_labels, predictions, average='weighted') + metrics['recall'] = recall_score(true_labels, predictions, average='weighted') + metrics['f1'] = f1_score(true_labels, predictions, average='weighted') + if len(np.unique(true_labels)) == 2: # 二分类问题 + metrics['roc_auc'] = roc_auc_score(true_labels, predictions) + + elif task_type == 'regression': + metrics['mae'] = mean_absolute_error(true_labels, predictions) + metrics['mse'] = mean_squared_error(true_labels, predictions) + metrics['rmse'] = np.sqrt(metrics['mse']) + metrics['r2'] = r2_score(true_labels, predictions) + metrics['explained_variance'] = explained_variance_score(true_labels, predictions) + + elif task_type == 'clustering': + metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions) + metrics['homogeneity'] = homogeneity_score(true_labels, predictions) + metrics['completeness'] = completeness_score(true_labels, predictions) + if len(np.unique(predictions)) > 1: # 确保有多个簇 + metrics['silhouette'] = silhouette_score( + true_labels.reshape(-1, 1), + predictions + ) + + return metrics \ No newline at end of file diff --git a/function/model_trainer.py b/function/model_trainer.py deleted file mode 100644 index 3556652..0000000 --- a/function/model_trainer.py +++ /dev/null @@ -1,298 +0,0 @@ -import numpy as pd -import numpy as np -from typing import Dict, List, Optional -import logging -from pathlib import Path -import datetime -import yaml -import mlflow -from sklearn.metrics import ( - accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, - mean_absolute_error, mean_squared_error, r2_score, explained_variance_score, - adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score -) - -class ModelTrainer: - """模型训练类""" - - def __init__(self, config: Dict = None): - """初始化模型训练器""" - self.config = config or {} - self.logger = logging.getLogger(__name__) - self._setup_logging() - self._load_metrics() - self._load_parameters() - self._load_model_info() - - # with open("confg/config.yaml", 'r', encoding='utf-8') as f: - # config = yaml.safe_load(f) - - # 初始化MLflow - mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000')) - - def _setup_logging(self): - """设置日志""" - log_dir = Path('.log') - log_dir.mkdir(exist_ok=True) - - file_handler = logging.FileHandler( - log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log' - ) - file_handler.setFormatter( - logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - ) - self.logger.addHandler(file_handler) - self.logger.setLevel(logging.INFO) - - def _load_metrics(self): - """加载评估指标配置""" - try: - with open('model/metrics.yaml', 'r', encoding='utf-8') as f: - self.metrics_config = yaml.safe_load(f) - except Exception as e: - self.logger.error(f"Error loading metrics config: {str(e)}") - raise - - def _load_parameters(self): - """加载模型参数配置""" - try: - with open('model/parameter.yaml', 'r', encoding='utf-8') as f: - self.parameter_config = yaml.safe_load(f) - except Exception as e: - self.logger.error(f"Error loading parameter config: {str(e)}") - raise - - def _load_model_info(self): - """加载模型信息配置""" - try: - with open('model/model.yaml', 'r', encoding='utf-8') as f: - self.model_info = yaml.safe_load(f) - except Exception as e: - self.logger.error(f"Error loading model info: {str(e)}") - raise - - def _get_model_class(self, algorithm_name: str): - """获取模型类""" - # 分类算法 - from sklearn.linear_model import LogisticRegression - from sklearn.svm import SVC, OneClassSVM - from sklearn.tree import DecisionTreeClassifier - from sklearn.ensemble import ( - RandomForestClassifier, GradientBoostingClassifier, - AdaBoostClassifier, IsolationForest - ) - from sklearn.naive_bayes import GaussianNB - from sklearn.neighbors import KNeighborsClassifier - from sklearn.neural_network import MLPClassifier - import xgboost as xgb - import lightgbm as lgb - from catboost import CatBoostClassifier - - # 回归算法 - from sklearn.linear_model import ( - LinearRegression, Ridge, Lasso, - ElasticNet - ) - from sklearn.svm import SVR - from sklearn.tree import DecisionTreeRegressor - from sklearn.ensemble import ( - RandomForestRegressor, GradientBoostingRegressor, - AdaBoostRegressor - ) - from catboost import CatBoostRegressor - from sklearn.neural_network import MLPRegressor - - # 聚类算法 - from sklearn.cluster import ( - KMeans, AgglomerativeClustering, - DBSCAN, SpectralClustering - ) - from sklearn.mixture import GaussianMixture - - algorithm_map = { - # 分类算法 - 'LogisticRegression': LogisticRegression, - 'SVC': SVC, - 'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现 - 'DecisionTreeClassifier': DecisionTreeClassifier, - 'RandomForestClassifier': RandomForestClassifier, - 'XGBClassifier': xgb.XGBClassifier, - 'AdaBoostClassifier': AdaBoostClassifier, - 'CatBoostClassifier': CatBoostClassifier, - 'LGBMClassifier': lgb.LGBMClassifier, - 'GaussianNB': GaussianNB, - 'KNeighborsClassifier': KNeighborsClassifier, - 'MLPClassifier': MLPClassifier, - 'GradientBoostingClassifier': GradientBoostingClassifier, - - # 回归算法 - 'LinearRegression': LinearRegression, - 'Ridge': Ridge, - 'Lasso': Lasso, - 'ElasticNet': ElasticNet, - 'SVR': SVR, - 'DecisionTreeRegressor': DecisionTreeRegressor, - 'RandomForestRegressor': RandomForestRegressor, - 'XGBRegressor': xgb.XGBRegressor, - 'AdaBoostRegressor': AdaBoostRegressor, - 'CatBoostRegressor': CatBoostRegressor, - 'LGBMRegressor': lgb.LGBMRegressor, - 'MLPRegressor': MLPRegressor, - - # 聚类算法 - 'KMeans': KMeans, - 'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制 - 'AgglomerativeClustering': AgglomerativeClustering, - 'DBSCAN': DBSCAN, - 'GaussianMixture': GaussianMixture, - 'SpectralClustering': SpectralClustering - } - - if algorithm_name not in algorithm_map: - raise ValueError(f"Unknown algorithm: {algorithm_name}") - - return algorithm_map[algorithm_name] - - def _get_algorithm_info(self, algorithm_name: str) -> Dict: - """获取算法信息""" - for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: - if algorithm_name in self.model_info.get(category, {}): - return self.model_info[category][algorithm_name] - raise ValueError(f"Algorithm {algorithm_name} not found in model info") - - def train_model( - self, - train_data: Dict, - val_data: Dict, - model_config: Dict, - experiment_name: str - ) -> Dict: - """ - 训练模型 - - Args: - train_data: 训练数据,包含特征和标签 - val_data: 验证数据,包含特征和标签 - model_config: 模型配置,包含算法名称和参数 - experiment_name: MLflow实验名称 - - Returns: - 训练结果字典 - """ - try: - # 检查实验是否存在且被删除 - experiment = mlflow.get_experiment_by_name(experiment_name) - if experiment and experiment.lifecycle_stage == 'deleted': - # 如果实验被删除,则创建一个新的实验名称 - timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - new_experiment_name = f"{experiment_name}_{timestamp}" - self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}") - experiment_name = new_experiment_name - - # 设置MLflow实验 - mlflow.set_experiment(experiment_name) - - with mlflow.start_run() as run: - # 记录基本信息 - mlflow.log_param('algorithm', model_config['algorithm']) - mlflow.log_param('task_type', model_config['task_type']) - # mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称 - mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径 - - # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - # mlflow.log_param('start_time', timestamp) - - # 记录模型参数 - for param_name, param_value in model_config['params'].items(): - mlflow.log_param(param_name, param_value) - - # 记录算法信息 - algorithm_info = self._get_algorithm_info(model_config['algorithm']) - mlflow.log_param('principle', algorithm_info['principle']) - mlflow.log_param('advantages', str(algorithm_info['advantages'])) - mlflow.log_param('disadvantages', str(algorithm_info['disadvantages'])) - - # 特殊处理KMeans++ - if model_config['algorithm'] == 'KMeansPlusPlus': - model_config['params']['init'] = 'k-means++' - - # 获取模型类和信息 - model_class = self._get_model_class(model_config['algorithm']) - - # 创建模型实例 - model = model_class(**model_config['params']) - - # 训练模型 - self.logger.info(f"Starting training {model_config['algorithm']}") - model.fit(train_data['features'], train_data['labels']) - - # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - # mlflow.log_param('end_time', timestamp) - - # 在验证集上评估 - val_predictions = model.predict(val_data['features']) - metrics = self._calculate_metrics( - val_data['labels'], - val_predictions, - model_config['task_type'] - ) - - # 记录指标 - for metric_name, metric_value in metrics.items(): - mlflow.log_metric(metric_name, metric_value) - - # 保存模型 - mlflow.sklearn.log_model(model, "model") - - self.logger.info(f"Training completed. Run ID: {run.info.run_id}") - - return { - 'status': 'success', - 'run_id': run.info.run_id, - 'metrics': metrics, - 'algorithm_info': algorithm_info - } - - except Exception as e: - error_msg = f"Error training model: {str(e)}" - self.logger.error(error_msg) - return { - 'status': 'error', - 'message': error_msg - } - - def _calculate_metrics( - self, - true_labels: np.ndarray, - predictions: np.ndarray, - task_type: str - ) -> Dict: - """计算评估指标""" - metrics = {} - - if task_type == 'classification': - metrics['accuracy'] = accuracy_score(true_labels, predictions) - metrics['precision'] = precision_score(true_labels, predictions, average='weighted') - metrics['recall'] = recall_score(true_labels, predictions, average='weighted') - metrics['f1'] = f1_score(true_labels, predictions, average='weighted') - if len(np.unique(true_labels)) == 2: # 二分类问题 - metrics['roc_auc'] = roc_auc_score(true_labels, predictions) - - elif task_type == 'regression': - metrics['mae'] = mean_absolute_error(true_labels, predictions) - metrics['mse'] = mean_squared_error(true_labels, predictions) - metrics['rmse'] = np.sqrt(metrics['mse']) - metrics['r2'] = r2_score(true_labels, predictions) - metrics['explained_variance'] = explained_variance_score(true_labels, predictions) - - elif task_type == 'clustering': - metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions) - metrics['homogeneity'] = homogeneity_score(true_labels, predictions) - metrics['completeness'] = completeness_score(true_labels, predictions) - if len(np.unique(predictions)) > 1: # 确保有多个簇 - metrics['silhouette'] = silhouette_score( - true_labels.reshape(-1, 1), - predictions - ) - - return metrics \ No newline at end of file