创建项目,实现需求文档和设计方案初稿
This commit is contained in:
commit
6a7b7b7da6
437
doc/ml_platform_spec.md
Normal file
437
doc/ml_platform_spec.md
Normal file
@ -0,0 +1,437 @@
|
||||
# 内网机器学习平台技术方案
|
||||
|
||||
## 目录
|
||||
1. 项目概述
|
||||
2. 系统架构
|
||||
3. 核心功能
|
||||
4. 部署指南
|
||||
5. API接口文档
|
||||
6. 注意事项
|
||||
|
||||
---
|
||||
|
||||
## 1. 项目概述
|
||||
### 1.1 目标
|
||||
为科研单位提供内网机器学习平台,支持:
|
||||
- 算法选择与模型训练
|
||||
- 超参数自动优化
|
||||
- 模型评估与部署
|
||||
- 批量预测服务
|
||||
|
||||
### 1.2 技术栈
|
||||
| 组件 | 技术选型 |
|
||||
| -------- | -------------- |
|
||||
| Web框架 | FastAPI |
|
||||
| 任务队列 | Celery + Redis |
|
||||
| 数据存储 | MySQL |
|
||||
| 模型管理 | MLflow |
|
||||
| 超参优化 | Optuna |
|
||||
|
||||
---
|
||||
|
||||
## 2. 系统架构
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Web UI/API] --> B[FastAPI]
|
||||
B --> C[算法仓库]
|
||||
B --> D[MySQL]
|
||||
B --> E[Celery Worker]
|
||||
E --> F[GPU训练]
|
||||
C --> G[MLflow模型仓库]
|
||||
G --> H[批量预测服务]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 核心功能
|
||||
|
||||
### 3.1 算法管理
|
||||
|
||||
```python
|
||||
预置算法示例
|
||||
ALGORITHMS = {
|
||||
"xgboost": {
|
||||
"type": "classification",
|
||||
"params": {
|
||||
"max_depth": {"type": "int", "range": [3,10]},
|
||||
"learning_rate": {"type": "float", "range": [0.001,0.1]}
|
||||
}
|
||||
},
|
||||
"random_forest": {...}
|
||||
}
|
||||
```
|
||||
|
||||
### 3.2 训练流程
|
||||
1. 接收训练请求
|
||||
2. 数据预处理
|
||||
3. 自动超参优化(可选)
|
||||
4. 模型训练与评估
|
||||
5. 模型注册到MLflow
|
||||
|
||||
### 3.3 数据预处理
|
||||
```python
|
||||
class DataPreprocessor:
|
||||
SUPPORTED_FORMATS = ['csv', 'parquet', 'hdf5']
|
||||
|
||||
def __init__(self, config):
|
||||
self.pipeline = [
|
||||
('缺失值处理', SimpleImputer()),
|
||||
('标准化', StandardScaler()),
|
||||
('特征选择', SelectKBest(k=20))
|
||||
]
|
||||
|
||||
def split_data(self, test_size=0.2, random_state=42):
|
||||
# 自动记录数据划分版本
|
||||
train, test = train_test_split(self.data, test_size=test_size)
|
||||
return train, test, self._generate_data_hash()
|
||||
```
|
||||
|
||||
### 3.4 数据版本控制
|
||||
```mermaid
|
||||
graph LR
|
||||
A[原始数据] --> B[预处理]
|
||||
B --> C[训练集 v1]
|
||||
B --> D[测试集 v1]
|
||||
C --> E[模型训练]
|
||||
D --> F[模型评估]
|
||||
E --> G[元数据记录]
|
||||
F --> G
|
||||
```
|
||||
|
||||
### 3.5 硬件感知训练
|
||||
```python
|
||||
class HardwareAwareTraining:
|
||||
def __init__(self):
|
||||
self.gpu_mem = self.get_gpu_memory()
|
||||
|
||||
def auto_batch_size(self, model_type):
|
||||
# 根据GPU显存自动调整批大小
|
||||
base_sizes = {
|
||||
'DNN': 1024,
|
||||
'LSTM': 256,
|
||||
'Transformer': 64
|
||||
}
|
||||
return base_sizes[model_type] * (self.gpu_mem // 16384) # 按16GB基准缩放
|
||||
```
|
||||
|
||||
### 3.6 显存监控
|
||||
```python
|
||||
class MemoryMonitor(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
used_mem = torch.cuda.memory_allocated() // 1024**2
|
||||
print(f"当前显存占用: {used_mem}MB / {TOTAL_GPU_MEM}MB")
|
||||
if used_mem > SAFE_THRESHOLD:
|
||||
self.model.stop_training = True
|
||||
```
|
||||
|
||||
### 3.7 训练稳定性保障
|
||||
```python
|
||||
# 梯度裁剪实现
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
optimizer = torch.optim.ClipGradNorm(optimizer, max_norm=1.0)
|
||||
|
||||
# 早停机制
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5,
|
||||
mode='min'
|
||||
)
|
||||
```
|
||||
|
||||
### 3.8 数据安全处理
|
||||
```python
|
||||
class DataSanitizer:
|
||||
def __init__(self):
|
||||
self.sensitive_patterns = [
|
||||
r'\d{18}', # 身份证号
|
||||
r'\d{11}' # 手机号
|
||||
]
|
||||
|
||||
def clean(self, data):
|
||||
# 数据脱敏处理
|
||||
return data.replace(self.sensitive_patterns, '***')
|
||||
```
|
||||
|
||||
### 3.9 完整流程示意图
|
||||
```mermaid
|
||||
graph TD
|
||||
A[数据接入] --> B[数据质量检查]
|
||||
B --> C[数据预处理]
|
||||
C --> D[特征工程]
|
||||
D --> E[算法选择]
|
||||
E --> F[模型训练]
|
||||
F --> G[模型评估]
|
||||
G --> H[模型注册]
|
||||
H --> I[模型部署]
|
||||
I --> J[预测服务]
|
||||
J --> K[效果监控]
|
||||
K -->|反馈| A
|
||||
```
|
||||
|
||||
### 3.10 新增关键模块说明
|
||||
|
||||
**数据质量检查**:
|
||||
```python
|
||||
class DataValidator:
|
||||
CHECKS = [
|
||||
('缺失值比例 < 30%', lambda df: df.isna().mean() < 0.3),
|
||||
('特征方差 > 0.01', lambda df: df.var() > 0.01),
|
||||
('类别平衡性 > 0.1', lambda df: check_class_balance(df))
|
||||
]
|
||||
|
||||
def validate(self, data):
|
||||
return {check[0]: check[1](data) for check in self.CHECKS}
|
||||
```
|
||||
|
||||
**实验跟踪**:
|
||||
```python
|
||||
class ExperimentTracker:
|
||||
def __init__(self):
|
||||
self.runs = {}
|
||||
|
||||
def log_run(self, params, metrics, data_version):
|
||||
run_id = generate_uuid()
|
||||
self.runs[run_id] = {
|
||||
'timestamp': datetime.now(),
|
||||
'data_version': data_version,
|
||||
'params': params,
|
||||
'metrics': metrics
|
||||
}
|
||||
```
|
||||
|
||||
**模型监控**:
|
||||
```python
|
||||
class ModelMonitor:
|
||||
def detect_drift(self, current_data, reference_data):
|
||||
# 计算数据分布差异
|
||||
psi = calculate_psi(current_data, reference_data)
|
||||
return psi > 0.25 # 触发重新训练阈值
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 部署指南
|
||||
### 4.1 环境要求
|
||||
- Python 3.8+
|
||||
- MySQL 5.7+
|
||||
- NVIDIA驱动(CUDA 11.6)
|
||||
|
||||
### 4.2 离线安装步骤
|
||||
|
||||
```bash:doc/ml_platform_spec.md
|
||||
# 1. 安装依赖
|
||||
pip install --no-index --find-links=./offline_packages -r requirements.txt
|
||||
|
||||
# 2. 数据库初始化
|
||||
mysql -u root -p < scripts/init_db.sql
|
||||
|
||||
# 3. 启动服务
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. API接口文档
|
||||
|
||||
### 5.1 核心接口概览
|
||||
| 类别 | 接口路径 | 方法 | 功能说明 |
|
||||
| -------- | ---------------------- | ------ | ------------------ |
|
||||
| 数据管理 | /api/data/upload | POST | 上传原始数据集 |
|
||||
| | /api/data/versions | GET | 查询数据版本历史 |
|
||||
| 模型训练 | /api/train | POST | 提交训练任务 |
|
||||
| | /api/train/status/{id} | GET | 查询训练状态 |
|
||||
| 模型管理 | /api/models | GET | 获取已注册模型列表 |
|
||||
| | /api/models/{id} | DELETE | 删除指定模型 |
|
||||
| 预测服务 | /api/predict/single | POST | 单条实时预测 |
|
||||
| | /api/predict/batch | POST | 批量异步预测 |
|
||||
| 系统管理 | /api/system/health | GET | 系统健康检查 |
|
||||
|
||||
### 5.2 训练接口详情
|
||||
```http
|
||||
POST /api/train
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"algorithm": "xgboost",
|
||||
"dataset": "/data/project1.csv",
|
||||
"hpo": true,
|
||||
"hpo_config": {
|
||||
"max_trials": 50,
|
||||
"timeout": 3600
|
||||
}
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"task_id": "train_01H9Z3X6V9",
|
||||
"status_url": "/api/train/status/train_01H9Z3X6V9"
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 模型查询接口
|
||||
```http
|
||||
GET /api/models?type=classification
|
||||
|
||||
Response:
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"id": "xgb_v1",
|
||||
"type": "classification",
|
||||
"metrics": {
|
||||
"accuracy": 0.92,
|
||||
"f1": 0.89
|
||||
},
|
||||
"registered_at": "2023-08-20T14:30:00"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 5.4 预测接口安全控制
|
||||
```http
|
||||
POST /api/predict/single
|
||||
Content-Type: application/json
|
||||
X-API-Key: your_api_key
|
||||
|
||||
{
|
||||
"model_id": "xgb_v1",
|
||||
"features": [0.5, 1.2, 3.4]
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"prediction": 1,
|
||||
"confidence": 0.87
|
||||
}
|
||||
```
|
||||
|
||||
### 5.5 新增实验对比接口
|
||||
```http
|
||||
POST /api/experiments/compare
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"experiment_ids": ["exp001", "exp002"],
|
||||
"metrics": ["accuracy", "f1"]
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"comparison": {
|
||||
"exp001": {"accuracy": 0.92, "f1": 0.89},
|
||||
"exp002": {"accuracy": 0.89, "f1": 0.85}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 注意事项
|
||||
1. 数据安全
|
||||
- 训练数据最大保留30天
|
||||
- 模型文件加密存储
|
||||
|
||||
2. 性能优化
|
||||
- 单次训练数据量建议不超过1GB
|
||||
- 批量预测支持CSV文件(最大100MB)
|
||||
|
||||
3. 扩展建议
|
||||
- 未来可扩展AutoML模块
|
||||
- 支持ONNX模型格式
|
||||
|
||||
### 6.4 硬件优化建议
|
||||
1. **显存管理策略**
|
||||
- 自动批处理调整(根据数据维度动态计算)
|
||||
- 梯度累积技术(累计步数≤4)
|
||||
|
||||
2. **训练稳定性措施**
|
||||
- 梯度裁剪阈值:1.0
|
||||
- 学习率自动衰减(衰减率0.1)
|
||||
- 最大连续失败次数:3次(自动终止异常任务)
|
||||
|
||||
3. **容错机制**
|
||||
- 训练快照(每小时保存checkpoint)
|
||||
- 异常恢复(从最近checkpoint重启)
|
||||
|
||||
### 6.5 数据管理规范
|
||||
1. **输入格式要求**
|
||||
- CSV文件必须包含header行
|
||||
- 数值型字段空值用NaN表示
|
||||
- 分类字段需预先编码
|
||||
|
||||
2. **数据生命周期**
|
||||
```mermaid
|
||||
graph LR
|
||||
A[原始数据] --> B[预处理数据]
|
||||
B --> C[训练数据]
|
||||
C --> D[30天后删除]
|
||||
```
|
||||
|
||||
3. **版本控制规则**
|
||||
- 数据划分随机种子固定
|
||||
- 每次预处理生成唯一数据指纹
|
||||
- 测试集数据指纹与模型版本绑定
|
||||
|
||||
### 6.6 流程完整性保障
|
||||
1. **数据质量红线**
|
||||
- 缺失值超过30%的特征自动剔除
|
||||
- 零方差特征自动过滤
|
||||
- 类别不平衡数据触发警告
|
||||
|
||||
2. **实验可复现性**
|
||||
- 记录随机种子状态
|
||||
- 保存完整依赖版本
|
||||
- 存储数据预处理参数
|
||||
|
||||
3. **模型监控机制**
|
||||
- 预测结果分布监控
|
||||
- 特征漂移检测(PSI指标)
|
||||
- 模型性能衰减报警
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
### A. 初始化脚本示例
|
||||
```sql
|
||||
-- scripts/init_db.sql
|
||||
CREATE DATABASE IF NOT EXISTS ml_platform;
|
||||
USE ml_platform;
|
||||
|
||||
CREATE TABLE training_jobs (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
algorithm VARCHAR(50) NOT NULL,
|
||||
status ENUM('pending', 'running', 'completed') DEFAULT 'pending'
|
||||
);
|
||||
```
|
||||
|
||||
### B. 目录结构
|
||||
```
|
||||
.
|
||||
├── docker-compose.yml
|
||||
├── app/
|
||||
│ ├── algorithms/ # 算法实现
|
||||
│ ├── services/ # 核心服务
|
||||
│ └── api/ # 接口定义
|
||||
└── data/ # 训练数据集
|
||||
```
|
||||
|
||||
完整文档已生成到项目doc目录,包含详细配置示例和接口说明。需要补充其他内容请随时告知。
|
||||
|
||||
### C. 资源监控指标
|
||||
| 指标名称 | 采集方式 | 告警阈值 |
|
||||
| ---------------- | -------------- | -------------- |
|
||||
| GPU显存使用率 | NVIDIA-SMI API | >85% 持续5分钟 |
|
||||
| 训练内存占用 | psutil 库监控 | >70% 系统内存 |
|
||||
| 训练进程存活状态 | 心跳检测 | 连续3次丢失 |
|
||||
|
||||
### D. 稳定性测试用例
|
||||
```python
|
||||
def test_oom_recovery():
|
||||
# 模拟显存溢出场景
|
||||
try:
|
||||
train_model(oversize_data)
|
||||
except RuntimeError as e:
|
||||
assert "CUDA out of memory" in str(e)
|
||||
assert check_rollback_success()
|
||||
247
doc/requirements.md
Normal file
247
doc/requirements.md
Normal file
@ -0,0 +1,247 @@
|
||||
# 内网机器学习平台需求文档
|
||||
|
||||
**版本**:1.0
|
||||
**最后更新**:2025-02-07
|
||||
**参与人员**:AI研发团队
|
||||
|
||||
---
|
||||
|
||||
## 一、项目背景
|
||||
为某科研单位建设内网机器学习平台,满足以下需求:
|
||||
1. 完全离线部署,无互联网依赖
|
||||
2. 支持常见机器学习算法快速实验
|
||||
3. 与现有科研数据管理系统集成
|
||||
4. 简化模型开发到部署流程
|
||||
|
||||
---
|
||||
|
||||
## 二、项目目标
|
||||
| 维度 | 具体要求 |
|
||||
| ---------- | ------------------------------------------------------- |
|
||||
| 功能目标 | 实现算法训练、超参优化、模型评估、批量预测全流程支持 |
|
||||
| 性能目标 | 单任务训练数据量≤1GB,批量预测响应时间≤30s(100MB数据) |
|
||||
| 安全目标 | 数据保留周期≤30天,模型文件AES加密存储 |
|
||||
| 扩展性目标 | 支持后续扩展3-5种新算法 |
|
||||
|
||||
---
|
||||
|
||||
## 三、功能需求
|
||||
### 3.1 核心功能
|
||||
```mermaid
|
||||
graph LR
|
||||
A[数据接入] --> B[算法选择]
|
||||
B --> C[训练配置]
|
||||
C --> D[模型评估]
|
||||
D --> E[模型部署]
|
||||
E --> F[预测服务]
|
||||
```
|
||||
|
||||
### 3.2 详细需求
|
||||
#### 3.2.1 算法管理
|
||||
- 预置算法清单(按任务类型分类)
|
||||
|
||||
**分类算法**:
|
||||
```python
|
||||
classification_algorithms = [
|
||||
"逻辑回归",
|
||||
"支持向量机(SVM)",
|
||||
"随机森林",
|
||||
"XGBoost",
|
||||
"LightGBM",
|
||||
"朴素贝叶斯",
|
||||
"K近邻(KNN)",
|
||||
"多层感知机(MLP)",
|
||||
"梯度提升决策树(GBDT)",
|
||||
"深度神经网络(DNN)"
|
||||
]
|
||||
```
|
||||
|
||||
**回归算法**:
|
||||
```python
|
||||
regression_algorithms = [
|
||||
"线性回归",
|
||||
"岭回归",
|
||||
"Lasso回归",
|
||||
"弹性网络",
|
||||
"支持向量回归(SVR)",
|
||||
"随机森林回归",
|
||||
"XGBoost回归",
|
||||
"LightGBM回归",
|
||||
"多层感知机回归"
|
||||
]
|
||||
```
|
||||
|
||||
**聚类算法**:
|
||||
```python
|
||||
clustering_algorithms = [
|
||||
"K均值(K-Means)",
|
||||
"层次聚类",
|
||||
"DBSCAN",
|
||||
"高斯混合模型(GMM)",
|
||||
"谱聚类"
|
||||
]
|
||||
```
|
||||
|
||||
**时间序列**:
|
||||
```python
|
||||
timeseries_algorithms = [
|
||||
"ARIMA",
|
||||
"Prophet",
|
||||
"LSTM(基础版)",
|
||||
"GRU(基础版)"
|
||||
]
|
||||
```
|
||||
|
||||
**降维算法**:
|
||||
```python
|
||||
dimensionality_reduction = [
|
||||
"主成分分析(PCA)",
|
||||
"线性判别分析(LDA)",
|
||||
"t-SNE",
|
||||
"UMAP"
|
||||
]
|
||||
```
|
||||
|
||||
**推荐算法**(新增类别):
|
||||
```python
|
||||
recommendation_algorithms = [
|
||||
"深度协同过滤(DeepCF)",
|
||||
"神经矩阵分解(NeuMF)",
|
||||
"Wide & Deep"
|
||||
]
|
||||
```
|
||||
|
||||
### 二、算法参数配置示例
|
||||
```python
|
||||
algorithm_params = {
|
||||
"XGBoost": {
|
||||
"max_depth": {"type": "int", "range": [3, 10]},
|
||||
"learning_rate": {"type": "float", "range": [0.001, 0.3]}
|
||||
},
|
||||
"LSTM": {
|
||||
"hidden_units": {"type": "int", "range": [16, 128]},
|
||||
"num_layers": {"type": "int", "range": [1, 3]}
|
||||
},
|
||||
"ResNet": {
|
||||
"num_blocks": {"type": "int", "range": [2, 4]},
|
||||
"dropout_rate": {"type": "float", "range": [0.0, 0.5]}
|
||||
},
|
||||
"Transformer": {
|
||||
"num_heads": {"type": "int", "range": [2, 8]},
|
||||
"ff_dim": {"type": "int", "range": [64, 256]}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 3.2.2 训练服务
|
||||
- 必需功能
|
||||
- 异步训练任务提交
|
||||
- 训练进度查询
|
||||
- 自动超参优化(HPO)
|
||||
- 训练日志记录
|
||||
|
||||
#### 3.2.3 预测服务
|
||||
- 接口规范
|
||||
```http
|
||||
POST /api/predict
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
Form Data:
|
||||
- model_id: 已部署模型ID
|
||||
- file: 预测数据文件(CSV格式)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 四、非功能需求
|
||||
### 4.1 性能指标
|
||||
| 场景 | 指标要求 |
|
||||
| ------------------- | ---------------------- |
|
||||
| 模型训练(1GB数据) | ≤30分钟(使用GPU加速) |
|
||||
| 批量预测(10万条) | ≤1分钟 |
|
||||
| API响应延迟 | ≤500ms |
|
||||
|
||||
### 4.2 安全要求
|
||||
1. 数据安全
|
||||
- 训练数据自动清理机制
|
||||
- 数据库访问白名单控制
|
||||
2. 模型安全
|
||||
- 模型文件加密存储
|
||||
- 模型下载权限控制
|
||||
|
||||
### 4.3 兼容性要求
|
||||
| 组件 | 版本要求 |
|
||||
| ------ | -------- |
|
||||
| Python | 3.8+ |
|
||||
| MySQL | 5.7+ |
|
||||
| CUDA | 11.6+ |
|
||||
|
||||
---
|
||||
|
||||
## 五、系统架构
|
||||
### 5.1 逻辑架构
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph 存储层
|
||||
A[MySQL] --> B[模型元数据]
|
||||
C[MLflow] --> D[模型文件]
|
||||
end
|
||||
|
||||
subgraph 服务层
|
||||
E[FastAPI] --> F[REST API]
|
||||
G[Celery] --> H[任务队列]
|
||||
end
|
||||
|
||||
subgraph 资源层
|
||||
I[GPU服务器] --> J[计算资源]
|
||||
end
|
||||
```
|
||||
|
||||
### 5.2 物理部署
|
||||
```bash
|
||||
部署拓扑:
|
||||
- 1台应用服务器(4核8G)
|
||||
- 1台GPU计算节点(NVIDIA T4 16G)
|
||||
- 1台MySQL数据库(SSD存储)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 六、接口规范
|
||||
### 6.1 训练接口
|
||||
```json
|
||||
{
|
||||
"operation": "train",
|
||||
"params": {
|
||||
"algorithm": "xgboost",
|
||||
"dataset": "/data/project001.csv",
|
||||
"hpo": true,
|
||||
"hpo_trials": 50
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 6.2 预测接口
|
||||
```json
|
||||
{
|
||||
"model_id": "model_20230820",
|
||||
"data_format": {
|
||||
"columns": ["feature1", "feature2"],
|
||||
"dtypes": ["float", "int"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、附录
|
||||
### 7.1 术语表
|
||||
| 术语 | 解释 |
|
||||
| ------ | ----------------------------------------- |
|
||||
| HPO | 超参数优化(Hyperparameter Optimization) |
|
||||
| MLflow | 机器学习生命周期管理平台 |
|
||||
|
||||
### 7.2 变更记录
|
||||
| 版本 | 日期 | 修改内容 |
|
||||
| ---- | ---------- | -------- |
|
||||
| 1.0 | 2023-08-20 | 初始版本 |
|
||||
Loading…
Reference in New Issue
Block a user