修改--将模型管理所有的方法都集成到一个文件中
This commit is contained in:
parent
d222cfe968
commit
c2b2e20a4a
@ -1,2 +1,37 @@
|
||||
mlfow:
|
||||
uri: "http://localhost:5000"
|
||||
# 服务器配置
|
||||
host: "10.0.0.202"
|
||||
port: 8992
|
||||
workers: 4
|
||||
debug: true
|
||||
|
||||
# MLflow配置
|
||||
mlflow_uri: "http://10.0.0.202:5000"
|
||||
|
||||
# 数据处理配置
|
||||
dataset:
|
||||
raw_dir: "dataset/dataset_raw"
|
||||
processed_dir: "dataset/dataset_processed"
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
save_dir: "models"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
|
||||
# 系统监控配置
|
||||
monitor:
|
||||
log_dir: ".log"
|
||||
resource_check_interval: 60 # 秒
|
||||
cleanup_interval: 86400 # 24小时
|
||||
max_log_days: 30 # 日志保留天数
|
||||
|
||||
# 安全配置
|
||||
security:
|
||||
secret_key: "your-secret-key"
|
||||
token_expire_minutes: 1440 # 24小时
|
||||
|
||||
# 性能配置
|
||||
performance:
|
||||
max_concurrent_trains: 4 # 最大并发训练数
|
||||
cache_size: 1024 # MB
|
||||
timeout: 3600 # 秒
|
||||
@ -523,6 +523,9 @@ Error Response:
|
||||
}
|
||||
```
|
||||
|
||||
### 2.9 模型优化 -- 未实现
|
||||
|
||||
|
||||
## 3. 系统监控
|
||||
### 3.1 获取资源使用情况
|
||||
```http
|
||||
@ -681,7 +684,7 @@ Error Response:
|
||||
}
|
||||
```
|
||||
|
||||
### 3.3 获取系统中训练状态
|
||||
### 3.3 获取系统中训练状态 ---- 未完成, 等开发系统后台时再实现.
|
||||
```http
|
||||
GET /api/train/status/{task_id}
|
||||
|
||||
@ -764,6 +767,76 @@ Error Response:
|
||||
}
|
||||
```
|
||||
|
||||
## 4. 系统后台整体实现
|
||||
|
||||
### 4.1 系统架构
|
||||
```
|
||||
MLPlatform/
|
||||
├── api/ # API接口层
|
||||
│ ├── __init__.py
|
||||
│ ├── data_api.py # 数据处理相关接口
|
||||
│ ├── model_api.py # 模型相关接口
|
||||
│ └── system_api.py # 系统监控相关接口
|
||||
├── function/ # 功能实现层
|
||||
│ ├── data_processor.py # 数据处理类
|
||||
│ ├── model_manager.py # 模型管理类
|
||||
│ ├── model_trainer.py # 模型训练类
|
||||
│ ├── system_monitor.py # 系统监控类
|
||||
│ └── utils/ # 工具函数
|
||||
├── config/ # 配置文件
|
||||
│ └── config.yaml # 系统配置
|
||||
├── dataset/ # 数据集
|
||||
│ ├── dataset_raw/ # 原始数据
|
||||
│ └── dataset_processed/ # 处理后数据
|
||||
├── .log/ # 日志文件
|
||||
├── doc/ # 文档
|
||||
└── main.py # 主程序入口
|
||||
```
|
||||
|
||||
### 4.2 技术栈
|
||||
- FastAPI: Web框架
|
||||
- MLflow: 模型管理和实验跟踪
|
||||
- PyTorch/Scikit-learn: 机器学习框架
|
||||
- Pydantic: 数据验证
|
||||
- Uvicorn: ASGI服务器
|
||||
|
||||
### 4.3 主要功能
|
||||
1. 异步任务处理
|
||||
- 支持多个模型同时训练
|
||||
- 后台任务状态监控
|
||||
- 任务队列管理
|
||||
|
||||
2. 实时监控
|
||||
- 系统资源监控
|
||||
- 训练进度监控
|
||||
- 日志实时查看
|
||||
|
||||
3. 错误处理
|
||||
- 全局异常处理
|
||||
- 错误日志记录
|
||||
- 优雅降级策略
|
||||
|
||||
4. 安全性
|
||||
- API认证授权
|
||||
- 请求限流
|
||||
- 参数验证
|
||||
|
||||
### 4.4 性能优化
|
||||
1. 数据处理
|
||||
- 数据流式处理
|
||||
- 缓存机制
|
||||
- 批量处理
|
||||
|
||||
2. 模型训练
|
||||
- GPU利用优化
|
||||
- 分布式训练支持
|
||||
- 模型检查点
|
||||
|
||||
3. 系统监控
|
||||
- 性能指标采集
|
||||
- 资源使用预警
|
||||
- 自动清理机制
|
||||
|
||||
## 附录A:方法详细说明
|
||||
|
||||
### A1. 数据预处理方法
|
||||
|
||||
Binary file not shown.
@ -1,79 +0,0 @@
|
||||
import yaml
|
||||
from typing import Dict, List
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
class MethodReader:
|
||||
"""方法配置读取器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化方法读取器"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.method_config = self._load_metrics()
|
||||
|
||||
|
||||
def _load_metrics(self) -> Dict:
|
||||
"""加载方法配置文件"""
|
||||
try:
|
||||
config_path = Path('model/metrics.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded method config")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading method config: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def get_metrics(self) -> Dict:
|
||||
"""获取预处理方法列表"""
|
||||
try:
|
||||
metrics = []
|
||||
|
||||
# 分类方法
|
||||
classification_metrics = self.method_config.get('classification', {})
|
||||
if classification_metrics:
|
||||
metrics.append({
|
||||
"name": "classification_metrics",
|
||||
"description": "分类方法评价指标",
|
||||
"metric": classification_metrics
|
||||
})
|
||||
|
||||
# 回归方法
|
||||
regression_metrics = self.method_config.get('regression', {})
|
||||
if regression_metrics:
|
||||
metrics.append({
|
||||
"name": "regression_metrics",
|
||||
"description": "回归方法评价指标",
|
||||
"metric": regression_metrics
|
||||
})
|
||||
|
||||
# 聚类方法
|
||||
clustering_metrics = self.method_config.get('clustering', {})
|
||||
if clustering_metrics:
|
||||
metrics.append({
|
||||
"name": "clustering_metrics",
|
||||
"description": "聚类方法评价指标",
|
||||
"metric": clustering_metrics
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"metric": metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@ -1,135 +0,0 @@
|
||||
import yaml
|
||||
from typing import Dict, List
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
class MethodReader:
|
||||
"""方法配置读取器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化方法读取器"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.method_config = self._load_model_config()
|
||||
self.parameter_config = self._load_parameter_config()
|
||||
|
||||
def _load_model_config(self) -> Dict:
|
||||
"""加载方法配置文件"""
|
||||
try:
|
||||
config_path = Path('model/model.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded method config")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading method config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_parameter_config(self) -> Dict:
|
||||
"""加载参数配置文件"""
|
||||
try:
|
||||
config_path = Path('model/parameter.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded parameter config")
|
||||
return config
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_models(self) -> Dict:
|
||||
"""获取预处理方法列表"""
|
||||
try:
|
||||
models = []
|
||||
|
||||
# 分类方法
|
||||
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
|
||||
if classification_algorithms:
|
||||
models.append({
|
||||
"name": "classification_algorithms",
|
||||
"description": "分类方法",
|
||||
"method": classification_algorithms
|
||||
})
|
||||
|
||||
# 回归方法
|
||||
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
|
||||
if regression_algorithms:
|
||||
models.append({
|
||||
"name": "regression_algorithms",
|
||||
"description": "回归方法",
|
||||
"method": regression_algorithms
|
||||
})
|
||||
|
||||
# 聚类方法
|
||||
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
|
||||
if clustering_algorithms:
|
||||
models.append({
|
||||
"name": "clustering_algorithms",
|
||||
"description": "聚类方法",
|
||||
"method": clustering_algorithms
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"models": models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def get_model_details(self, method_name: str) -> Dict:
|
||||
"""获取指定方法的详细信息"""
|
||||
try:
|
||||
# 在各个方法类别中查找方法原理和优缺点
|
||||
method_info = None
|
||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||
if method_name in self.method_config.get(category, {}):
|
||||
method_info = self.method_config[category][method_name]
|
||||
break
|
||||
|
||||
if method_info is None:
|
||||
raise ValueError(f"Method {method_name} not found in method config")
|
||||
|
||||
# 查找方法参数信息
|
||||
parameter_info = None
|
||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||
if method_name in self.parameter_config.get(category, {}):
|
||||
parameter_info = self.parameter_config[category][method_name]
|
||||
break
|
||||
|
||||
if parameter_info is None:
|
||||
raise ValueError(f"Method {method_name} not found in parameter config")
|
||||
|
||||
# 组合返回信息
|
||||
return {
|
||||
"status": "success",
|
||||
"method": {
|
||||
"name": method_name,
|
||||
"description": parameter_info.get('description', ''),
|
||||
"principle": method_info.get('principle', ''),
|
||||
"advantages": method_info.get('advantages', []),
|
||||
"disadvantages": method_info.get('disadvantages', []),
|
||||
"applicable_scenarios": method_info.get('applicable_scenarios', []),
|
||||
"parameters": parameter_info.get('parameters', [])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting method details: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
@ -18,6 +18,11 @@ from sklearn.metrics import (
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
|
||||
'''
|
||||
模型管理整体集成
|
||||
'''
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理类"""
|
||||
|
||||
@ -27,12 +32,33 @@ class ModelManager:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._setup_logging()
|
||||
self._metrics_map()
|
||||
self.method_config = self._load_metrics()
|
||||
|
||||
self.method_config = self._load_model_config()
|
||||
self.parameter_config = self._load_parameter_config()
|
||||
|
||||
# 初始化MLflow客户端
|
||||
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
|
||||
mlflow.set_tracking_uri(self.mlflow_uri)
|
||||
self.client = MlflowClient()
|
||||
|
||||
def _load_metrics(self) -> Dict:
|
||||
"""加载方法配置文件"""
|
||||
try:
|
||||
config_path = Path('model/metrics.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded method config")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading method config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _setup_logging(self):
|
||||
"""设置日志"""
|
||||
log_dir = Path('.log')
|
||||
@ -47,6 +73,54 @@ class ModelManager:
|
||||
self.logger.addHandler(file_handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def _load_model_config(self) -> Dict:
|
||||
"""加载方法配置文件"""
|
||||
try:
|
||||
config_path = Path('model/model.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Method config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_model = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded model config")
|
||||
|
||||
|
||||
config_path = Path('model/metrics.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Metrics config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_metric = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded metrics config")
|
||||
|
||||
|
||||
config = {**config_model, **config_metric}
|
||||
return config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading method or metric config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_parameter_config(self) -> Dict:
|
||||
"""加载参数配置文件"""
|
||||
try:
|
||||
config_path = Path('model/parameter.yaml')
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.logger.info("Successfully loaded parameter config")
|
||||
return config
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _metrics_map(self):
|
||||
self.metrics_map={
|
||||
'accuracy' : accuracy_score,
|
||||
@ -63,7 +137,99 @@ class ModelManager:
|
||||
'completeness': completeness_score,
|
||||
'silhouette' : silhouette_score
|
||||
}
|
||||
|
||||
|
||||
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
|
||||
"""获取算法信息"""
|
||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||
if algorithm_name in self.method_config.get(category, {}):
|
||||
return self.method_config[category][algorithm_name]
|
||||
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
|
||||
|
||||
|
||||
def _get_model_class(self, algorithm_name: str):
|
||||
"""获取模型类"""
|
||||
# 分类算法
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.svm import SVC, OneClassSVM
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.ensemble import (
|
||||
RandomForestClassifier, GradientBoostingClassifier,
|
||||
AdaBoostClassifier, IsolationForest
|
||||
)
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
import xgboost as xgb
|
||||
import lightgbm as lgb
|
||||
from catboost import CatBoostClassifier
|
||||
|
||||
# 回归算法
|
||||
from sklearn.linear_model import (
|
||||
LinearRegression, Ridge, Lasso,
|
||||
ElasticNet
|
||||
)
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.ensemble import (
|
||||
RandomForestRegressor, GradientBoostingRegressor,
|
||||
AdaBoostRegressor
|
||||
)
|
||||
from catboost import CatBoostRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
|
||||
# 聚类算法
|
||||
from sklearn.cluster import (
|
||||
KMeans, AgglomerativeClustering,
|
||||
DBSCAN, SpectralClustering
|
||||
)
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
algorithm_map = {
|
||||
# 分类算法
|
||||
'LogisticRegression': LogisticRegression,
|
||||
'SVC': SVC,
|
||||
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
|
||||
'DecisionTreeClassifier': DecisionTreeClassifier,
|
||||
'RandomForestClassifier': RandomForestClassifier,
|
||||
'XGBClassifier': xgb.XGBClassifier,
|
||||
'AdaBoostClassifier': AdaBoostClassifier,
|
||||
'CatBoostClassifier': CatBoostClassifier,
|
||||
'LGBMClassifier': lgb.LGBMClassifier,
|
||||
'GaussianNB': GaussianNB,
|
||||
'KNeighborsClassifier': KNeighborsClassifier,
|
||||
'MLPClassifier': MLPClassifier,
|
||||
'GradientBoostingClassifier': GradientBoostingClassifier,
|
||||
|
||||
# 回归算法
|
||||
'LinearRegression': LinearRegression,
|
||||
'Ridge': Ridge,
|
||||
'Lasso': Lasso,
|
||||
'ElasticNet': ElasticNet,
|
||||
'SVR': SVR,
|
||||
'DecisionTreeRegressor': DecisionTreeRegressor,
|
||||
'RandomForestRegressor': RandomForestRegressor,
|
||||
'XGBRegressor': xgb.XGBRegressor,
|
||||
'AdaBoostRegressor': AdaBoostRegressor,
|
||||
'CatBoostRegressor': CatBoostRegressor,
|
||||
'LGBMRegressor': lgb.LGBMRegressor,
|
||||
'MLPRegressor': MLPRegressor,
|
||||
|
||||
# 聚类算法
|
||||
'KMeans': KMeans,
|
||||
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制
|
||||
'AgglomerativeClustering': AgglomerativeClustering,
|
||||
'DBSCAN': DBSCAN,
|
||||
'GaussianMixture': GaussianMixture,
|
||||
'SpectralClustering': SpectralClustering
|
||||
}
|
||||
|
||||
if algorithm_name not in algorithm_map:
|
||||
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
||||
|
||||
return algorithm_map[algorithm_name]
|
||||
|
||||
|
||||
def get_finished_models(
|
||||
self,
|
||||
page: int = 1,
|
||||
@ -440,7 +606,96 @@ class ModelManager:
|
||||
'execution_time': f"{execution_time:.2f}s"
|
||||
}
|
||||
}
|
||||
|
||||
def get_models(self) -> Dict:
|
||||
"""获取预处理方法列表"""
|
||||
try:
|
||||
models = []
|
||||
|
||||
# 分类方法
|
||||
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
|
||||
if classification_algorithms:
|
||||
models.append({
|
||||
"name": "classification_algorithms",
|
||||
"description": "分类方法",
|
||||
"method": classification_algorithms
|
||||
})
|
||||
|
||||
# 回归方法
|
||||
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
|
||||
if regression_algorithms:
|
||||
models.append({
|
||||
"name": "regression_algorithms",
|
||||
"description": "回归方法",
|
||||
"method": regression_algorithms
|
||||
})
|
||||
|
||||
# 聚类方法
|
||||
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
|
||||
if clustering_algorithms:
|
||||
models.append({
|
||||
"name": "clustering_algorithms",
|
||||
"description": "聚类方法",
|
||||
"method": clustering_algorithms
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"models": models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def get_model_details(self, method_name: str) -> Dict:
|
||||
"""获取指定方法的详细信息"""
|
||||
try:
|
||||
# 在各个方法类别中查找方法原理和优缺点
|
||||
method_info = None
|
||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||
if method_name in self.method_config.get(category, {}):
|
||||
method_info = self.method_config[category][method_name]
|
||||
break
|
||||
|
||||
if method_info is None:
|
||||
raise ValueError(f"Method {method_name} not found in method config")
|
||||
|
||||
# 查找方法参数信息
|
||||
parameter_info = None
|
||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||
if method_name in self.parameter_config.get(category, {}):
|
||||
parameter_info = self.parameter_config[category][method_name]
|
||||
break
|
||||
|
||||
if parameter_info is None:
|
||||
raise ValueError(f"Method {method_name} not found in parameter config")
|
||||
|
||||
# 组合返回信息
|
||||
return {
|
||||
"status": "success",
|
||||
"method": {
|
||||
"name": method_name,
|
||||
"description": parameter_info.get('description', ''),
|
||||
"principle": method_info.get('principle', ''),
|
||||
"advantages": method_info.get('advantages', []),
|
||||
"disadvantages": method_info.get('disadvantages', []),
|
||||
"applicable_scenarios": method_info.get('applicable_scenarios', []),
|
||||
"parameters": parameter_info.get('parameters', [])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting method details: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# except Exception as e:
|
||||
# error_msg = f"预测过程发生错误: {str(e)}"
|
||||
# self.logger.error(error_msg)
|
||||
@ -451,4 +706,184 @@ class ModelManager:
|
||||
# 'error_type': type(e).__name__,
|
||||
# 'error_message': str(e)
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
|
||||
def get_metrics(self) -> Dict:
|
||||
"""获取预处理方法列表"""
|
||||
try:
|
||||
metrics = []
|
||||
|
||||
# 分类方法
|
||||
classification_metrics = self.method_config.get('classification', {})
|
||||
if classification_metrics:
|
||||
metrics.append({
|
||||
"name": "classification_metrics",
|
||||
"description": "分类方法评价指标",
|
||||
"metric": classification_metrics
|
||||
})
|
||||
|
||||
# 回归方法
|
||||
regression_metrics = self.method_config.get('regression', {})
|
||||
if regression_metrics:
|
||||
metrics.append({
|
||||
"name": "regression_metrics",
|
||||
"description": "回归方法评价指标",
|
||||
"metric": regression_metrics
|
||||
})
|
||||
|
||||
# 聚类方法
|
||||
clustering_metrics = self.method_config.get('clustering', {})
|
||||
if clustering_metrics:
|
||||
metrics.append({
|
||||
"name": "clustering_metrics",
|
||||
"description": "聚类方法评价指标",
|
||||
"metric": clustering_metrics
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"metric": metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
def train_model(
|
||||
self,
|
||||
train_data: Dict,
|
||||
val_data: Dict,
|
||||
model_config: Dict,
|
||||
experiment_name: str
|
||||
) -> Dict:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据,包含特征和标签
|
||||
val_data: 验证数据,包含特征和标签
|
||||
model_config: 模型配置,包含算法名称和参数
|
||||
experiment_name: MLflow实验名称
|
||||
|
||||
Returns:
|
||||
训练结果字典
|
||||
"""
|
||||
try:
|
||||
# 检查实验是否存在且被删除
|
||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||
if experiment and experiment.lifecycle_stage == 'deleted':
|
||||
# 如果实验被删除,则创建一个新的实验名称
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
new_experiment_name = f"{experiment_name}_{timestamp}"
|
||||
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
|
||||
experiment_name = new_experiment_name
|
||||
|
||||
# 设置MLflow实验
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
with mlflow.start_run() as run:
|
||||
# 记录基本信息
|
||||
mlflow.log_param('algorithm', model_config['algorithm'])
|
||||
mlflow.log_param('task_type', model_config['task_type'])
|
||||
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
||||
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
||||
|
||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
# mlflow.log_param('start_time', timestamp)
|
||||
|
||||
# 记录模型参数
|
||||
for param_name, param_value in model_config['params'].items():
|
||||
mlflow.log_param(param_name, param_value)
|
||||
|
||||
# 记录算法信息
|
||||
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
|
||||
mlflow.log_param('principle', algorithm_info['principle'])
|
||||
mlflow.log_param('advantages', str(algorithm_info['advantages']))
|
||||
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
|
||||
|
||||
# 特殊处理KMeans++
|
||||
if model_config['algorithm'] == 'KMeansPlusPlus':
|
||||
model_config['params']['init'] = 'k-means++'
|
||||
|
||||
# 获取模型类和信息
|
||||
model_class = self._get_model_class(model_config['algorithm'])
|
||||
|
||||
# 创建模型实例
|
||||
model = model_class(**model_config['params'])
|
||||
|
||||
# 训练模型
|
||||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||||
model.fit(train_data['features'], train_data['labels'])
|
||||
|
||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
# mlflow.log_param('end_time', timestamp)
|
||||
|
||||
# 在验证集上评估
|
||||
val_predictions = model.predict(val_data['features'])
|
||||
metrics = self._calculate_metrics(
|
||||
val_data['labels'],
|
||||
val_predictions,
|
||||
model_config['task_type']
|
||||
)
|
||||
|
||||
# 记录指标
|
||||
for metric_name, metric_value in metrics.items():
|
||||
mlflow.log_metric(metric_name, metric_value)
|
||||
|
||||
# 保存模型
|
||||
mlflow.sklearn.log_model(model, "model")
|
||||
|
||||
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'run_id': run.info.run_id,
|
||||
'metrics': metrics,
|
||||
'algorithm_info': algorithm_info
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error training model: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': error_msg
|
||||
}
|
||||
|
||||
def _calculate_metrics(
|
||||
self,
|
||||
true_labels: np.ndarray,
|
||||
predictions: np.ndarray,
|
||||
task_type: str
|
||||
) -> Dict:
|
||||
"""计算评估指标"""
|
||||
metrics = {}
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
||||
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
||||
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
||||
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
||||
if len(np.unique(true_labels)) == 2: # 二分类问题
|
||||
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
|
||||
|
||||
elif task_type == 'regression':
|
||||
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
||||
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
||||
metrics['rmse'] = np.sqrt(metrics['mse'])
|
||||
metrics['r2'] = r2_score(true_labels, predictions)
|
||||
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
||||
|
||||
elif task_type == 'clustering':
|
||||
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
||||
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
||||
metrics['completeness'] = completeness_score(true_labels, predictions)
|
||||
if len(np.unique(predictions)) > 1: # 确保有多个簇
|
||||
metrics['silhouette'] = silhouette_score(
|
||||
true_labels.reshape(-1, 1),
|
||||
predictions
|
||||
)
|
||||
|
||||
return metrics
|
||||
@ -1,298 +0,0 @@
|
||||
import numpy as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
import yaml
|
||||
import mlflow
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
|
||||
mean_absolute_error, mean_squared_error, r2_score, explained_variance_score,
|
||||
adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score
|
||||
)
|
||||
|
||||
class ModelTrainer:
|
||||
"""模型训练类"""
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
"""初始化模型训练器"""
|
||||
self.config = config or {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._setup_logging()
|
||||
self._load_metrics()
|
||||
self._load_parameters()
|
||||
self._load_model_info()
|
||||
|
||||
# with open("confg/config.yaml", 'r', encoding='utf-8') as f:
|
||||
# config = yaml.safe_load(f)
|
||||
|
||||
# 初始化MLflow
|
||||
mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000'))
|
||||
|
||||
def _setup_logging(self):
|
||||
"""设置日志"""
|
||||
log_dir = Path('.log')
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
|
||||
)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
)
|
||||
self.logger.addHandler(file_handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def _load_metrics(self):
|
||||
"""加载评估指标配置"""
|
||||
try:
|
||||
with open('model/metrics.yaml', 'r', encoding='utf-8') as f:
|
||||
self.metrics_config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading metrics config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_parameters(self):
|
||||
"""加载模型参数配置"""
|
||||
try:
|
||||
with open('model/parameter.yaml', 'r', encoding='utf-8') as f:
|
||||
self.parameter_config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading parameter config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_model_info(self):
|
||||
"""加载模型信息配置"""
|
||||
try:
|
||||
with open('model/model.yaml', 'r', encoding='utf-8') as f:
|
||||
self.model_info = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading model info: {str(e)}")
|
||||
raise
|
||||
|
||||
def _get_model_class(self, algorithm_name: str):
|
||||
"""获取模型类"""
|
||||
# 分类算法
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.svm import SVC, OneClassSVM
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.ensemble import (
|
||||
RandomForestClassifier, GradientBoostingClassifier,
|
||||
AdaBoostClassifier, IsolationForest
|
||||
)
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
import xgboost as xgb
|
||||
import lightgbm as lgb
|
||||
from catboost import CatBoostClassifier
|
||||
|
||||
# 回归算法
|
||||
from sklearn.linear_model import (
|
||||
LinearRegression, Ridge, Lasso,
|
||||
ElasticNet
|
||||
)
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.ensemble import (
|
||||
RandomForestRegressor, GradientBoostingRegressor,
|
||||
AdaBoostRegressor
|
||||
)
|
||||
from catboost import CatBoostRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
|
||||
# 聚类算法
|
||||
from sklearn.cluster import (
|
||||
KMeans, AgglomerativeClustering,
|
||||
DBSCAN, SpectralClustering
|
||||
)
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
algorithm_map = {
|
||||
# 分类算法
|
||||
'LogisticRegression': LogisticRegression,
|
||||
'SVC': SVC,
|
||||
'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现
|
||||
'DecisionTreeClassifier': DecisionTreeClassifier,
|
||||
'RandomForestClassifier': RandomForestClassifier,
|
||||
'XGBClassifier': xgb.XGBClassifier,
|
||||
'AdaBoostClassifier': AdaBoostClassifier,
|
||||
'CatBoostClassifier': CatBoostClassifier,
|
||||
'LGBMClassifier': lgb.LGBMClassifier,
|
||||
'GaussianNB': GaussianNB,
|
||||
'KNeighborsClassifier': KNeighborsClassifier,
|
||||
'MLPClassifier': MLPClassifier,
|
||||
'GradientBoostingClassifier': GradientBoostingClassifier,
|
||||
|
||||
# 回归算法
|
||||
'LinearRegression': LinearRegression,
|
||||
'Ridge': Ridge,
|
||||
'Lasso': Lasso,
|
||||
'ElasticNet': ElasticNet,
|
||||
'SVR': SVR,
|
||||
'DecisionTreeRegressor': DecisionTreeRegressor,
|
||||
'RandomForestRegressor': RandomForestRegressor,
|
||||
'XGBRegressor': xgb.XGBRegressor,
|
||||
'AdaBoostRegressor': AdaBoostRegressor,
|
||||
'CatBoostRegressor': CatBoostRegressor,
|
||||
'LGBMRegressor': lgb.LGBMRegressor,
|
||||
'MLPRegressor': MLPRegressor,
|
||||
|
||||
# 聚类算法
|
||||
'KMeans': KMeans,
|
||||
'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制
|
||||
'AgglomerativeClustering': AgglomerativeClustering,
|
||||
'DBSCAN': DBSCAN,
|
||||
'GaussianMixture': GaussianMixture,
|
||||
'SpectralClustering': SpectralClustering
|
||||
}
|
||||
|
||||
if algorithm_name not in algorithm_map:
|
||||
raise ValueError(f"Unknown algorithm: {algorithm_name}")
|
||||
|
||||
return algorithm_map[algorithm_name]
|
||||
|
||||
def _get_algorithm_info(self, algorithm_name: str) -> Dict:
|
||||
"""获取算法信息"""
|
||||
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
|
||||
if algorithm_name in self.model_info.get(category, {}):
|
||||
return self.model_info[category][algorithm_name]
|
||||
raise ValueError(f"Algorithm {algorithm_name} not found in model info")
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
train_data: Dict,
|
||||
val_data: Dict,
|
||||
model_config: Dict,
|
||||
experiment_name: str
|
||||
) -> Dict:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据,包含特征和标签
|
||||
val_data: 验证数据,包含特征和标签
|
||||
model_config: 模型配置,包含算法名称和参数
|
||||
experiment_name: MLflow实验名称
|
||||
|
||||
Returns:
|
||||
训练结果字典
|
||||
"""
|
||||
try:
|
||||
# 检查实验是否存在且被删除
|
||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||
if experiment and experiment.lifecycle_stage == 'deleted':
|
||||
# 如果实验被删除,则创建一个新的实验名称
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
new_experiment_name = f"{experiment_name}_{timestamp}"
|
||||
self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}")
|
||||
experiment_name = new_experiment_name
|
||||
|
||||
# 设置MLflow实验
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
with mlflow.start_run() as run:
|
||||
# 记录基本信息
|
||||
mlflow.log_param('algorithm', model_config['algorithm'])
|
||||
mlflow.log_param('task_type', model_config['task_type'])
|
||||
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
||||
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
||||
|
||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
# mlflow.log_param('start_time', timestamp)
|
||||
|
||||
# 记录模型参数
|
||||
for param_name, param_value in model_config['params'].items():
|
||||
mlflow.log_param(param_name, param_value)
|
||||
|
||||
# 记录算法信息
|
||||
algorithm_info = self._get_algorithm_info(model_config['algorithm'])
|
||||
mlflow.log_param('principle', algorithm_info['principle'])
|
||||
mlflow.log_param('advantages', str(algorithm_info['advantages']))
|
||||
mlflow.log_param('disadvantages', str(algorithm_info['disadvantages']))
|
||||
|
||||
# 特殊处理KMeans++
|
||||
if model_config['algorithm'] == 'KMeansPlusPlus':
|
||||
model_config['params']['init'] = 'k-means++'
|
||||
|
||||
# 获取模型类和信息
|
||||
model_class = self._get_model_class(model_config['algorithm'])
|
||||
|
||||
# 创建模型实例
|
||||
model = model_class(**model_config['params'])
|
||||
|
||||
# 训练模型
|
||||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||||
model.fit(train_data['features'], train_data['labels'])
|
||||
|
||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
# mlflow.log_param('end_time', timestamp)
|
||||
|
||||
# 在验证集上评估
|
||||
val_predictions = model.predict(val_data['features'])
|
||||
metrics = self._calculate_metrics(
|
||||
val_data['labels'],
|
||||
val_predictions,
|
||||
model_config['task_type']
|
||||
)
|
||||
|
||||
# 记录指标
|
||||
for metric_name, metric_value in metrics.items():
|
||||
mlflow.log_metric(metric_name, metric_value)
|
||||
|
||||
# 保存模型
|
||||
mlflow.sklearn.log_model(model, "model")
|
||||
|
||||
self.logger.info(f"Training completed. Run ID: {run.info.run_id}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'run_id': run.info.run_id,
|
||||
'metrics': metrics,
|
||||
'algorithm_info': algorithm_info
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error training model: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': error_msg
|
||||
}
|
||||
|
||||
def _calculate_metrics(
|
||||
self,
|
||||
true_labels: np.ndarray,
|
||||
predictions: np.ndarray,
|
||||
task_type: str
|
||||
) -> Dict:
|
||||
"""计算评估指标"""
|
||||
metrics = {}
|
||||
|
||||
if task_type == 'classification':
|
||||
metrics['accuracy'] = accuracy_score(true_labels, predictions)
|
||||
metrics['precision'] = precision_score(true_labels, predictions, average='weighted')
|
||||
metrics['recall'] = recall_score(true_labels, predictions, average='weighted')
|
||||
metrics['f1'] = f1_score(true_labels, predictions, average='weighted')
|
||||
if len(np.unique(true_labels)) == 2: # 二分类问题
|
||||
metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
|
||||
|
||||
elif task_type == 'regression':
|
||||
metrics['mae'] = mean_absolute_error(true_labels, predictions)
|
||||
metrics['mse'] = mean_squared_error(true_labels, predictions)
|
||||
metrics['rmse'] = np.sqrt(metrics['mse'])
|
||||
metrics['r2'] = r2_score(true_labels, predictions)
|
||||
metrics['explained_variance'] = explained_variance_score(true_labels, predictions)
|
||||
|
||||
elif task_type == 'clustering':
|
||||
metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions)
|
||||
metrics['homogeneity'] = homogeneity_score(true_labels, predictions)
|
||||
metrics['completeness'] = completeness_score(true_labels, predictions)
|
||||
if len(np.unique(predictions)) > 1: # 确保有多个簇
|
||||
metrics['silhouette'] = silhouette_score(
|
||||
true_labels.reshape(-1, 1),
|
||||
predictions
|
||||
)
|
||||
|
||||
return metrics
|
||||
Loading…
Reference in New Issue
Block a user