MLPlatform/doc/ml_platform_spec.md

15 KiB
Raw Blame History

内网机器学习平台技术方案

目录

  1. 项目概述
  2. 系统架构
  3. 核心功能
  4. 部署指南
  5. API接口文档
  6. 注意事项

1. 项目概述

1.1 目标

为科研单位提供内网机器学习平台,支持:

  • 算法选择与模型训练
  • 超参数自动优化
  • 模型评估与部署
  • 批量预测服务

1.2 技术栈

组件 技术选型
Web框架 FastAPI
任务队列 Celery + Redis
数据存储 MySQL
模型管理 MLflow
超参优化 Optuna

1.2 技术栈扩展

功能领域 开源库 版本 用途说明
基础框架 Python 3.8+ 主开发语言
FastAPI 0.88+ REST API开发
数据处理 Pandas 1.5+ 数据操作与分析
NumPy 1.23+ 数值计算基础
传统ML Scikit-learn 1.2+ 经典机器学习算法
XGBoost 1.7+ 梯度提升树模型
LightGBM 3.3+ 高效梯度提升框架
深度学习 PyTorch 2.0+ 神经网络框架
PyTorch Lightning 2.0+ 深度学习训练框架
Transformers 4.28+ 预训练模型库
时序分析 Prophet 1.1+ 时间序列预测
优化调参 Optuna 3.1+ 超参数优化框架
可视化 Matplotlib 3.7+ 基础可视化
Plotly 5.14+ 交互式可视化
模型解释 SHAP 0.42+ 模型可解释性分析

2. 系统架构

graph TD
A[Web UI/API] --> B[FastAPI]
B --> C[算法仓库]
B --> D[MySQL]
B --> E[Celery Worker]
E --> F[GPU训练]
C --> G[MLflow模型仓库]
G --> H[批量预测服务]

3. 核心功能

3.1 算法管理

预置算法示例
ALGORITHMS = {
"xgboost": {
"type": "classification",
"params": {
"max_depth": {"type": "int", "range": [3,10]},
"learning_rate": {"type": "float", "range": [0.001,0.1]}
}
},
"random_forest": {...}
}

3.2 训练流程

  1. 接收训练请求
  2. 数据预处理
  3. 自动超参优化(可选)
  4. 模型训练与评估
  5. 模型注册到MLflow

3.3 数据预处理

class DataPreprocessor:
    SUPPORTED_FORMATS = ['csv', 'parquet', 'hdf5']
    
    def __init__(self, config):
        self.pipeline = [
            ('缺失值处理', SimpleImputer()),
            ('标准化', StandardScaler()),
            ('特征选择', SelectKBest(k=20))
        ]
        
    def split_data(self, test_size=0.2, random_state=42):
        # 自动记录数据划分版本
        train, test = train_test_split(self.data, test_size=test_size)
        return train, test, self._generate_data_hash()

3.4 数据版本控制

graph LR
A[原始数据] --> B[预处理]
B --> C[训练集 v1]
B --> D[测试集 v1]
C --> E[模型训练]
D --> F[模型评估]
E --> G[元数据记录]
F --> G

3.5 硬件感知训练

class HardwareAwareTraining:
    def __init__(self):
        self.gpu_mem = self.get_gpu_memory()
        
    def auto_batch_size(self, model_type):
        # 根据GPU显存自动调整批大小
        base_sizes = {
            'DNN': 1024,
            'LSTM': 256,
            'Transformer': 64
        }
        return base_sizes[model_type] * (self.gpu_mem // 16384)  # 按16GB基准缩放

3.6 显存监控

class MemoryMonitor(Callback):
    def on_epoch_end(self, epoch, logs=None):
        used_mem = torch.cuda.memory_allocated() // 1024**2
        print(f"当前显存占用: {used_mem}MB / {TOTAL_GPU_MEM}MB")
        if used_mem > SAFE_THRESHOLD:
            self.model.stop_training = True

3.7 训练稳定性保障

# 梯度裁剪实现
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.ClipGradNorm(optimizer, max_norm=1.0)

# 早停机制
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min'
)

3.8 数据安全处理

class DataSanitizer:
    def __init__(self):
        self.sensitive_patterns = [
            r'\d{18}',       # 身份证号
            r'\d{11}'        # 手机号
        ]
        
    def clean(self, data):
        # 数据脱敏处理
        return data.replace(self.sensitive_patterns, '***')

3.9 完整流程示意图

graph TD
A[数据接入] --> B[数据质量检查]
B --> C[数据预处理]
C --> D[特征工程]
D --> E[算法选择]
E --> F[模型训练]
F --> G[模型评估]
G --> H[模型注册]
H --> I[模型部署]
I --> J[预测服务]
J --> K[效果监控]
K -->|反馈| A

3.10 新增关键模块说明

数据质量检查

class DataValidator:
    CHECKS = [
        ('缺失值比例 < 30%', lambda df: df.isna().mean() < 0.3),
        ('特征方差 > 0.01', lambda df: df.var() > 0.01),
        ('类别平衡性 > 0.1', lambda df: check_class_balance(df))
    ]
    
    def validate(self, data):
        return {check[0]: check[1](data) for check in self.CHECKS}

实验跟踪

class ExperimentTracker:
    def __init__(self):
        self.runs = {}
        
    def log_run(self, params, metrics, data_version):
        run_id = generate_uuid()
        self.runs[run_id] = {
            'timestamp': datetime.now(),
            'data_version': data_version,
            'params': params,
            'metrics': metrics
        }

模型监控

class ModelMonitor:
    def detect_drift(self, current_data, reference_data):
        # 计算数据分布差异
        psi = calculate_psi(current_data, reference_data)
        return psi > 0.25  # 触发重新训练阈值

3.11 深度学习训练优化

class LitModel(pl.LightningModule):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(64, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 10)
        )
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

3.12 Lightning 集成方案

# 训练服务集成示例
class TrainingService:
    def start_training(self, config):
        model = LitModel(config.hidden_dim)
        trainer = pl.Trainer(
            max_epochs=config.epochs,
            accelerator="auto",
            callbacks=[
                ModelCheckpoint(dirpath="./checkpoints"),
                MLFlowLogger(experiment_name=config.exp_name)
            ]
        )
        trainer.fit(model, data_module)

3.13 分布式训练支持

# 多GPU训练配置
trainer = pl.Trainer(
    devices=2,
    strategy="ddp",
    precision="16-mixed"
)

4. 部署指南

4.1 环境要求

  • Python 3.8+
  • MySQL 5.7+
  • NVIDIA驱动CUDA 11.6

4.2 离线安装步骤

# 1. 安装依赖
pip install --no-index --find-links=./offline_packages -r requirements.txt

# 2. 数据库初始化
mysql -u root -p < scripts/init_db.sql

# 3. 启动服务
docker-compose up -d

5. API接口文档

5.1 核心接口概览

类别 接口路径 方法 功能说明
数据管理 /api/data/upload POST 上传原始数据集
/api/data/versions GET 查询数据版本历史
模型训练 /api/train POST 提交训练任务
/api/train/status/{id} GET 查询训练状态
模型管理 /api/models GET 获取已注册模型列表
/api/models/{id} DELETE 删除指定模型
预测服务 /api/predict/single POST 单条实时预测
/api/predict/batch POST 批量异步预测
系统管理 /api/system/health GET 系统健康检查

5.2 训练接口详情

POST /api/train
Content-Type: application/json

{
  "algorithm": "xgboost",
  "dataset": "/data/project1.csv",
  "hpo": true,
  "hpo_config": {
    "max_trials": 50,
    "timeout": 3600
  }
}

Response:
{
  "task_id": "train_01H9Z3X6V9",
  "status_url": "/api/train/status/train_01H9Z3X6V9"
}

5.3 模型查询接口

GET /api/models?type=classification

Response:
{
  "models": [
    {
      "id": "xgb_v1",
      "type": "classification",
      "metrics": {
        "accuracy": 0.92,
        "f1": 0.89
      },
      "registered_at": "2023-08-20T14:30:00"
    }
  ]
}

5.4 预测接口安全控制

POST /api/predict/single
Content-Type: application/json
X-API-Key: your_api_key

{
  "model_id": "xgb_v1",
      "features": [0.5, 1.2, 3.4]
}

Response:
{
  "prediction": 1,
  "confidence": 0.87
}

5.5 新增实验对比接口

POST /api/experiments/compare
Content-Type: application/json

{
  "experiment_ids": ["exp001", "exp002"],
  "metrics": ["accuracy", "f1"]
}

Response:
{
  "comparison": {
    "exp001": {"accuracy": 0.92, "f1": 0.89},
    "exp002": {"accuracy": 0.89, "f1": 0.85}
  }
}

6. 注意事项

  1. 数据安全

    • 训练数据最大保留30天
    • 模型文件加密存储
  2. 性能优化

    • 单次训练数据量建议不超过1GB
    • 批量预测支持CSV文件最大100MB
  3. 扩展建议

    • 未来可扩展AutoML模块
    • 支持ONNX模型格式

