修改 API 测试程序,更新 README
This commit is contained in:
parent
e67da8eaed
commit
8cd0bb5c06
131
README.md
131
README.md
@ -1,77 +1,102 @@
|
||||
# 数据库配置说明
|
||||
# 装备成本预测系统
|
||||
|
||||
本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要:
|
||||
基于机器学习的装备成本预测系统,支持多种预测模型和数据分析功能。
|
||||
|
||||
1. 创建数据库用户
|
||||
## 功能特性
|
||||
|
||||
```sql
|
||||
CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password';
|
||||
GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost';
|
||||
FLUSH PRIVILEGES;
|
||||
```
|
||||
- 多模型成本预测
|
||||
- 机器学习模型 (XGBoost, LightGBM, RandomForest)
|
||||
- PLS 回归模型
|
||||
- 特征分析与数据可视化
|
||||
- 生产商分析
|
||||
- 数据集管理
|
||||
- 模型训练与评估
|
||||
|
||||
2. 配置数据库字符集
|
||||
确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置:
|
||||
## 系统要求
|
||||
|
||||
```ini
|
||||
[mysqld]
|
||||
character-set-server=utf8mb4
|
||||
collation-server=utf8mb4_unicode_ci
|
||||
- Python >= 3.9, < 3.12
|
||||
- MySQL >= 8.0
|
||||
- 其他依赖见 pyproject.toml
|
||||
|
||||
[client]
|
||||
default-character-set=utf8mb4
|
||||
```
|
||||
## 快速开始
|
||||
|
||||
## 环境配置
|
||||
|
||||
本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。
|
||||
|
||||
### 使用脚本自动配置(推荐)
|
||||
|
||||
Unix/macOS:
|
||||
1. 克隆项目
|
||||
|
||||
```bash
|
||||
chmod +x scripts/setup_env.sh
|
||||
./scripts/setup_env.sh
|
||||
git clone [repository-url]
|
||||
cd cost-prediction
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
2. 安装依赖
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
||||
.\scripts\setup_env.ps1
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 手动配置
|
||||
3. 配置数据库
|
||||
|
||||
1. 安装 pyenv
|
||||
2. 安装 Python 3.11.8:
|
||||
```bash
|
||||
[Windows]
|
||||
scripts/setup_env.ps1
|
||||
|
||||
```bash
|
||||
pyenv install 3.11.8
|
||||
```
|
||||
[Linux/macOS]
|
||||
scripts/setup_env.sh
|
||||
```
|
||||
|
||||
3. 设置本地 Python 版本:
|
||||
4. 运行系统
|
||||
|
||||
```bash
|
||||
pyenv local 3.11.8
|
||||
```
|
||||
```bash
|
||||
python run.py
|
||||
```
|
||||
|
||||
4. 创建虚拟环境:
|
||||
## API 文档
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
```
|
||||
### 预测接口
|
||||
|
||||
5. 激活虚拟环境:
|
||||
- POST `/api/predict` - 使用最优机器学习模型预测
|
||||
- POST `/api/pls/predict` - 使用 PLS 模型预测
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # Unix
|
||||
.venv\Scripts\activate # Windows
|
||||
```
|
||||
### 数据管理
|
||||
|
||||
6. 安装依赖:
|
||||
- GET `/api/data` - 获取装备数据列表
|
||||
- GET `/api/data/details/<id>` - 获取装备详情
|
||||
- PUT `/api/data/<id>` - 更新装备数据
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
### 数据集管理
|
||||
|
||||
- GET `/api/datasets` - 获取数据集列表
|
||||
- POST `/api/datasets` - 创建数据集
|
||||
- GET `/api/datasets/<id>` - 获取数据集详情
|
||||
- PUT `/api/datasets/<id>` - 更新数据集
|
||||
- DELETE `/api/datasets/<id>` - 删除数据集
|
||||
|
||||
### 模型管理
|
||||
|
||||
- GET `/api/models` - 获取模型列表
|
||||
- POST `/api/train` - 训练模型
|
||||
- POST `/api/models/<id>/activate` - 激活模型
|
||||
- DELETE `/api/models/<id>` - 删除模型
|
||||
|
||||
### 分析功能
|
||||
|
||||
- POST `/api/analyze-features` - 特征分析
|
||||
- POST `/api/analyze-manufacturers` - 生产商分析
|
||||
|
||||
## 开发指南
|
||||
|
||||
详细的开发文档请参考 `docs/dev/` 目录:
|
||||
|
||||
- requirements.md - 项目需求文档
|
||||
- debug.md - 调试指南
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试:
|
||||
|
||||
```bash
|
||||
python src/test_api.py
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 [LICENSE](LICENSE) 许可证。
|
||||
|
||||
@ -32,6 +32,7 @@ dependencies = [
|
||||
# 工具
|
||||
"openpyxl>=3.1.5", # Excel支持
|
||||
"python-dotenv>=1.0.0", # 环境变量
|
||||
"requests>=2.31.0", # API测试
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
118
src/test_api.py
118
src/test_api.py
@ -147,14 +147,124 @@ def test_api_endpoints():
|
||||
response = requests.get(f'{base_url}/models/巡飞弹/latest')
|
||||
print_response(response, "获取最新模型")
|
||||
|
||||
# 8. 测试多模型预测接口
|
||||
logger.info("\n8. 测试多模型预测接口")
|
||||
# 8. 测试预测接口
|
||||
logger.info("\n8. 测试预测接口")
|
||||
|
||||
# 8.1 测试普通预测接口
|
||||
logger.info("8.1 测试普通预测接口")
|
||||
predict_data = {
|
||||
"type": "巡飞弹",
|
||||
"length_m": 1.3,
|
||||
"width_m": 0.23,
|
||||
"height_m": 0.23,
|
||||
"weight_kg": 12.5,
|
||||
"max_range_km": 40,
|
||||
"max_speed_ms": 50,
|
||||
"cruise_speed_kmh": 100,
|
||||
"flight_time_min": 60,
|
||||
"folded_length_mm": 1300,
|
||||
"folded_width_mm": 230,
|
||||
"folded_height_mm": 230,
|
||||
"warhead_type": "破片杀伤战斗部",
|
||||
"launch_mode": "凭自身动力起飞"
|
||||
}
|
||||
response = requests.post(
|
||||
f'{base_url}/predict/all',
|
||||
f'{base_url}/predict',
|
||||
json=predict_data
|
||||
)
|
||||
print_response(response, "多模型预测")
|
||||
print_response(response, "普通预测")
|
||||
|
||||
# 8.2 测试 PLS 预测接口
|
||||
logger.info("8.2 测试 PLS 预测接口")
|
||||
response = requests.post(
|
||||
f'{base_url}/pls/predict',
|
||||
json=predict_data
|
||||
)
|
||||
print_response(response, "PLS 预测")
|
||||
|
||||
# 9. 测试生产商分析接口
|
||||
logger.info("\n9. 测试生产商分析接口")
|
||||
manufacturer_data = {
|
||||
"dataset_id": 1 # 使用已存在的数据集ID
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f'{base_url}/analyze-manufacturers',
|
||||
json=manufacturer_data
|
||||
)
|
||||
print_response(response, "生产商分析")
|
||||
|
||||
# 10. 测试模型激活接口
|
||||
logger.info("\n10. 测试模型激活接口")
|
||||
# 假设存在ID为1的模型
|
||||
response = requests.post(f'{base_url}/models/1/activate')
|
||||
print_response(response, "模型激活")
|
||||
|
||||
# 11. 测试获取最新模型接口
|
||||
logger.info("\n11. 测试获取最新模型接口")
|
||||
response = requests.get(f'{base_url}/models/巡飞弹/latest')
|
||||
print_response(response, "获取最新模型")
|
||||
|
||||
# 12. 测试数据集详情接口
|
||||
logger.info("\n12. 测试数据集详情接口")
|
||||
response = requests.get(f'{base_url}/datasets/1') # 假设存在ID为1的数据集
|
||||
print_response(response, "数据集详情")
|
||||
|
||||
# 13. 测试更新数据集接口
|
||||
logger.info("\n13. 测试更新数据集接口")
|
||||
if available_equipment_ids:
|
||||
update_dataset_data = {
|
||||
"name": "更新后的测试数据集",
|
||||
"description": "用于测试的更新数据集",
|
||||
"equipment_type": "巡飞弹",
|
||||
"purpose": "测试",
|
||||
"equipment_ids": available_equipment_ids[:2] # 使用前两个可用的装备ID
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
f'{base_url}/datasets/1', # 假设更新ID为1的数据集
|
||||
json=update_dataset_data
|
||||
)
|
||||
print_response(response, "更新数据集")
|
||||
else:
|
||||
logger.warning("没有可用的装备ID,跳过数据集更新测试")
|
||||
|
||||
# 14. 测试装备详情接口
|
||||
logger.info("\n14. 测试装备详情接口")
|
||||
if available_equipment_ids:
|
||||
response = requests.get(f'{base_url}/data/details/{available_equipment_ids[0]}')
|
||||
print_response(response, "装备详情")
|
||||
|
||||
# 15. 测试更新装备接口
|
||||
logger.info("\n15. 测试更新装备接口")
|
||||
if available_equipment_ids:
|
||||
equipment_update_data = {
|
||||
"equipment_id": available_equipment_ids[0],
|
||||
"name": "更新后的装备名称",
|
||||
"type": "巡飞弹",
|
||||
"manufacturer": "测试厂商",
|
||||
"length_m": 1.5,
|
||||
"width_m": 0.3,
|
||||
"height_m": 0.3,
|
||||
"weight_kg": 15.0,
|
||||
"wingspan_m": 0.8,
|
||||
"warhead_weight_kg": 5.0,
|
||||
"max_speed_ms": 60,
|
||||
"cruise_speed_kmh": 120,
|
||||
"endurance_min": 45,
|
||||
"max_range_km": 50,
|
||||
"warhead_type": "高爆战斗部",
|
||||
"launch_mode": "弹射起飞",
|
||||
"power_system": "涡轮发动机",
|
||||
"guidance_system": "GPS/INS组合导航"
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
f'{base_url}/data/{available_equipment_ids[0]}',
|
||||
json=equipment_update_data
|
||||
)
|
||||
print_response(response, "更新装备")
|
||||
|
||||
logger.info("所有测试完成")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user