From 8cd0bb5c06a854622f6bf0049832470f5f45a678 Mon Sep 17 00:00:00 2001 From: Tian jianyong <11429339@qq.com> Date: Wed, 27 Nov 2024 10:52:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20API=20=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=A8=8B=E5=BA=8F=EF=BC=8C=E6=9B=B4=E6=96=B0=20README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 131 ++++++++++++++++++++++++++++-------------------- pyproject.toml | 1 + src/test_api.py | 118 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 193 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 1ecc6d2..ef74541 100644 --- a/README.md +++ b/README.md @@ -1,77 +1,102 @@ -# 数据库配置说明 +# 装备成本预测系统 -本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要: +基于机器学习的装备成本预测系统,支持多种预测模型和数据分析功能。 -1. 创建数据库用户 +## 功能特性 -```sql -CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password'; -GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost'; -FLUSH PRIVILEGES; -``` +- 多模型成本预测 + - 机器学习模型 (XGBoost, LightGBM, RandomForest) + - PLS 回归模型 +- 特征分析与数据可视化 +- 生产商分析 +- 数据集管理 +- 模型训练与评估 -2. 配置数据库字符集 -确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置: +## 系统要求 -```ini -[mysqld] -character-set-server=utf8mb4 -collation-server=utf8mb4_unicode_ci +- Python >= 3.9, < 3.12 +- MySQL >= 8.0 +- 其他依赖见 pyproject.toml -[client] -default-character-set=utf8mb4 -``` +## 快速开始 -## 环境配置 - -本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。 - -### 使用脚本自动配置(推荐) - -Unix/macOS: +1. 克隆项目 ```bash -chmod +x scripts/setup_env.sh -./scripts/setup_env.sh +git clone [repository-url] +cd cost-prediction ``` -Windows (PowerShell): +2. 安装依赖 -```powershell -Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser -.\scripts\setup_env.ps1 +```bash +pip install -e . ``` -### 手动配置 +3. 配置数据库 -1. 安装 pyenv -2. 安装 Python 3.11.8: +```bash +[Windows] +scripts/setup_env.ps1 - ```bash - pyenv install 3.11.8 - ``` +[Linux/macOS] +scripts/setup_env.sh +``` -3. 设置本地 Python 版本: +4. 运行系统 - ```bash - pyenv local 3.11.8 - ``` +```bash +python run.py +``` -4. 创建虚拟环境: +## API 文档 - ```bash - python -m venv .venv - ``` +### 预测接口 -5. 激活虚拟环境: +- POST `/api/predict` - 使用最优机器学习模型预测 +- POST `/api/pls/predict` - 使用 PLS 模型预测 - ```bash - source .venv/bin/activate # Unix - .venv\Scripts\activate # Windows - ``` +### 数据管理 -6. 安装依赖: +- GET `/api/data` - 获取装备数据列表 +- GET `/api/data/details/` - 获取装备详情 +- PUT `/api/data/` - 更新装备数据 - ```bash - pip install -e ".[dev]" - ``` +### 数据集管理 + +- GET `/api/datasets` - 获取数据集列表 +- POST `/api/datasets` - 创建数据集 +- GET `/api/datasets/` - 获取数据集详情 +- PUT `/api/datasets/` - 更新数据集 +- DELETE `/api/datasets/` - 删除数据集 + +### 模型管理 + +- GET `/api/models` - 获取模型列表 +- POST `/api/train` - 训练模型 +- POST `/api/models//activate` - 激活模型 +- DELETE `/api/models/` - 删除模型 + +### 分析功能 + +- POST `/api/analyze-features` - 特征分析 +- POST `/api/analyze-manufacturers` - 生产商分析 + +## 开发指南 + +详细的开发文档请参考 `docs/dev/` 目录: + +- requirements.md - 项目需求文档 +- debug.md - 调试指南 + +## 测试 + +运行测试: + +```bash +python src/test_api.py +``` + +## 许可证 + +本项目采用 [LICENSE](LICENSE) 许可证。 diff --git a/pyproject.toml b/pyproject.toml index 7ca60a6..12ee67e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ # 工具 "openpyxl>=3.1.5", # Excel支持 "python-dotenv>=1.0.0", # 环境变量 + "requests>=2.31.0", # API测试 ] [project.optional-dependencies] diff --git a/src/test_api.py b/src/test_api.py index f68d88d..d90ea7f 100644 --- a/src/test_api.py +++ b/src/test_api.py @@ -147,14 +147,124 @@ def test_api_endpoints(): response = requests.get(f'{base_url}/models/巡飞弹/latest') print_response(response, "获取最新模型") - # 8. 测试多模型预测接口 - logger.info("\n8. 测试多模型预测接口") + # 8. 测试预测接口 + logger.info("\n8. 测试预测接口") + + # 8.1 测试普通预测接口 + logger.info("8.1 测试普通预测接口") + predict_data = { + "type": "巡飞弹", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "folded_length_mm": 1300, + "folded_width_mm": 230, + "folded_height_mm": 230, + "warhead_type": "破片杀伤战斗部", + "launch_mode": "凭自身动力起飞" + } response = requests.post( - f'{base_url}/predict/all', + f'{base_url}/predict', json=predict_data ) - print_response(response, "多模型预测") + print_response(response, "普通预测") + # 8.2 测试 PLS 预测接口 + logger.info("8.2 测试 PLS 预测接口") + response = requests.post( + f'{base_url}/pls/predict', + json=predict_data + ) + print_response(response, "PLS 预测") + + # 9. 测试生产商分析接口 + logger.info("\n9. 测试生产商分析接口") + manufacturer_data = { + "dataset_id": 1 # 使用已存在的数据集ID + } + + response = requests.post( + f'{base_url}/analyze-manufacturers', + json=manufacturer_data + ) + print_response(response, "生产商分析") + + # 10. 测试模型激活接口 + logger.info("\n10. 测试模型激活接口") + # 假设存在ID为1的模型 + response = requests.post(f'{base_url}/models/1/activate') + print_response(response, "模型激活") + + # 11. 测试获取最新模型接口 + logger.info("\n11. 测试获取最新模型接口") + response = requests.get(f'{base_url}/models/巡飞弹/latest') + print_response(response, "获取最新模型") + + # 12. 测试数据集详情接口 + logger.info("\n12. 测试数据集详情接口") + response = requests.get(f'{base_url}/datasets/1') # 假设存在ID为1的数据集 + print_response(response, "数据集详情") + + # 13. 测试更新数据集接口 + logger.info("\n13. 测试更新数据集接口") + if available_equipment_ids: + update_dataset_data = { + "name": "更新后的测试数据集", + "description": "用于测试的更新数据集", + "equipment_type": "巡飞弹", + "purpose": "测试", + "equipment_ids": available_equipment_ids[:2] # 使用前两个可用的装备ID + } + + response = requests.put( + f'{base_url}/datasets/1', # 假设更新ID为1的数据集 + json=update_dataset_data + ) + print_response(response, "更新数据集") + else: + logger.warning("没有可用的装备ID,跳过数据集更新测试") + + # 14. 测试装备详情接口 + logger.info("\n14. 测试装备详情接口") + if available_equipment_ids: + response = requests.get(f'{base_url}/data/details/{available_equipment_ids[0]}') + print_response(response, "装备详情") + + # 15. 测试更新装备接口 + logger.info("\n15. 测试更新装备接口") + if available_equipment_ids: + equipment_update_data = { + "equipment_id": available_equipment_ids[0], + "name": "更新后的装备名称", + "type": "巡飞弹", + "manufacturer": "测试厂商", + "length_m": 1.5, + "width_m": 0.3, + "height_m": 0.3, + "weight_kg": 15.0, + "wingspan_m": 0.8, + "warhead_weight_kg": 5.0, + "max_speed_ms": 60, + "cruise_speed_kmh": 120, + "endurance_min": 45, + "max_range_km": 50, + "warhead_type": "高爆战斗部", + "launch_mode": "弹射起飞", + "power_system": "涡轮发动机", + "guidance_system": "GPS/INS组合导航" + } + + response = requests.put( + f'{base_url}/data/{available_equipment_ids[0]}', + json=equipment_update_data + ) + print_response(response, "更新装备") + logger.info("所有测试完成") except requests.exceptions.RequestException as e: