MLPlatform/doc/design.md

18 KiB

机器学习平台系统设计

0. 依赖库总览

# 基础依赖
import numpy as np
import pandas as pd
from typing import Dict, List
from dataclasses import dataclass
import uuid
import json
from datetime import datetime

# 数据处理
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest

# 机器学习框架
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier, CatBoostRegressor

# 深度学习
import pytorch_lightning as pl
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionPipeline

# 超参数优化
import optuna
from optuna.pruners import MedianPruner

# 模型管理
import mlflow
from mlflow.tracking import MlflowClient

# API框架
from fastapi import FastAPI, BackgroundTasks, Depends
from pydantic import BaseModel

# 分布式训练
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group

# 系统监控
import psutil
import GPUtil
from prometheus_client import Counter, Gauge

# 安全相关
from cryptography.fernet import Fernet
from passlib.hash import bcrypt

# 任务队列
from celery import Celery
from redis import Redis

# 日志管理
import logging
from tensorboard import SummaryWriter

1. 数据处理模块

使用的库

# 数据读取与处理
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest

# 数据版本控制
import hashlib
import json
from datetime import datetime
import uuid

1.1 数据读取

  1. 支持多种数据格式读取:
SUPPORTED_FORMATS = {
    'csv': pd.read_csv,
    'parquet': pd.read_parquet, 
    'hdf5': pd.read_hdf
}
  1. 数据读取参数配置:
READ_CONFIG = {
    'csv': {
        'encoding': 'utf-8',
        'sep': ',',
        'na_values': ['', 'NULL', 'null']
    }
}

1.2 数据预处理

  1. 数据清洗流水线:
class DataCleaner:
    def __init__(self):
        self.pipeline = [
            ('missing_handler', SimpleImputer(strategy='mean')),
            ('outlier_detector', IsolationForest()),
            ('duplicates_remover', DuplicatesRemover())
        ]
        
    def clean(self, data):
        for name, processor in self.pipeline:
            data = processor.transform(data)
        return data
  1. 数据转换:
class DataTransformer:
    def __init__(self):
        self.numerical_pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('selector', VarianceThreshold(threshold=0.01))
        ])
        
        self.categorical_pipeline = Pipeline([
            ('encoder', OneHotEncoder(sparse=False)),
            ('selector', SelectKBest(k=20))
        ])
  1. 特征工程:
class FeatureEngineer:
    def __init__(self):
        self.feature_generators = {
            'polynomial': PolynomialFeatures(),
            'interaction': InteractionFeatures(),
            'time': TimeFeatures()
        }
  1. 数据集划分:
class DataSplitter:
    def split(self, data, test_size=0.2, val_size=0.1):
        train, test = train_test_split(data, test_size=test_size)
        train, val = train_test_split(train, test_size=val_size)
        return {
            'train': train,
            'val': val, 
            'test': test
        }

1.3 数据版本控制

class DataVersionControl:
    def __init__(self):
        self.version_db = {}
        
    def save_version(self, data, metadata):
        version_id = str(uuid.uuid4())
        hash_value = self._calculate_hash(data)
        
        self.version_db[version_id] = {
            'hash': hash_value,
            'timestamp': datetime.now(),
            'metadata': metadata
        }
        return version_id

1.1 数据预处理方法库

1.1.1 缺失值处理

MISSING_VALUE_METHODS = {
    'numerical': {
        'mean': '使用均值填充',
        'median': '使用中位数填充',
        'mode': '使用众数填充',
        'constant': '使用指定常数填充',
        'interpolate': '使用插值方法填充(线性/多项式)',
        'knn': '使用K近邻方法填充',
        'iterative': '使用迭代方法填充(MICE)'
    },
    'categorical': {
        'mode': '使用众数填充',
        'constant': '使用指定常数填充',
        'new_category': '创建新类别',
        'knn': '使用K近邻方法填充'
    }
}

1.1.2 数据标准化/归一化

SCALING_METHODS = {
    'standardization': {
        'description': 'Z-Score标准化,均值为0,方差为1',
        'formula': '(x - mean) / std',
        'use_case': '适用于数据服从正态分布的情况'
    },
    'min_max': {
        'description': '将数据缩放到[0,1]区间',
        'formula': '(x - min) / (max - min)',
        'use_case': '适用于需要限定范围的情况'
    },
    'robust': {
        'description': '使用四分位数进行缩放,对异常值不敏感',
        'formula': '(x - Q2) / (Q3 - Q1)',
        'use_case': '存在异常值时使用'
    },
    'log': {
        'description': '对数变换',
        'formula': 'log(x)',
        'use_case': '处理长尾分布数据'
    }
}

1.1.3 特征编码

ENCODING_METHODS = {
    'categorical': {
        'one_hot': {
            'description': '独热编码',
            'use_case': '类别之间无序关系',
            'pros': '不引入大小关系',
            'cons': '特征维度增加'
        },
        'label': {
            'description': '标签编码',
            'use_case': '类别之间有序关系',
            'pros': '维度不变',
            'cons': '引入大小关系'
        },
        'target': {
            'description': '目标编码',
            'use_case': '类别较多时',
            'pros': '降低维度',
            'cons': '可能过拟合'
        }
    },
    'text': {
        'tfidf': '词频-逆文档频率',
        'word2vec': '词向量',
        'bert': 'BERT编码'
    }
}

1.1.4 异常值检测

OUTLIER_METHODS = {
    'statistical': {
        'zscore': {
            'description': 'Z分数法',
            'threshold': '通常为3个标准差'
        },
        'iqr': {
            'description': '四分位数法',
            'threshold': '1.5倍IQR'
        }
    },
    'model_based': {
        'isolation_forest': '基于孤立森林',
        'one_class_svm': '单类支持向量机',
        'local_outlier_factor': '局部异常因子'
    }
}

1.2 特征工程方法

1.2.1 特征选择

FEATURE_SELECTION = {
    'filter': {
        'variance': '方差选择法',
        'correlation': '相关系数法',
        'chi2': '卡方检验',
        'mutual_info': '互信息法'
    },
    'wrapper': {
        'rfe': '递归特征消除',
        'forward': '前向选择',
        'backward': '后向消除'
    },
    'embedded': {
        'lasso': 'L1正则化',
        'ridge': 'L2正则化',
        'tree_importance': '树模型特征重要性'
    }
}

1.2.2 特征构造

FEATURE_CONSTRUCTION = {
    'numerical': {
        'polynomial': '多项式特征',
        'interaction': '交互特征',
        'binning': '分箱特征'
    },
    'datetime': {
        'year': '年份提取',
        'month': '月份提取',
        'day': '日期提取',
        'weekday': '星期提取',
        'is_weekend': '是否周末'
    },
    'text': {
        'length': '文本长度',
        'word_count': '词数统计',
        'ngram': 'N元语法特征'
    }
}

2. 模型训练模块

使用的库

# 机器学习框架
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# 传统机器学习算法
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
from sklearn.ensemble import RandomForestClassifier

# 深度学习模型
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionPipeline

# 超参数优化
import optuna
from optuna.pruners import MedianPruner

2.1 算法管理

  1. 算法注册表:
ALGORITHMS = {
    'classification': {
        'xgboost': {
            'class': XGBClassifier,
            'params': {
                'max_depth': {'type': 'int', 'range': [3,10]},
                'learning_rate': {'type': 'float', 'range': [0.001,0.1]},
                'n_estimators': {'type': 'int', 'range': [100,1000]}
            }
        },
        'lightgbm': {...},
        'catboost': {...}
    },
    'regression': {
        'xgboost_regressor': {...},
        'lightgbm_regressor': {...}
    },
    'clustering': {
        'kmeans': {...},
        'dbscan': {...}
    }
}
  1. 模型工厂:
class ModelFactory:
    @staticmethod
    def create_model(algorithm_name, params=None):
        if algorithm_name not in ALGORITHMS:
            raise ValueError(f"Algorithm {algorithm_name} not supported")
            
        algorithm = ALGORITHMS[algorithm_name]
        return algorithm['class'](**params if params else {})

2.2 训练管理

  1. 训练配置:
@dataclass
class TrainingConfig:
    algorithm: str
    params: Dict
    data_version: str
    device: str = 'cuda'
    batch_size: int = 32
    num_epochs: int = 100
    early_stopping: bool = True
  1. 训练器:
class ModelTrainer:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = ModelFactory.create_model(config.algorithm, config.params)
        self.device = torch.device(config.device)
        
    def train(self, train_data, val_data=None):
        self.model.to(self.device)
        optimizer = self._configure_optimizer()
        scheduler = self._configure_scheduler()
        
        for epoch in range(self.config.num_epochs):
            train_loss = self._train_epoch(train_data)
            if val_data is not None:
                val_loss = self._validate(val_data)
                
            if self.early_stopping.should_stop():
                break

2.3 超参数优化

class HPOptimizer:
    def __init__(self, study_name, n_trials=100):
        self.study = optuna.create_study(
            study_name=study_name,
            direction='minimize'
        )
        self.n_trials = n_trials
        
    def optimize(self, objective_fn):
        self.study.optimize(
            objective_fn,
            n_trials=self.n_trials,
            callbacks=[self._pruning_callback]
        )
        return self.study.best_params

3. 模型管理模块

使用的库

# 模型管理与追踪
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.models import Model

# 模型序列化
import pickle
import torch.serialization
import json

# 模型版本控制
from datetime import datetime
import hashlib

3.1 模型注册

class ModelRegistry:
    def __init__(self, mlflow_uri):
        self.client = MlflowClient(mlflow_uri)
        
    def register_model(self, model_path, name, metadata=None):
        model_version = self.client.create_model_version(
            name=name,
            source=model_path,
            run_id=mlflow.active_run().info.run_id
        )
        
        if metadata:
            self.client.set_model_version_tag(
                name=name,
                version=model_version.version,
                key='metadata',
                value=json.dumps(metadata)
            )

3.2 模型加载

class ModelLoader:
    @staticmethod
    def load_model(model_id):
        model_path = f"models/{model_id}"
        return torch.load(model_path)
        
    @staticmethod
    def load_for_inference(model_id):
        model = ModelLoader.load_model(model_id)
        model.eval()
        return model

3.3 模型服务

class ModelService:
    def __init__(self):
        self.loaded_models = {}
        
    async def predict(self, model_id, data):
        if model_id not in self.loaded_models:
            self.loaded_models[model_id] = ModelLoader.load_for_inference(model_id)
            
        model = self.loaded_models[model_id]
        with torch.no_grad():
            predictions = model(torch.tensor(data))
        return predictions.numpy()

4. 系统监控模块

使用的库

# 系统监控
import psutil
import GPUtil
from prometheus_client import Counter, Gauge

# 日志管理
import logging
from tensorboard import SummaryWriter

# 性能分析
import cProfile
import line_profiler
import memory_profiler

4.1 资源监控

class ResourceMonitor:
    def __init__(self):
        self.gpu_util = GPUUtilization()
        self.memory_util = MemoryUtilization()
        
    def get_metrics(self):
        return {
            'gpu': self.gpu_util.get_usage(),
            'memory': self.memory_util.get_usage(),
            'cpu': psutil.cpu_percent()
        }

4.2 训练监控

class TrainingMonitor:
    def __init__(self):
        self.metrics_history = defaultdict(list)
        
    def log_metrics(self, metrics):
        for name, value in metrics.items():
            self.metrics_history[name].append(value)
            
    def get_summary(self):
        return {
            name: {
                'current': values[-1],
                'mean': np.mean(values),
                'std': np.std(values)
            }
            for name, values in self.metrics_history.items()
        }

5. API接口设计

使用的库

# Web框架
from fastapi import FastAPI, BackgroundTasks, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, validator

# 异步支持
import asyncio
from concurrent.futures import ThreadPoolExecutor

# API文档
from fastapi.openapi.utils import get_openapi

5.1 训练接口

@router.post("/train")
async def start_training(
    request: TrainingRequest,
    background_tasks: BackgroundTasks
):
    task_id = str(uuid.uuid4())
    background_tasks.add_task(
        train_model,
        task_id=task_id,
        config=request.dict()
    )
    return {"task_id": task_id}

5.2 预测接口

@router.post("/predict")
async def predict(
    request: PredictionRequest,
    model_service: ModelService = Depends()
):
    predictions = await model_service.predict(
        model_id=request.model_id,
        data=request.data
    )
    return {"predictions": predictions.tolist()}

6. 安全模块

使用的库

# 加密
from cryptography.fernet import Fernet
from passlib.hash import bcrypt
from jwt import encode, decode

# 访问控制
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt

# 数据脱敏
import re
from typing import Pattern

6.1 数据安全

class DataSecurity:
    def __init__(self):
        self.encryptor = DataEncryptor()
        self.sanitizer = DataSanitizer()
        
    def secure_save(self, data, path):
        sanitized_data = self.sanitizer.clean(data)
        encrypted_data = self.encryptor.encrypt(sanitized_data)
        self._save_to_disk(encrypted_data, path)

6.2 模型安全

class ModelSecurity:
    def __init__(self):
        self.access_control = AccessControl()
        
    def check_permission(self, user_id, model_id, operation):
        return self.access_control.has_permission(user_id, model_id, operation)
        
    def encrypt_model(self, model, key):
        return AESEncryption.encrypt(model.state_dict(), key)

7. 部署配置

使用的库

# 容器化
import docker
from docker.types import Mount

# 任务队列
from celery import Celery
from redis import Redis

# 配置管理
import yaml
from dynaconf import Dynaconf

# 环境管理
import os
from dotenv import load_dotenv

7.1 Docker配置

version: '3'
services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - CUDA_VISIBLE_DEVICES=0
    volumes:
      - ./models:/app/models
      
  worker:
    build: .
    command: celery -A tasks worker
    environment:
      - CUDA_VISIBLE_DEVICES=1

7.2 资源限制

class ResourceLimiter:
    def __init__(self):
        self.limits = {
            'max_data_size': 1024 * 1024 * 1024,  # 1GB
            'max_training_time': 3600,  # 1 hour
            'max_batch_size': 1024
        }
    
    def check_limits(self, config):
        if config.data_size > self.limits['max_data_size']:
            raise ValueError("Data size exceeds limit")

8. 版本信息

DEPENDENCIES = {
    # 基础依赖
    'numpy': '>=1.21.0',
    'pandas': '>=1.3.0',
    'scikit-learn': '>=1.0.0',
    
    # 机器学习框架
    'torch': '>=2.0.0',
    'xgboost': '>=1.5.0',
    'lightgbm': '>=3.3.0',
    'catboost': '>=1.0.0',
    
    # 深度学习
    'pytorch-lightning': '>=2.0.0',
    'transformers': '>=4.30.0',
    'diffusers': '>=0.19.0',
    
    # 模型管理
    'mlflow': '>=2.4.0',
    'optuna': '>=3.2.0',
    
    # Web框架
    'fastapi': '>=0.95.0',
    'celery': '>=5.3.0',
    
    # 监控与日志
    'prometheus-client': '>=0.16.0',
    'tensorboard': '>=2.12.0',
    
    # 安全
    'cryptography': '>=40.0.0',
    'passlib': '>=1.7.4',
    
    # 部署
    'docker': '>=6.1.0',
    'redis': '>=4.5.0'
}