613 lines
15 KiB
Markdown
613 lines
15 KiB
Markdown
# 内网机器学习平台技术方案
|
||
|
||
## 目录
|
||
1. 项目概述
|
||
2. 系统架构
|
||
3. 核心功能
|
||
4. 部署指南
|
||
5. API接口文档
|
||
6. 注意事项
|
||
|
||
---
|
||
|
||
## 1. 项目概述
|
||
### 1.1 目标
|
||
为科研单位提供内网机器学习平台,支持:
|
||
- 算法选择与模型训练
|
||
- 超参数自动优化
|
||
- 模型评估与部署
|
||
- 批量预测服务
|
||
|
||
### 1.2 技术栈
|
||
| 组件 | 技术选型 |
|
||
| -------- | -------------- |
|
||
| Web框架 | FastAPI |
|
||
| 任务队列 | Celery + Redis |
|
||
| 数据存储 | MySQL |
|
||
| 模型管理 | MLflow |
|
||
| 超参优化 | Optuna |
|
||
|
||
### 1.2 技术栈扩展
|
||
| 功能领域 | 开源库 | 版本 | 用途说明 |
|
||
| ------------ | ----------------- | ----- | ---------------- |
|
||
| **基础框架** | Python | 3.8+ | 主开发语言 |
|
||
| | FastAPI | 0.88+ | REST API开发 |
|
||
| **数据处理** | Pandas | 1.5+ | 数据操作与分析 |
|
||
| | NumPy | 1.23+ | 数值计算基础 |
|
||
| **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 |
|
||
| | XGBoost | 1.7+ | 梯度提升树模型 |
|
||
| | LightGBM | 3.3+ | 高效梯度提升框架 |
|
||
| **深度学习** | PyTorch | 2.0+ | 神经网络框架 |
|
||
| | PyTorch Lightning | 2.0+ | 深度学习训练框架 |
|
||
| | Transformers | 4.28+ | 预训练模型库 |
|
||
| **时序分析** | Prophet | 1.1+ | 时间序列预测 |
|
||
| **优化调参** | Optuna | 3.1+ | 超参数优化框架 |
|
||
| **可视化** | Matplotlib | 3.7+ | 基础可视化 |
|
||
| | Plotly | 5.14+ | 交互式可视化 |
|
||
| **模型解释** | SHAP | 0.42+ | 模型可解释性分析 |
|
||
|
||
---
|
||
|
||
## 2. 系统架构
|
||
|
||
```mermaid
|
||
graph TD
|
||
A[Web UI/API] --> B[FastAPI]
|
||
B --> C[算法仓库]
|
||
B --> D[MySQL]
|
||
B --> E[Celery Worker]
|
||
E --> F[GPU训练]
|
||
C --> G[MLflow模型仓库]
|
||
G --> H[批量预测服务]
|
||
```
|
||
|
||
---
|
||
|
||
## 3. 核心功能
|
||
|
||
### 3.1 算法管理
|
||
|
||
```python
|
||
预置算法示例
|
||
ALGORITHMS = {
|
||
"xgboost": {
|
||
"type": "classification",
|
||
"params": {
|
||
"max_depth": {"type": "int", "range": [3,10]},
|
||
"learning_rate": {"type": "float", "range": [0.001,0.1]}
|
||
}
|
||
},
|
||
"random_forest": {...}
|
||
}
|
||
```
|
||
|
||
### 3.2 训练流程
|
||
1. 接收训练请求
|
||
2. 数据预处理
|
||
3. 自动超参优化(可选)
|
||
4. 模型训练与评估
|
||
5. 模型注册到MLflow
|
||
|
||
### 3.3 数据预处理
|
||
```python
|
||
class DataPreprocessor:
|
||
SUPPORTED_FORMATS = ['csv', 'parquet', 'hdf5']
|
||
|
||
def __init__(self, config):
|
||
self.pipeline = [
|
||
('缺失值处理', SimpleImputer()),
|
||
('标准化', StandardScaler()),
|
||
('特征选择', SelectKBest(k=20))
|
||
]
|
||
|
||
def split_data(self, test_size=0.2, random_state=42):
|
||
# 自动记录数据划分版本
|
||
train, test = train_test_split(self.data, test_size=test_size)
|
||
return train, test, self._generate_data_hash()
|
||
```
|
||
|
||
### 3.4 数据版本控制
|
||
```mermaid
|
||
graph LR
|
||
A[原始数据] --> B[预处理]
|
||
B --> C[训练集 v1]
|
||
B --> D[测试集 v1]
|
||
C --> E[模型训练]
|
||
D --> F[模型评估]
|
||
E --> G[元数据记录]
|
||
F --> G
|
||
```
|
||
|
||
### 3.5 硬件感知训练
|
||
```python
|
||
class HardwareAwareTraining:
|
||
def __init__(self):
|
||
self.gpu_mem = self.get_gpu_memory()
|
||
|
||
def auto_batch_size(self, model_type):
|
||
# 根据GPU显存自动调整批大小
|
||
base_sizes = {
|
||
'DNN': 1024,
|
||
'LSTM': 256,
|
||
'Transformer': 64
|
||
}
|
||
return base_sizes[model_type] * (self.gpu_mem // 16384) # 按16GB基准缩放
|
||
```
|
||
|
||
### 3.6 显存监控
|
||
```python
|
||
class MemoryMonitor(Callback):
|
||
def on_epoch_end(self, epoch, logs=None):
|
||
used_mem = torch.cuda.memory_allocated() // 1024**2
|
||
print(f"当前显存占用: {used_mem}MB / {TOTAL_GPU_MEM}MB")
|
||
if used_mem > SAFE_THRESHOLD:
|
||
self.model.stop_training = True
|
||
```
|
||
|
||
### 3.7 训练稳定性保障
|
||
```python
|
||
# 梯度裁剪实现
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
optimizer = torch.optim.ClipGradNorm(optimizer, max_norm=1.0)
|
||
|
||
# 早停机制
|
||
early_stop = EarlyStopping(
|
||
monitor='val_loss',
|
||
patience=5,
|
||
mode='min'
|
||
)
|
||
```
|
||
|
||
### 3.8 数据安全处理
|
||
```python
|
||
class DataSanitizer:
|
||
def __init__(self):
|
||
self.sensitive_patterns = [
|
||
r'\d{18}', # 身份证号
|
||
r'\d{11}' # 手机号
|
||
]
|
||
|
||
def clean(self, data):
|
||
# 数据脱敏处理
|
||
return data.replace(self.sensitive_patterns, '***')
|
||
```
|
||
|
||
### 3.9 完整流程示意图
|
||
```mermaid
|
||
graph TD
|
||
A[数据接入] --> B[数据质量检查]
|
||
B --> C[数据预处理]
|
||
C --> D[特征工程]
|
||
D --> E[算法选择]
|
||
E --> F[模型训练]
|
||
F --> G[模型评估]
|
||
G --> H[模型注册]
|
||
H --> I[模型部署]
|
||
I --> J[预测服务]
|
||
J --> K[效果监控]
|
||
K -->|反馈| A
|
||
```
|
||
|
||
### 3.10 新增关键模块说明
|
||
|
||
**数据质量检查**:
|
||
```python
|
||
class DataValidator:
|
||
CHECKS = [
|
||
('缺失值比例 < 30%', lambda df: df.isna().mean() < 0.3),
|
||
('特征方差 > 0.01', lambda df: df.var() > 0.01),
|
||
('类别平衡性 > 0.1', lambda df: check_class_balance(df))
|
||
]
|
||
|
||
def validate(self, data):
|
||
return {check[0]: check[1](data) for check in self.CHECKS}
|
||
```
|
||
|
||
**实验跟踪**:
|
||
```python
|
||
class ExperimentTracker:
|
||
def __init__(self):
|
||
self.runs = {}
|
||
|
||
def log_run(self, params, metrics, data_version):
|
||
run_id = generate_uuid()
|
||
self.runs[run_id] = {
|
||
'timestamp': datetime.now(),
|
||
'data_version': data_version,
|
||
'params': params,
|
||
'metrics': metrics
|
||
}
|
||
```
|
||
|
||
**模型监控**:
|
||
```python
|
||
class ModelMonitor:
|
||
def detect_drift(self, current_data, reference_data):
|
||
# 计算数据分布差异
|
||
psi = calculate_psi(current_data, reference_data)
|
||
return psi > 0.25 # 触发重新训练阈值
|
||
```
|
||
|
||
### 3.11 深度学习训练优化
|
||
```python
|
||
class LitModel(pl.LightningModule):
|
||
def __init__(self, hidden_dim=128):
|
||
super().__init__()
|
||
self.model = nn.Sequential(
|
||
nn.Linear(64, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, 10)
|
||
)
|
||
|
||
def training_step(self, batch, batch_idx):
|
||
x, y = batch
|
||
y_hat = self.model(x)
|
||
loss = F.cross_entropy(y_hat, y)
|
||
self.log('train_loss', loss)
|
||
return loss
|
||
```
|
||
|
||
### 3.12 Lightning 集成方案
|
||
```python
|
||
# 训练服务集成示例
|
||
class TrainingService:
|
||
def start_training(self, config):
|
||
model = LitModel(config.hidden_dim)
|
||
trainer = pl.Trainer(
|
||
max_epochs=config.epochs,
|
||
accelerator="auto",
|
||
callbacks=[
|
||
ModelCheckpoint(dirpath="./checkpoints"),
|
||
MLFlowLogger(experiment_name=config.exp_name)
|
||
]
|
||
)
|
||
trainer.fit(model, data_module)
|
||
```
|
||
|
||
### 3.13 分布式训练支持
|
||
```python
|
||
# 多GPU训练配置
|
||
trainer = pl.Trainer(
|
||
devices=2,
|
||
strategy="ddp",
|
||
precision="16-mixed"
|
||
)
|
||
```
|
||
|
||
---
|
||
|
||
## 4. 部署指南
|
||
### 4.1 环境要求
|
||
- Python 3.8+
|
||
- MySQL 5.7+
|
||
- NVIDIA驱动(CUDA 11.6)
|
||
|
||
### 4.2 离线安装步骤
|
||
|
||
```bash:doc/ml_platform_spec.md
|
||
# 1. 安装依赖
|
||
pip install --no-index --find-links=./offline_packages -r requirements.txt
|
||
|
||
# 2. 数据库初始化
|
||
mysql -u root -p < scripts/init_db.sql
|
||
|
||
# 3. 启动服务
|
||
docker-compose up -d
|
||
```
|
||
|
||
---
|
||
|
||
## 5. API接口文档
|
||
|
||
### 5.1 核心接口概览
|
||
| 类别 | 接口路径 | 方法 | 功能说明 |
|
||
| -------- | ---------------------- | ------ | ------------------ |
|
||
| 数据管理 | /api/data/upload | POST | 上传原始数据集 |
|
||
| | /api/data/versions | GET | 查询数据版本历史 |
|
||
| 模型训练 | /api/train | POST | 提交训练任务 |
|
||
| | /api/train/status/{id} | GET | 查询训练状态 |
|
||
| 模型管理 | /api/models | GET | 获取已注册模型列表 |
|
||
| | /api/models/{id} | DELETE | 删除指定模型 |
|
||
| 预测服务 | /api/predict/single | POST | 单条实时预测 |
|
||
| | /api/predict/batch | POST | 批量异步预测 |
|
||
| 系统管理 | /api/system/health | GET | 系统健康检查 |
|
||
|
||
### 5.2 训练接口详情
|
||
```http
|
||
POST /api/train
|
||
Content-Type: application/json
|
||
|
||
{
|
||
"algorithm": "xgboost",
|
||
"dataset": "/data/project1.csv",
|
||
"hpo": true,
|
||
"hpo_config": {
|
||
"max_trials": 50,
|
||
"timeout": 3600
|
||
}
|
||
}
|
||
|
||
Response:
|
||
{
|
||
"task_id": "train_01H9Z3X6V9",
|
||
"status_url": "/api/train/status/train_01H9Z3X6V9"
|
||
}
|
||
```
|
||
|
||
### 5.3 模型查询接口
|
||
```http
|
||
GET /api/models?type=classification
|
||
|
||
Response:
|
||
{
|
||
"models": [
|
||
{
|
||
"id": "xgb_v1",
|
||
"type": "classification",
|
||
"metrics": {
|
||
"accuracy": 0.92,
|
||
"f1": 0.89
|
||
},
|
||
"registered_at": "2023-08-20T14:30:00"
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
### 5.4 预测接口安全控制
|
||
```http
|
||
POST /api/predict/single
|
||
Content-Type: application/json
|
||
X-API-Key: your_api_key
|
||
|
||
{
|
||
"model_id": "xgb_v1",
|
||
"features": [0.5, 1.2, 3.4]
|
||
}
|
||
|
||
Response:
|
||
{
|
||
"prediction": 1,
|
||
"confidence": 0.87
|
||
}
|
||
```
|
||
|
||
### 5.5 新增实验对比接口
|
||
```http
|
||
POST /api/experiments/compare
|
||
Content-Type: application/json
|
||
|
||
{
|
||
"experiment_ids": ["exp001", "exp002"],
|
||
"metrics": ["accuracy", "f1"]
|
||
}
|
||
|
||
Response:
|
||
{
|
||
"comparison": {
|
||
"exp001": {"accuracy": 0.92, "f1": 0.89},
|
||
"exp002": {"accuracy": 0.89, "f1": 0.85}
|
||
}
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## 6. 注意事项
|
||
1. 数据安全
|
||
- 训练数据最大保留30天
|
||
- 模型文件加密存储
|
||
|
||
2. 性能优化
|
||
- 单次训练数据量建议不超过1GB
|
||
- 批量预测支持CSV文件(最大100MB)
|
||
|
||
3. 扩展建议
|
||
- 未来可扩展AutoML模块
|
||
- 支持ONNX模型格式
|
||
|
||
### 6.4 硬件优化建议
|
||
1. **显存管理策略**
|
||
- 自动批处理调整(根据数据维度动态计算)
|
||
- 梯度累积技术(累计步数≤4)
|
||
|
||
2. **训练稳定性措施**
|
||
- 梯度裁剪阈值:1.0
|
||
- 学习率自动衰减(衰减率0.1)
|
||
- 最大连续失败次数:3次(自动终止异常任务)
|
||
|
||
3. **容错机制**
|
||
- 训练快照(每小时保存checkpoint)
|
||
- 异常恢复(从最近checkpoint重启)
|
||
|
||
### 6.5 数据管理规范
|
||
1. **输入格式要求**
|
||
- CSV文件必须包含header行
|
||
- 数值型字段空值用NaN表示
|
||
- 分类字段需预先编码
|
||
|
||
2. **数据生命周期**
|
||
```mermaid
|
||
graph LR
|
||
A[原始数据] --> B[预处理数据]
|
||
B --> C[训练数据]
|
||
C --> D[30天后删除]
|
||
```
|
||
|
||
3. **版本控制规则**
|
||
- 数据划分随机种子固定
|
||
- 每次预处理生成唯一数据指纹
|
||
- 测试集数据指纹与模型版本绑定
|
||
|
||
### 6.6 流程完整性保障
|
||
1. **数据质量红线**
|
||
- 缺失值超过30%的特征自动剔除
|
||
- 零方差特征自动过滤
|
||
- 类别不平衡数据触发警告
|
||
|
||
2. **实验可复现性**
|
||
- 记录随机种子状态
|
||
- 保存完整依赖版本
|
||
- 存储数据预处理参数
|
||
|
||
3. **模型监控机制**
|
||
- 预测结果分布监控
|
||
- 特征漂移检测(PSI指标)
|
||
- 模型性能衰减报警
|
||
|
||
### 6.7 依赖管理规范
|
||
1. **版本锁定**:使用`pip-tools`固定依赖版本
|
||
2. **许可证审查**:确保所有库符合BSD/MIT/Apache许可证
|
||
3. **离线打包**:通过`pip download`生成离线安装包
|
||
4. **依赖隔离**:使用虚拟环境(venv/conda)
|
||
|
||
### 6.8 深度学习训练规范
|
||
1. **训练模板**:使用PyTorch Lightning标准模板
|
||
2. **设备管理**:自动选择GPU/CPU
|
||
3. **混合精度**:默认启用`fp16`模式
|
||
4. **分布式训练**:支持单机多卡(需额外配置)
|
||
|
||
### 6.9 配置管理规范
|
||
```yaml
|
||
# configs/train_config.yaml
|
||
defaults:
|
||
- base
|
||
- override /algorithm: xgboost
|
||
|
||
algorithm:
|
||
name: xgboost
|
||
params:
|
||
max_depth: 6
|
||
learning_rate: 0.1
|
||
|
||
data:
|
||
path: /data/project1.csv
|
||
split_ratio: 0.8
|
||
|
||
logging:
|
||
experiment_name: baseline_v1
|
||
```
|
||
|
||
```python
|
||
# 配置加载实现
|
||
import yaml
|
||
from dataclasses import dataclass
|
||
|
||
@dataclass
|
||
class TrainConfig:
|
||
algorithm: dict
|
||
data: dict
|
||
logging: dict
|
||
|
||
def load_config(path):
|
||
with open(path) as f:
|
||
raw = yaml.safe_load(f)
|
||
return TrainConfig(**raw)
|
||
```
|
||
|
||
**配置管理原则**:
|
||
1. 配置与代码分离
|
||
2. 支持环境差异(dev/test/prod)
|
||
3. Schema验证(使用Pydantic)
|
||
4. 版本控制(配置文件随代码库管理)
|
||
|
||
---
|
||
|
||
## 附录
|
||
### A. 初始化脚本示例
|
||
```sql
|
||
-- scripts/init_db.sql
|
||
CREATE DATABASE IF NOT EXISTS ml_platform;
|
||
USE ml_platform;
|
||
|
||
CREATE TABLE training_jobs (
|
||
id VARCHAR(36) PRIMARY KEY,
|
||
algorithm VARCHAR(50) NOT NULL,
|
||
status ENUM('pending', 'running', 'completed') DEFAULT 'pending'
|
||
);
|
||
```
|
||
|
||
### B. 目录结构
|
||
```
|
||
.
|
||
├── docker-compose.yml
|
||
├── app/
|
||
│ ├── algorithms/ # 算法实现
|
||
│ ├── services/ # 核心服务
|
||
│ └── api/ # 接口定义
|
||
└── data/ # 训练数据集
|
||
```
|
||
|
||
完整文档已生成到项目doc目录,包含详细配置示例和接口说明。需要补充其他内容请随时告知。
|
||
|
||
### C. 资源监控指标
|
||
| 指标名称 | 采集方式 | 告警阈值 |
|
||
| ---------------- | -------------- | -------------- |
|
||
| GPU显存使用率 | NVIDIA-SMI API | >85% 持续5分钟 |
|
||
| 训练内存占用 | psutil 库监控 | >70% 系统内存 |
|
||
| 训练进程存活状态 | 心跳检测 | 连续3次丢失 |
|
||
|
||
### D. 稳定性测试用例
|
||
```python
|
||
def test_oom_recovery():
|
||
# 模拟显存溢出场景
|
||
try:
|
||
train_model(oversize_data)
|
||
except RuntimeError as e:
|
||
assert "CUDA out of memory" in str(e)
|
||
assert check_rollback_success()
|
||
|
||
### E. Lightning 训练模板
|
||
```python
|
||
import pytorch_lightning as pl
|
||
from torch import nn
|
||
|
||
class LitClassifier(pl.LightningModule):
|
||
def __init__(self, input_dim=64, hidden_dim=128):
|
||
super().__init__()
|
||
self.save_hyperparameters()
|
||
self.layers = nn.Sequential(
|
||
nn.Linear(input_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, 10)
|
||
)
|
||
|
||
def forward(self, x):
|
||
return self.layers(x)
|
||
|
||
def training_step(self, batch, batch_idx):
|
||
x, y = batch
|
||
y_hat = self(x)
|
||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||
self.log("train_loss", loss)
|
||
return loss
|
||
|
||
def configure_optimizers(self):
|
||
return torch.optim.Adam(self.parameters(), lr=0.001)
|
||
```
|
||
|
||
### F. 配置版本控制
|
||
```python
|
||
# 配置变更检测
|
||
class ConfigValidator:
|
||
SCHEMA = {
|
||
"algorithm": {"type": dict, "required": True},
|
||
"data": {"type": dict, "required": True}
|
||
}
|
||
|
||
def validate(self, config):
|
||
for key, rules in self.SCHEMA.items():
|
||
if rules["required"] and key not in config:
|
||
raise ValueError(f"Missing required config: {key}")
|
||
```
|
||
|
||
### G. 多环境配置示例
|
||
```
|
||
configs/
|
||
├── base.yaml
|
||
├── dev.yaml
|
||
├── test.yaml
|
||
└── prod.yaml
|
||
```
|
||
|
||
</rewritten_file> |