创建项目,实现需求文档和设计方案初稿

This commit is contained in:
Tian jianyong 2025-02-08 15:23:53 +08:00
commit 6a7b7b7da6
2 changed files with 684 additions and 0 deletions

437
doc/ml_platform_spec.md Normal file
View File

@ -0,0 +1,437 @@
# 内网机器学习平台技术方案
## 目录
1. 项目概述
2. 系统架构
3. 核心功能
4. 部署指南
5. API接口文档
6. 注意事项
---
## 1. 项目概述
### 1.1 目标
为科研单位提供内网机器学习平台,支持:
- 算法选择与模型训练
- 超参数自动优化
- 模型评估与部署
- 批量预测服务
### 1.2 技术栈
| 组件 | 技术选型 |
| -------- | -------------- |
| Web框架 | FastAPI |
| 任务队列 | Celery + Redis |
| 数据存储 | MySQL |
| 模型管理 | MLflow |
| 超参优化 | Optuna |
---
## 2. 系统架构
```mermaid
graph TD
A[Web UI/API] --> B[FastAPI]
B --> C[算法仓库]
B --> D[MySQL]
B --> E[Celery Worker]
E --> F[GPU训练]
C --> G[MLflow模型仓库]
G --> H[批量预测服务]
```
---
## 3. 核心功能
### 3.1 算法管理
```python
预置算法示例
ALGORITHMS = {
"xgboost": {
"type": "classification",
"params": {
"max_depth": {"type": "int", "range": [3,10]},
"learning_rate": {"type": "float", "range": [0.001,0.1]}
}
},
"random_forest": {...}
}
```
### 3.2 训练流程
1. 接收训练请求
2. 数据预处理
3. 自动超参优化(可选)
4. 模型训练与评估
5. 模型注册到MLflow
### 3.3 数据预处理
```python
class DataPreprocessor:
SUPPORTED_FORMATS = ['csv', 'parquet', 'hdf5']
def __init__(self, config):
self.pipeline = [
('缺失值处理', SimpleImputer()),
('标准化', StandardScaler()),
('特征选择', SelectKBest(k=20))
]
def split_data(self, test_size=0.2, random_state=42):
# 自动记录数据划分版本
train, test = train_test_split(self.data, test_size=test_size)
return train, test, self._generate_data_hash()
```
### 3.4 数据版本控制
```mermaid
graph LR
A[原始数据] --> B[预处理]
B --> C[训练集 v1]
B --> D[测试集 v1]
C --> E[模型训练]
D --> F[模型评估]
E --> G[元数据记录]
F --> G
```
### 3.5 硬件感知训练
```python
class HardwareAwareTraining:
def __init__(self):
self.gpu_mem = self.get_gpu_memory()
def auto_batch_size(self, model_type):
# 根据GPU显存自动调整批大小
base_sizes = {
'DNN': 1024,
'LSTM': 256,
'Transformer': 64
}
return base_sizes[model_type] * (self.gpu_mem // 16384) # 按16GB基准缩放
```
### 3.6 显存监控
```python
class MemoryMonitor(Callback):
def on_epoch_end(self, epoch, logs=None):
used_mem = torch.cuda.memory_allocated() // 1024**2
print(f"当前显存占用: {used_mem}MB / {TOTAL_GPU_MEM}MB")
if used_mem > SAFE_THRESHOLD:
self.model.stop_training = True
```
### 3.7 训练稳定性保障
```python
# 梯度裁剪实现
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.ClipGradNorm(optimizer, max_norm=1.0)
# 早停机制
early_stop = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min'
)
```
### 3.8 数据安全处理
```python
class DataSanitizer:
def __init__(self):
self.sensitive_patterns = [
r'\d{18}', # 身份证号
r'\d{11}' # 手机号
]
def clean(self, data):
# 数据脱敏处理
return data.replace(self.sensitive_patterns, '***')
```
### 3.9 完整流程示意图
```mermaid
graph TD
A[数据接入] --> B[数据质量检查]
B --> C[数据预处理]
C --> D[特征工程]
D --> E[算法选择]
E --> F[模型训练]
F --> G[模型评估]
G --> H[模型注册]
H --> I[模型部署]
I --> J[预测服务]
J --> K[效果监控]
K -->|反馈| A
```
### 3.10 新增关键模块说明
**数据质量检查**
```python
class DataValidator:
CHECKS = [
('缺失值比例 < 30%', lambda df: df.isna().mean() < 0.3),
('特征方差 > 0.01', lambda df: df.var() > 0.01),
('类别平衡性 > 0.1', lambda df: check_class_balance(df))
]
def validate(self, data):
return {check[0]: check[1](data) for check in self.CHECKS}
```
**实验跟踪**
```python
class ExperimentTracker:
def __init__(self):
self.runs = {}
def log_run(self, params, metrics, data_version):
run_id = generate_uuid()
self.runs[run_id] = {
'timestamp': datetime.now(),
'data_version': data_version,
'params': params,
'metrics': metrics
}
```
**模型监控**
```python
class ModelMonitor:
def detect_drift(self, current_data, reference_data):
# 计算数据分布差异
psi = calculate_psi(current_data, reference_data)
return psi > 0.25 # 触发重新训练阈值
```
---
## 4. 部署指南
### 4.1 环境要求
- Python 3.8+
- MySQL 5.7+
- NVIDIA驱动CUDA 11.6
### 4.2 离线安装步骤
```bash:doc/ml_platform_spec.md
# 1. 安装依赖
pip install --no-index --find-links=./offline_packages -r requirements.txt
# 2. 数据库初始化
mysql -u root -p < scripts/init_db.sql
# 3. 启动服务
docker-compose up -d
```
---
## 5. API接口文档
### 5.1 核心接口概览
| 类别 | 接口路径 | 方法 | 功能说明 |
| -------- | ---------------------- | ------ | ------------------ |
| 数据管理 | /api/data/upload | POST | 上传原始数据集 |
| | /api/data/versions | GET | 查询数据版本历史 |
| 模型训练 | /api/train | POST | 提交训练任务 |
| | /api/train/status/{id} | GET | 查询训练状态 |
| 模型管理 | /api/models | GET | 获取已注册模型列表 |
| | /api/models/{id} | DELETE | 删除指定模型 |
| 预测服务 | /api/predict/single | POST | 单条实时预测 |
| | /api/predict/batch | POST | 批量异步预测 |
| 系统管理 | /api/system/health | GET | 系统健康检查 |
### 5.2 训练接口详情
```http
POST /api/train
Content-Type: application/json
{
"algorithm": "xgboost",
"dataset": "/data/project1.csv",
"hpo": true,
"hpo_config": {
"max_trials": 50,
"timeout": 3600
}
}
Response:
{
"task_id": "train_01H9Z3X6V9",
"status_url": "/api/train/status/train_01H9Z3X6V9"
}
```
### 5.3 模型查询接口
```http
GET /api/models?type=classification
Response:
{
"models": [
{
"id": "xgb_v1",
"type": "classification",
"metrics": {
"accuracy": 0.92,
"f1": 0.89
},
"registered_at": "2023-08-20T14:30:00"
}
]
}
```
### 5.4 预测接口安全控制
```http
POST /api/predict/single
Content-Type: application/json
X-API-Key: your_api_key
{
"model_id": "xgb_v1",
"features": [0.5, 1.2, 3.4]
}
Response:
{
"prediction": 1,
"confidence": 0.87
}
```
### 5.5 新增实验对比接口
```http
POST /api/experiments/compare
Content-Type: application/json
{
"experiment_ids": ["exp001", "exp002"],
"metrics": ["accuracy", "f1"]
}
Response:
{
"comparison": {
"exp001": {"accuracy": 0.92, "f1": 0.89},
"exp002": {"accuracy": 0.89, "f1": 0.85}
}
}
```
---
## 6. 注意事项
1. 数据安全
- 训练数据最大保留30天
- 模型文件加密存储
2. 性能优化
- 单次训练数据量建议不超过1GB
- 批量预测支持CSV文件最大100MB
3. 扩展建议
- 未来可扩展AutoML模块
- 支持ONNX模型格式
### 6.4 硬件优化建议
1. **显存管理策略**
- 自动批处理调整(根据数据维度动态计算)
- 梯度累积技术累计步数≤4
2. **训练稳定性措施**
- 梯度裁剪阈值1.0
- 学习率自动衰减衰减率0.1
- 最大连续失败次数3次自动终止异常任务
3. **容错机制**
- 训练快照每小时保存checkpoint
- 异常恢复从最近checkpoint重启
### 6.5 数据管理规范
1. **输入格式要求**
- CSV文件必须包含header行
- 数值型字段空值用NaN表示
- 分类字段需预先编码
2. **数据生命周期**
```mermaid
graph LR
A[原始数据] --> B[预处理数据]
B --> C[训练数据]
C --> D[30天后删除]
```
3. **版本控制规则**
- 数据划分随机种子固定
- 每次预处理生成唯一数据指纹
- 测试集数据指纹与模型版本绑定
### 6.6 流程完整性保障
1. **数据质量红线**
- 缺失值超过30%的特征自动剔除
- 零方差特征自动过滤
- 类别不平衡数据触发警告
2. **实验可复现性**
- 记录随机种子状态
- 保存完整依赖版本
- 存储数据预处理参数
3. **模型监控机制**
- 预测结果分布监控
- 特征漂移检测PSI指标
- 模型性能衰减报警
---
## 附录
### A. 初始化脚本示例
```sql
-- scripts/init_db.sql
CREATE DATABASE IF NOT EXISTS ml_platform;
USE ml_platform;
CREATE TABLE training_jobs (
id VARCHAR(36) PRIMARY KEY,
algorithm VARCHAR(50) NOT NULL,
status ENUM('pending', 'running', 'completed') DEFAULT 'pending'
);
```
### B. 目录结构
```
.
├── docker-compose.yml
├── app/
│ ├── algorithms/ # 算法实现
│ ├── services/ # 核心服务
│ └── api/ # 接口定义
└── data/ # 训练数据集
```
完整文档已生成到项目doc目录包含详细配置示例和接口说明。需要补充其他内容请随时告知。
### C. 资源监控指标
| 指标名称 | 采集方式 | 告警阈值 |
| ---------------- | -------------- | -------------- |
| GPU显存使用率 | NVIDIA-SMI API | >85% 持续5分钟 |
| 训练内存占用 | psutil 库监控 | >70% 系统内存 |
| 训练进程存活状态 | 心跳检测 | 连续3次丢失 |
### D. 稳定性测试用例
```python
def test_oom_recovery():
# 模拟显存溢出场景
try:
train_model(oversize_data)
except RuntimeError as e:
assert "CUDA out of memory" in str(e)
assert check_rollback_success()

247
doc/requirements.md Normal file
View File

@ -0,0 +1,247 @@
# 内网机器学习平台需求文档
**版本**1.0
**最后更新**2025-02-07
**参与人员**AI研发团队
---
## 一、项目背景
为某科研单位建设内网机器学习平台,满足以下需求:
1. 完全离线部署,无互联网依赖
2. 支持常见机器学习算法快速实验
3. 与现有科研数据管理系统集成
4. 简化模型开发到部署流程
---
## 二、项目目标
| 维度 | 具体要求 |
| ---------- | ------------------------------------------------------- |
| 功能目标 | 实现算法训练、超参优化、模型评估、批量预测全流程支持 |
| 性能目标 | 单任务训练数据量≤1GB批量预测响应时间≤30s100MB数据 |
| 安全目标 | 数据保留周期≤30天模型文件AES加密存储 |
| 扩展性目标 | 支持后续扩展3-5种新算法 |
---
## 三、功能需求
### 3.1 核心功能
```mermaid
graph LR
A[数据接入] --> B[算法选择]
B --> C[训练配置]
C --> D[模型评估]
D --> E[模型部署]
E --> F[预测服务]
```
### 3.2 详细需求
#### 3.2.1 算法管理
- 预置算法清单(按任务类型分类)
**分类算法**
```python
classification_algorithms = [
"逻辑回归",
"支持向量机(SVM)",
"随机森林",
"XGBoost",
"LightGBM",
"朴素贝叶斯",
"K近邻(KNN)",
"多层感知机(MLP)",
"梯度提升决策树(GBDT)",
"深度神经网络(DNN)"
]
```
**回归算法**
```python
regression_algorithms = [
"线性回归",
"岭回归",
"Lasso回归",
"弹性网络",
"支持向量回归(SVR)",
"随机森林回归",
"XGBoost回归",
"LightGBM回归",
"多层感知机回归"
]
```
**聚类算法**
```python
clustering_algorithms = [
"K均值(K-Means)",
"层次聚类",
"DBSCAN",
"高斯混合模型(GMM)",
"谱聚类"
]
```
**时间序列**
```python
timeseries_algorithms = [
"ARIMA",
"Prophet",
"LSTM(基础版)",
"GRU(基础版)"
]
```
**降维算法**
```python
dimensionality_reduction = [
"主成分分析(PCA)",
"线性判别分析(LDA)",
"t-SNE",
"UMAP"
]
```
**推荐算法**(新增类别):
```python
recommendation_algorithms = [
"深度协同过滤(DeepCF)",
"神经矩阵分解(NeuMF)",
"Wide & Deep"
]
```
### 二、算法参数配置示例
```python
algorithm_params = {
"XGBoost": {
"max_depth": {"type": "int", "range": [3, 10]},
"learning_rate": {"type": "float", "range": [0.001, 0.3]}
},
"LSTM": {
"hidden_units": {"type": "int", "range": [16, 128]},
"num_layers": {"type": "int", "range": [1, 3]}
},
"ResNet": {
"num_blocks": {"type": "int", "range": [2, 4]},
"dropout_rate": {"type": "float", "range": [0.0, 0.5]}
},
"Transformer": {
"num_heads": {"type": "int", "range": [2, 8]},
"ff_dim": {"type": "int", "range": [64, 256]}
}
}
```
#### 3.2.2 训练服务
- 必需功能
- 异步训练任务提交
- 训练进度查询
- 自动超参优化HPO
- 训练日志记录
#### 3.2.3 预测服务
- 接口规范
```http
POST /api/predict
Content-Type: multipart/form-data
Form Data:
- model_id: 已部署模型ID
- file: 预测数据文件CSV格式
```
---
## 四、非功能需求
### 4.1 性能指标
| 场景 | 指标要求 |
| ------------------- | ---------------------- |
| 模型训练1GB数据 | ≤30分钟使用GPU加速 |
| 批量预测10万条 | ≤1分钟 |
| API响应延迟 | ≤500ms |
### 4.2 安全要求
1. 数据安全
- 训练数据自动清理机制
- 数据库访问白名单控制
2. 模型安全
- 模型文件加密存储
- 模型下载权限控制
### 4.3 兼容性要求
| 组件 | 版本要求 |
| ------ | -------- |
| Python | 3.8+ |
| MySQL | 5.7+ |
| CUDA | 11.6+ |
---
## 五、系统架构
### 5.1 逻辑架构
```mermaid
graph TB
subgraph 存储层
A[MySQL] --> B[模型元数据]
C[MLflow] --> D[模型文件]
end
subgraph 服务层
E[FastAPI] --> F[REST API]
G[Celery] --> H[任务队列]
end
subgraph 资源层
I[GPU服务器] --> J[计算资源]
end
```
### 5.2 物理部署
```bash
部署拓扑:
- 1台应用服务器4核8G
- 1台GPU计算节点NVIDIA T4 16G
- 1台MySQL数据库SSD存储
```
---
## 六、接口规范
### 6.1 训练接口
```json
{
"operation": "train",
"params": {
"algorithm": "xgboost",
"dataset": "/data/project001.csv",
"hpo": true,
"hpo_trials": 50
}
}
```
### 6.2 预测接口
```json
{
"model_id": "model_20230820",
"data_format": {
"columns": ["feature1", "feature2"],
"dtypes": ["float", "int"]
}
}
```
---
## 七、附录
### 7.1 术语表
| 术语 | 解释 |
| ------ | ----------------------------------------- |
| HPO | 超参数优化Hyperparameter Optimization |
| MLflow | 机器学习生命周期管理平台 |
### 7.2 变更记录
| 版本 | 日期 | 修改内容 |
| ---- | ---------- | -------- |
| 1.0 | 2023-08-20 | 初始版本 |