6.4 硬件优化建议

  1. 显存管理策略

    • 自动批处理调整(根据数据维度动态计算)
    • 梯度累积技术累计步数≤4
  2. 训练稳定性措施

    • 梯度裁剪阈值1.0
    • 学习率自动衰减衰减率0.1
    • 最大连续失败次数3次自动终止异常任务
  3. 容错机制

    • 训练快照每小时保存checkpoint
    • 异常恢复从最近checkpoint重启

6.5 数据管理规范

  1. 输入格式要求

    • CSV文件必须包含header行
    • 数值型字段空值用NaN表示
    • 分类字段需预先编码
  2. 数据生命周期

    graph LR
    A[原始数据] --> B[预处理数据]
    B --> C[训练数据]
    C --> D[30天后删除]
    
  3. 版本控制规则

    • 数据划分随机种子固定
    • 每次预处理生成唯一数据指纹
    • 测试集数据指纹与模型版本绑定

6.6 流程完整性保障

  1. 数据质量红线

    • 缺失值超过30%的特征自动剔除
    • 零方差特征自动过滤
    • 类别不平衡数据触发警告
  2. 实验可复现性

    • 记录随机种子状态
    • 保存完整依赖版本
    • 存储数据预处理参数
  3. 模型监控机制

    • 预测结果分布监控
    • 特征漂移检测PSI指标
    • 模型性能衰减报警

6.7 依赖管理规范

  1. 版本锁定:使用pip-tools固定依赖版本
  2. 许可证审查确保所有库符合BSD/MIT/Apache许可证
  3. 离线打包:通过pip download生成离线安装包
  4. 依赖隔离使用虚拟环境venv/conda

6.8 深度学习训练规范

  1. 训练模板使用PyTorch Lightning标准模板
  2. 设备管理自动选择GPU/CPU
  3. 混合精度:默认启用fp16模式
  4. 分布式训练:支持单机多卡(需额外配置)

6.9 配置管理规范

# configs/train_config.yaml
defaults:
  - base
  - override /algorithm: xgboost

algorithm: 
  name: xgboost
  params:
    max_depth: 6
    learning_rate: 0.1

data:
  path: /data/project1.csv
  split_ratio: 0.8

logging:
  experiment_name: baseline_v1
# 配置加载实现
import yaml
from dataclasses import dataclass

@dataclass
class TrainConfig:
    algorithm: dict
    data: dict
    logging: dict

def load_config(path):
    with open(path) as f:
        raw = yaml.safe_load(f)
    return TrainConfig(**raw)

配置管理原则

  1. 配置与代码分离
  2. 支持环境差异dev/test/prod
  3. Schema验证使用Pydantic
  4. 版本控制(配置文件随代码库管理)

附录

A. 初始化脚本示例

-- scripts/init_db.sql
CREATE DATABASE IF NOT EXISTS ml_platform;
USE ml_platform;

CREATE TABLE training_jobs (
    id VARCHAR(36) PRIMARY KEY,
    algorithm VARCHAR(50) NOT NULL,
    status ENUM('pending', 'running', 'completed') DEFAULT 'pending'
);

B. 目录结构

.
├── docker-compose.yml
├── app/
│   ├── algorithms/    # 算法实现
│   ├── services/      # 核心服务
│   └── api/           # 接口定义
└── data/              # 训练数据集

完整文档已生成到项目doc目录包含详细配置示例和接口说明。需要补充其他内容请随时告知。

C. 资源监控指标

指标名称 采集方式 告警阈值
GPU显存使用率 NVIDIA-SMI API >85% 持续5分钟
训练内存占用 psutil 库监控 >70% 系统内存
训练进程存活状态 心跳检测 连续3次丢失

D. 稳定性测试用例

def test_oom_recovery():
    # 模拟显存溢出场景
    try:
        train_model(oversize_data)
    except RuntimeError as e:
        assert "CUDA out of memory" in str(e)
        assert check_rollback_success()

### E. Lightning 训练模板
```python
import pytorch_lightning as pl
from torch import nn

class LitClassifier(pl.LightningModule):
    def __init__(self, input_dim=64, hidden_dim=128):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 10)
        )
    
    def forward(self, x):
        return self.layers(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

F. 配置版本控制

# 配置变更检测
class ConfigValidator:
    SCHEMA = {
        "algorithm": {"type": dict, "required": True},
        "data": {"type": dict, "required": True}
    }

    def validate(self, config):
        for key, rules in self.SCHEMA.items():
            if rules["required"] and key not in config:
                raise ValueError(f"Missing required config: {key}")

G. 多环境配置示例

configs/
├── base.yaml
├── dev.yaml
├── test.yaml
└── prod.yaml

</rewritten_file>