增加深度学习训练规范和配置管理规范
This commit is contained in:
parent
d86a31c60e
commit
79f0a418af
@ -28,22 +28,23 @@
|
||||
| 超参优化 | Optuna |
|
||||
|
||||
### 1.2 技术栈扩展
|
||||
| 功能领域 | 开源库 | 版本 | 用途说明 |
|
||||
| ------------ | ------------ | ----- | ---------------- |
|
||||
| **基础框架** | Python | 3.8+ | 主开发语言 |
|
||||
| | FastAPI | 0.88+ | REST API开发 |
|
||||
| **数据处理** | Pandas | 1.5+ | 数据操作与分析 |
|
||||
| | NumPy | 1.23+ | 数值计算基础 |
|
||||
| **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 |
|
||||
| | XGBoost | 1.7+ | 梯度提升树模型 |
|
||||
| | LightGBM | 3.3+ | 高效梯度提升框架 |
|
||||
| **深度学习** | PyTorch | 2.0+ | 神经网络框架 |
|
||||
| | Transformers | 4.28+ | 预训练模型库 |
|
||||
| **时序分析** | Prophet | 1.1+ | 时间序列预测 |
|
||||
| **优化调参** | Optuna | 3.1+ | 超参数优化框架 |
|
||||
| **可视化** | Matplotlib | 3.7+ | 基础可视化 |
|
||||
| | Plotly | 5.14+ | 交互式可视化 |
|
||||
| **模型解释** | SHAP | 0.42+ | 模型可解释性分析 |
|
||||
| 功能领域 | 开源库 | 版本 | 用途说明 |
|
||||
| ------------ | ----------------- | ----- | ---------------- |
|
||||
| **基础框架** | Python | 3.8+ | 主开发语言 |
|
||||
| | FastAPI | 0.88+ | REST API开发 |
|
||||
| **数据处理** | Pandas | 1.5+ | 数据操作与分析 |
|
||||
| | NumPy | 1.23+ | 数值计算基础 |
|
||||
| **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 |
|
||||
| | XGBoost | 1.7+ | 梯度提升树模型 |
|
||||
| | LightGBM | 3.3+ | 高效梯度提升框架 |
|
||||
| **深度学习** | PyTorch | 2.0+ | 神经网络框架 |
|
||||
| | PyTorch Lightning | 2.0+ | 深度学习训练框架 |
|
||||
| | Transformers | 4.28+ | 预训练模型库 |
|
||||
| **时序分析** | Prophet | 1.1+ | 时间序列预测 |
|
||||
| **优化调参** | Optuna | 3.1+ | 超参数优化框架 |
|
||||
| **可视化** | Matplotlib | 3.7+ | 基础可视化 |
|
||||
| | Plotly | 5.14+ | 交互式可视化 |
|
||||
| **模型解释** | SHAP | 0.42+ | 模型可解释性分析 |
|
||||
|
||||
---
|
||||
|
||||
@ -227,6 +228,52 @@ class ModelMonitor:
|
||||
return psi > 0.25 # 触发重新训练阈值
|
||||
```
|
||||
|
||||
### 3.11 深度学习训练优化
|
||||
```python
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(64, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 10)
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
### 3.12 Lightning 集成方案
|
||||
```python
|
||||
# 训练服务集成示例
|
||||
class TrainingService:
|
||||
def start_training(self, config):
|
||||
model = LitModel(config.hidden_dim)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=config.epochs,
|
||||
accelerator="auto",
|
||||
callbacks=[
|
||||
ModelCheckpoint(dirpath="./checkpoints"),
|
||||
MLFlowLogger(experiment_name=config.exp_name)
|
||||
]
|
||||
)
|
||||
trainer.fit(model, data_module)
|
||||
```
|
||||
|
||||
### 3.13 分布式训练支持
|
||||
```python
|
||||
# 多GPU训练配置
|
||||
trainer = pl.Trainer(
|
||||
devices=2,
|
||||
strategy="ddp",
|
||||
precision="16-mixed"
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 部署指南
|
||||
@ -414,6 +461,56 @@ Response:
|
||||
3. **离线打包**:通过`pip download`生成离线安装包
|
||||
4. **依赖隔离**:使用虚拟环境(venv/conda)
|
||||
|
||||
### 6.8 深度学习训练规范
|
||||
1. **训练模板**:使用PyTorch Lightning标准模板
|
||||
2. **设备管理**:自动选择GPU/CPU
|
||||
3. **混合精度**:默认启用`fp16`模式
|
||||
4. **分布式训练**:支持单机多卡(需额外配置)
|
||||
|
||||
### 6.9 配置管理规范
|
||||
```yaml
|
||||
# configs/train_config.yaml
|
||||
defaults:
|
||||
- base
|
||||
- override /algorithm: xgboost
|
||||
|
||||
algorithm:
|
||||
name: xgboost
|
||||
params:
|
||||
max_depth: 6
|
||||
learning_rate: 0.1
|
||||
|
||||
data:
|
||||
path: /data/project1.csv
|
||||
split_ratio: 0.8
|
||||
|
||||
logging:
|
||||
experiment_name: baseline_v1
|
||||
```
|
||||
|
||||
```python
|
||||
# 配置加载实现
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
algorithm: dict
|
||||
data: dict
|
||||
logging: dict
|
||||
|
||||
def load_config(path):
|
||||
with open(path) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
return TrainConfig(**raw)
|
||||
```
|
||||
|
||||
**配置管理原则**:
|
||||
1. 配置与代码分离
|
||||
2. 支持环境差异(dev/test/prod)
|
||||
3. Schema验证(使用Pydantic)
|
||||
4. 版本控制(配置文件随代码库管理)
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
@ -459,3 +556,58 @@ def test_oom_recovery():
|
||||
except RuntimeError as e:
|
||||
assert "CUDA out of memory" in str(e)
|
||||
assert check_rollback_success()
|
||||
|
||||
### E. Lightning 训练模板
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from torch import nn
|
||||
|
||||
class LitClassifier(pl.LightningModule):
|
||||
def __init__(self, input_dim=64, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.CrossEntropyLoss()(y_hat, y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.001)
|
||||
```
|
||||
|
||||
### F. 配置版本控制
|
||||
```python
|
||||
# 配置变更检测
|
||||
class ConfigValidator:
|
||||
SCHEMA = {
|
||||
"algorithm": {"type": dict, "required": True},
|
||||
"data": {"type": dict, "required": True}
|
||||
}
|
||||
|
||||
def validate(self, config):
|
||||
for key, rules in self.SCHEMA.items():
|
||||
if rules["required"] and key not in config:
|
||||
raise ValueError(f"Missing required config: {key}")
|
||||
```
|
||||
|
||||
### G. 多环境配置示例
|
||||
```
|
||||
configs/
|
||||
├── base.yaml
|
||||
├── dev.yaml
|
||||
├── test.yaml
|
||||
└── prod.yaml
|
||||
```
|
||||
|
||||
</rewritten_file>
|
||||
Loading…
Reference in New Issue
Block a user