diff --git a/doc/design.md b/doc/design.md new file mode 100644 index 0000000..4007473 --- /dev/null +++ b/doc/design.md @@ -0,0 +1,611 @@ +# 机器学习平台系统设计 + +## 0. 依赖库总览 +```python +# 基础依赖 +import numpy as np +import pandas as pd +from typing import Dict, List +from dataclasses import dataclass +import uuid +import json +from datetime import datetime + +# 数据处理 +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from sklearn.impute import SimpleImputer +from sklearn.feature_selection import SelectKBest, VarianceThreshold +from sklearn.pipeline import Pipeline +from sklearn.model_selection import train_test_split +from sklearn.ensemble import IsolationForest + +# 机器学习框架 +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import xgboost as xgb +import lightgbm as lgb +from catboost import CatBoostClassifier, CatBoostRegressor + +# 深度学习 +import pytorch_lightning as pl +from transformers import AutoModel, AutoTokenizer +from diffusers import StableDiffusionPipeline + +# 超参数优化 +import optuna +from optuna.pruners import MedianPruner + +# 模型管理 +import mlflow +from mlflow.tracking import MlflowClient + +# API框架 +from fastapi import FastAPI, BackgroundTasks, Depends +from pydantic import BaseModel + +# 分布式训练 +from torch.nn.parallel import DistributedDataParallel +from torch.distributed import init_process_group + +# 系统监控 +import psutil +import GPUtil +from prometheus_client import Counter, Gauge + +# 安全相关 +from cryptography.fernet import Fernet +from passlib.hash import bcrypt + +# 任务队列 +from celery import Celery +from redis import Redis + +# 日志管理 +import logging +from tensorboard import SummaryWriter +``` + +## 1. 数据处理模块 +### 使用的库 +```python +# 数据读取与处理 +import pandas as pd +import numpy as np +from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder +from sklearn.impute import SimpleImputer +from sklearn.feature_selection import SelectKBest, VarianceThreshold +from sklearn.pipeline import Pipeline +from sklearn.model_selection import train_test_split +from sklearn.ensemble import IsolationForest + +# 数据版本控制 +import hashlib +import json +from datetime import datetime +import uuid +``` + +### 1.1 数据读取 +1. 支持多种数据格式读取: +```python +SUPPORTED_FORMATS = { + 'csv': pd.read_csv, + 'parquet': pd.read_parquet, + 'hdf5': pd.read_hdf +} +``` + +2. 数据读取参数配置: +```python +READ_CONFIG = { + 'csv': { + 'encoding': 'utf-8', + 'sep': ',', + 'na_values': ['', 'NULL', 'null'] + } +} +``` + +### 1.2 数据预处理 +1. 数据清洗流水线: +```python +class DataCleaner: + def __init__(self): + self.pipeline = [ + ('missing_handler', SimpleImputer(strategy='mean')), + ('outlier_detector', IsolationForest()), + ('duplicates_remover', DuplicatesRemover()) + ] + + def clean(self, data): + for name, processor in self.pipeline: + data = processor.transform(data) + return data +``` + +2. 数据转换: +```python +class DataTransformer: + def __init__(self): + self.numerical_pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('selector', VarianceThreshold(threshold=0.01)) + ]) + + self.categorical_pipeline = Pipeline([ + ('encoder', OneHotEncoder(sparse=False)), + ('selector', SelectKBest(k=20)) + ]) +``` + +3. 特征工程: +```python +class FeatureEngineer: + def __init__(self): + self.feature_generators = { + 'polynomial': PolynomialFeatures(), + 'interaction': InteractionFeatures(), + 'time': TimeFeatures() + } +``` + +4. 数据集划分: +```python +class DataSplitter: + def split(self, data, test_size=0.2, val_size=0.1): + train, test = train_test_split(data, test_size=test_size) + train, val = train_test_split(train, test_size=val_size) + return { + 'train': train, + 'val': val, + 'test': test + } +``` + +### 1.3 数据版本控制 +```python +class DataVersionControl: + def __init__(self): + self.version_db = {} + + def save_version(self, data, metadata): + version_id = str(uuid.uuid4()) + hash_value = self._calculate_hash(data) + + self.version_db[version_id] = { + 'hash': hash_value, + 'timestamp': datetime.now(), + 'metadata': metadata + } + return version_id +``` + +## 2. 模型训练模块 +### 使用的库 +```python +# 机器学习框架 +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +import pytorch_lightning as pl + +# 传统机器学习算法 +import xgboost as xgb +import lightgbm as lgb +from catboost import CatBoostClassifier +from sklearn.ensemble import RandomForestClassifier + +# 深度学习模型 +from transformers import AutoModel, AutoTokenizer +from diffusers import StableDiffusionPipeline + +# 超参数优化 +import optuna +from optuna.pruners import MedianPruner +``` + +### 2.1 算法管理 +1. 算法注册表: +```python +ALGORITHMS = { + 'classification': { + 'xgboost': { + 'class': XGBClassifier, + 'params': { + 'max_depth': {'type': 'int', 'range': [3,10]}, + 'learning_rate': {'type': 'float', 'range': [0.001,0.1]}, + 'n_estimators': {'type': 'int', 'range': [100,1000]} + } + }, + 'lightgbm': {...}, + 'catboost': {...} + }, + 'regression': { + 'xgboost_regressor': {...}, + 'lightgbm_regressor': {...} + }, + 'clustering': { + 'kmeans': {...}, + 'dbscan': {...} + } +} +``` + +2. 模型工厂: +```python +class ModelFactory: + @staticmethod + def create_model(algorithm_name, params=None): + if algorithm_name not in ALGORITHMS: + raise ValueError(f"Algorithm {algorithm_name} not supported") + + algorithm = ALGORITHMS[algorithm_name] + return algorithm['class'](**params if params else {}) +``` + +### 2.2 训练管理 +1. 训练配置: +```python +@dataclass +class TrainingConfig: + algorithm: str + params: Dict + data_version: str + device: str = 'cuda' + batch_size: int = 32 + num_epochs: int = 100 + early_stopping: bool = True +``` + +2. 训练器: +```python +class ModelTrainer: + def __init__(self, config: TrainingConfig): + self.config = config + self.model = ModelFactory.create_model(config.algorithm, config.params) + self.device = torch.device(config.device) + + def train(self, train_data, val_data=None): + self.model.to(self.device) + optimizer = self._configure_optimizer() + scheduler = self._configure_scheduler() + + for epoch in range(self.config.num_epochs): + train_loss = self._train_epoch(train_data) + if val_data is not None: + val_loss = self._validate(val_data) + + if self.early_stopping.should_stop(): + break +``` + +### 2.3 超参数优化 +```python +class HPOptimizer: + def __init__(self, study_name, n_trials=100): + self.study = optuna.create_study( + study_name=study_name, + direction='minimize' + ) + self.n_trials = n_trials + + def optimize(self, objective_fn): + self.study.optimize( + objective_fn, + n_trials=self.n_trials, + callbacks=[self._pruning_callback] + ) + return self.study.best_params +``` + +## 3. 模型管理模块 +### 使用的库 +```python +# 模型管理与追踪 +import mlflow +from mlflow.tracking import MlflowClient +from mlflow.models import Model + +# 模型序列化 +import pickle +import torch.serialization +import json + +# 模型版本控制 +from datetime import datetime +import hashlib +``` + +### 3.1 模型注册 +```python +class ModelRegistry: + def __init__(self, mlflow_uri): + self.client = MlflowClient(mlflow_uri) + + def register_model(self, model_path, name, metadata=None): + model_version = self.client.create_model_version( + name=name, + source=model_path, + run_id=mlflow.active_run().info.run_id + ) + + if metadata: + self.client.set_model_version_tag( + name=name, + version=model_version.version, + key='metadata', + value=json.dumps(metadata) + ) +``` + +### 3.2 模型加载 +```python +class ModelLoader: + @staticmethod + def load_model(model_id): + model_path = f"models/{model_id}" + return torch.load(model_path) + + @staticmethod + def load_for_inference(model_id): + model = ModelLoader.load_model(model_id) + model.eval() + return model +``` + +### 3.3 模型服务 +```python +class ModelService: + def __init__(self): + self.loaded_models = {} + + async def predict(self, model_id, data): + if model_id not in self.loaded_models: + self.loaded_models[model_id] = ModelLoader.load_for_inference(model_id) + + model = self.loaded_models[model_id] + with torch.no_grad(): + predictions = model(torch.tensor(data)) + return predictions.numpy() +``` + +## 4. 系统监控模块 +### 使用的库 +```python +# 系统监控 +import psutil +import GPUtil +from prometheus_client import Counter, Gauge + +# 日志管理 +import logging +from tensorboard import SummaryWriter + +# 性能分析 +import cProfile +import line_profiler +import memory_profiler +``` + +### 4.1 资源监控 +```python +class ResourceMonitor: + def __init__(self): + self.gpu_util = GPUUtilization() + self.memory_util = MemoryUtilization() + + def get_metrics(self): + return { + 'gpu': self.gpu_util.get_usage(), + 'memory': self.memory_util.get_usage(), + 'cpu': psutil.cpu_percent() + } +``` + +### 4.2 训练监控 +```python +class TrainingMonitor: + def __init__(self): + self.metrics_history = defaultdict(list) + + def log_metrics(self, metrics): + for name, value in metrics.items(): + self.metrics_history[name].append(value) + + def get_summary(self): + return { + name: { + 'current': values[-1], + 'mean': np.mean(values), + 'std': np.std(values) + } + for name, values in self.metrics_history.items() + } +``` + +## 5. API接口设计 +### 使用的库 +```python +# Web框架 +from fastapi import FastAPI, BackgroundTasks, Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from pydantic import BaseModel, validator + +# 异步支持 +import asyncio +from concurrent.futures import ThreadPoolExecutor + +# API文档 +from fastapi.openapi.utils import get_openapi +``` + +### 5.1 训练接口 +```python +@router.post("/train") +async def start_training( + request: TrainingRequest, + background_tasks: BackgroundTasks +): + task_id = str(uuid.uuid4()) + background_tasks.add_task( + train_model, + task_id=task_id, + config=request.dict() + ) + return {"task_id": task_id} +``` + +### 5.2 预测接口 +```python +@router.post("/predict") +async def predict( + request: PredictionRequest, + model_service: ModelService = Depends() +): + predictions = await model_service.predict( + model_id=request.model_id, + data=request.data + ) + return {"predictions": predictions.tolist()} +``` + +## 6. 安全模块 +### 使用的库 +```python +# 加密 +from cryptography.fernet import Fernet +from passlib.hash import bcrypt +from jwt import encode, decode + +# 访问控制 +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt + +# 数据脱敏 +import re +from typing import Pattern +``` + +### 6.1 数据安全 +```python +class DataSecurity: + def __init__(self): + self.encryptor = DataEncryptor() + self.sanitizer = DataSanitizer() + + def secure_save(self, data, path): + sanitized_data = self.sanitizer.clean(data) + encrypted_data = self.encryptor.encrypt(sanitized_data) + self._save_to_disk(encrypted_data, path) +``` + +### 6.2 模型安全 +```python +class ModelSecurity: + def __init__(self): + self.access_control = AccessControl() + + def check_permission(self, user_id, model_id, operation): + return self.access_control.has_permission(user_id, model_id, operation) + + def encrypt_model(self, model, key): + return AESEncryption.encrypt(model.state_dict(), key) +``` + +## 7. 部署配置 +### 使用的库 +```python +# 容器化 +import docker +from docker.types import Mount + +# 任务队列 +from celery import Celery +from redis import Redis + +# 配置管理 +import yaml +from dynaconf import Dynaconf + +# 环境管理 +import os +from dotenv import load_dotenv +``` + +### 7.1 Docker配置 +```yaml +version: '3' +services: + api: + build: . + ports: + - "8000:8000" + environment: + - CUDA_VISIBLE_DEVICES=0 + volumes: + - ./models:/app/models + + worker: + build: . + command: celery -A tasks worker + environment: + - CUDA_VISIBLE_DEVICES=1 +``` + +### 7.2 资源限制 +```python +class ResourceLimiter: + def __init__(self): + self.limits = { + 'max_data_size': 1024 * 1024 * 1024, # 1GB + 'max_training_time': 3600, # 1 hour + 'max_batch_size': 1024 + } + + def check_limits(self, config): + if config.data_size > self.limits['max_data_size']: + raise ValueError("Data size exceeds limit") +``` + +## 8. 版本信息 +```python +DEPENDENCIES = { + # 基础依赖 + 'numpy': '>=1.21.0', + 'pandas': '>=1.3.0', + 'scikit-learn': '>=1.0.0', + + # 机器学习框架 + 'torch': '>=2.0.0', + 'xgboost': '>=1.5.0', + 'lightgbm': '>=3.3.0', + 'catboost': '>=1.0.0', + + # 深度学习 + 'pytorch-lightning': '>=2.0.0', + 'transformers': '>=4.30.0', + 'diffusers': '>=0.19.0', + + # 模型管理 + 'mlflow': '>=2.4.0', + 'optuna': '>=3.2.0', + + # Web框架 + 'fastapi': '>=0.95.0', + 'celery': '>=5.3.0', + + # 监控与日志 + 'prometheus-client': '>=0.16.0', + 'tensorboard': '>=2.12.0', + + # 安全 + 'cryptography': '>=40.0.0', + 'passlib': '>=1.7.4', + + # 部署 + 'docker': '>=6.1.0', + 'redis': '>=4.5.0' +} +```