添加系统设计

This commit is contained in:
haotian 2025-02-13 16:08:07 +08:00
parent b580e077b6
commit 41ed35349c

611
doc/design.md Normal file
View File

@ -0,0 +1,611 @@
# 机器学习平台系统设计
## 0. 依赖库总览
```python
# 基础依赖
import numpy as np
import pandas as pd
from typing import Dict, List
from dataclasses import dataclass
import uuid
import json
from datetime import datetime
# 数据处理
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest
# 机器学习框架
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier, CatBoostRegressor
# 深度学习
import pytorch_lightning as pl
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionPipeline
# 超参数优化
import optuna
from optuna.pruners import MedianPruner
# 模型管理
import mlflow
from mlflow.tracking import MlflowClient
# API框架
from fastapi import FastAPI, BackgroundTasks, Depends
from pydantic import BaseModel
# 分布式训练
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group
# 系统监控
import psutil
import GPUtil
from prometheus_client import Counter, Gauge
# 安全相关
from cryptography.fernet import Fernet
from passlib.hash import bcrypt
# 任务队列
from celery import Celery
from redis import Redis
# 日志管理
import logging
from tensorboard import SummaryWriter
```
## 1. 数据处理模块
### 使用的库
```python
# 数据读取与处理
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest
# 数据版本控制
import hashlib
import json
from datetime import datetime
import uuid
```
### 1.1 数据读取
1. 支持多种数据格式读取:
```python
SUPPORTED_FORMATS = {
'csv': pd.read_csv,
'parquet': pd.read_parquet,
'hdf5': pd.read_hdf
}
```
2. 数据读取参数配置:
```python
READ_CONFIG = {
'csv': {
'encoding': 'utf-8',
'sep': ',',
'na_values': ['', 'NULL', 'null']
}
}
```
### 1.2 数据预处理
1. 数据清洗流水线:
```python
class DataCleaner:
def __init__(self):
self.pipeline = [
('missing_handler', SimpleImputer(strategy='mean')),
('outlier_detector', IsolationForest()),
('duplicates_remover', DuplicatesRemover())
]
def clean(self, data):
for name, processor in self.pipeline:
data = processor.transform(data)
return data
```
2. 数据转换:
```python
class DataTransformer:
def __init__(self):
self.numerical_pipeline = Pipeline([
('scaler', StandardScaler()),
('selector', VarianceThreshold(threshold=0.01))
])
self.categorical_pipeline = Pipeline([
('encoder', OneHotEncoder(sparse=False)),
('selector', SelectKBest(k=20))
])
```
3. 特征工程:
```python
class FeatureEngineer:
def __init__(self):
self.feature_generators = {
'polynomial': PolynomialFeatures(),
'interaction': InteractionFeatures(),
'time': TimeFeatures()
}
```
4. 数据集划分:
```python
class DataSplitter:
def split(self, data, test_size=0.2, val_size=0.1):
train, test = train_test_split(data, test_size=test_size)
train, val = train_test_split(train, test_size=val_size)
return {
'train': train,
'val': val,
'test': test
}
```
### 1.3 数据版本控制
```python
class DataVersionControl:
def __init__(self):
self.version_db = {}
def save_version(self, data, metadata):
version_id = str(uuid.uuid4())
hash_value = self._calculate_hash(data)
self.version_db[version_id] = {
'hash': hash_value,
'timestamp': datetime.now(),
'metadata': metadata
}
return version_id
```
## 2. 模型训练模块
### 使用的库
```python
# 机器学习框架
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl
# 传统机器学习算法
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
from sklearn.ensemble import RandomForestClassifier
# 深度学习模型
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionPipeline
# 超参数优化
import optuna
from optuna.pruners import MedianPruner
```
### 2.1 算法管理
1. 算法注册表:
```python
ALGORITHMS = {
'classification': {
'xgboost': {
'class': XGBClassifier,
'params': {
'max_depth': {'type': 'int', 'range': [3,10]},
'learning_rate': {'type': 'float', 'range': [0.001,0.1]},
'n_estimators': {'type': 'int', 'range': [100,1000]}
}
},
'lightgbm': {...},
'catboost': {...}
},
'regression': {
'xgboost_regressor': {...},
'lightgbm_regressor': {...}
},
'clustering': {
'kmeans': {...},
'dbscan': {...}
}
}
```
2. 模型工厂:
```python
class ModelFactory:
@staticmethod
def create_model(algorithm_name, params=None):
if algorithm_name not in ALGORITHMS:
raise ValueError(f"Algorithm {algorithm_name} not supported")
algorithm = ALGORITHMS[algorithm_name]
return algorithm['class'](**params if params else {})
```
### 2.2 训练管理
1. 训练配置:
```python
@dataclass
class TrainingConfig:
algorithm: str
params: Dict
data_version: str
device: str = 'cuda'
batch_size: int = 32
num_epochs: int = 100
early_stopping: bool = True
```
2. 训练器:
```python
class ModelTrainer:
def __init__(self, config: TrainingConfig):
self.config = config
self.model = ModelFactory.create_model(config.algorithm, config.params)
self.device = torch.device(config.device)
def train(self, train_data, val_data=None):
self.model.to(self.device)
optimizer = self._configure_optimizer()
scheduler = self._configure_scheduler()
for epoch in range(self.config.num_epochs):
train_loss = self._train_epoch(train_data)
if val_data is not None:
val_loss = self._validate(val_data)
if self.early_stopping.should_stop():
break
```
### 2.3 超参数优化
```python
class HPOptimizer:
def __init__(self, study_name, n_trials=100):
self.study = optuna.create_study(
study_name=study_name,
direction='minimize'
)
self.n_trials = n_trials
def optimize(self, objective_fn):
self.study.optimize(
objective_fn,
n_trials=self.n_trials,
callbacks=[self._pruning_callback]
)
return self.study.best_params
```
## 3. 模型管理模块
### 使用的库
```python
# 模型管理与追踪
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.models import Model
# 模型序列化
import pickle
import torch.serialization
import json
# 模型版本控制
from datetime import datetime
import hashlib
```
### 3.1 模型注册
```python
class ModelRegistry:
def __init__(self, mlflow_uri):
self.client = MlflowClient(mlflow_uri)
def register_model(self, model_path, name, metadata=None):
model_version = self.client.create_model_version(
name=name,
source=model_path,
run_id=mlflow.active_run().info.run_id
)
if metadata:
self.client.set_model_version_tag(
name=name,
version=model_version.version,
key='metadata',
value=json.dumps(metadata)
)
```
### 3.2 模型加载
```python
class ModelLoader:
@staticmethod
def load_model(model_id):
model_path = f"models/{model_id}"
return torch.load(model_path)
@staticmethod
def load_for_inference(model_id):
model = ModelLoader.load_model(model_id)
model.eval()
return model
```
### 3.3 模型服务
```python
class ModelService:
def __init__(self):
self.loaded_models = {}
async def predict(self, model_id, data):
if model_id not in self.loaded_models:
self.loaded_models[model_id] = ModelLoader.load_for_inference(model_id)
model = self.loaded_models[model_id]
with torch.no_grad():
predictions = model(torch.tensor(data))
return predictions.numpy()
```
## 4. 系统监控模块
### 使用的库
```python
# 系统监控
import psutil
import GPUtil
from prometheus_client import Counter, Gauge
# 日志管理
import logging
from tensorboard import SummaryWriter
# 性能分析
import cProfile
import line_profiler
import memory_profiler
```
### 4.1 资源监控
```python
class ResourceMonitor:
def __init__(self):
self.gpu_util = GPUUtilization()
self.memory_util = MemoryUtilization()
def get_metrics(self):
return {
'gpu': self.gpu_util.get_usage(),
'memory': self.memory_util.get_usage(),
'cpu': psutil.cpu_percent()
}
```
### 4.2 训练监控
```python
class TrainingMonitor:
def __init__(self):
self.metrics_history = defaultdict(list)
def log_metrics(self, metrics):
for name, value in metrics.items():
self.metrics_history[name].append(value)
def get_summary(self):
return {
name: {
'current': values[-1],
'mean': np.mean(values),
'std': np.std(values)
}
for name, values in self.metrics_history.items()
}
```
## 5. API接口设计
### 使用的库
```python
# Web框架
from fastapi import FastAPI, BackgroundTasks, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, validator
# 异步支持
import asyncio
from concurrent.futures import ThreadPoolExecutor
# API文档
from fastapi.openapi.utils import get_openapi
```
### 5.1 训练接口
```python
@router.post("/train")
async def start_training(
request: TrainingRequest,
background_tasks: BackgroundTasks
):
task_id = str(uuid.uuid4())
background_tasks.add_task(
train_model,
task_id=task_id,
config=request.dict()
)
return {"task_id": task_id}
```
### 5.2 预测接口
```python
@router.post("/predict")
async def predict(
request: PredictionRequest,
model_service: ModelService = Depends()
):
predictions = await model_service.predict(
model_id=request.model_id,
data=request.data
)
return {"predictions": predictions.tolist()}
```
## 6. 安全模块
### 使用的库
```python
# 加密
from cryptography.fernet import Fernet
from passlib.hash import bcrypt
from jwt import encode, decode
# 访问控制
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
# 数据脱敏
import re
from typing import Pattern
```
### 6.1 数据安全
```python
class DataSecurity:
def __init__(self):
self.encryptor = DataEncryptor()
self.sanitizer = DataSanitizer()
def secure_save(self, data, path):
sanitized_data = self.sanitizer.clean(data)
encrypted_data = self.encryptor.encrypt(sanitized_data)
self._save_to_disk(encrypted_data, path)
```
### 6.2 模型安全
```python
class ModelSecurity:
def __init__(self):
self.access_control = AccessControl()
def check_permission(self, user_id, model_id, operation):
return self.access_control.has_permission(user_id, model_id, operation)
def encrypt_model(self, model, key):
return AESEncryption.encrypt(model.state_dict(), key)
```
## 7. 部署配置
### 使用的库
```python
# 容器化
import docker
from docker.types import Mount
# 任务队列
from celery import Celery
from redis import Redis
# 配置管理
import yaml
from dynaconf import Dynaconf
# 环境管理
import os
from dotenv import load_dotenv
```
### 7.1 Docker配置
```yaml
version: '3'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- CUDA_VISIBLE_DEVICES=0
volumes:
- ./models:/app/models
worker:
build: .
command: celery -A tasks worker
environment:
- CUDA_VISIBLE_DEVICES=1
```
### 7.2 资源限制
```python
class ResourceLimiter:
def __init__(self):
self.limits = {
'max_data_size': 1024 * 1024 * 1024, # 1GB
'max_training_time': 3600, # 1 hour
'max_batch_size': 1024
}
def check_limits(self, config):
if config.data_size > self.limits['max_data_size']:
raise ValueError("Data size exceeds limit")
```
## 8. 版本信息
```python
DEPENDENCIES = {
# 基础依赖
'numpy': '>=1.21.0',
'pandas': '>=1.3.0',
'scikit-learn': '>=1.0.0',
# 机器学习框架
'torch': '>=2.0.0',
'xgboost': '>=1.5.0',
'lightgbm': '>=3.3.0',
'catboost': '>=1.0.0',
# 深度学习
'pytorch-lightning': '>=2.0.0',
'transformers': '>=4.30.0',
'diffusers': '>=0.19.0',
# 模型管理
'mlflow': '>=2.4.0',
'optuna': '>=3.2.0',
# Web框架
'fastapi': '>=0.95.0',
'celery': '>=5.3.0',
# 监控与日志
'prometheus-client': '>=0.16.0',
'tensorboard': '>=2.12.0',
# 安全
'cryptography': '>=40.0.0',
'passlib': '>=1.7.4',
# 部署
'docker': '>=6.1.0',
'redis': '>=4.5.0'
}
```