增加深度学习训练规范和配置管理规范

This commit is contained in:
Tian jianyong 2025-02-08 16:25:56 +08:00
parent d86a31c60e
commit 79f0a418af

View File

@ -28,22 +28,23 @@
| 超参优化 | Optuna |
### 1.2 技术栈扩展
| 功能领域 | 开源库 | 版本 | 用途说明 |
| ------------ | ------------ | ----- | ---------------- |
| **基础框架** | Python | 3.8+ | 主开发语言 |
| | FastAPI | 0.88+ | REST API开发 |
| **数据处理** | Pandas | 1.5+ | 数据操作与分析 |
| | NumPy | 1.23+ | 数值计算基础 |
| **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 |
| | XGBoost | 1.7+ | 梯度提升树模型 |
| | LightGBM | 3.3+ | 高效梯度提升框架 |
| **深度学习** | PyTorch | 2.0+ | 神经网络框架 |
| | Transformers | 4.28+ | 预训练模型库 |
| **时序分析** | Prophet | 1.1+ | 时间序列预测 |
| **优化调参** | Optuna | 3.1+ | 超参数优化框架 |
| **可视化** | Matplotlib | 3.7+ | 基础可视化 |
| | Plotly | 5.14+ | 交互式可视化 |
| **模型解释** | SHAP | 0.42+ | 模型可解释性分析 |
| 功能领域 | 开源库 | 版本 | 用途说明 |
| ------------ | ----------------- | ----- | ---------------- |
| **基础框架** | Python | 3.8+ | 主开发语言 |
| | FastAPI | 0.88+ | REST API开发 |
| **数据处理** | Pandas | 1.5+ | 数据操作与分析 |
| | NumPy | 1.23+ | 数值计算基础 |
| **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 |
| | XGBoost | 1.7+ | 梯度提升树模型 |
| | LightGBM | 3.3+ | 高效梯度提升框架 |
| **深度学习** | PyTorch | 2.0+ | 神经网络框架 |
| | PyTorch Lightning | 2.0+ | 深度学习训练框架 |
| | Transformers | 4.28+ | 预训练模型库 |
| **时序分析** | Prophet | 1.1+ | 时间序列预测 |
| **优化调参** | Optuna | 3.1+ | 超参数优化框架 |
| **可视化** | Matplotlib | 3.7+ | 基础可视化 |
| | Plotly | 5.14+ | 交互式可视化 |
| **模型解释** | SHAP | 0.42+ | 模型可解释性分析 |
---
@ -227,6 +228,52 @@ class ModelMonitor:
return psi > 0.25 # 触发重新训练阈值
```
### 3.11 深度学习训练优化
```python
class LitModel(pl.LightningModule):
def __init__(self, hidden_dim=128):
super().__init__()
self.model = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 10)
)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
```
### 3.12 Lightning 集成方案
```python
# 训练服务集成示例
class TrainingService:
def start_training(self, config):
model = LitModel(config.hidden_dim)
trainer = pl.Trainer(
max_epochs=config.epochs,
accelerator="auto",
callbacks=[
ModelCheckpoint(dirpath="./checkpoints"),
MLFlowLogger(experiment_name=config.exp_name)
]
)
trainer.fit(model, data_module)
```
### 3.13 分布式训练支持
```python
# 多GPU训练配置
trainer = pl.Trainer(
devices=2,
strategy="ddp",
precision="16-mixed"
)
```
---
## 4. 部署指南
@ -414,6 +461,56 @@ Response:
3. **离线打包**:通过`pip download`生成离线安装包
4. **依赖隔离**使用虚拟环境venv/conda
### 6.8 深度学习训练规范
1. **训练模板**使用PyTorch Lightning标准模板
2. **设备管理**自动选择GPU/CPU
3. **混合精度**:默认启用`fp16`模式
4. **分布式训练**:支持单机多卡(需额外配置)
### 6.9 配置管理规范
```yaml
# configs/train_config.yaml
defaults:
- base
- override /algorithm: xgboost
algorithm:
name: xgboost
params:
max_depth: 6
learning_rate: 0.1
data:
path: /data/project1.csv
split_ratio: 0.8
logging:
experiment_name: baseline_v1
```
```python
# 配置加载实现
import yaml
from dataclasses import dataclass
@dataclass
class TrainConfig:
algorithm: dict
data: dict
logging: dict
def load_config(path):
with open(path) as f:
raw = yaml.safe_load(f)
return TrainConfig(**raw)
```
**配置管理原则**
1. 配置与代码分离
2. 支持环境差异dev/test/prod
3. Schema验证使用Pydantic
4. 版本控制(配置文件随代码库管理)
---
## 附录
@ -459,3 +556,58 @@ def test_oom_recovery():
except RuntimeError as e:
assert "CUDA out of memory" in str(e)
assert check_rollback_success()
### E. Lightning 训练模板
```python
import pytorch_lightning as pl
from torch import nn
class LitClassifier(pl.LightningModule):
def __init__(self, input_dim=64, hidden_dim=128):
super().__init__()
self.save_hyperparameters()
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 10)
)
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
```
### F. 配置版本控制
```python
# 配置变更检测
class ConfigValidator:
SCHEMA = {
"algorithm": {"type": dict, "required": True},
"data": {"type": dict, "required": True}
}
def validate(self, config):
for key, rules in self.SCHEMA.items():
if rules["required"] and key not in config:
raise ValueError(f"Missing required config: {key}")
```
### G. 多环境配置示例
```
configs/
├── base.yaml
├── dev.yaml
├── test.yaml
└── prod.yaml
```
</rewritten_file>