MLPlatform/doc/ml_platform_spec.md

613 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 内网机器学习平台技术方案
## 目录
1. 项目概述
2. 系统架构
3. 核心功能
4. 部署指南
5. API接口文档
6. 注意事项
---
## 1. 项目概述
### 1.1 目标
为科研单位提供内网机器学习平台,支持:
- 算法选择与模型训练
- 超参数自动优化
- 模型评估与部署
- 批量预测服务
### 1.2 技术栈
| 组件 | 技术选型 |
| -------- | -------------- |
| Web框架 | FastAPI |
| 任务队列 | Celery + Redis |
| 数据存储 | MySQL |
| 模型管理 | MLflow |
| 超参优化 | Optuna |
### 1.2 技术栈扩展
| 功能领域 | 开源库 | 版本 | 用途说明 |
| ------------ | ----------------- | ----- | ---------------- |
| **基础框架** | Python | 3.8+ | 主开发语言 |
| | FastAPI | 0.88+ | REST API开发 |
| **数据处理** | Pandas | 1.5+ | 数据操作与分析 |
| | NumPy | 1.23+ | 数值计算基础 |
| **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 |
| | XGBoost | 1.7+ | 梯度提升树模型 |
| | LightGBM | 3.3+ | 高效梯度提升框架 |
| **深度学习** | PyTorch | 2.0+ | 神经网络框架 |
| | PyTorch Lightning | 2.0+ | 深度学习训练框架 |
| | Transformers | 4.28+ | 预训练模型库 |
| **时序分析** | Prophet | 1.1+ | 时间序列预测 |
| **优化调参** | Optuna | 3.1+ | 超参数优化框架 |
| **可视化** | Matplotlib | 3.7+ | 基础可视化 |
| | Plotly | 5.14+ | 交互式可视化 |
| **模型解释** | SHAP | 0.42+ | 模型可解释性分析 |
---
## 2. 系统架构
```mermaid
graph TD
A[Web UI/API] --> B[FastAPI]
B --> C[算法仓库]
B --> D[MySQL]
B --> E[Celery Worker]
E --> F[GPU训练]
C --> G[MLflow模型仓库]
G --> H[批量预测服务]
```
---
## 3. 核心功能
### 3.1 算法管理
```python
预置算法示例
ALGORITHMS = {
"xgboost": {
"type": "classification",
"params": {
"max_depth": {"type": "int", "range": [3,10]},
"learning_rate": {"type": "float", "range": [0.001,0.1]}
}
},
"random_forest": {...}
}
```
### 3.2 训练流程
1. 接收训练请求
2. 数据预处理
3. 自动超参优化(可选)
4. 模型训练与评估
5. 模型注册到MLflow
### 3.3 数据预处理
```python
class DataPreprocessor:
SUPPORTED_FORMATS = ['csv', 'parquet', 'hdf5']
def __init__(self, config):
self.pipeline = [
('缺失值处理', SimpleImputer()),
('标准化', StandardScaler()),
('特征选择', SelectKBest(k=20))
]
def split_data(self, test_size=0.2, random_state=42):
# 自动记录数据划分版本
train, test = train_test_split(self.data, test_size=test_size)
return train, test, self._generate_data_hash()
```
### 3.4 数据版本控制
```mermaid
graph LR
A[原始数据] --> B[预处理]
B --> C[训练集 v1]
B --> D[测试集 v1]
C --> E[模型训练]
D --> F[模型评估]
E --> G[元数据记录]
F --> G
```
### 3.5 硬件感知训练
```python
class HardwareAwareTraining:
def __init__(self):
self.gpu_mem = self.get_gpu_memory()
def auto_batch_size(self, model_type):
# 根据GPU显存自动调整批大小
base_sizes = {
'DNN': 1024,
'LSTM': 256,
'Transformer': 64
}
return base_sizes[model_type] * (self.gpu_mem // 16384) # 按16GB基准缩放
```
### 3.6 显存监控
```python
class MemoryMonitor(Callback):
def on_epoch_end(self, epoch, logs=None):
used_mem = torch.cuda.memory_allocated() // 1024**2
print(f"当前显存占用: {used_mem}MB / {TOTAL_GPU_MEM}MB")
if used_mem > SAFE_THRESHOLD:
self.model.stop_training = True
```
### 3.7 训练稳定性保障
```python
# 梯度裁剪实现
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.ClipGradNorm(optimizer, max_norm=1.0)
# 早停机制
early_stop = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min'
)
```
### 3.8 数据安全处理
```python
class DataSanitizer:
def __init__(self):
self.sensitive_patterns = [
r'\d{18}', # 身份证号
r'\d{11}' # 手机号
]
def clean(self, data):
# 数据脱敏处理
return data.replace(self.sensitive_patterns, '***')
```
### 3.9 完整流程示意图
```mermaid
graph TD
A[数据接入] --> B[数据质量检查]
B --> C[数据预处理]
C --> D[特征工程]
D --> E[算法选择]
E --> F[模型训练]
F --> G[模型评估]
G --> H[模型注册]
H --> I[模型部署]
I --> J[预测服务]
J --> K[效果监控]
K -->|反馈| A
```
### 3.10 新增关键模块说明
**数据质量检查**
```python
class DataValidator:
CHECKS = [
('缺失值比例 < 30%', lambda df: df.isna().mean() < 0.3),
('特征方差 > 0.01', lambda df: df.var() > 0.01),
('类别平衡性 > 0.1', lambda df: check_class_balance(df))
]
def validate(self, data):
return {check[0]: check[1](data) for check in self.CHECKS}
```
**实验跟踪**
```python
class ExperimentTracker:
def __init__(self):
self.runs = {}
def log_run(self, params, metrics, data_version):
run_id = generate_uuid()
self.runs[run_id] = {
'timestamp': datetime.now(),
'data_version': data_version,
'params': params,
'metrics': metrics
}
```
**模型监控**
```python
class ModelMonitor:
def detect_drift(self, current_data, reference_data):
# 计算数据分布差异
psi = calculate_psi(current_data, reference_data)
return psi > 0.25 # 触发重新训练阈值
```
### 3.11 深度学习训练优化
```python
class LitModel(pl.LightningModule):
def __init__(self, hidden_dim=128):
super().__init__()
self.model = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 10)
)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
```
### 3.12 Lightning 集成方案
```python
# 训练服务集成示例
class TrainingService:
def start_training(self, config):
model = LitModel(config.hidden_dim)
trainer = pl.Trainer(
max_epochs=config.epochs,
accelerator="auto",
callbacks=[
ModelCheckpoint(dirpath="./checkpoints"),
MLFlowLogger(experiment_name=config.exp_name)
]
)
trainer.fit(model, data_module)
```
### 3.13 分布式训练支持
```python
# 多GPU训练配置
trainer = pl.Trainer(
devices=2,
strategy="ddp",
precision="16-mixed"
)
```
---
## 4. 部署指南
### 4.1 环境要求
- Python 3.8+
- MySQL 5.7+
- NVIDIA驱动CUDA 11.6
### 4.2 离线安装步骤
```bash:doc/ml_platform_spec.md
# 1. 安装依赖
pip install --no-index --find-links=./offline_packages -r requirements.txt
# 2. 数据库初始化
mysql -u root -p < scripts/init_db.sql
# 3. 启动服务
docker-compose up -d
```
---
## 5. API接口文档
### 5.1 核心接口概览
| 类别 | 接口路径 | 方法 | 功能说明 |
| -------- | ---------------------- | ------ | ------------------ |
| 数据管理 | /api/data/upload | POST | 上传原始数据集 |
| | /api/data/versions | GET | 查询数据版本历史 |
| 模型训练 | /api/train | POST | 提交训练任务 |
| | /api/train/status/{id} | GET | 查询训练状态 |
| 模型管理 | /api/models | GET | 获取已注册模型列表 |
| | /api/models/{id} | DELETE | 删除指定模型 |
| 预测服务 | /api/predict/single | POST | 单条实时预测 |
| | /api/predict/batch | POST | 批量异步预测 |
| 系统管理 | /api/system/health | GET | 系统健康检查 |
### 5.2 训练接口详情
```http
POST /api/train
Content-Type: application/json
{
"algorithm": "xgboost",
"dataset": "/data/project1.csv",
"hpo": true,
"hpo_config": {
"max_trials": 50,
"timeout": 3600
}
}
Response:
{
"task_id": "train_01H9Z3X6V9",
"status_url": "/api/train/status/train_01H9Z3X6V9"
}
```
### 5.3 模型查询接口
```http
GET /api/models?type=classification
Response:
{
"models": [
{
"id": "xgb_v1",
"type": "classification",
"metrics": {
"accuracy": 0.92,
"f1": 0.89
},
"registered_at": "2023-08-20T14:30:00"
}
]
}
```
### 5.4 预测接口安全控制
```http
POST /api/predict/single
Content-Type: application/json
X-API-Key: your_api_key
{
"model_id": "xgb_v1",
"features": [0.5, 1.2, 3.4]
}
Response:
{
"prediction": 1,
"confidence": 0.87
}
```
### 5.5 新增实验对比接口
```http
POST /api/experiments/compare
Content-Type: application/json
{
"experiment_ids": ["exp001", "exp002"],
"metrics": ["accuracy", "f1"]
}
Response:
{
"comparison": {
"exp001": {"accuracy": 0.92, "f1": 0.89},
"exp002": {"accuracy": 0.89, "f1": 0.85}
}
}
```
---
## 6. 注意事项
1. 数据安全
- 训练数据最大保留30天
- 模型文件加密存储
2. 性能优化
- 单次训练数据量建议不超过1GB
- 批量预测支持CSV文件最大100MB
3. 扩展建议
- 未来可扩展AutoML模块
- 支持ONNX模型格式
### 6.4 硬件优化建议
1. **显存管理策略**
- 自动批处理调整根据数据维度动态计算
- 梯度累积技术累计步数4
2. **训练稳定性措施**
- 梯度裁剪阈值1.0
- 学习率自动衰减衰减率0.1
- 最大连续失败次数3次自动终止异常任务
3. **容错机制**
- 训练快照每小时保存checkpoint
- 异常恢复从最近checkpoint重启
### 6.5 数据管理规范
1. **输入格式要求**
- CSV文件必须包含header行
- 数值型字段空值用NaN表示
- 分类字段需预先编码
2. **数据生命周期**
```mermaid
graph LR
A[原始数据] --> B[预处理数据]
B --> C[训练数据]
C --> D[30天后删除]
```
3. **版本控制规则**
- 数据划分随机种子固定
- 每次预处理生成唯一数据指纹
- 测试集数据指纹与模型版本绑定
### 6.6 流程完整性保障
1. **数据质量红线**
- 缺失值超过30%的特征自动剔除
- 零方差特征自动过滤
- 类别不平衡数据触发警告
2. **实验可复现性**
- 记录随机种子状态
- 保存完整依赖版本
- 存储数据预处理参数
3. **模型监控机制**
- 预测结果分布监控
- 特征漂移检测PSI指标
- 模型性能衰减报警
### 6.7 依赖管理规范
1. **版本锁定**:使用`pip-tools`固定依赖版本
2. **许可证审查**确保所有库符合BSD/MIT/Apache许可证
3. **离线打包**:通过`pip download`生成离线安装包
4. **依赖隔离**使用虚拟环境venv/conda
### 6.8 深度学习训练规范
1. **训练模板**使用PyTorch Lightning标准模板
2. **设备管理**自动选择GPU/CPU
3. **混合精度**:默认启用`fp16`模式
4. **分布式训练**:支持单机多卡(需额外配置)
### 6.9 配置管理规范
```yaml
# configs/train_config.yaml
defaults:
- base
- override /algorithm: xgboost
algorithm:
name: xgboost
params:
max_depth: 6
learning_rate: 0.1
data:
path: /data/project1.csv
split_ratio: 0.8
logging:
experiment_name: baseline_v1
```
```python
# 配置加载实现
import yaml
from dataclasses import dataclass
@dataclass
class TrainConfig:
algorithm: dict
data: dict
logging: dict
def load_config(path):
with open(path) as f:
raw = yaml.safe_load(f)
return TrainConfig(**raw)
```
**配置管理原则**
1. 配置与代码分离
2. 支持环境差异dev/test/prod
3. Schema验证使用Pydantic
4. 版本控制(配置文件随代码库管理)
---
## 附录
### A. 初始化脚本示例
```sql
-- scripts/init_db.sql
CREATE DATABASE IF NOT EXISTS ml_platform;
USE ml_platform;
CREATE TABLE training_jobs (
id VARCHAR(36) PRIMARY KEY,
algorithm VARCHAR(50) NOT NULL,
status ENUM('pending', 'running', 'completed') DEFAULT 'pending'
);
```
### B. 目录结构
```
.
├── docker-compose.yml
├── app/
│ ├── algorithms/ # 算法实现
│ ├── services/ # 核心服务
│ └── api/ # 接口定义
└── data/ # 训练数据集
```
完整文档已生成到项目doc目录包含详细配置示例和接口说明。需要补充其他内容请随时告知。
### C. 资源监控指标
| 指标名称 | 采集方式 | 告警阈值 |
| ---------------- | -------------- | -------------- |
| GPU显存使用率 | NVIDIA-SMI API | >85% 持续5分钟 |
| 训练内存占用 | psutil 库监控 | >70% 系统内存 |
| 训练进程存活状态 | 心跳检测 | 连续3次丢失 |
### D. 稳定性测试用例
```python
def test_oom_recovery():
# 模拟显存溢出场景
try:
train_model(oversize_data)
except RuntimeError as e:
assert "CUDA out of memory" in str(e)
assert check_rollback_success()
### E. Lightning 训练模板
```python
import pytorch_lightning as pl
from torch import nn
class LitClassifier(pl.LightningModule):
def __init__(self, input_dim=64, hidden_dim=128):
super().__init__()
self.save_hyperparameters()
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 10)
)
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
```
### F. 配置版本控制
```python
# 配置变更检测
class ConfigValidator:
SCHEMA = {
"algorithm": {"type": dict, "required": True},
"data": {"type": dict, "required": True}
}
def validate(self, config):
for key, rules in self.SCHEMA.items():
if rules["required"] and key not in config:
raise ValueError(f"Missing required config: {key}")
```
### G. 多环境配置示例
```
configs/
├── base.yaml
├── dev.yaml
├── test.yaml
└── prod.yaml
```
</rewritten_file>