760 lines
18 KiB
Markdown
760 lines
18 KiB
Markdown
# 机器学习平台系统设计
|
|
|
|
## 0. 依赖库总览
|
|
```python
|
|
# 基础依赖
|
|
import numpy as np
|
|
import pandas as pd
|
|
from typing import Dict, List
|
|
from dataclasses import dataclass
|
|
import uuid
|
|
import json
|
|
from datetime import datetime
|
|
|
|
# 数据处理
|
|
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
|
from sklearn.impute import SimpleImputer
|
|
from sklearn.feature_selection import SelectKBest, VarianceThreshold
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.ensemble import IsolationForest
|
|
|
|
# 机器学习框架
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
import xgboost as xgb
|
|
import lightgbm as lgb
|
|
from catboost import CatBoostClassifier, CatBoostRegressor
|
|
|
|
# 深度学习
|
|
import pytorch_lightning as pl
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
# 超参数优化
|
|
import optuna
|
|
from optuna.pruners import MedianPruner
|
|
|
|
# 模型管理
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
# API框架
|
|
from fastapi import FastAPI, BackgroundTasks, Depends
|
|
from pydantic import BaseModel
|
|
|
|
# 分布式训练
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.distributed import init_process_group
|
|
|
|
# 系统监控
|
|
import psutil
|
|
import GPUtil
|
|
from prometheus_client import Counter, Gauge
|
|
|
|
# 安全相关
|
|
from cryptography.fernet import Fernet
|
|
from passlib.hash import bcrypt
|
|
|
|
# 任务队列
|
|
from celery import Celery
|
|
from redis import Redis
|
|
|
|
# 日志管理
|
|
import logging
|
|
from tensorboard import SummaryWriter
|
|
```
|
|
|
|
## 1. 数据处理模块
|
|
### 使用的库
|
|
```python
|
|
# 数据读取与处理
|
|
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
|
|
from sklearn.impute import SimpleImputer
|
|
from sklearn.feature_selection import SelectKBest, VarianceThreshold
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.ensemble import IsolationForest
|
|
|
|
# 数据版本控制
|
|
import hashlib
|
|
import json
|
|
from datetime import datetime
|
|
import uuid
|
|
```
|
|
|
|
### 1.1 数据读取
|
|
1. 支持多种数据格式读取:
|
|
```python
|
|
SUPPORTED_FORMATS = {
|
|
'csv': pd.read_csv,
|
|
'parquet': pd.read_parquet,
|
|
'hdf5': pd.read_hdf
|
|
}
|
|
```
|
|
|
|
2. 数据读取参数配置:
|
|
```python
|
|
READ_CONFIG = {
|
|
'csv': {
|
|
'encoding': 'utf-8',
|
|
'sep': ',',
|
|
'na_values': ['', 'NULL', 'null']
|
|
}
|
|
}
|
|
```
|
|
|
|
### 1.2 数据预处理
|
|
1. 数据清洗流水线:
|
|
```python
|
|
class DataCleaner:
|
|
def __init__(self):
|
|
self.pipeline = [
|
|
('missing_handler', SimpleImputer(strategy='mean')),
|
|
('outlier_detector', IsolationForest()),
|
|
('duplicates_remover', DuplicatesRemover())
|
|
]
|
|
|
|
def clean(self, data):
|
|
for name, processor in self.pipeline:
|
|
data = processor.transform(data)
|
|
return data
|
|
```
|
|
|
|
2. 数据转换:
|
|
```python
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.numerical_pipeline = Pipeline([
|
|
('scaler', StandardScaler()),
|
|
('selector', VarianceThreshold(threshold=0.01))
|
|
])
|
|
|
|
self.categorical_pipeline = Pipeline([
|
|
('encoder', OneHotEncoder(sparse=False)),
|
|
('selector', SelectKBest(k=20))
|
|
])
|
|
```
|
|
|
|
3. 特征工程:
|
|
```python
|
|
class FeatureEngineer:
|
|
def __init__(self):
|
|
self.feature_generators = {
|
|
'polynomial': PolynomialFeatures(),
|
|
'interaction': InteractionFeatures(),
|
|
'time': TimeFeatures()
|
|
}
|
|
```
|
|
|
|
4. 数据集划分:
|
|
```python
|
|
class DataSplitter:
|
|
def split(self, data, test_size=0.2, val_size=0.1):
|
|
train, test = train_test_split(data, test_size=test_size)
|
|
train, val = train_test_split(train, test_size=val_size)
|
|
return {
|
|
'train': train,
|
|
'val': val,
|
|
'test': test
|
|
}
|
|
```
|
|
|
|
### 1.3 数据版本控制
|
|
```python
|
|
class DataVersionControl:
|
|
def __init__(self):
|
|
self.version_db = {}
|
|
|
|
def save_version(self, data, metadata):
|
|
version_id = str(uuid.uuid4())
|
|
hash_value = self._calculate_hash(data)
|
|
|
|
self.version_db[version_id] = {
|
|
'hash': hash_value,
|
|
'timestamp': datetime.now(),
|
|
'metadata': metadata
|
|
}
|
|
return version_id
|
|
```
|
|
|
|
### 1.1 数据预处理方法库
|
|
|
|
#### 1.1.1 缺失值处理
|
|
```python
|
|
MISSING_VALUE_METHODS = {
|
|
'numerical': {
|
|
'mean': '使用均值填充',
|
|
'median': '使用中位数填充',
|
|
'mode': '使用众数填充',
|
|
'constant': '使用指定常数填充',
|
|
'interpolate': '使用插值方法填充(线性/多项式)',
|
|
'knn': '使用K近邻方法填充',
|
|
'iterative': '使用迭代方法填充(MICE)'
|
|
},
|
|
'categorical': {
|
|
'mode': '使用众数填充',
|
|
'constant': '使用指定常数填充',
|
|
'new_category': '创建新类别',
|
|
'knn': '使用K近邻方法填充'
|
|
}
|
|
}
|
|
```
|
|
|
|
#### 1.1.2 数据标准化/归一化
|
|
```python
|
|
SCALING_METHODS = {
|
|
'standardization': {
|
|
'description': 'Z-Score标准化,均值为0,方差为1',
|
|
'formula': '(x - mean) / std',
|
|
'use_case': '适用于数据服从正态分布的情况'
|
|
},
|
|
'min_max': {
|
|
'description': '将数据缩放到[0,1]区间',
|
|
'formula': '(x - min) / (max - min)',
|
|
'use_case': '适用于需要限定范围的情况'
|
|
},
|
|
'robust': {
|
|
'description': '使用四分位数进行缩放,对异常值不敏感',
|
|
'formula': '(x - Q2) / (Q3 - Q1)',
|
|
'use_case': '存在异常值时使用'
|
|
},
|
|
'log': {
|
|
'description': '对数变换',
|
|
'formula': 'log(x)',
|
|
'use_case': '处理长尾分布数据'
|
|
}
|
|
}
|
|
```
|
|
|
|
#### 1.1.3 特征编码
|
|
```python
|
|
ENCODING_METHODS = {
|
|
'categorical': {
|
|
'one_hot': {
|
|
'description': '独热编码',
|
|
'use_case': '类别之间无序关系',
|
|
'pros': '不引入大小关系',
|
|
'cons': '特征维度增加'
|
|
},
|
|
'label': {
|
|
'description': '标签编码',
|
|
'use_case': '类别之间有序关系',
|
|
'pros': '维度不变',
|
|
'cons': '引入大小关系'
|
|
},
|
|
'target': {
|
|
'description': '目标编码',
|
|
'use_case': '类别较多时',
|
|
'pros': '降低维度',
|
|
'cons': '可能过拟合'
|
|
}
|
|
},
|
|
'text': {
|
|
'tfidf': '词频-逆文档频率',
|
|
'word2vec': '词向量',
|
|
'bert': 'BERT编码'
|
|
}
|
|
}
|
|
```
|
|
|
|
#### 1.1.4 异常值检测
|
|
```python
|
|
OUTLIER_METHODS = {
|
|
'statistical': {
|
|
'zscore': {
|
|
'description': 'Z分数法',
|
|
'threshold': '通常为3个标准差'
|
|
},
|
|
'iqr': {
|
|
'description': '四分位数法',
|
|
'threshold': '1.5倍IQR'
|
|
}
|
|
},
|
|
'model_based': {
|
|
'isolation_forest': '基于孤立森林',
|
|
'one_class_svm': '单类支持向量机',
|
|
'local_outlier_factor': '局部异常因子'
|
|
}
|
|
}
|
|
```
|
|
|
|
### 1.2 特征工程方法
|
|
|
|
#### 1.2.1 特征选择
|
|
```python
|
|
FEATURE_SELECTION = {
|
|
'filter': {
|
|
'variance': '方差选择法',
|
|
'correlation': '相关系数法',
|
|
'chi2': '卡方检验',
|
|
'mutual_info': '互信息法'
|
|
},
|
|
'wrapper': {
|
|
'rfe': '递归特征消除',
|
|
'forward': '前向选择',
|
|
'backward': '后向消除'
|
|
},
|
|
'embedded': {
|
|
'lasso': 'L1正则化',
|
|
'ridge': 'L2正则化',
|
|
'tree_importance': '树模型特征重要性'
|
|
}
|
|
}
|
|
```
|
|
|
|
#### 1.2.2 特征构造
|
|
```python
|
|
FEATURE_CONSTRUCTION = {
|
|
'numerical': {
|
|
'polynomial': '多项式特征',
|
|
'interaction': '交互特征',
|
|
'binning': '分箱特征'
|
|
},
|
|
'datetime': {
|
|
'year': '年份提取',
|
|
'month': '月份提取',
|
|
'day': '日期提取',
|
|
'weekday': '星期提取',
|
|
'is_weekend': '是否周末'
|
|
},
|
|
'text': {
|
|
'length': '文本长度',
|
|
'word_count': '词数统计',
|
|
'ngram': 'N元语法特征'
|
|
}
|
|
}
|
|
```
|
|
|
|
## 2. 模型训练模块
|
|
### 使用的库
|
|
```python
|
|
# 机器学习框架
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
import pytorch_lightning as pl
|
|
|
|
# 传统机器学习算法
|
|
import xgboost as xgb
|
|
import lightgbm as lgb
|
|
from catboost import CatBoostClassifier
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
# 深度学习模型
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
# 超参数优化
|
|
import optuna
|
|
from optuna.pruners import MedianPruner
|
|
```
|
|
|
|
### 2.1 算法管理
|
|
1. 算法注册表:
|
|
```python
|
|
ALGORITHMS = {
|
|
'classification': {
|
|
'xgboost': {
|
|
'class': XGBClassifier,
|
|
'params': {
|
|
'max_depth': {'type': 'int', 'range': [3,10]},
|
|
'learning_rate': {'type': 'float', 'range': [0.001,0.1]},
|
|
'n_estimators': {'type': 'int', 'range': [100,1000]}
|
|
}
|
|
},
|
|
'lightgbm': {...},
|
|
'catboost': {...}
|
|
},
|
|
'regression': {
|
|
'xgboost_regressor': {...},
|
|
'lightgbm_regressor': {...}
|
|
},
|
|
'clustering': {
|
|
'kmeans': {...},
|
|
'dbscan': {...}
|
|
}
|
|
}
|
|
```
|
|
|
|
2. 模型工厂:
|
|
```python
|
|
class ModelFactory:
|
|
@staticmethod
|
|
def create_model(algorithm_name, params=None):
|
|
if algorithm_name not in ALGORITHMS:
|
|
raise ValueError(f"Algorithm {algorithm_name} not supported")
|
|
|
|
algorithm = ALGORITHMS[algorithm_name]
|
|
return algorithm['class'](**params if params else {})
|
|
```
|
|
|
|
### 2.2 训练管理
|
|
1. 训练配置:
|
|
```python
|
|
@dataclass
|
|
class TrainingConfig:
|
|
algorithm: str
|
|
params: Dict
|
|
data_version: str
|
|
device: str = 'cuda'
|
|
batch_size: int = 32
|
|
num_epochs: int = 100
|
|
early_stopping: bool = True
|
|
```
|
|
|
|
2. 训练器:
|
|
```python
|
|
class ModelTrainer:
|
|
def __init__(self, config: TrainingConfig):
|
|
self.config = config
|
|
self.model = ModelFactory.create_model(config.algorithm, config.params)
|
|
self.device = torch.device(config.device)
|
|
|
|
def train(self, train_data, val_data=None):
|
|
self.model.to(self.device)
|
|
optimizer = self._configure_optimizer()
|
|
scheduler = self._configure_scheduler()
|
|
|
|
for epoch in range(self.config.num_epochs):
|
|
train_loss = self._train_epoch(train_data)
|
|
if val_data is not None:
|
|
val_loss = self._validate(val_data)
|
|
|
|
if self.early_stopping.should_stop():
|
|
break
|
|
```
|
|
|
|
### 2.3 超参数优化
|
|
```python
|
|
class HPOptimizer:
|
|
def __init__(self, study_name, n_trials=100):
|
|
self.study = optuna.create_study(
|
|
study_name=study_name,
|
|
direction='minimize'
|
|
)
|
|
self.n_trials = n_trials
|
|
|
|
def optimize(self, objective_fn):
|
|
self.study.optimize(
|
|
objective_fn,
|
|
n_trials=self.n_trials,
|
|
callbacks=[self._pruning_callback]
|
|
)
|
|
return self.study.best_params
|
|
```
|
|
|
|
## 3. 模型管理模块
|
|
### 使用的库
|
|
```python
|
|
# 模型管理与追踪
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
from mlflow.models import Model
|
|
|
|
# 模型序列化
|
|
import pickle
|
|
import torch.serialization
|
|
import json
|
|
|
|
# 模型版本控制
|
|
from datetime import datetime
|
|
import hashlib
|
|
```
|
|
|
|
### 3.1 模型注册
|
|
```python
|
|
class ModelRegistry:
|
|
def __init__(self, mlflow_uri):
|
|
self.client = MlflowClient(mlflow_uri)
|
|
|
|
def register_model(self, model_path, name, metadata=None):
|
|
model_version = self.client.create_model_version(
|
|
name=name,
|
|
source=model_path,
|
|
run_id=mlflow.active_run().info.run_id
|
|
)
|
|
|
|
if metadata:
|
|
self.client.set_model_version_tag(
|
|
name=name,
|
|
version=model_version.version,
|
|
key='metadata',
|
|
value=json.dumps(metadata)
|
|
)
|
|
```
|
|
|
|
### 3.2 模型加载
|
|
```python
|
|
class ModelLoader:
|
|
@staticmethod
|
|
def load_model(model_id):
|
|
model_path = f"models/{model_id}"
|
|
return torch.load(model_path)
|
|
|
|
@staticmethod
|
|
def load_for_inference(model_id):
|
|
model = ModelLoader.load_model(model_id)
|
|
model.eval()
|
|
return model
|
|
```
|
|
|
|
### 3.3 模型服务
|
|
```python
|
|
class ModelService:
|
|
def __init__(self):
|
|
self.loaded_models = {}
|
|
|
|
async def predict(self, model_id, data):
|
|
if model_id not in self.loaded_models:
|
|
self.loaded_models[model_id] = ModelLoader.load_for_inference(model_id)
|
|
|
|
model = self.loaded_models[model_id]
|
|
with torch.no_grad():
|
|
predictions = model(torch.tensor(data))
|
|
return predictions.numpy()
|
|
```
|
|
|
|
## 4. 系统监控模块
|
|
### 使用的库
|
|
```python
|
|
# 系统监控
|
|
import psutil
|
|
import GPUtil
|
|
from prometheus_client import Counter, Gauge
|
|
|
|
# 日志管理
|
|
import logging
|
|
from tensorboard import SummaryWriter
|
|
|
|
# 性能分析
|
|
import cProfile
|
|
import line_profiler
|
|
import memory_profiler
|
|
```
|
|
|
|
### 4.1 资源监控
|
|
```python
|
|
class ResourceMonitor:
|
|
def __init__(self):
|
|
self.gpu_util = GPUUtilization()
|
|
self.memory_util = MemoryUtilization()
|
|
|
|
def get_metrics(self):
|
|
return {
|
|
'gpu': self.gpu_util.get_usage(),
|
|
'memory': self.memory_util.get_usage(),
|
|
'cpu': psutil.cpu_percent()
|
|
}
|
|
```
|
|
|
|
### 4.2 训练监控
|
|
```python
|
|
class TrainingMonitor:
|
|
def __init__(self):
|
|
self.metrics_history = defaultdict(list)
|
|
|
|
def log_metrics(self, metrics):
|
|
for name, value in metrics.items():
|
|
self.metrics_history[name].append(value)
|
|
|
|
def get_summary(self):
|
|
return {
|
|
name: {
|
|
'current': values[-1],
|
|
'mean': np.mean(values),
|
|
'std': np.std(values)
|
|
}
|
|
for name, values in self.metrics_history.items()
|
|
}
|
|
```
|
|
|
|
## 5. API接口设计
|
|
### 使用的库
|
|
```python
|
|
# Web框架
|
|
from fastapi import FastAPI, BackgroundTasks, Depends, HTTPException
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from pydantic import BaseModel, validator
|
|
|
|
# 异步支持
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
# API文档
|
|
from fastapi.openapi.utils import get_openapi
|
|
```
|
|
|
|
### 5.1 训练接口
|
|
```python
|
|
@router.post("/train")
|
|
async def start_training(
|
|
request: TrainingRequest,
|
|
background_tasks: BackgroundTasks
|
|
):
|
|
task_id = str(uuid.uuid4())
|
|
background_tasks.add_task(
|
|
train_model,
|
|
task_id=task_id,
|
|
config=request.dict()
|
|
)
|
|
return {"task_id": task_id}
|
|
```
|
|
|
|
### 5.2 预测接口
|
|
```python
|
|
@router.post("/predict")
|
|
async def predict(
|
|
request: PredictionRequest,
|
|
model_service: ModelService = Depends()
|
|
):
|
|
predictions = await model_service.predict(
|
|
model_id=request.model_id,
|
|
data=request.data
|
|
)
|
|
return {"predictions": predictions.tolist()}
|
|
```
|
|
|
|
## 6. 安全模块
|
|
### 使用的库
|
|
```python
|
|
# 加密
|
|
from cryptography.fernet import Fernet
|
|
from passlib.hash import bcrypt
|
|
from jwt import encode, decode
|
|
|
|
# 访问控制
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from jose import JWTError, jwt
|
|
|
|
# 数据脱敏
|
|
import re
|
|
from typing import Pattern
|
|
```
|
|
|
|
### 6.1 数据安全
|
|
```python
|
|
class DataSecurity:
|
|
def __init__(self):
|
|
self.encryptor = DataEncryptor()
|
|
self.sanitizer = DataSanitizer()
|
|
|
|
def secure_save(self, data, path):
|
|
sanitized_data = self.sanitizer.clean(data)
|
|
encrypted_data = self.encryptor.encrypt(sanitized_data)
|
|
self._save_to_disk(encrypted_data, path)
|
|
```
|
|
|
|
### 6.2 模型安全
|
|
```python
|
|
class ModelSecurity:
|
|
def __init__(self):
|
|
self.access_control = AccessControl()
|
|
|
|
def check_permission(self, user_id, model_id, operation):
|
|
return self.access_control.has_permission(user_id, model_id, operation)
|
|
|
|
def encrypt_model(self, model, key):
|
|
return AESEncryption.encrypt(model.state_dict(), key)
|
|
```
|
|
|
|
## 7. 部署配置
|
|
### 使用的库
|
|
```python
|
|
# 容器化
|
|
import docker
|
|
from docker.types import Mount
|
|
|
|
# 任务队列
|
|
from celery import Celery
|
|
from redis import Redis
|
|
|
|
# 配置管理
|
|
import yaml
|
|
from dynaconf import Dynaconf
|
|
|
|
# 环境管理
|
|
import os
|
|
from dotenv import load_dotenv
|
|
```
|
|
|
|
### 7.1 Docker配置
|
|
```yaml
|
|
version: '3'
|
|
services:
|
|
api:
|
|
build: .
|
|
ports:
|
|
- "8000:8000"
|
|
environment:
|
|
- CUDA_VISIBLE_DEVICES=0
|
|
volumes:
|
|
- ./models:/app/models
|
|
|
|
worker:
|
|
build: .
|
|
command: celery -A tasks worker
|
|
environment:
|
|
- CUDA_VISIBLE_DEVICES=1
|
|
```
|
|
|
|
### 7.2 资源限制
|
|
```python
|
|
class ResourceLimiter:
|
|
def __init__(self):
|
|
self.limits = {
|
|
'max_data_size': 1024 * 1024 * 1024, # 1GB
|
|
'max_training_time': 3600, # 1 hour
|
|
'max_batch_size': 1024
|
|
}
|
|
|
|
def check_limits(self, config):
|
|
if config.data_size > self.limits['max_data_size']:
|
|
raise ValueError("Data size exceeds limit")
|
|
```
|
|
|
|
## 8. 版本信息
|
|
```python
|
|
DEPENDENCIES = {
|
|
# 基础依赖
|
|
'numpy': '>=1.21.0',
|
|
'pandas': '>=1.3.0',
|
|
'scikit-learn': '>=1.0.0',
|
|
|
|
# 机器学习框架
|
|
'torch': '>=2.0.0',
|
|
'xgboost': '>=1.5.0',
|
|
'lightgbm': '>=3.3.0',
|
|
'catboost': '>=1.0.0',
|
|
|
|
# 深度学习
|
|
'pytorch-lightning': '>=2.0.0',
|
|
'transformers': '>=4.30.0',
|
|
'diffusers': '>=0.19.0',
|
|
|
|
# 模型管理
|
|
'mlflow': '>=2.4.0',
|
|
'optuna': '>=3.2.0',
|
|
|
|
# Web框架
|
|
'fastapi': '>=0.95.0',
|
|
'celery': '>=5.3.0',
|
|
|
|
# 监控与日志
|
|
'prometheus-client': '>=0.16.0',
|
|
'tensorboard': '>=2.12.0',
|
|
|
|
# 安全
|
|
'cryptography': '>=40.0.0',
|
|
'passlib': '>=1.7.4',
|
|
|
|
# 部署
|
|
'docker': '>=6.1.0',
|
|
'redis': '>=4.5.0'
|
|
}
|
|
```
|