5.3 KiB
5.3 KiB
内网机器学习平台需求文档
版本:1.0
最后更新:2025-02-07
参与人员:AI研发团队
一、项目背景
为某科研单位建设内网机器学习平台,满足以下需求:
- 完全离线部署,无互联网依赖
- 支持常见机器学习算法快速实验
- 与现有科研数据管理系统集成
- 简化模型开发到部署流程
二、项目目标
| 维度 | 具体要求 |
|---|---|
| 功能目标 | 实现算法训练、超参优化、模型评估、批量预测全流程支持 |
| 性能目标 | 单任务训练数据量≤1GB,批量预测响应时间≤30s(100MB数据) |
| 安全目标 | 数据保留周期≤30天,模型文件AES加密存储 |
| 扩展性目标 | 支持后续扩展3-5种新算法 |
三、功能需求
3.1 核心功能
graph LR
A[数据接入] --> B[算法选择]
B --> C[训练配置]
C --> D[模型评估]
D --> E[模型部署]
E --> F[预测服务]
3.2 详细需求
3.2.1 算法管理
- 预置算法清单(按任务类型分类)
分类算法:
classification_algorithms = [
"逻辑回归",
"支持向量机(SVM)",
"随机森林",
"XGBoost",
"LightGBM",
"朴素贝叶斯",
"K近邻(KNN)",
"多层感知机(MLP)",
"梯度提升决策树(GBDT)",
"深度神经网络(DNN)"
]
回归算法:
regression_algorithms = [
"线性回归",
"岭回归",
"Lasso回归",
"弹性网络",
"支持向量回归(SVR)",
"随机森林回归",
"XGBoost回归",
"LightGBM回归",
"多层感知机回归"
]
聚类算法:
clustering_algorithms = [
"K均值(K-Means)",
"层次聚类",
"DBSCAN",
"高斯混合模型(GMM)",
"谱聚类"
]
时间序列:
timeseries_algorithms = [
"ARIMA",
"Prophet",
"LSTM(基础版)",
"GRU(基础版)"
]
降维算法:
dimensionality_reduction = [
"主成分分析(PCA)",
"线性判别分析(LDA)",
"t-SNE",
"UMAP"
]
推荐算法(新增类别):
recommendation_algorithms = [
"深度协同过滤(DeepCF)",
"神经矩阵分解(NeuMF)",
"Wide & Deep"
]
二、算法参数配置示例
algorithm_params = {
"XGBoost": {
"max_depth": {"type": "int", "range": [3, 10]},
"learning_rate": {"type": "float", "range": [0.001, 0.3]}
},
"LSTM": {
"hidden_units": {"type": "int", "range": [16, 128]},
"num_layers": {"type": "int", "range": [1, 3]}
},
"ResNet": {
"num_blocks": {"type": "int", "range": [2, 4]},
"dropout_rate": {"type": "float", "range": [0.0, 0.5]}
},
"Transformer": {
"num_heads": {"type": "int", "range": [2, 8]},
"ff_dim": {"type": "int", "range": [64, 256]}
}
}
3.2.2 训练服务
- 必需功能
- 异步训练任务提交
- 训练进度查询
- 自动超参优化(HPO)
- 训练日志记录
3.2.3 预测服务
- 接口规范
POST /api/predict Content-Type: multipart/form-data Form Data: - model_id: 已部署模型ID - file: 预测数据文件(CSV格式)
四、非功能需求
4.1 性能指标
| 场景 | 指标要求 |
|---|---|
| 模型训练(1GB数据) | ≤30分钟(使用GPU加速) |
| 批量预测(10万条) | ≤1分钟 |
| API响应延迟 | ≤500ms |
4.2 安全要求
- 数据安全
- 训练数据自动清理机制
- 数据库访问白名单控制
- 模型安全
- 模型文件加密存储
- 模型下载权限控制
4.3 兼容性要求
| 组件 | 版本要求 |
|---|---|
| Python | 3.8+ |
| MySQL | 5.7+ |
| CUDA | 11.6+ |
五、系统架构
5.1 逻辑架构
graph TB
subgraph 存储层
A[MySQL] --> B[模型元数据]
C[MLflow] --> D[模型文件]
end
subgraph 服务层
E[FastAPI] --> F[REST API]
G[Celery] --> H[任务队列]
end
subgraph 资源层
I[GPU服务器] --> J[计算资源]
end
5.2 物理部署
部署拓扑:
- 1台应用服务器(4核8G)
- 1台GPU计算节点(NVIDIA T4 16G)
- 1台MySQL数据库(SSD存储)
六、接口规范
6.1 训练接口
{
"operation": "train",
"params": {
"algorithm": "xgboost",
"dataset": "/data/project001.csv",
"hpo": true,
"hpo_trials": 50
}
}
6.2 预测接口
{
"model_id": "model_20230820",
"data_format": {
"columns": ["feature1", "feature2"],
"dtypes": ["float", "int"]
}
}
七、附录
7.1 术语表
| 术语 | 解释 |
|---|---|
| HPO | 超参数优化(Hyperparameter Optimization) |
| MLflow | 机器学习生命周期管理平台 |
7.2 变更记录
| 版本 | 日期 | 修改内容 |
|---|---|---|
| 1.0 | 2025-02-07 | 初始版本 |