修改 API 测试程序,更新 README

This commit is contained in:
Tian jianyong 2024-11-27 10:52:47 +08:00
parent e67da8eaed
commit 8cd0bb5c06
3 changed files with 193 additions and 57 deletions

131
README.md
View File

@ -1,77 +1,102 @@
# 数据库配置说明 # 装备成本预测系统
本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要: 基于机器学习的装备成本预测系统,支持多种预测模型和数据分析功能。
1. 创建数据库用户 ## 功能特性
```sql - 多模型成本预测
CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password'; - 机器学习模型 (XGBoost, LightGBM, RandomForest)
GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost'; - PLS 回归模型
FLUSH PRIVILEGES; - 特征分析与数据可视化
``` - 生产商分析
- 数据集管理
- 模型训练与评估
2. 配置数据库字符集 ## 系统要求
确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置:
```ini - Python >= 3.9, < 3.12
[mysqld] - MySQL >= 8.0
character-set-server=utf8mb4 - 其他依赖见 pyproject.toml
collation-server=utf8mb4_unicode_ci
[client] ## 快速开始
default-character-set=utf8mb4
```
## 环境配置 1. 克隆项目
本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。
### 使用脚本自动配置(推荐)
Unix/macOS:
```bash ```bash
chmod +x scripts/setup_env.sh git clone [repository-url]
./scripts/setup_env.sh cd cost-prediction
``` ```
Windows (PowerShell): 2. 安装依赖
```powershell ```bash
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser pip install -e .
.\scripts\setup_env.ps1
``` ```
### 手动配置 3. 配置数据库
1. 安装 pyenv ```bash
2. 安装 Python 3.11.8: [Windows]
scripts/setup_env.ps1
```bash [Linux/macOS]
pyenv install 3.11.8 scripts/setup_env.sh
``` ```
3. 设置本地 Python 版本: 4. 运行系统
```bash ```bash
pyenv local 3.11.8 python run.py
``` ```
4. 创建虚拟环境: ## API 文档
```bash ### 预测接口
python -m venv .venv
```
5. 激活虚拟环境: - POST `/api/predict` - 使用最优机器学习模型预测
- POST `/api/pls/predict` - 使用 PLS 模型预测
```bash ### 数据管理
source .venv/bin/activate # Unix
.venv\Scripts\activate # Windows
```
6. 安装依赖: - GET `/api/data` - 获取装备数据列表
- GET `/api/data/details/<id>` - 获取装备详情
- PUT `/api/data/<id>` - 更新装备数据
```bash ### 数据集管理
pip install -e ".[dev]"
``` - GET `/api/datasets` - 获取数据集列表
- POST `/api/datasets` - 创建数据集
- GET `/api/datasets/<id>` - 获取数据集详情
- PUT `/api/datasets/<id>` - 更新数据集
- DELETE `/api/datasets/<id>` - 删除数据集
### 模型管理
- GET `/api/models` - 获取模型列表
- POST `/api/train` - 训练模型
- POST `/api/models/<id>/activate` - 激活模型
- DELETE `/api/models/<id>` - 删除模型
### 分析功能
- POST `/api/analyze-features` - 特征分析
- POST `/api/analyze-manufacturers` - 生产商分析
## 开发指南
详细的开发文档请参考 `docs/dev/` 目录:
- requirements.md - 项目需求文档
- debug.md - 调试指南
## 测试
运行测试:
```bash
python src/test_api.py
```
## 许可证
本项目采用 [LICENSE](LICENSE) 许可证。

View File

@ -32,6 +32,7 @@ dependencies = [
# 工具 # 工具
"openpyxl>=3.1.5", # Excel支持 "openpyxl>=3.1.5", # Excel支持
"python-dotenv>=1.0.0", # 环境变量 "python-dotenv>=1.0.0", # 环境变量
"requests>=2.31.0", # API测试
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@ -147,14 +147,124 @@ def test_api_endpoints():
response = requests.get(f'{base_url}/models/巡飞弹/latest') response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型") print_response(response, "获取最新模型")
# 8. 测试多模型预测接口 # 8. 测试预测接口
logger.info("\n8. 测试多模型预测接口") logger.info("\n8. 测试预测接口")
# 8.1 测试普通预测接口
logger.info("8.1 测试普通预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post( response = requests.post(
f'{base_url}/predict/all', f'{base_url}/predict',
json=predict_data json=predict_data
) )
print_response(response, "多模型预测") print_response(response, "普通预测")
# 8.2 测试 PLS 预测接口
logger.info("8.2 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print_response(response, "PLS 预测")
# 9. 测试生产商分析接口
logger.info("\n9. 测试生产商分析接口")
manufacturer_data = {
"dataset_id": 1 # 使用已存在的数据集ID
}
response = requests.post(
f'{base_url}/analyze-manufacturers',
json=manufacturer_data
)
print_response(response, "生产商分析")
# 10. 测试模型激活接口
logger.info("\n10. 测试模型激活接口")
# 假设存在ID为1的模型
response = requests.post(f'{base_url}/models/1/activate')
print_response(response, "模型激活")
# 11. 测试获取最新模型接口
logger.info("\n11. 测试获取最新模型接口")
response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型")
# 12. 测试数据集详情接口
logger.info("\n12. 测试数据集详情接口")
response = requests.get(f'{base_url}/datasets/1') # 假设存在ID为1的数据集
print_response(response, "数据集详情")
# 13. 测试更新数据集接口
logger.info("\n13. 测试更新数据集接口")
if available_equipment_ids:
update_dataset_data = {
"name": "更新后的测试数据集",
"description": "用于测试的更新数据集",
"equipment_type": "巡飞弹",
"purpose": "测试",
"equipment_ids": available_equipment_ids[:2] # 使用前两个可用的装备ID
}
response = requests.put(
f'{base_url}/datasets/1', # 假设更新ID为1的数据集
json=update_dataset_data
)
print_response(response, "更新数据集")
else:
logger.warning("没有可用的装备ID跳过数据集更新测试")
# 14. 测试装备详情接口
logger.info("\n14. 测试装备详情接口")
if available_equipment_ids:
response = requests.get(f'{base_url}/data/details/{available_equipment_ids[0]}')
print_response(response, "装备详情")
# 15. 测试更新装备接口
logger.info("\n15. 测试更新装备接口")
if available_equipment_ids:
equipment_update_data = {
"equipment_id": available_equipment_ids[0],
"name": "更新后的装备名称",
"type": "巡飞弹",
"manufacturer": "测试厂商",
"length_m": 1.5,
"width_m": 0.3,
"height_m": 0.3,
"weight_kg": 15.0,
"wingspan_m": 0.8,
"warhead_weight_kg": 5.0,
"max_speed_ms": 60,
"cruise_speed_kmh": 120,
"endurance_min": 45,
"max_range_km": 50,
"warhead_type": "高爆战斗部",
"launch_mode": "弹射起飞",
"power_system": "涡轮发动机",
"guidance_system": "GPS/INS组合导航"
}
response = requests.put(
f'{base_url}/data/{available_equipment_ids[0]}',
json=equipment_update_data
)
print_response(response, "更新装备")
logger.info("所有测试完成") logger.info("所有测试完成")
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e: