18 KiB
18 KiB
机器学习平台系统设计
0. 依赖库总览
# 基础依赖
import numpy as np
import pandas as pd
from typing import Dict, List
from dataclasses import dataclass
import uuid
import json
from datetime import datetime
# 数据处理
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest
# 机器学习框架
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier, CatBoostRegressor
# 深度学习
import pytorch_lightning as pl
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionPipeline
# 超参数优化
import optuna
from optuna.pruners import MedianPruner
# 模型管理
import mlflow
from mlflow.tracking import MlflowClient
# API框架
from fastapi import FastAPI, BackgroundTasks, Depends
from pydantic import BaseModel
# 分布式训练
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group
# 系统监控
import psutil
import GPUtil
from prometheus_client import Counter, Gauge
# 安全相关
from cryptography.fernet import Fernet
from passlib.hash import bcrypt
# 任务队列
from celery import Celery
from redis import Redis
# 日志管理
import logging
from tensorboard import SummaryWriter
1. 数据处理模块
使用的库
# 数据读取与处理
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest
# 数据版本控制
import hashlib
import json
from datetime import datetime
import uuid
1.1 数据读取
- 支持多种数据格式读取:
SUPPORTED_FORMATS = {
'csv': pd.read_csv,
'parquet': pd.read_parquet,
'hdf5': pd.read_hdf
}
- 数据读取参数配置:
READ_CONFIG = {
'csv': {
'encoding': 'utf-8',
'sep': ',',
'na_values': ['', 'NULL', 'null']
}
}
1.2 数据预处理
- 数据清洗流水线:
class DataCleaner:
def __init__(self):
self.pipeline = [
('missing_handler', SimpleImputer(strategy='mean')),
('outlier_detector', IsolationForest()),
('duplicates_remover', DuplicatesRemover())
]
def clean(self, data):
for name, processor in self.pipeline:
data = processor.transform(data)
return data
- 数据转换:
class DataTransformer:
def __init__(self):
self.numerical_pipeline = Pipeline([
('scaler', StandardScaler()),
('selector', VarianceThreshold(threshold=0.01))
])
self.categorical_pipeline = Pipeline([
('encoder', OneHotEncoder(sparse=False)),
('selector', SelectKBest(k=20))
])
- 特征工程:
class FeatureEngineer:
def __init__(self):
self.feature_generators = {
'polynomial': PolynomialFeatures(),
'interaction': InteractionFeatures(),
'time': TimeFeatures()
}
- 数据集划分:
class DataSplitter:
def split(self, data, test_size=0.2, val_size=0.1):
train, test = train_test_split(data, test_size=test_size)
train, val = train_test_split(train, test_size=val_size)
return {
'train': train,
'val': val,
'test': test
}
1.3 数据版本控制
class DataVersionControl:
def __init__(self):
self.version_db = {}
def save_version(self, data, metadata):
version_id = str(uuid.uuid4())
hash_value = self._calculate_hash(data)
self.version_db[version_id] = {
'hash': hash_value,
'timestamp': datetime.now(),
'metadata': metadata
}
return version_id
1.1 数据预处理方法库
1.1.1 缺失值处理
MISSING_VALUE_METHODS = {
'numerical': {
'mean': '使用均值填充',
'median': '使用中位数填充',
'mode': '使用众数填充',
'constant': '使用指定常数填充',
'interpolate': '使用插值方法填充(线性/多项式)',
'knn': '使用K近邻方法填充',
'iterative': '使用迭代方法填充(MICE)'
},
'categorical': {
'mode': '使用众数填充',
'constant': '使用指定常数填充',
'new_category': '创建新类别',
'knn': '使用K近邻方法填充'
}
}
1.1.2 数据标准化/归一化
SCALING_METHODS = {
'standardization': {
'description': 'Z-Score标准化,均值为0,方差为1',
'formula': '(x - mean) / std',
'use_case': '适用于数据服从正态分布的情况'
},
'min_max': {
'description': '将数据缩放到[0,1]区间',
'formula': '(x - min) / (max - min)',
'use_case': '适用于需要限定范围的情况'
},
'robust': {
'description': '使用四分位数进行缩放,对异常值不敏感',
'formula': '(x - Q2) / (Q3 - Q1)',
'use_case': '存在异常值时使用'
},
'log': {
'description': '对数变换',
'formula': 'log(x)',
'use_case': '处理长尾分布数据'
}
}
1.1.3 特征编码
ENCODING_METHODS = {
'categorical': {
'one_hot': {
'description': '独热编码',
'use_case': '类别之间无序关系',
'pros': '不引入大小关系',
'cons': '特征维度增加'
},
'label': {
'description': '标签编码',
'use_case': '类别之间有序关系',
'pros': '维度不变',
'cons': '引入大小关系'
},
'target': {
'description': '目标编码',
'use_case': '类别较多时',
'pros': '降低维度',
'cons': '可能过拟合'
}
},
'text': {
'tfidf': '词频-逆文档频率',
'word2vec': '词向量',
'bert': 'BERT编码'
}
}
1.1.4 异常值检测
OUTLIER_METHODS = {
'statistical': {
'zscore': {
'description': 'Z分数法',
'threshold': '通常为3个标准差'
},
'iqr': {
'description': '四分位数法',
'threshold': '1.5倍IQR'
}
},
'model_based': {
'isolation_forest': '基于孤立森林',
'one_class_svm': '单类支持向量机',
'local_outlier_factor': '局部异常因子'
}
}
1.2 特征工程方法
1.2.1 特征选择
FEATURE_SELECTION = {
'filter': {
'variance': '方差选择法',
'correlation': '相关系数法',
'chi2': '卡方检验',
'mutual_info': '互信息法'
},
'wrapper': {
'rfe': '递归特征消除',
'forward': '前向选择',
'backward': '后向消除'
},
'embedded': {
'lasso': 'L1正则化',
'ridge': 'L2正则化',
'tree_importance': '树模型特征重要性'
}
}
1.2.2 特征构造
FEATURE_CONSTRUCTION = {
'numerical': {
'polynomial': '多项式特征',
'interaction': '交互特征',
'binning': '分箱特征'
},
'datetime': {
'year': '年份提取',
'month': '月份提取',
'day': '日期提取',
'weekday': '星期提取',
'is_weekend': '是否周末'
},
'text': {
'length': '文本长度',
'word_count': '词数统计',
'ngram': 'N元语法特征'
}
}
2. 模型训练模块
使用的库
# 机器学习框架
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl
# 传统机器学习算法
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
from sklearn.ensemble import RandomForestClassifier
# 深度学习模型
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusionPipeline
# 超参数优化
import optuna
from optuna.pruners import MedianPruner
2.1 算法管理
- 算法注册表:
ALGORITHMS = {
'classification': {
'xgboost': {
'class': XGBClassifier,
'params': {
'max_depth': {'type': 'int', 'range': [3,10]},
'learning_rate': {'type': 'float', 'range': [0.001,0.1]},
'n_estimators': {'type': 'int', 'range': [100,1000]}
}
},
'lightgbm': {...},
'catboost': {...}
},
'regression': {
'xgboost_regressor': {...},
'lightgbm_regressor': {...}
},
'clustering': {
'kmeans': {...},
'dbscan': {...}
}
}
- 模型工厂:
class ModelFactory:
@staticmethod
def create_model(algorithm_name, params=None):
if algorithm_name not in ALGORITHMS:
raise ValueError(f"Algorithm {algorithm_name} not supported")
algorithm = ALGORITHMS[algorithm_name]
return algorithm['class'](**params if params else {})
2.2 训练管理
- 训练配置:
@dataclass
class TrainingConfig:
algorithm: str
params: Dict
data_version: str
device: str = 'cuda'
batch_size: int = 32
num_epochs: int = 100
early_stopping: bool = True
- 训练器:
class ModelTrainer:
def __init__(self, config: TrainingConfig):
self.config = config
self.model = ModelFactory.create_model(config.algorithm, config.params)
self.device = torch.device(config.device)
def train(self, train_data, val_data=None):
self.model.to(self.device)
optimizer = self._configure_optimizer()
scheduler = self._configure_scheduler()
for epoch in range(self.config.num_epochs):
train_loss = self._train_epoch(train_data)
if val_data is not None:
val_loss = self._validate(val_data)
if self.early_stopping.should_stop():
break
2.3 超参数优化
class HPOptimizer:
def __init__(self, study_name, n_trials=100):
self.study = optuna.create_study(
study_name=study_name,
direction='minimize'
)
self.n_trials = n_trials
def optimize(self, objective_fn):
self.study.optimize(
objective_fn,
n_trials=self.n_trials,
callbacks=[self._pruning_callback]
)
return self.study.best_params
3. 模型管理模块
使用的库
# 模型管理与追踪
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.models import Model
# 模型序列化
import pickle
import torch.serialization
import json
# 模型版本控制
from datetime import datetime
import hashlib
3.1 模型注册
class ModelRegistry:
def __init__(self, mlflow_uri):
self.client = MlflowClient(mlflow_uri)
def register_model(self, model_path, name, metadata=None):
model_version = self.client.create_model_version(
name=name,
source=model_path,
run_id=mlflow.active_run().info.run_id
)
if metadata:
self.client.set_model_version_tag(
name=name,
version=model_version.version,
key='metadata',
value=json.dumps(metadata)
)
3.2 模型加载
class ModelLoader:
@staticmethod
def load_model(model_id):
model_path = f"models/{model_id}"
return torch.load(model_path)
@staticmethod
def load_for_inference(model_id):
model = ModelLoader.load_model(model_id)
model.eval()
return model
3.3 模型服务
class ModelService:
def __init__(self):
self.loaded_models = {}
async def predict(self, model_id, data):
if model_id not in self.loaded_models:
self.loaded_models[model_id] = ModelLoader.load_for_inference(model_id)
model = self.loaded_models[model_id]
with torch.no_grad():
predictions = model(torch.tensor(data))
return predictions.numpy()
4. 系统监控模块
使用的库
# 系统监控
import psutil
import GPUtil
from prometheus_client import Counter, Gauge
# 日志管理
import logging
from tensorboard import SummaryWriter
# 性能分析
import cProfile
import line_profiler
import memory_profiler
4.1 资源监控
class ResourceMonitor:
def __init__(self):
self.gpu_util = GPUUtilization()
self.memory_util = MemoryUtilization()
def get_metrics(self):
return {
'gpu': self.gpu_util.get_usage(),
'memory': self.memory_util.get_usage(),
'cpu': psutil.cpu_percent()
}
4.2 训练监控
class TrainingMonitor:
def __init__(self):
self.metrics_history = defaultdict(list)
def log_metrics(self, metrics):
for name, value in metrics.items():
self.metrics_history[name].append(value)
def get_summary(self):
return {
name: {
'current': values[-1],
'mean': np.mean(values),
'std': np.std(values)
}
for name, values in self.metrics_history.items()
}
5. API接口设计
使用的库
# Web框架
from fastapi import FastAPI, BackgroundTasks, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, validator
# 异步支持
import asyncio
from concurrent.futures import ThreadPoolExecutor
# API文档
from fastapi.openapi.utils import get_openapi
5.1 训练接口
@router.post("/train")
async def start_training(
request: TrainingRequest,
background_tasks: BackgroundTasks
):
task_id = str(uuid.uuid4())
background_tasks.add_task(
train_model,
task_id=task_id,
config=request.dict()
)
return {"task_id": task_id}
5.2 预测接口
@router.post("/predict")
async def predict(
request: PredictionRequest,
model_service: ModelService = Depends()
):
predictions = await model_service.predict(
model_id=request.model_id,
data=request.data
)
return {"predictions": predictions.tolist()}
6. 安全模块
使用的库
# 加密
from cryptography.fernet import Fernet
from passlib.hash import bcrypt
from jwt import encode, decode
# 访问控制
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
# 数据脱敏
import re
from typing import Pattern
6.1 数据安全
class DataSecurity:
def __init__(self):
self.encryptor = DataEncryptor()
self.sanitizer = DataSanitizer()
def secure_save(self, data, path):
sanitized_data = self.sanitizer.clean(data)
encrypted_data = self.encryptor.encrypt(sanitized_data)
self._save_to_disk(encrypted_data, path)
6.2 模型安全
class ModelSecurity:
def __init__(self):
self.access_control = AccessControl()
def check_permission(self, user_id, model_id, operation):
return self.access_control.has_permission(user_id, model_id, operation)
def encrypt_model(self, model, key):
return AESEncryption.encrypt(model.state_dict(), key)
7. 部署配置
使用的库
# 容器化
import docker
from docker.types import Mount
# 任务队列
from celery import Celery
from redis import Redis
# 配置管理
import yaml
from dynaconf import Dynaconf
# 环境管理
import os
from dotenv import load_dotenv
7.1 Docker配置
version: '3'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- CUDA_VISIBLE_DEVICES=0
volumes:
- ./models:/app/models
worker:
build: .
command: celery -A tasks worker
environment:
- CUDA_VISIBLE_DEVICES=1
7.2 资源限制
class ResourceLimiter:
def __init__(self):
self.limits = {
'max_data_size': 1024 * 1024 * 1024, # 1GB
'max_training_time': 3600, # 1 hour
'max_batch_size': 1024
}
def check_limits(self, config):
if config.data_size > self.limits['max_data_size']:
raise ValueError("Data size exceeds limit")
8. 版本信息
DEPENDENCIES = {
# 基础依赖
'numpy': '>=1.21.0',
'pandas': '>=1.3.0',
'scikit-learn': '>=1.0.0',
# 机器学习框架
'torch': '>=2.0.0',
'xgboost': '>=1.5.0',
'lightgbm': '>=3.3.0',
'catboost': '>=1.0.0',
# 深度学习
'pytorch-lightning': '>=2.0.0',
'transformers': '>=4.30.0',
'diffusers': '>=0.19.0',
# 模型管理
'mlflow': '>=2.4.0',
'optuna': '>=3.2.0',
# Web框架
'fastapi': '>=0.95.0',
'celery': '>=5.3.0',
# 监控与日志
'prometheus-client': '>=0.16.0',
'tensorboard': '>=2.12.0',
# 安全
'cryptography': '>=40.0.0',
'passlib': '>=1.7.4',
# 部署
'docker': '>=6.1.0',
'redis': '>=4.5.0'
}