修改--将模型管理所有的方法都集成到一个文件中
This commit is contained in:
parent
d222cfe968
commit
c2b2e20a4a
@ -1,2 +1,37 @@
|
|||||||
mlfow:
|
# 服务器配置
|
||||||
uri: "http://localhost:5000"
|
host: "10.0.0.202"
|
||||||
|
port: 8992
|
||||||
|
workers: 4
|
||||||
|
debug: true
|
||||||
|
|
||||||
|
# MLflow配置
|
||||||
|
mlflow_uri: "http://10.0.0.202:5000"
|
||||||
|
|
||||||
|
# 数据处理配置
|
||||||
|
dataset:
|
||||||
|
raw_dir: "dataset/dataset_raw"
|
||||||
|
processed_dir: "dataset/dataset_processed"
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
model:
|
||||||
|
save_dir: "models"
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
# 系统监控配置
|
||||||
|
monitor:
|
||||||
|
log_dir: ".log"
|
||||||
|
resource_check_interval: 60 # 秒
|
||||||
|
cleanup_interval: 86400 # 24小时
|
||||||
|
max_log_days: 30 # 日志保留天数
|
||||||
|
|
||||||
|
# 安全配置
|
||||||
|
security:
|
||||||
|
secret_key: "your-secret-key"
|
||||||
|
token_expire_minutes: 1440 # 24小时
|
||||||
|
|
||||||
|
# 性能配置
|
||||||
|
performance:
|
||||||
|
max_concurrent_trains: 4 # 最大并发训练数
|
||||||
|
cache_size: 1024 # MB
|
||||||
|
timeout: 3600 # 秒
|
||||||
@ -523,6 +523,9 @@ Error Response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 2.9 模型优化 -- 未实现
|
||||||
|
|
||||||
|
|
||||||
## 3. 系统监控
|
## 3. 系统监控
|
||||||
### 3.1 获取资源使用情况
|
### 3.1 获取资源使用情况
|
||||||
```http
|
```http
|
||||||
@ -681,7 +684,7 @@ Error Response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.3 获取系统中训练状态
|
### 3.3 获取系统中训练状态 ---- 未完成, 等开发系统后台时再实现.
|
||||||
```http
|
```http
|
||||||
GET /api/train/status/{task_id}
|
GET /api/train/status/{task_id}
|
||||||
|
|
||||||
@ -764,6 +767,76 @@ Error Response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 4. 系统后台整体实现
|
||||||
|
|
||||||
|
### 4.1 系统架构
|
||||||
|
```
|
||||||
|
MLPlatform/
|
||||||
|
├── api/ # API接口层
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── data_api.py # 数据处理相关接口
|
||||||
|
│ ├── model_api.py # 模型相关接口
|
||||||
|
│ └── system_api.py # 系统监控相关接口
|
||||||
|
├── function/ # 功能实现层
|
||||||
|
│ ├── data_processor.py # 数据处理类
|
||||||
|
│ ├── model_manager.py # 模型管理类
|
||||||
|
│ ├── model_trainer.py # 模型训练类
|
||||||
|
│ ├── system_monitor.py # 系统监控类
|
||||||
|
│ └── utils/ # 工具函数
|
||||||
|
├── config/ # 配置文件
|
||||||
|
│ └── config.yaml # 系统配置
|
||||||
|
├── dataset/ # 数据集
|
||||||
|
│ ├── dataset_raw/ # 原始数据
|
||||||
|
│ └── dataset_processed/ # 处理后数据
|
||||||
|
├── .log/ # 日志文件
|
||||||
|
├── doc/ # 文档
|
||||||
|
└── main.py # 主程序入口
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 技术栈
|
||||||
|
- FastAPI: Web框架
|
||||||
|
- MLflow: 模型管理和实验跟踪
|
||||||
|
- PyTorch/Scikit-learn: 机器学习框架
|
||||||
|
- Pydantic: 数据验证
|
||||||
|
- Uvicorn: ASGI服务器
|
||||||
|
|
||||||
|
### 4.3 主要功能
|
||||||
|
1. 异步任务处理
|
||||||
|
- 支持多个模型同时训练
|
||||||
|
- 后台任务状态监控
|
||||||
|
- 任务队列管理
|
||||||
|
|
||||||
|
2. 实时监控
|
||||||
|
- 系统资源监控
|
||||||
|
- 训练进度监控
|
||||||
|
- 日志实时查看
|
||||||
|
|
||||||
|
3. 错误处理
|
||||||
|
- 全局异常处理
|
||||||
|
- 错误日志记录
|
||||||
|
- 优雅降级策略
|
||||||
|
|
||||||
|
4. 安全性
|
||||||
|
- API认证授权
|
||||||
|
- 请求限流
|
||||||
|
- 参数验证
|
||||||
|
|
||||||
|
### 4.4 性能优化
|
||||||
|
1. 数据处理
|
||||||
|
- 数据流式处理
|
||||||
|
- 缓存机制
|
||||||
|
- 批量处理
|
||||||
|
|
||||||
|
2. 模型训练
|
||||||
|
- GPU利用优化
|
||||||
|
- 分布式训练支持
|
||||||
|
- 模型检查点
|
||||||
|
|
||||||
|
3. 系统监控
|
||||||
|
- 性能指标采集
|
||||||
|
- 资源使用预警
|
||||||
|
- 自动清理机制
|
||||||
|
|
||||||
## 附录A:方法详细说明
|
## 附录A:方法详细说明
|
||||||
|
|
||||||
### A1. 数据预处理方法
|
### A1. 数据预处理方法
|
||||||
|
|||||||
Binary file not shown.
@ -1,79 +0,0 @@
|
|||||||
import yaml
|
|
||||||
from typing import Dict, List
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
class MethodReader:
|
|
||||||
"""方法配置读取器"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化方法读取器"""
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.method_config = self._load_metrics()
|
|
||||||
|
|
||||||
|
|
||||||
def _load_metrics(self) -> Dict:
|
|
||||||
"""加载方法配置文件"""
|
|
||||||
try:
|
|
||||||
config_path = Path('model/metrics.yaml')
|
|
||||||
if not config_path.exists():
|
|
||||||
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
self.logger.info("Successfully loaded method config")
|
|
||||||
return config
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading method config: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_metrics(self) -> Dict:
|
|
||||||
"""获取预处理方法列表"""
|
|
||||||
try:
|
|
||||||
metrics = []
|
|
||||||
|
|
||||||
# 分类方法
|
|
||||||
classification_metrics = self.method_config.get('classification', {})
|
|
||||||
if classification_metrics:
|
|
||||||
metrics.append({
|
|
||||||
"name": "classification_metrics",
|
|
||||||
"description": "分类方法评价指标",
|
|
||||||
"metric": classification_metrics
|
|
||||||
})
|
|
||||||
|
|
||||||
# 回归方法
|
|
||||||
regression_metrics = self.method_config.get('regression', {})
|
|
||||||
if regression_metrics:
|
|
||||||
metrics.append({
|
|
||||||
"name": "regression_metrics",
|
|
||||||
"description": "回归方法评价指标",
|
|
||||||
"metric": regression_metrics
|
|
||||||
})
|
|
||||||
|
|
||||||
# 聚类方法
|
|
||||||
clustering_metrics = self.method_config.get('clustering', {})
|
|
||||||
if clustering_metrics:
|
|
||||||
metrics.append({
|
|
||||||
"name": "clustering_metrics",
|
|
||||||
"description": "聚类方法评价指标",
|
|
||||||
"metric": clustering_metrics
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"metric": metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,135 +0,0 @@
|
|||||||
import yaml
|
|
||||||
from typing import Dict, List
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
class MethodReader:
|
|
||||||
"""方法配置读取器"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化方法读取器"""
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self.method_config = self._load_model_config()
|
|
||||||
self.parameter_config = self._load_parameter_config()
|
|
||||||
|
|
||||||
def _load_model_config(self) -> Dict:
|
|
||||||
"""加载方法配置文件"""
|
|
||||||
try:
|
|
||||||
config_path = Path('model/model.yaml')
|
|
||||||
if not config_path.exists():
|
|
||||||
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
self.logger.info("Successfully loaded method config")
|
|
||||||
return config
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading method config: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _load_parameter_config(self) -> Dict:
|
|
||||||
"""加载参数配置文件"""
|
|
||||||
try:
|
|
||||||
config_path = Path('model/parameter.yaml')
|
|
||||||
if not config_path.exists():
|
|
||||||
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
|
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
self.logger.info("Successfully loaded parameter config")
|
|
||||||
return config
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_models(self) -> Dict:
|
|
||||||
"""获取预处理方法列表"""
|
|
||||||
try:
|
|
||||||
models = []
|
|
||||||
|
|
||||||
# 分类方法
|
|
||||||
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
|
|
||||||
if classification_algorithms:
|
|
||||||
models.append({
|
|
||||||
"name": "classification_algorithms",
|
|
||||||
"description": "分类方法",
|
|
||||||
"method": classification_algorithms
|
|
||||||
})
|
|
||||||
|
|
||||||
# 回归方法
|
|
||||||
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
|
|
||||||
if regression_algorithms:
|
|
||||||
models.append({
|
|
||||||
"name": "regression_algorithms",
|
|
||||||
"description": "回归方法",
|
|
||||||
"method": regression_algorithms
|
|
||||||
})
|
|
||||||
|
|
||||||
# 聚类方法
|
|
||||||
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
|
|
||||||
if clustering_algorithms:
|
|
||||||
models.append({
|
|
||||||
"name": "clustering_algorithms",
|
|
||||||
"description": "聚类方法",
|
|
||||||
"method": clustering_algorithms
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"models": models
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_model_details(self, method_name: str) -> Dict:
|
|
||||||
"""获取指定方法的详细信息"""
|
|
||||||
try:
|
|
||||||
# 在各个方法类别中查找方法原理和优缺点
|
|
||||||
method_info = None
|
|
||||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
|
||||||
if method_name in self.method_config.get(category, {}):
|
|
||||||
method_info = self.method_config[category][method_name]
|
|
||||||
break
|
|
||||||
|
|
||||||
if method_info is None:
|
|
||||||
raise ValueError(f"Method {method_name} not found in method config")
|
|
||||||
|
|
||||||
# 查找方法参数信息
|
|
||||||
parameter_info = None
|
|
||||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
|
||||||
if method_name in self.parameter_config.get(category, {}):
|
|
||||||
parameter_info = self.parameter_config[category][method_name]
|
|
||||||
break
|
|
||||||
|
|
||||||
if parameter_info is None:
|
|
||||||
raise ValueError(f"Method {method_name} not found in parameter config")
|
|
||||||
|
|
||||||
# 组合返回信息
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"method": {
|
|
||||||
"name": method_name,
|
|
||||||
"description": parameter_info.get('description', ''),
|
|
||||||
"principle": method_info.get('principle', ''),
|
|
||||||
"advantages": method_info.get('advantages', []),
|
|
||||||
"disadvantages": method_info.get('disadvantages', []),
|
|
||||||
"applicable_scenarios": method_info.get('applicable_scenarios', []),
|
|
||||||
"parameters": parameter_info.get('parameters', [])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error getting method details: {str(e)}")
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
@ -18,6 +18,11 @@ from sklearn.metrics import (
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
模型管理整体集成
|
||||||
|
'''
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
"""模型管理类"""
|
"""模型管理类"""
|
||||||
|
|
||||||
@ -27,12 +32,33 @@ class ModelManager:
|
|||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self._setup_logging()
|
self._setup_logging()
|
||||||
self._metrics_map()
|
self._metrics_map()
|
||||||
|
self.method_config = self._load_metrics()
|
||||||
|
|
||||||
|
self.method_config = self._load_model_config()
|
||||||
|
self.parameter_config = self._load_parameter_config()
|
||||||
|
|
||||||
# 初始化MLflow客户端
|
# 初始化MLflow客户端
|
||||||
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
|
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
|
||||||
mlflow.set_tracking_uri(self.mlflow_uri)
|
mlflow.set_tracking_uri(self.mlflow_uri)
|
||||||
self.client = MlflowClient()
|
self.client = MlflowClient()
|
||||||
|
|
||||||
|
def _load_metrics(self) -> Dict:
|
||||||
|
"""加载方法配置文件"""
|
||||||
|
try:
|
||||||
|
config_path = Path('model/metrics.yaml')
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
self.logger.info("Successfully loaded method config")
|
||||||
|
return config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error loading method config: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def _setup_logging(self):
|
def _setup_logging(self):
|
||||||
"""设置日志"""
|
"""设置日志"""
|
||||||
log_dir = Path('.log')
|
log_dir = Path('.log')
|
||||||
@ -47,6 +73,54 @@ class ModelManager:
|
|||||||
self.logger.addHandler(file_handler)
|
self.logger.addHandler(file_handler)
|
||||||
self.logger.setLevel(logging.INFO)
|
self.logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
def _load_model_config(self) -> Dict:
|
||||||
|
"""加载方法配置文件"""
|
||||||
|
try:
|
||||||
|
config_path = Path('model/model.yaml')
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config_model = yaml.safe_load(f)
|
||||||
|
|
||||||
|
self.logger.info("Successfully loaded model config")
|
||||||
|
|
||||||
|
|
||||||
|
config_path = Path('model/metrics.yaml')
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config_metric = yaml.safe_load(f)
|
||||||
|
|
||||||
|
self.logger.info("Successfully loaded metrics config")
|
||||||
|
|
||||||
|
|
||||||
|
config = {**config_model, **config_metric}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error loading method or metric config: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_parameter_config(self) -> Dict:
|
||||||
|
"""加载参数配置文件"""
|
||||||
|
try:
|
||||||
|
config_path = Path('model/parameter.yaml')
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
self.logger.info("Successfully loaded parameter config")
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _metrics_map(self):
|
def _metrics_map(self):
|
||||||
self.metrics_map={
|
self.metrics_map={
|
||||||
'accuracy' : accuracy_score,
|
'accuracy' : accuracy_score,
|
||||||
@ -63,7 +137,99 @@ class ModelManager:
|
|||||||
'completeness': completeness_score,
|
'completeness': completeness_score,
|
||||||
'silhouette' : silhouette_score
|
'silhouette' : silhouette_score
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
|
||||||
|
"""获取算法信息"""
|
||||||
|
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||||
|
if algorithm_name in self.method_config.get(category, {}):
|
||||||
|
return self.method_config[category][algorithm_name]
|
||||||
|
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_class(self, algorithm_name: str):
|
||||||
|
"""获取模型类"""
|
||||||
|
# 分类算法
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.svm import SVC, OneClassSVM
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
from sklearn.ensemble import (
|
||||||
|
RandomForestClassifier, GradientBoostingClassifier,
|
||||||
|
AdaBoostClassifier, IsolationForest
|
||||||
|
)
|
||||||
|
from sklearn.naive_bayes import GaussianNB
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from sklearn.neural_network import MLPClassifier
|
||||||
|
import xgboost as xgb
|
||||||
|
import lightgbm as lgb
|
||||||
|
from catboost import CatBoostClassifier
|
||||||
|
|
||||||
|
# 回归算法
|
||||||
|
from sklearn.linear_model import (
|
||||||
|
LinearRegression, Ridge, Lasso,
|
||||||
|
ElasticNet
|
||||||
|
)
|
||||||
|
from sklearn.svm import SVR
|
||||||
|
from sklearn.tree import DecisionTreeRegressor
|
||||||
|
from sklearn.ensemble import (
|
||||||
|
RandomForestRegressor, GradientBoostingRegressor,
|
||||||
|
AdaBoostRegressor
|
||||||
|
)
|
||||||
|
from catboost import CatBoostRegressor
|
||||||
|
from sklearn.neural_network import MLPRegressor
|
||||||
|
|
||||||
|
# 聚类算法
|
||||||
|
from sklearn.cluster import (
|
||||||
|
KMeans, AgglomerativeClustering,
|
||||||
|
DBSCAN, SpectralClustering
|
||||||
|
)
|
||||||
|
from sklearn.mixture import GaussianMixture
|
||||||
|
|
||||||
|
algorithm_map = {
|
||||||
|
# 分类算法
|
||||||
|
'LogisticRegression': LogisticRegression,
|
||||||
|
'SVC': SVC,
|
||||||
|
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
|
||||||
|
'DecisionTreeClassifier': DecisionTreeClassifier,
|
||||||
|
'RandomForestClassifier': RandomForestClassifier,
|
||||||
|
'XGBClassifier': xgb.XGBClassifier,
|
||||||
|
'AdaBoostClassifier': AdaBoostClassifier,
|
||||||
|
'CatBoostClassifier': CatBoostClassifier,
|
||||||
|
'LGBMClassifier': lgb.LGBMClassifier,
|
||||||
|
'GaussianNB': GaussianNB,
|
||||||
|
'KNeighborsClassifier': KNeighborsClassifier,
|
||||||
|
'MLPClassifier': MLPClassifier,
|
||||||
|
'GradientBoostingClassifier': GradientBoostingClassifier,
|
||||||
|
|
||||||
|
# 回归算法
|
||||||
|
'LinearRegression': LinearRegression,
|
||||||
|
'Ridge': Ridge,
|
||||||
|
'Lasso': Lasso,
|
||||||
|
'ElasticNet': ElasticNet,
|
||||||
|
'SVR': SVR,
|
||||||
|
'DecisionTreeRegressor': DecisionTreeRegressor,
|
||||||
|
'RandomForestRegressor': RandomForestRegressor,
|
||||||
|
'XGBRegressor': xgb.XGBRegressor,
|
||||||
|
'AdaBoostRegressor': AdaBoostRegressor,
|
||||||
|
'CatBoostRegressor': CatBoostRegressor,
|
||||||
|
'LGBMRegressor': lgb.LGBMRegressor,
|
||||||
|
'MLPRegressor': MLPRegressor,
|
||||||
|
|
||||||
|
# 聚类算法
|
||||||
|
'KMeans': KMeans,
|
||||||
|
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制
|
||||||
|
'AgglomerativeClustering': AgglomerativeClustering,
|
||||||
|
'DBSCAN': DBSCAN,
|
||||||
|
'GaussianMixture': GaussianMixture,
|
||||||
|
'SpectralClustering': SpectralClustering
|
||||||
|
}
|
||||||
|
|
||||||
|
if algorithm_name not in algorithm_map:
|
||||||
|
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
||||||
|
|
||||||
|
return algorithm_map[algorithm_name]
|
||||||
|
|
||||||
|
|
||||||
def get_finished_models(
|
def get_finished_models(
|
||||||
self,
|
self,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@ -440,7 +606,96 @@ class ModelManager:
|
|||||||
'execution_time': f"{execution_time:.2f}s"
|
'execution_time': f"{execution_time:.2f}s"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_models(self) -> Dict:
|
||||||
|
"""获取预处理方法列表"""
|
||||||
|
try:
|
||||||
|
models = []
|
||||||
|
|
||||||
|
# 分类方法
|
||||||
|
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
|
||||||
|
if classification_algorithms:
|
||||||
|
models.append({
|
||||||
|
"name": "classification_algorithms",
|
||||||
|
"description": "分类方法",
|
||||||
|
"method": classification_algorithms
|
||||||
|
})
|
||||||
|
|
||||||
|
# 回归方法
|
||||||
|
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
|
||||||
|
if regression_algorithms:
|
||||||
|
models.append({
|
||||||
|
"name": "regression_algorithms",
|
||||||
|
"description": "回归方法",
|
||||||
|
"method": regression_algorithms
|
||||||
|
})
|
||||||
|
|
||||||
|
# 聚类方法
|
||||||
|
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
|
||||||
|
if clustering_algorithms:
|
||||||
|
models.append({
|
||||||
|
"name": "clustering_algorithms",
|
||||||
|
"description": "聚类方法",
|
||||||
|
"method": clustering_algorithms
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"models": models
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_details(self, method_name: str) -> Dict:
|
||||||
|
"""获取指定方法的详细信息"""
|
||||||
|
try:
|
||||||
|
# 在各个方法类别中查找方法原理和优缺点
|
||||||
|
method_info = None
|
||||||
|
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||||
|
if method_name in self.method_config.get(category, {}):
|
||||||
|
method_info = self.method_config[category][method_name]
|
||||||
|
break
|
||||||
|
|
||||||
|
if method_info is None:
|
||||||
|
raise ValueError(f"Method {method_name} not found in method config")
|
||||||
|
|
||||||
|
# 查找方法参数信息
|
||||||
|
parameter_info = None
|
||||||
|
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||||
|
if method_name in self.parameter_config.get(category, {}):
|
||||||
|
parameter_info = self.parameter_config[category][method_name]
|
||||||
|
break
|
||||||
|
|
||||||
|
if parameter_info is None:
|
||||||
|
raise ValueError(f"Method {method_name} not found in parameter config")
|
||||||
|
|
||||||
|
# 组合返回信息
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"method": {
|
||||||
|
"name": method_name,
|
||||||
|
"description": parameter_info.get('description', ''),
|
||||||
|
"principle": method_info.get('principle', ''),
|
||||||
|
"advantages": method_info.get('advantages', []),
|
||||||
|
"disadvantages": method_info.get('disadvantages', []),
|
||||||
|
"applicable_scenarios": method_info.get('applicable_scenarios', []),
|
||||||
|
"parameters": parameter_info.get('parameters', [])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting method details: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# error_msg = f"预测过程发生错误: {str(e)}"
|
# error_msg = f"预测过程发生错误: {str(e)}"
|
||||||
# self.logger.error(error_msg)
|
# self.logger.error(error_msg)
|
||||||
@ -451,4 +706,184 @@ class ModelManager:
|
|||||||
# 'error_type': type(e).__name__,
|
# 'error_type': type(e).__name__,
|
||||||
# 'error_message': str(e)
|
# 'error_message': str(e)
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
|
|
||||||
|
def get_metrics(self) -> Dict:
|
||||||
|
"""获取预处理方法列表"""
|
||||||
|
try:
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
# 分类方法
|
||||||
|
classification_metrics = self.method_config.get('classification', {})
|
||||||
|
if classification_metrics:
|
||||||
|
metrics.append({
|
||||||
|
"name": "classification_metrics",
|
||||||
|
"description": "分类方法评价指标",
|
||||||
|
"metric": classification_metrics
|
||||||
|
})
|
||||||
|
|
||||||
|
# 回归方法
|
||||||
|
regression_metrics = self.method_config.get('regression', {})
|
||||||
|
if regression_metrics:
|
||||||
|
metrics.append({
|
||||||
|
"name": "regression_metrics",
|
||||||
|
"description": "回归方法评价指标",
|
||||||
|
"metric": regression_metrics
|
||||||
|
})
|
||||||
|
|
||||||
|
# 聚类方法
|
||||||
|
clustering_metrics = self.method_config.get('clustering', {})
|
||||||
|
if clustering_metrics:
|
||||||
|
metrics.append({
|
||||||
|
"name": "clustering_metrics",
|
||||||
|
"description": "聚类方法评价指标",
|
||||||
|
"metric": clustering_metrics
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"metric": metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
train_data: Dict,
|
||||||
|
val_data: Dict,
|
||||||
|
model_config: Dict,
|
||||||
|
experiment_name: str
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
训练模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: 训练数据,包含特征和标签
|
||||||
|
val_data: 验证数据,包含特征和标签
|
||||||
|
model_config: 模型配置,包含算法名称和参数
|
||||||
|
experiment_name: MLflow实验名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练结果字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查实验是否存在且被删除
|
||||||
|
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||||
|
if experiment and experiment.lifecycle_stage == 'deleted':
|
||||||
|
# 如果实验被删除,则创建一个新的实验名称
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
new_experiment_name = f"{experiment_name}_{timestamp}"
|
||||||
|
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
|
||||||
|
experiment_name = new_experiment_name
|
||||||
|
|
||||||
|
# 设置MLflow实验
|
||||||
|
mlflow.set_experiment(experiment_name)
|
||||||
|
|
||||||
|
with mlflow.start_run() as run:
|
||||||
|
# 记录基本信息
|
||||||
|
mlflow.log_param('algorithm', model_config['algorithm'])
|
||||||
|
mlflow.log_param('task_type', model_config['task_type'])
|
||||||
|
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
||||||
|
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
||||||
|
|
||||||
|
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
# mlflow.log_param('start_time', timestamp)
|
||||||
|
|
||||||
|
# 记录模型参数
|
||||||
|
for param_name, param_value in model_config['params'].items():
|
||||||
|
mlflow.log_param(param_name, param_value)
|
||||||
|
|
||||||
|
# 记录算法信息
|
||||||
|
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
|
||||||
|
mlflow.log_param('principle', algorithm_info['principle'])
|
||||||
|
mlflow.log_param('advantages', str(algorithm_info['advantages']))
|
||||||
|
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
|
||||||
|
|
||||||
|
# 特殊处理KMeans++
|
||||||
|
if model_config['algorithm'] == 'KMeansPlusPlus':
|
||||||
|
model_config['params']['init'] = 'k-means++'
|
||||||
|
|
||||||
|
# 获取模型类和信息
|
||||||
|
model_class = self._get_model_class(model_config['algorithm'])
|
||||||
|
|
||||||
|
# 创建模型实例
|
||||||
|
model = model_class(**model_config['params'])
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||||||
|
model.fit(train_data['features'], train_data['labels'])
|
||||||
|
|
||||||
|
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
# mlflow.log_param('end_time', timestamp)
|
||||||
|
|
||||||
|
# 在验证集上评估
|
||||||
|
val_predictions = model.predict(val_data['features'])
|
||||||
|
metrics = self._calculate_metrics(
|
||||||
|
val_data['labels'],
|
||||||
|
val_predictions,
|
||||||
|
model_config['task_type']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录指标
|
||||||
|
for metric_name, metric_value in metrics.items():
|
||||||
|
mlflow.log_metric(metric_name, metric_value)
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
mlflow.sklearn.log_model(model, "model")
|
||||||
|
|
||||||
|
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'run_id': run.info.run_id,
|
||||||
|
'metrics': metrics,
|
||||||
|
'algorithm_info': algorithm_info
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error training model: {str(e)}"
|
||||||
|
self.logger.error(error_msg)
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'message': error_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_metrics(
|
||||||
|
self,
|
||||||
|
true_labels: np.ndarray,
|
||||||
|
predictions: np.ndarray,
|
||||||
|
task_type: str
|
||||||
|
) -> Dict:
|
||||||
|
"""计算评估指标"""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
if task_type == 'classification':
|
||||||
|
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
||||||
|
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
||||||
|
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
||||||
|
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
||||||
|
if len(np.unique(true_labels)) == 2: # 二分类问题
|
||||||
|
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
|
||||||
|
|
||||||
|
elif task_type == 'regression':
|
||||||
|
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
||||||
|
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
||||||
|
metrics['rmse'] = np.sqrt(metrics['mse'])
|
||||||
|
metrics['r2'] = r2_score(true_labels, predictions)
|
||||||
|
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
||||||
|
|
||||||
|
elif task_type == 'clustering':
|
||||||
|
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
||||||
|
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
||||||
|
metrics['completeness'] = completeness_score(true_labels, predictions)
|
||||||
|
if len(np.unique(predictions)) > 1: # 确保有多个簇
|
||||||
|
metrics['silhouette'] = silhouette_score(
|
||||||
|
true_labels.reshape(-1, 1),
|
||||||
|
predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
@ -1,298 +0,0 @@
|
|||||||
import numpy as pd
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
import datetime
|
|
||||||
import yaml
|
|
||||||
import mlflow
|
|
||||||
from sklearn.metrics import (
|
|
||||||
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
|
|
||||||
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
|
|
||||||
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
|
|
||||||
)
|
|
||||||
|
|
||||||
class ModelTrainer:
|
|
||||||
"""模型训练类"""
|
|
||||||
|
|
||||||
def __init__(self, config: Dict = None):
|
|
||||||
"""初始化模型训练器"""
|
|
||||||
self.config = config or {}
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
self._setup_logging()
|
|
||||||
self._load_metrics()
|
|
||||||
self._load_parameters()
|
|
||||||
self._load_model_info()
|
|
||||||
|
|
||||||
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
|
|
||||||
# config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
# 初始化MLflow
|
|
||||||
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
|
|
||||||
|
|
||||||
def _setup_logging(self):
|
|
||||||
"""设置日志"""
|
|
||||||
log_dir = Path('.log')
|
|
||||||
log_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(
|
|
||||||
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(
|
|
||||||
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
)
|
|
||||||
self.logger.addHandler(file_handler)
|
|
||||||
self.logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
def _load_metrics(self):
|
|
||||||
"""加载评估指标配置"""
|
|
||||||
try:
|
|
||||||
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
|
|
||||||
self.metrics_config = yaml.safe_load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading metrics config: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _load_parameters(self):
|
|
||||||
"""加载模型参数配置"""
|
|
||||||
try:
|
|
||||||
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
|
|
||||||
self.parameter_config = yaml.safe_load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _load_model_info(self):
|
|
||||||
"""加载模型信息配置"""
|
|
||||||
try:
|
|
||||||
with open('model/model.yaml', 'r', encoding='utf-8') as f:
|
|
||||||
self.model_info = yaml.safe_load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading model info: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _get_model_class(self, algorithm_name: str):
|
|
||||||
"""获取模型类"""
|
|
||||||
# 分类算法
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.svm import SVC, OneClassSVM
|
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
|
||||||
from sklearn.ensemble import (
|
|
||||||
RandomForestClassifier, GradientBoostingClassifier,
|
|
||||||
AdaBoostClassifier, IsolationForest
|
|
||||||
)
|
|
||||||
from sklearn.naive_bayes import GaussianNB
|
|
||||||
from sklearn.neighbors import KNeighborsClassifier
|
|
||||||
from sklearn.neural_network import MLPClassifier
|
|
||||||
import xgboost as xgb
|
|
||||||
import lightgbm as lgb
|
|
||||||
from catboost import CatBoostClassifier
|
|
||||||
|
|
||||||
# 回归算法
|
|
||||||
from sklearn.linear_model import (
|
|
||||||
LinearRegression, Ridge, Lasso,
|
|
||||||
ElasticNet
|
|
||||||
)
|
|
||||||
from sklearn.svm import SVR
|
|
||||||
from sklearn.tree import DecisionTreeRegressor
|
|
||||||
from sklearn.ensemble import (
|
|
||||||
RandomForestRegressor, GradientBoostingRegressor,
|
|
||||||
AdaBoostRegressor
|
|
||||||
)
|
|
||||||
from catboost import CatBoostRegressor
|
|
||||||
from sklearn.neural_network import MLPRegressor
|
|
||||||
|
|
||||||
# 聚类算法
|
|
||||||
from sklearn.cluster import (
|
|
||||||
KMeans, AgglomerativeClustering,
|
|
||||||
DBSCAN, SpectralClustering
|
|
||||||
)
|
|
||||||
from sklearn.mixture import GaussianMixture
|
|
||||||
|
|
||||||
algorithm_map = {
|
|
||||||
# 分类算法
|
|
||||||
'LogisticRegression': LogisticRegression,
|
|
||||||
'SVC': SVC,
|
|
||||||
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
|
|
||||||
'DecisionTreeClassifier': DecisionTreeClassifier,
|
|
||||||
'RandomForestClassifier': RandomForestClassifier,
|
|
||||||
'XGBClassifier': xgb.XGBClassifier,
|
|
||||||
'AdaBoostClassifier': AdaBoostClassifier,
|
|
||||||
'CatBoostClassifier': CatBoostClassifier,
|
|
||||||
'LGBMClassifier': lgb.LGBMClassifier,
|
|
||||||
'GaussianNB': GaussianNB,
|
|
||||||
'KNeighborsClassifier': KNeighborsClassifier,
|
|
||||||
'MLPClassifier': MLPClassifier,
|
|
||||||
'GradientBoostingClassifier': GradientBoostingClassifier,
|
|
||||||
|
|
||||||
# 回归算法
|
|
||||||
'LinearRegression': LinearRegression,
|
|
||||||
'Ridge': Ridge,
|
|
||||||
'Lasso': Lasso,
|
|
||||||
'ElasticNet': ElasticNet,
|
|
||||||
'SVR': SVR,
|
|
||||||
'DecisionTreeRegressor': DecisionTreeRegressor,
|
|
||||||
'RandomForestRegressor': RandomForestRegressor,
|
|
||||||
'XGBRegressor': xgb.XGBRegressor,
|
|
||||||
'AdaBoostRegressor': AdaBoostRegressor,
|
|
||||||
'CatBoostRegressor': CatBoostRegressor,
|
|
||||||
'LGBMRegressor': lgb.LGBMRegressor,
|
|
||||||
'MLPRegressor': MLPRegressor,
|
|
||||||
|
|
||||||
# 聚类算法
|
|
||||||
'KMeans': KMeans,
|
|
||||||
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制
|
|
||||||
'AgglomerativeClustering': AgglomerativeClustering,
|
|
||||||
'DBSCAN': DBSCAN,
|
|
||||||
'GaussianMixture': GaussianMixture,
|
|
||||||
'SpectralClustering': SpectralClustering
|
|
||||||
}
|
|
||||||
|
|
||||||
if algorithm_name not in algorithm_map:
|
|
||||||
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
|
||||||
|
|
||||||
return algorithm_map[algorithm_name]
|
|
||||||
|
|
||||||
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
|
|
||||||
"""获取算法信息"""
|
|
||||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
|
||||||
if algorithm_name in self.model_info.get(category, {}):
|
|
||||||
return self.model_info[category][algorithm_name]
|
|
||||||
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
|
|
||||||
|
|
||||||
def train_model(
|
|
||||||
self,
|
|
||||||
train_data: Dict,
|
|
||||||
val_data: Dict,
|
|
||||||
model_config: Dict,
|
|
||||||
experiment_name: str
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
训练模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_data: 训练数据,包含特征和标签
|
|
||||||
val_data: 验证数据,包含特征和标签
|
|
||||||
model_config: 模型配置,包含算法名称和参数
|
|
||||||
experiment_name: MLflow实验名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练结果字典
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 检查实验是否存在且被删除
|
|
||||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
|
||||||
if experiment and experiment.lifecycle_stage == 'deleted':
|
|
||||||
# 如果实验被删除,则创建一个新的实验名称
|
|
||||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
new_experiment_name = f"{experiment_name}_{timestamp}"
|
|
||||||
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
|
|
||||||
experiment_name = new_experiment_name
|
|
||||||
|
|
||||||
# 设置MLflow实验
|
|
||||||
mlflow.set_experiment(experiment_name)
|
|
||||||
|
|
||||||
with mlflow.start_run() as run:
|
|
||||||
# 记录基本信息
|
|
||||||
mlflow.log_param('algorithm', model_config['algorithm'])
|
|
||||||
mlflow.log_param('task_type', model_config['task_type'])
|
|
||||||
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
|
||||||
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
|
||||||
|
|
||||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
# mlflow.log_param('start_time', timestamp)
|
|
||||||
|
|
||||||
# 记录模型参数
|
|
||||||
for param_name, param_value in model_config['params'].items():
|
|
||||||
mlflow.log_param(param_name, param_value)
|
|
||||||
|
|
||||||
# 记录算法信息
|
|
||||||
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
|
|
||||||
mlflow.log_param('principle', algorithm_info['principle'])
|
|
||||||
mlflow.log_param('advantages', str(algorithm_info['advantages']))
|
|
||||||
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
|
|
||||||
|
|
||||||
# 特殊处理KMeans++
|
|
||||||
if model_config['algorithm'] == 'KMeansPlusPlus':
|
|
||||||
model_config['params']['init'] = 'k-means++'
|
|
||||||
|
|
||||||
# 获取模型类和信息
|
|
||||||
model_class = self._get_model_class(model_config['algorithm'])
|
|
||||||
|
|
||||||
# 创建模型实例
|
|
||||||
model = model_class(**model_config['params'])
|
|
||||||
|
|
||||||
# 训练模型
|
|
||||||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
|
||||||
model.fit(train_data['features'], train_data['labels'])
|
|
||||||
|
|
||||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
# mlflow.log_param('end_time', timestamp)
|
|
||||||
|
|
||||||
# 在验证集上评估
|
|
||||||
val_predictions = model.predict(val_data['features'])
|
|
||||||
metrics = self._calculate_metrics(
|
|
||||||
val_data['labels'],
|
|
||||||
val_predictions,
|
|
||||||
model_config['task_type']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记录指标
|
|
||||||
for metric_name, metric_value in metrics.items():
|
|
||||||
mlflow.log_metric(metric_name, metric_value)
|
|
||||||
|
|
||||||
# 保存模型
|
|
||||||
mlflow.sklearn.log_model(model, "model")
|
|
||||||
|
|
||||||
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'status': 'success',
|
|
||||||
'run_id': run.info.run_id,
|
|
||||||
'metrics': metrics,
|
|
||||||
'algorithm_info': algorithm_info
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Error training model: {str(e)}"
|
|
||||||
self.logger.error(error_msg)
|
|
||||||
return {
|
|
||||||
'status': 'error',
|
|
||||||
'message': error_msg
|
|
||||||
}
|
|
||||||
|
|
||||||
def _calculate_metrics(
|
|
||||||
self,
|
|
||||||
true_labels: np.ndarray,
|
|
||||||
predictions: np.ndarray,
|
|
||||||
task_type: str
|
|
||||||
) -> Dict:
|
|
||||||
"""计算评估指标"""
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
if task_type == 'classification':
|
|
||||||
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
|
||||||
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
|
||||||
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
|
||||||
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
|
||||||
if len(np.unique(true_labels)) == 2: # 二分类问题
|
|
||||||
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
|
|
||||||
|
|
||||||
elif task_type == 'regression':
|
|
||||||
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
|
||||||
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
|
||||||
metrics['rmse'] = np.sqrt(metrics['mse'])
|
|
||||||
metrics['r2'] = r2_score(true_labels, predictions)
|
|
||||||
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
|
||||||
|
|
||||||
elif task_type == 'clustering':
|
|
||||||
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
|
||||||
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
|
||||||
metrics['completeness'] = completeness_score(true_labels, predictions)
|
|
||||||
if len(np.unique(predictions)) > 1: # 确保有多个簇
|
|
||||||
metrics['silhouette'] = silhouette_score(
|
|
||||||
true_labels.reshape(-1, 1),
|
|
||||||
predictions
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
Loading…
Reference in New Issue
Block a user