# 内网机器学习平台技术方案 ## 目录 1. 项目概述 2. 系统架构 3. 核心功能 4. 部署指南 5. API接口文档 6. 注意事项 --- ## 1. 项目概述 ### 1.1 目标 为科研单位提供内网机器学习平台,支持: - 算法选择与模型训练 - 超参数自动优化 - 模型评估与部署 - 批量预测服务 ### 1.2 技术栈 | 组件 | 技术选型 | | -------- | -------------- | | Web框架 | FastAPI | | 任务队列 | Celery + Redis | | 数据存储 | MySQL | | 模型管理 | MLflow | | 超参优化 | Optuna | ### 1.2 技术栈扩展 | 功能领域 | 开源库 | 版本 | 用途说明 | | ------------ | ----------------- | ----- | ---------------- | | **基础框架** | Python | 3.8+ | 主开发语言 | | | FastAPI | 0.88+ | REST API开发 | | **数据处理** | Pandas | 1.5+ | 数据操作与分析 | | | NumPy | 1.23+ | 数值计算基础 | | **传统ML** | Scikit-learn | 1.2+ | 经典机器学习算法 | | | XGBoost | 1.7+ | 梯度提升树模型 | | | LightGBM | 3.3+ | 高效梯度提升框架 | | **深度学习** | PyTorch | 2.0+ | 神经网络框架 | | | PyTorch Lightning | 2.0+ | 深度学习训练框架 | | | Transformers | 4.28+ | 预训练模型库 | | **时序分析** | Prophet | 1.1+ | 时间序列预测 | | **优化调参** | Optuna | 3.1+ | 超参数优化框架 | | **可视化** | Matplotlib | 3.7+ | 基础可视化 | | | Plotly | 5.14+ | 交互式可视化 | | **模型解释** | SHAP | 0.42+ | 模型可解释性分析 | --- ## 2. 系统架构 ```mermaid graph TD A[Web UI/API] --> B[FastAPI] B --> C[算法仓库] B --> D[MySQL] B --> E[Celery Worker] E --> F[GPU训练] C --> G[MLflow模型仓库] G --> H[批量预测服务] ``` --- ## 3. 核心功能 ### 3.1 算法管理 ```python 预置算法示例 ALGORITHMS = { "xgboost": { "type": "classification", "params": { "max_depth": {"type": "int", "range": [3,10]}, "learning_rate": {"type": "float", "range": [0.001,0.1]} } }, "random_forest": {...} } ``` ### 3.2 训练流程 1. 接收训练请求 2. 数据预处理 3. 自动超参优化(可选) 4. 模型训练与评估 5. 模型注册到MLflow ### 3.3 数据预处理 ```python class DataPreprocessor: SUPPORTED_FORMATS = ['csv', 'parquet', 'hdf5'] def __init__(self, config): self.pipeline = [ ('缺失值处理', SimpleImputer()), ('标准化', StandardScaler()), ('特征选择', SelectKBest(k=20)) ] def split_data(self, test_size=0.2, random_state=42): # 自动记录数据划分版本 train, test = train_test_split(self.data, test_size=test_size) return train, test, self._generate_data_hash() ``` ### 3.4 数据版本控制 ```mermaid graph LR A[原始数据] --> B[预处理] B --> C[训练集 v1] B --> D[测试集 v1] C --> E[模型训练] D --> F[模型评估] E --> G[元数据记录] F --> G ``` ### 3.5 硬件感知训练 ```python class HardwareAwareTraining: def __init__(self): self.gpu_mem = self.get_gpu_memory() def auto_batch_size(self, model_type): # 根据GPU显存自动调整批大小 base_sizes = { 'DNN': 1024, 'LSTM': 256, 'Transformer': 64 } return base_sizes[model_type] * (self.gpu_mem // 16384) # 按16GB基准缩放 ``` ### 3.6 显存监控 ```python class MemoryMonitor(Callback): def on_epoch_end(self, epoch, logs=None): used_mem = torch.cuda.memory_allocated() // 1024**2 print(f"当前显存占用: {used_mem}MB / {TOTAL_GPU_MEM}MB") if used_mem > SAFE_THRESHOLD: self.model.stop_training = True ``` ### 3.7 训练稳定性保障 ```python # 梯度裁剪实现 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) optimizer = torch.optim.ClipGradNorm(optimizer, max_norm=1.0) # 早停机制 early_stop = EarlyStopping( monitor='val_loss', patience=5, mode='min' ) ``` ### 3.8 数据安全处理 ```python class DataSanitizer: def __init__(self): self.sensitive_patterns = [ r'\d{18}', # 身份证号 r'\d{11}' # 手机号 ] def clean(self, data): # 数据脱敏处理 return data.replace(self.sensitive_patterns, '***') ``` ### 3.9 完整流程示意图 ```mermaid graph TD A[数据接入] --> B[数据质量检查] B --> C[数据预处理] C --> D[特征工程] D --> E[算法选择] E --> F[模型训练] F --> G[模型评估] G --> H[模型注册] H --> I[模型部署] I --> J[预测服务] J --> K[效果监控] K -->|反馈| A ``` ### 3.10 新增关键模块说明 **数据质量检查**: ```python class DataValidator: CHECKS = [ ('缺失值比例 < 30%', lambda df: df.isna().mean() < 0.3), ('特征方差 > 0.01', lambda df: df.var() > 0.01), ('类别平衡性 > 0.1', lambda df: check_class_balance(df)) ] def validate(self, data): return {check[0]: check[1](data) for check in self.CHECKS} ``` **实验跟踪**: ```python class ExperimentTracker: def __init__(self): self.runs = {} def log_run(self, params, metrics, data_version): run_id = generate_uuid() self.runs[run_id] = { 'timestamp': datetime.now(), 'data_version': data_version, 'params': params, 'metrics': metrics } ``` **模型监控**: ```python class ModelMonitor: def detect_drift(self, current_data, reference_data): # 计算数据分布差异 psi = calculate_psi(current_data, reference_data) return psi > 0.25 # 触发重新训练阈值 ``` ### 3.11 深度学习训练优化 ```python class LitModel(pl.LightningModule): def __init__(self, hidden_dim=128): super().__init__() self.model = nn.Sequential( nn.Linear(64, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 10) ) def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) self.log('train_loss', loss) return loss ``` ### 3.12 Lightning 集成方案 ```python # 训练服务集成示例 class TrainingService: def start_training(self, config): model = LitModel(config.hidden_dim) trainer = pl.Trainer( max_epochs=config.epochs, accelerator="auto", callbacks=[ ModelCheckpoint(dirpath="./checkpoints"), MLFlowLogger(experiment_name=config.exp_name) ] ) trainer.fit(model, data_module) ``` ### 3.13 分布式训练支持 ```python # 多GPU训练配置 trainer = pl.Trainer( devices=2, strategy="ddp", precision="16-mixed" ) ``` --- ## 4. 部署指南 ### 4.1 环境要求 - Python 3.8+ - MySQL 5.7+ - NVIDIA驱动(CUDA 11.6) ### 4.2 离线安装步骤 ```bash:doc/ml_platform_spec.md # 1. 安装依赖 pip install --no-index --find-links=./offline_packages -r requirements.txt # 2. 数据库初始化 mysql -u root -p < scripts/init_db.sql # 3. 启动服务 docker-compose up -d ``` --- ## 5. API接口文档 ### 5.1 核心接口概览 | 类别 | 接口路径 | 方法 | 功能说明 | | -------- | ---------------------- | ------ | ------------------ | | 数据管理 | /api/data/upload | POST | 上传原始数据集 | | | /api/data/versions | GET | 查询数据版本历史 | | 模型训练 | /api/train | POST | 提交训练任务 | | | /api/train/status/{id} | GET | 查询训练状态 | | 模型管理 | /api/models | GET | 获取已注册模型列表 | | | /api/models/{id} | DELETE | 删除指定模型 | | 预测服务 | /api/predict/single | POST | 单条实时预测 | | | /api/predict/batch | POST | 批量异步预测 | | 系统管理 | /api/system/health | GET | 系统健康检查 | ### 5.2 训练接口详情 ```http POST /api/train Content-Type: application/json { "algorithm": "xgboost", "dataset": "/data/project1.csv", "hpo": true, "hpo_config": { "max_trials": 50, "timeout": 3600 } } Response: { "task_id": "train_01H9Z3X6V9", "status_url": "/api/train/status/train_01H9Z3X6V9" } ``` ### 5.3 模型查询接口 ```http GET /api/models?type=classification Response: { "models": [ { "id": "xgb_v1", "type": "classification", "metrics": { "accuracy": 0.92, "f1": 0.89 }, "registered_at": "2023-08-20T14:30:00" } ] } ``` ### 5.4 预测接口安全控制 ```http POST /api/predict/single Content-Type: application/json X-API-Key: your_api_key { "model_id": "xgb_v1", "features": [0.5, 1.2, 3.4] } Response: { "prediction": 1, "confidence": 0.87 } ``` ### 5.5 新增实验对比接口 ```http POST /api/experiments/compare Content-Type: application/json { "experiment_ids": ["exp001", "exp002"], "metrics": ["accuracy", "f1"] } Response: { "comparison": { "exp001": {"accuracy": 0.92, "f1": 0.89}, "exp002": {"accuracy": 0.89, "f1": 0.85} } } ``` --- ## 6. 注意事项 1. 数据安全 - 训练数据最大保留30天 - 模型文件加密存储 2. 性能优化 - 单次训练数据量建议不超过1GB - 批量预测支持CSV文件(最大100MB) 3. 扩展建议 - 未来可扩展AutoML模块 - 支持ONNX模型格式 ### 6.4 硬件优化建议 1. **显存管理策略** - 自动批处理调整(根据数据维度动态计算) - 梯度累积技术(累计步数≤4) 2. **训练稳定性措施** - 梯度裁剪阈值:1.0 - 学习率自动衰减(衰减率0.1) - 最大连续失败次数:3次(自动终止异常任务) 3. **容错机制** - 训练快照(每小时保存checkpoint) - 异常恢复(从最近checkpoint重启) ### 6.5 数据管理规范 1. **输入格式要求** - CSV文件必须包含header行 - 数值型字段空值用NaN表示 - 分类字段需预先编码 2. **数据生命周期** ```mermaid graph LR A[原始数据] --> B[预处理数据] B --> C[训练数据] C --> D[30天后删除] ``` 3. **版本控制规则** - 数据划分随机种子固定 - 每次预处理生成唯一数据指纹 - 测试集数据指纹与模型版本绑定 ### 6.6 流程完整性保障 1. **数据质量红线** - 缺失值超过30%的特征自动剔除 - 零方差特征自动过滤 - 类别不平衡数据触发警告 2. **实验可复现性** - 记录随机种子状态 - 保存完整依赖版本 - 存储数据预处理参数 3. **模型监控机制** - 预测结果分布监控 - 特征漂移检测(PSI指标) - 模型性能衰减报警 ### 6.7 依赖管理规范 1. **版本锁定**:使用`pip-tools`固定依赖版本 2. **许可证审查**:确保所有库符合BSD/MIT/Apache许可证 3. **离线打包**:通过`pip download`生成离线安装包 4. **依赖隔离**:使用虚拟环境(venv/conda) ### 6.8 深度学习训练规范 1. **训练模板**:使用PyTorch Lightning标准模板 2. **设备管理**:自动选择GPU/CPU 3. **混合精度**:默认启用`fp16`模式 4. **分布式训练**:支持单机多卡(需额外配置) ### 6.9 配置管理规范 ```yaml # configs/train_config.yaml defaults: - base - override /algorithm: xgboost algorithm: name: xgboost params: max_depth: 6 learning_rate: 0.1 data: path: /data/project1.csv split_ratio: 0.8 logging: experiment_name: baseline_v1 ``` ```python # 配置加载实现 import yaml from dataclasses import dataclass @dataclass class TrainConfig: algorithm: dict data: dict logging: dict def load_config(path): with open(path) as f: raw = yaml.safe_load(f) return TrainConfig(**raw) ``` **配置管理原则**: 1. 配置与代码分离 2. 支持环境差异(dev/test/prod) 3. Schema验证(使用Pydantic) 4. 版本控制(配置文件随代码库管理) --- ## 附录 ### A. 初始化脚本示例 ```sql -- scripts/init_db.sql CREATE DATABASE IF NOT EXISTS ml_platform; USE ml_platform; CREATE TABLE training_jobs ( id VARCHAR(36) PRIMARY KEY, algorithm VARCHAR(50) NOT NULL, status ENUM('pending', 'running', 'completed') DEFAULT 'pending' ); ``` ### B. 目录结构 ``` . ├── docker-compose.yml ├── app/ │ ├── algorithms/ # 算法实现 │ ├── services/ # 核心服务 │ └── api/ # 接口定义 └── data/ # 训练数据集 ``` 完整文档已生成到项目doc目录,包含详细配置示例和接口说明。需要补充其他内容请随时告知。 ### C. 资源监控指标 | 指标名称 | 采集方式 | 告警阈值 | | ---------------- | -------------- | -------------- | | GPU显存使用率 | NVIDIA-SMI API | >85% 持续5分钟 | | 训练内存占用 | psutil 库监控 | >70% 系统内存 | | 训练进程存活状态 | 心跳检测 | 连续3次丢失 | ### D. 稳定性测试用例 ```python def test_oom_recovery(): # 模拟显存溢出场景 try: train_model(oversize_data) except RuntimeError as e: assert "CUDA out of memory" in str(e) assert check_rollback_success() ### E. Lightning 训练模板 ```python import pytorch_lightning as pl from torch import nn class LitClassifier(pl.LightningModule): def __init__(self, input_dim=64, hidden_dim=128): super().__init__() self.save_hyperparameters() self.layers = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 10) ) def forward(self, x): return self.layers(x) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("train_loss", loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001) ``` ### F. 配置版本控制 ```python # 配置变更检测 class ConfigValidator: SCHEMA = { "algorithm": {"type": dict, "required": True}, "data": {"type": dict, "required": True} } def validate(self, config): for key, rules in self.SCHEMA.items(): if rules["required"] and key not in config: raise ValueError(f"Missing required config: {key}") ``` ### G. 多环境配置示例 ``` configs/ ├── base.yaml ├── dev.yaml ├── test.yaml └── prod.yaml ```