MLPlatform/doc/requirements.md
2025-02-08 15:25:24 +08:00

247 lines
5.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 内网机器学习平台需求文档
**版本**1.0
**最后更新**2025-02-07
**参与人员**AI研发团队
---
## 一、项目背景
为某科研单位建设内网机器学习平台,满足以下需求:
1. 完全离线部署,无互联网依赖
2. 支持常见机器学习算法快速实验
3. 与现有科研数据管理系统集成
4. 简化模型开发到部署流程
---
## 二、项目目标
| 维度 | 具体要求 |
| ---------- | ------------------------------------------------------- |
| 功能目标 | 实现算法训练、超参优化、模型评估、批量预测全流程支持 |
| 性能目标 | 单任务训练数据量≤1GB批量预测响应时间≤30s100MB数据 |
| 安全目标 | 数据保留周期≤30天模型文件AES加密存储 |
| 扩展性目标 | 支持后续扩展3-5种新算法 |
---
## 三、功能需求
### 3.1 核心功能
```mermaid
graph LR
A[数据接入] --> B[算法选择]
B --> C[训练配置]
C --> D[模型评估]
D --> E[模型部署]
E --> F[预测服务]
```
### 3.2 详细需求
#### 3.2.1 算法管理
- 预置算法清单(按任务类型分类)
**分类算法**
```python
classification_algorithms = [
"逻辑回归",
"支持向量机(SVM)",
"随机森林",
"XGBoost",
"LightGBM",
"朴素贝叶斯",
"K近邻(KNN)",
"多层感知机(MLP)",
"梯度提升决策树(GBDT)",
"深度神经网络(DNN)"
]
```
**回归算法**
```python
regression_algorithms = [
"线性回归",
"岭回归",
"Lasso回归",
"弹性网络",
"支持向量回归(SVR)",
"随机森林回归",
"XGBoost回归",
"LightGBM回归",
"多层感知机回归"
]
```
**聚类算法**
```python
clustering_algorithms = [
"K均值(K-Means)",
"层次聚类",
"DBSCAN",
"高斯混合模型(GMM)",
"谱聚类"
]
```
**时间序列**
```python
timeseries_algorithms = [
"ARIMA",
"Prophet",
"LSTM(基础版)",
"GRU(基础版)"
]
```
**降维算法**
```python
dimensionality_reduction = [
"主成分分析(PCA)",
"线性判别分析(LDA)",
"t-SNE",
"UMAP"
]
```
**推荐算法**(新增类别):
```python
recommendation_algorithms = [
"深度协同过滤(DeepCF)",
"神经矩阵分解(NeuMF)",
"Wide & Deep"
]
```
### 二、算法参数配置示例
```python
algorithm_params = {
"XGBoost": {
"max_depth": {"type": "int", "range": [3, 10]},
"learning_rate": {"type": "float", "range": [0.001, 0.3]}
},
"LSTM": {
"hidden_units": {"type": "int", "range": [16, 128]},
"num_layers": {"type": "int", "range": [1, 3]}
},
"ResNet": {
"num_blocks": {"type": "int", "range": [2, 4]},
"dropout_rate": {"type": "float", "range": [0.0, 0.5]}
},
"Transformer": {
"num_heads": {"type": "int", "range": [2, 8]},
"ff_dim": {"type": "int", "range": [64, 256]}
}
}
```
#### 3.2.2 训练服务
- 必需功能
- 异步训练任务提交
- 训练进度查询
- 自动超参优化HPO
- 训练日志记录
#### 3.2.3 预测服务
- 接口规范
```http
POST /api/predict
Content-Type: multipart/form-data
Form Data:
- model_id: 已部署模型ID
- file: 预测数据文件CSV格式
```
---
## 四、非功能需求
### 4.1 性能指标
| 场景 | 指标要求 |
| ------------------- | ---------------------- |
| 模型训练1GB数据 | ≤30分钟使用GPU加速 |
| 批量预测10万条 | ≤1分钟 |
| API响应延迟 | ≤500ms |
### 4.2 安全要求
1. 数据安全
- 训练数据自动清理机制
- 数据库访问白名单控制
2. 模型安全
- 模型文件加密存储
- 模型下载权限控制
### 4.3 兼容性要求
| 组件 | 版本要求 |
| ------ | -------- |
| Python | 3.8+ |
| MySQL | 5.7+ |
| CUDA | 11.6+ |
---
## 五、系统架构
### 5.1 逻辑架构
```mermaid
graph TB
subgraph 存储层
A[MySQL] --> B[模型元数据]
C[MLflow] --> D[模型文件]
end
subgraph 服务层
E[FastAPI] --> F[REST API]
G[Celery] --> H[任务队列]
end
subgraph 资源层
I[GPU服务器] --> J[计算资源]
end
```
### 5.2 物理部署
```bash
部署拓扑:
- 1台应用服务器4核8G
- 1台GPU计算节点NVIDIA T4 16G
- 1台MySQL数据库SSD存储
```
---
## 六、接口规范
### 6.1 训练接口
```json
{
"operation": "train",
"params": {
"algorithm": "xgboost",
"dataset": "/data/project001.csv",
"hpo": true,
"hpo_trials": 50
}
}
```
### 6.2 预测接口
```json
{
"model_id": "model_20230820",
"data_format": {
"columns": ["feature1", "feature2"],
"dtypes": ["float", "int"]
}
}
```
---
## 七、附录
### 7.1 术语表
| 术语 | 解释 |
| ------ | ----------------------------------------- |
| HPO | 超参数优化Hyperparameter Optimization |
| MLflow | 机器学习生命周期管理平台 |
### 7.2 变更记录
| 版本 | 日期 | 修改内容 |
| ---- | ---------- | -------- |
| 1.0 | 2025-02-07 | 初始版本 |