Compare commits

..

15 Commits
v0.0.2 ... main

75 changed files with 8642 additions and 12932 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,3 @@
# 代码修改最佳实践
1. 修改前的准备
@ -56,3 +55,33 @@
- 处理异常情况
- 保护敏感信息
- 添加访问控制
9. 中文处理规则
- 不修改任何包含中文的注释行
- 使用 `// ... existing code ...` 跳过包含中文的代码块
- 如需修改中文附近的代码,应完整保留原有中文内容
示例:
```
// ... existing code ...
// 这是中文注释,保持不变
newCode = value;
// ... existing code ...
```
10. 编码规则
- 所有文件统一使用 UTF-8 编码
- 不使用 BOM 头
- 换行符统一使用 LF (\n)
- 文件末尾保留一个换行符
- 代码注释中的中文必须使用 UTF-8 编码
配置示例:
```json
{
"charset": "utf-8",
"end_of_line": "lf",
"insert_final_newline": true
}

View File

@ -1,25 +0,0 @@
# 数据库配置
MYSQL_HOST=localhost
MYSQL_USER=root
MYSQL_PASSWORD=your_password_here
MYSQL_DATABASE=equipment_cost_db
# 服务配置
PORT=5001
DEBUG=False
# 日志配置
LOG_LEVEL=INFO
LOG_DIR=logs
# 模型配置
MODEL_DIR=models
DATA_DIR=data
# 安全配置
SECRET_KEY=your_secret_key_here
ALLOWED_HOSTS=localhost,127.0.0.1
# 其他配置
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
ALLOWED_FILE_TYPES=.xlsx,.xls

49
.gitignore vendored
View File

@ -1,14 +1,54 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
.env
.venv
env/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
node_modules
/dist
/models
/logs
/uploads
/data
/data/*
!/data/demo_equipment_costs.csv
# local env files
.env.local
.env.*.local
.venv
# Log files
npm-debug.log*
@ -25,3 +65,10 @@ pnpm-debug.log*
*.sln
*.sw?
/frontend/node_modules
/frontend/dist
/release/
*.zip
*.tar
*.gz
*.whl

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.11.8

108
CLAUDE.md Normal file
View File

@ -0,0 +1,108 @@
# CLAUDE.md
本文件为 Claude Codeclaude.ai/code在此仓库中工作提供指引。
## 常用命令
```bash
# 安装后端依赖(核心:无需 MySQL、无需 PyTorch
pip install -e .
pip install -e ".[dev]" # 含开发工具pytest, black, mypy
pip install -e ".[torch]" # 安装可选 PyTorch神经网络训练需要
# 启动后端Flask0.0.0.0:5001
python run.py
# 启动前端开发服务器另开终端localhost:3000
cd frontend && npm install && npm run serve
# 构建前端生产版本
cd frontend && npm run build
# 运行测试
python -m pytest tests/
python -m pytest tests/test_demo_service.py -q # 单个测试文件
python -m pytest tests/test_demo_routes.py -q
# 代码格式化 / 类型检查
black src/ tests/ # 自动格式化line-length=88
mypy src/ # 类型检查
# 运行独立演示模式(仅需 5 个依赖)
cd demo_standalone && pip install -r requirements.txt && python server.py
```
## 项目概述
基于机器学习的装备成本预测系统。支持两种装备类型:**火箭炮**和**巡飞弹**。
### 技术栈
- **后端**Python 3.9-3.11, Flask 3.1+
- **数据库**SQLitePython 内置,零外部依赖),首次启动自动建表,无需手动安装配置
- **机器学习**scikit-learn核心、XGBoost、LightGBM
- **可选依赖**PyTorch仅神经网络训练需要约 800MB
- **前端**Vue 3 (Composition API) + Element Plus + ECharts, 使用 Vite 构建
### 关键文件
| 文件 | 用途 |
|---|---|
| `run.py` | 入口点,启动 Flask 服务器 |
| `src/app.py` | Flask 应用工厂 |
| `src/routes.py` | 所有 API 路由(约 1300 行) |
| `src/model_trainer.py` | ModelTrainer 类 + CostPredictionModelPyTorch NN |
| `src/data_preparation.py` | DataPreparation + EquipmentDataset |
| `src/cost_prediction.py` | CostPredictor预测编排 |
| `src/demo_service.py` | DemoModelService基于 CSV无需数据库 |
| `src/database/db_connection.py` | SQLite 数据库连接 + 建表 DDL内置 |
| `config.py` | 运行时配置 |
| `frontend/src/router/index.js` | Vue Router共 8 个路由 |
| `frontend/src/api/index.js` | Axios API 客户端 |
## 架构
### 与旧架构的核心变化
| 项目 | 旧架构MySQL | 新架构SQLite |
|------|----------------|-----------------|
| 数据库 | MySQL 8.0+,需单独安装运行 | SQLitePython 内置,零依赖 |
| 数据库依赖 | sqlalchemy, pymysql, cryptography, mysql-connector-python | 无(仅用 Python 标准库 sqlite3 |
| PyTorch | 硬依赖(顶层 import无则崩溃 | 可选依赖try/except 保护,无 PyTorch 可启动) |
| 前端构建 | Vue CLI + Babel + SCSS | Vite无需 Babel/SCSS |
| Vuex | 存在但完全空 | 已移除 |
| 依赖总数(核心)| 15+ | 5flask, numpy, pandas, scikit-learn, openpyxl |
### 训练数据流
```
用户界面 → POST /api/train → 查询 SQLite → DataPreparation特征提取 + 标准化)
→ ModelTrainer.fit_model()(训练 XGBoost, LightGBM, RF, GBM, PyTorch NN, PLS
→ 保存最优模型 → 写入 SQLite trained_models 表
```
### 预测数据流
```
用户界面 → POST /api/predict → 从 SQLite 加载最优模型 + 标准化器 → 提取特征
→ 特征标准化 → 预测 → ±20% 置信区间 → JSON 返回
```
### 机器学习模型
| 模型 | 标识 | 说明 |
|---|---|---|
| XGBoost | `xgboost` | 针对小数据量使用保守参数 |
| LightGBM | `lightgbm` | 针对小数据量使用保守参数 |
| Random Forest | `rf` | sklearn.ensemble |
| Gradient Boosting | `gbm` | sklearn.ensemble |
| PyTorch NN | `pytorch` | 可选(需安装 torch针对不同装备类型定制网络结构 |
| PLS 回归 | `pls` | 不参与最优模型评选 |
| Linear/Ridge/SVR/KNN | 仅演示 | 用于算法对比演示 |
## 重要注意事项
- **小数据量机器学习**:所有模型超参数均为保守设置(强正则化、浅树、低学习率、早停)
- **两种装备类型**火箭炮27+ 特征和巡飞弹24+ 特征),各自有独立数据库表和神经网络结构
- **生产商议价能力特征**:技术等级、规模、供应链、地区等信息被纳入特征,地区成本乘数(如美国 1.2 倍、中国 0.8 倍)
- **SQLite 首次使用**`data/equipment_cost.db` 文件在首次数据库操作时自动创建,无需手动初始化
- **编码规范**UTF-8 无 BOMLF 换行符,文件末尾保留换行符
- **中文内容**:绝不修改中文注释或文本。编辑中文附近代码时,完整保留原有中文内容

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
# MIT License
Copyright (c) 2024 Your Name or Your Organization
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

111
README.md
View File

@ -1,23 +1,102 @@
# 数据库配置说明
# 装备成本预测系统
本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要:
基于机器学习的装备成本预测系统,支持多种预测模型和数据分析功能。
1. 创建数据库用户
## 功能特性
```sql
CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password';
GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost';
FLUSH PRIVILEGES;
- 多模型成本预测
- 机器学习模型 (XGBoost, LightGBM, RandomForest)
- PLS 回归模型
- 特征分析与数据可视化
- 生产商分析
- 数据集管理
- 模型训练与评估
## 系统要求
- Python >= 3.9, < 3.12
- MySQL >= 8.0
- 其他依赖见 pyproject.toml
## 快速开始
1. 克隆项目
```bash
git clone [repository-url]
cd cost-prediction
```
2. 配置数据库字符集
确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置:
2. 安装依赖
```ini
[mysqld]
character-set-server=utf8mb4
collation-server=utf8mb4_unicode_ci
[client]
default-character-set=utf8mb4
```bash
pip install -e .
```
3. 配置数据库
```bash
[Windows]
scripts/setup_env.ps1
[Linux/macOS]
scripts/setup_env.sh
```
4. 运行系统
```bash
python run.py
```
## API 文档
### 预测接口
- POST `/api/predict` - 使用最优机器学习模型预测
- POST `/api/pls/predict` - 使用 PLS 模型预测
### 数据管理
- GET `/api/data` - 获取装备数据列表
- GET `/api/data/details/<id>` - 获取装备详情
- PUT `/api/data/<id>` - 更新装备数据
### 数据集管理
- GET `/api/datasets` - 获取数据集列表
- POST `/api/datasets` - 创建数据集
- GET `/api/datasets/<id>` - 获取数据集详情
- PUT `/api/datasets/<id>` - 更新数据集
- DELETE `/api/datasets/<id>` - 删除数据集
### 模型管理
- GET `/api/models` - 获取模型列表
- POST `/api/train` - 训练模型
- POST `/api/models/<id>/activate` - 激活模型
- DELETE `/api/models/<id>` - 删除模型
### 分析功能
- POST `/api/analyze-features` - 特征分析
- POST `/api/analyze-manufacturers` - 生产商分析
## 开发指南
详细的开发文档请参考 `docs/dev/` 目录:
- requirements.md - 项目需求文档
- debug.md - 调试指南
## 测试
运行测试:
```bash
python src/test_api.py
```
## 许可证
本项目采用 [LICENSE](LICENSE) 许可证。

126
config.py
View File

@ -1,32 +1,100 @@
import os
import secrets
# 数据库配置
DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
class Config:
"""配置类"""
# 数据库配置(使用 SQLite
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
# Flask配置
FLASK_HOST = '0.0.0.0'
FLASK_PORT = 5001
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
# 目录配置
MODEL_DIR = 'models'
DATA_DIR = 'data'
LOG_DIR = 'logs'
UPLOAD_DIR = 'uploads'
TEMPLATE_DIR = 'templates'
# 文件上传配置
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 日志配置
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_LEVEL = 'INFO'
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
LOG_BACKUP_COUNT = 5
# PyTorch配置
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 100
# 模型训练配置
TRAIN_TEST_SPLIT = 0.2
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 10
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
# 缓存配置
CACHE_TYPE = 'simple'
CACHE_DEFAULT_TIMEOUT = 300
# 安全配置
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here')
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here')
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
# 跨域配置
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
# 数据验证配置
MAX_EQUIPMENT_NAME_LENGTH = 100
MAX_MANUFACTURER_NAME_LENGTH = 100
@classmethod
def init_app(cls, app):
"""初始化应用配置"""
# 创建必要的目录
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
os.makedirs(directory, exist_ok=True)
# 配置日志
import logging
from logging.handlers import RotatingFileHandler
formatter = logging.Formatter(cls.LOG_FORMAT)
file_handler = RotatingFileHandler(
cls.LOG_FILE,
maxBytes=cls.LOG_MAX_SIZE,
backupCount=cls.LOG_BACKUP_COUNT
)
file_handler.setFormatter(formatter)
file_handler.setLevel(cls.LOG_LEVEL)
app.logger.addHandler(file_handler)
app.logger.setLevel(cls.LOG_LEVEL)
# 配置上传目录
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
# 配置跨域
from flask_cors import CORS
CORS(app, resources={
r"/api/*": {"origins": cls.CORS_ORIGINS}
})
return app
# 安全密钥配置(自动生成随机密钥)
SECRET_KEY = secrets.token_hex(16)
# 环境配置
DEBUG = False
ENV = 'production'
# 文件上传配置
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 跨域配置
CORS_ORIGINS = [
"http://localhost:8080",
"http://127.0.0.1:8080",
]
# 日志配置
LOG_LEVEL = 'DEBUG'
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')
# 创建配置实例
config = Config()

View File

@ -0,0 +1,27 @@
name,type,length_m,width_m,height_m,weight_kg,max_range_km,payload_kg,max_speed_kmh,endurance_min,tech_level,scale_level,supply_chain_level,complexity_score,actual_cost
隼击-A,巡飞弹,1.2,1.8,0.32,18,35,4,145,55,6.4,5.8,6.2,5.9,420000
隼击-B,巡飞弹,1.5,2.1,0.36,26,48,6,160,70,6.8,6.2,6.4,6.6,610000
隼击-C,巡飞弹,1.8,2.5,0.42,34,65,8,175,85,7.2,6.4,6.8,7.1,830000
侦察-100,巡飞弹,0.9,1.4,0.25,9,18,2,110,35,5.4,5.1,5.5,4.8,190000
侦察-200,巡飞弹,1.1,1.7,0.29,14,28,3,125,48,5.9,5.4,5.7,5.3,310000
侦察-300,巡飞弹,1.4,2.0,0.34,22,42,5,150,62,6.3,5.9,6.0,6.1,520000
锐蛇-S,巡飞弹,1.7,2.4,0.38,30,58,7,185,76,7.5,6.7,6.9,7.4,940000
锐蛇-M,巡飞弹,2.0,2.8,0.46,44,82,10,205,94,8.0,7.1,7.3,8.0,1360000
锐蛇-L,巡飞弹,2.4,3.2,0.55,62,120,15,230,125,8.7,7.5,7.8,8.8,2100000
鹰眼-1,巡飞弹,1.3,1.9,0.31,20,40,4,155,58,6.6,5.7,6.3,6.0,470000
鹰眼-2,巡飞弹,1.6,2.2,0.37,29,57,7,172,78,7.1,6.1,6.6,6.9,760000
鹰眼-3,巡飞弹,2.1,2.9,0.49,51,95,12,215,105,8.2,7.0,7.2,8.1,1580000
雷霆-122,火箭炮,6.9,2.4,2.8,13500,22,480,72,0,5.8,6.6,6.0,5.5,980000
雷霆-160,火箭炮,7.6,2.6,3.0,16800,40,760,68,0,6.4,6.9,6.3,6.1,1450000
雷霆-220,火箭炮,8.3,2.8,3.2,21500,70,1200,65,0,7.0,7.1,6.8,7.0,2380000
雷霆-300,火箭炮,9.8,3.0,3.4,28500,120,1850,62,0,7.8,7.4,7.2,8.0,4200000
山猫-95,火箭炮,6.2,2.3,2.7,11800,18,360,78,0,5.4,6.0,5.7,5.0,740000
山猫-120,火箭炮,6.7,2.4,2.8,13000,30,520,75,0,5.9,6.2,6.0,5.6,1050000
山猫-200,火箭炮,7.9,2.7,3.1,19800,60,980,70,0,6.8,6.8,6.5,6.7,1980000
山猫-300,火箭炮,9.3,2.9,3.3,26000,105,1600,66,0,7.6,7.2,7.0,7.8,3560000
弓兵-L,火箭炮,8.8,2.9,3.2,23500,85,1350,69,0,7.2,7.0,6.9,7.3,2860000
弓兵-X,火箭炮,10.2,3.1,3.6,31000,150,2100,60,0,8.4,7.8,7.6,8.7,5400000
长矛-1,火箭炮,7.1,2.5,2.9,14200,28,560,73,0,6.1,6.4,6.1,5.8,1180000
长矛-2,火箭炮,8.1,2.7,3.1,20500,75,1120,68,0,7.1,6.9,6.7,7.1,2420000
长矛-3,火箭炮,9.6,3.0,3.5,29200,130,1900,63,0,8.1,7.5,7.4,8.3,4650000
擎天-M,火箭炮,10.8,3.2,3.8,34800,180,2450,58,0,8.9,8.0,7.9,9.2,6900000
1 name type length_m width_m height_m weight_kg max_range_km payload_kg max_speed_kmh endurance_min tech_level scale_level supply_chain_level complexity_score actual_cost
2 隼击-A 巡飞弹 1.2 1.8 0.32 18 35 4 145 55 6.4 5.8 6.2 5.9 420000
3 隼击-B 巡飞弹 1.5 2.1 0.36 26 48 6 160 70 6.8 6.2 6.4 6.6 610000
4 隼击-C 巡飞弹 1.8 2.5 0.42 34 65 8 175 85 7.2 6.4 6.8 7.1 830000
5 侦察-100 巡飞弹 0.9 1.4 0.25 9 18 2 110 35 5.4 5.1 5.5 4.8 190000
6 侦察-200 巡飞弹 1.1 1.7 0.29 14 28 3 125 48 5.9 5.4 5.7 5.3 310000
7 侦察-300 巡飞弹 1.4 2.0 0.34 22 42 5 150 62 6.3 5.9 6.0 6.1 520000
8 锐蛇-S 巡飞弹 1.7 2.4 0.38 30 58 7 185 76 7.5 6.7 6.9 7.4 940000
9 锐蛇-M 巡飞弹 2.0 2.8 0.46 44 82 10 205 94 8.0 7.1 7.3 8.0 1360000
10 锐蛇-L 巡飞弹 2.4 3.2 0.55 62 120 15 230 125 8.7 7.5 7.8 8.8 2100000
11 鹰眼-1 巡飞弹 1.3 1.9 0.31 20 40 4 155 58 6.6 5.7 6.3 6.0 470000
12 鹰眼-2 巡飞弹 1.6 2.2 0.37 29 57 7 172 78 7.1 6.1 6.6 6.9 760000
13 鹰眼-3 巡飞弹 2.1 2.9 0.49 51 95 12 215 105 8.2 7.0 7.2 8.1 1580000
14 雷霆-122 火箭炮 6.9 2.4 2.8 13500 22 480 72 0 5.8 6.6 6.0 5.5 980000
15 雷霆-160 火箭炮 7.6 2.6 3.0 16800 40 760 68 0 6.4 6.9 6.3 6.1 1450000
16 雷霆-220 火箭炮 8.3 2.8 3.2 21500 70 1200 65 0 7.0 7.1 6.8 7.0 2380000
17 雷霆-300 火箭炮 9.8 3.0 3.4 28500 120 1850 62 0 7.8 7.4 7.2 8.0 4200000
18 山猫-95 火箭炮 6.2 2.3 2.7 11800 18 360 78 0 5.4 6.0 5.7 5.0 740000
19 山猫-120 火箭炮 6.7 2.4 2.8 13000 30 520 75 0 5.9 6.2 6.0 5.6 1050000
20 山猫-200 火箭炮 7.9 2.7 3.1 19800 60 980 70 0 6.8 6.8 6.5 6.7 1980000
21 山猫-300 火箭炮 9.3 2.9 3.3 26000 105 1600 66 0 7.6 7.2 7.0 7.8 3560000
22 弓兵-L 火箭炮 8.8 2.9 3.2 23500 85 1350 69 0 7.2 7.0 6.9 7.3 2860000
23 弓兵-X 火箭炮 10.2 3.1 3.6 31000 150 2100 60 0 8.4 7.8 7.6 8.7 5400000
24 长矛-1 火箭炮 7.1 2.5 2.9 14200 28 560 73 0 6.1 6.4 6.1 5.8 1180000
25 长矛-2 火箭炮 8.1 2.7 3.1 20500 75 1120 68 0 7.1 6.9 6.7 7.1 2420000
26 长矛-3 火箭炮 9.6 3.0 3.5 29200 130 1900 63 0 8.1 7.5 7.4 8.3 4650000
27 擎天-M 火箭炮 10.8 3.2 3.8 34800 180 2450 58 0 8.9 8.0 7.9 9.2 6900000

13
demo_standalone/README.md Normal file
View File

@ -0,0 +1,13 @@
# 机器学习算法演示
## 运行方式
1. 解压 zip 文件。
2. 双击 `start_demo.bat`
3. 浏览器会自动打开 `http://127.0.0.1:5001/algorithm-demo`
## 说明
- 演示使用 `data/demo_equipment_costs.csv`,不需要 MySQL。
- 首次运行会创建 `.venv` 并安装最小 Python 依赖。
- 需要本机已安装 Python 3.9 至 3.11。

View File

@ -0,0 +1,5 @@
flask>=3.1.0
flask-cors>=5.0.0
numpy>=1.26.0,<2.0.0
pandas>=2.2.0
scikit-learn>=1.5.2

48
demo_standalone/server.py Normal file
View File

@ -0,0 +1,48 @@
from pathlib import Path
from flask import Flask, jsonify, request, send_from_directory
from flask_cors import CORS
from demo_service import DemoModelService
BASE_DIR = Path(__file__).resolve().parent
STATIC_DIR = BASE_DIR / "frontend"
DATASET_PATH = BASE_DIR / "data" / "demo_equipment_costs.csv"
def create_app():
app = Flask(__name__, static_folder=None)
CORS(app)
@app.get("/api/demo/algorithms")
def demo_algorithms():
service = DemoModelService(DATASET_PATH)
return jsonify({"algorithms": service.get_algorithms()})
@app.get("/api/demo/dataset")
def demo_dataset():
service = DemoModelService(DATASET_PATH)
return jsonify(service.get_dataset_summary())
@app.post("/api/demo/run")
def demo_run():
payload = request.get_json(silent=True) or {}
service = DemoModelService(DATASET_PATH)
return jsonify(service.run_demo(payload.get("algorithms")))
@app.get("/")
@app.get("/<path:path>")
def frontend(path=""):
file_path = STATIC_DIR / path
if path and file_path.exists() and file_path.is_file():
return send_from_directory(STATIC_DIR, path)
return send_from_directory(STATIC_DIR, "index.html")
return app
if __name__ == "__main__":
app = create_app()
print("算法演示服务已启动http://127.0.0.1:5001/algorithm-demo")
app.run(host="127.0.0.1", port=5001, debug=False)

View File

@ -0,0 +1,32 @@
@echo off
setlocal
cd /d "%~dp0"
where python >nul 2>nul
if errorlevel 1 (
echo 未找到 Python。请先安装 Python 3.9 至 3.11,然后重新运行本脚本。
pause
exit /b 1
)
if not exist ".venv\Scripts\python.exe" (
echo 正在创建演示环境...
python -m venv .venv
if errorlevel 1 (
echo 创建环境失败。
pause
exit /b 1
)
)
echo 正在安装或检查依赖...
".venv\Scripts\python.exe" -m pip install -r requirements.txt
if errorlevel 1 (
echo 依赖安装失败,请检查网络或 Python 环境。
pause
exit /b 1
)
start "" http://127.0.0.1:5001/algorithm-demo
".venv\Scripts\python.exe" server.py
pause

View File

@ -4,8 +4,8 @@
### 1. 基础软件
- Linux 操作系统 (推荐 Ubuntu 20.04+)
- Python 3.8+ 及相关组件
- Linux 操作系统 (推荐 Ubuntu 22.04+)
- Python 3.12 及相关组件
```bash
sudo apt update
@ -23,6 +23,19 @@
nvm use 14
```
- Windows 操作系统 (推荐 Windows 10+)
- Python 3.12 及相关组件
参考:<https://www.python.org/downloads/>
- Node.js 14+ 及 npm
参考:<https://learn.microsoft.com/en-us/windows/dev-environment/javascript/nodejs-on-windows>
```bash
# 设置执行策略
set-executionpolicy remotesigned
```
### 2. 数据库
- MySQL 8.0+
@ -32,16 +45,58 @@
sudo apt install libmysqlclient-dev
```
Windows 参考:<https://dev.mysql.com/downloads/installer/>
### 3. Python包依赖
```bash
# 科学计算相关
# Windows系统下安装依赖
# 1. 创建并激活虚拟环境
python -m venv venv
.\venv\Scripts\activate
# 2. 设置pip源为国内镜像可选但推荐
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
# 3. 更新pip
python -m pip install --upgrade pip
# 4. 安装依赖包使用UTF-8编码
# PowerShell命令行
$env:PYTHONUTF8=1
pip install -r requirements.txt
# Linux系统下安装依赖
# 1. 创建并激活虚拟环境
python3 -m venv venv
source venv/bin/activate
# 2. 安装依赖包
pip install -r requirements.txt
# 解析Excel文件需要安装以下依赖
pip install pandas
pip install openpyxl
pip install xlrd
# 常见问题解决:
# 1. 如果遇到编码错误请确保使用UTF-8编码
# 2. 如果安装过程中出现权限问题,请使用管理员权限运行命令行
# 3. 如果下载速度慢,建议使用国内镜像源
# 4. 如果出现SSL证书错误可以尝试添加--trusted-host参数
pip install -r requirements.txt --trusted-host pypi.tuna.tsinghua.edu.cn
```
### 4. 科学计算相关
sudo apt install libatlas-base-dev # numpy依赖
sudo apt install libopenblas-dev # 线性代数库
sudo apt install liblapack-dev # 线性代数包
sudo apt install gfortran # Fortran编译器(scipy依赖)
# XML处理相关(用于Excel文件处理)
## XML处理相关(用于Excel文件处理)
```bash
sudo apt install libxml2-dev
sudo apt install libxslt1-dev
```

View File

@ -0,0 +1,224 @@
# 模型超参数设置
## 1. PyTorch神经网络
### 火箭炮配置
1. 输入层 -> 隐藏层1Linear(input_size -> 32) + ReLU + BatchNorm
2. 隐藏层1 -> 隐藏层2Linear(32 -> 16) + ReLU + BatchNorm
3. 隐藏层2 -> 隐藏层3Linear(16 -> 8) + ReLU + BatchNorm
4. 隐藏层3 -> 输出层Linear(8 -> 1)
```python
learning_rate = 0.0003
weight_decay = 0.001
optimizer = AdamW(
betas=(0.8, 0.9),
eps=1e-8
)
loss_function = SmoothL1Loss(beta=0.1)
scheduler = 带预热的余弦退火
gradient_clip = max_norm=0.1
```
### 巡飞弹配置
生产商特征网络2层
1. Linear(5 -> 4) + ReLU + BatchNorm + Dropout(0.2)
装备特征网络4层
1. Linear(input_size-5 -> 64) + LeakyReLU + BatchNorm + Dropout
2. Linear(64 -> 32) + LeakyReLU + BatchNorm + Dropout
3. Linear(32 -> 16) + LeakyReLU + BatchNorm + Dropout
合并网络4层
1. Linear(20 -> 32) + LeakyReLU + BatchNorm + Dropout
2. Linear(32 -> 16) + LeakyReLU + BatchNorm + Dropout
3. Linear(16 -> 8) + LeakyReLU + BatchNorm
4. Linear(8 -> 1)
```python
learning_rate = 0.001
weight_decay = 0.001
optimizer = Adam(betas=(0.9, 0.999))
loss_function = MSELoss()
scheduler = 余弦退火
```
## 2. XGBoost
```python
n_estimators = 50
learning_rate = 0.03
max_depth = 3
min_child_weight = 5
subsample = 0.6
colsample_bytree = 0.6
reg_alpha = 0.5
reg_lambda = 2.0
gamma = 1
random_state = 42
```
## 3. LightGBM
```python
n_estimators = 50
learning_rate = 0.03
max_depth = 3
num_leaves = 8
subsample = 0.6
colsample_bytree = 0.6
reg_alpha = 0.5
reg_lambda = 2.0
min_child_samples = 10
min_split_gain = 1.0
random_state = 42
```
## 4. GBM梯度提升机
```python
n_estimators = 50
learning_rate = 0.03
max_depth = 3
min_samples_split = 10
min_samples_leaf = 5
subsample = 0.6
min_impurity_decrease = 0.01
random_state = 42
```
## 5. Random Forest随机森林
```python
n_estimators = 100
max_depth = 4
min_samples_split = 5
min_samples_leaf = 3
max_features = 'sqrt'
bootstrap = True
random_state = 42
```
## 6. PLS回归
```python
n_components = min(3, 特征数量//5)
scale = True
max_iter = 500
tol = 1e-6
```
## 超参数调优策略
### 1. 样本量增加时的调整策略
#### PyTorch神经网络
- 增加网络深度和宽度
- 可以在现有层之间添加更多隐藏层
- 适当增加每层神经元数量
- 调整学习率和优化器
- 可以使用更大的学习率如0.001-0.005
- 减小weight_decay如0.0005
- 减少正则化强度
- 降低Dropout率如0.1
- 可以移除部分BatchNorm层
#### 树模型XGBoost/LightGBM/GBM
- 增加树的数量n_estimators100-500
- 增加树的深度max_depth4-6
- 减小正则化参数
- reg_alpha0.3
- reg_lambda1.0
- 增大子采样比例subsample0.8-0.9
#### Random Forest
- 增加树的数量n_estimators200-500
- 增加树的深度max_depth6-8
- 减小最小分裂样本数
- min_samples_split3
- min_samples_leaf2
#### PLS回归
- 增加组件数量n_components
- 可以考虑使用非线性核函数
### 2. 特征数量变化的调整策略
#### 特征数量增加时
- 增强特征选择和降维
- 增加正则化强度
- 考虑使用特征筛选方法
- 可以使用自动特征选择算法
#### 特征数量减少时
- 简化模型结构
- 减少正则化强度
- 增加每个特征的权重
### 3. 自动化调优建议
1. 使用网格搜索Grid Search
- 适用于参数空间较小时
- 可以详尽搜索最优参数
2. 使用随机搜索Random Search
- 适用于参数空间较大时
- 比网格搜索更高效
3. 使用贝叶斯优化
- 适用于计算资源有限时
- 能更智能地搜索参数空间
4. 交叉验证策略
- 样本量大时使用K折交叉验证K=5或10
- 样本量小时:使用留一法交叉验证
### 4. 性能监控指标
在调参过程中需要监控:
1. 训练集和验证集的损失曲线
2. 模型复杂度vs性能提升
3. 训练时间vs性能提升
4. 过拟合风险
### 5. 调优注意事项
1. 保持可解释性
- 模型复杂度增加时,确保结果仍可解释
- 记录参数调整的原因和效果
2. 计算资源平衡
- 在性能提升和计算成本间找到平衡点
- 考虑模型部署的实际环境限制
3. 稳定性要求
- 确保模型在不同数据分布下仍能稳定工作
- 定期使用新数据验证模型性能
## 参数说明
所有模型都设置了 `random_state=42` 以确保结果可重现。这些参数是经过调优的,针对小样本量的特点,采用了较为保守的设置:
- 较小的学习率:避免过拟合,提高模型稳定性
- 较浅的树深度:防止模型过于复杂
- 较强的正则化:增强模型泛化能力
- 适当的子采样比例:提高模型鲁棒性
这些参数设置主要考虑了以下因素:
1. 样本量较小
2. 特征维度适中
3. 需要较强的泛化能力
4. 预测稳定性要求高

30
docs/release_guide.md Normal file
View File

@ -0,0 +1,30 @@
# Windows 发布包制作指南
1. 准备工作
1.1 安装必要软件
- Python 3.11.8: <https://www.python.org/downloads/>
- Visual Studio Build Tools: <https://visualstudio.microsoft.com/visual-cpp-build-tools/>
1.2 下载安装程序
- Python 3.11.8: <https://www.python.org/ftp/python/3.11.8/python-3.11.8-amd64.exe>
- Visual C++ Redistributable: <https://aka.ms/vs/17/release/vc_redist.x64.exe可选如果系统已安装则不需要>
1. 克隆项目
```powershell
git clone [repository-url]
cd cost-prediction
```
2. 打包步骤
```powershell
# 运行打包脚本
.\scripts\build_win.ps1
```
打包完成后会在项目根目录生成 `cost-prediction-[version]-win64.zip`

View File

@ -0,0 +1,57 @@
# ML Algorithm Demo Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add a modern demo page that compares common machine learning algorithms using a local data file instead of MySQL.
**Architecture:** Add an isolated backend demo service that reads `data/demo_equipment_costs.csv`, trains selected regressors in memory, and returns metrics, prediction points, feature importance, and a sample prediction. Add a Vue route that calls the demo API and renders algorithm switching, charts, metrics, and data preview. Existing database-backed pages remain unchanged.
**Tech Stack:** Flask, pandas, scikit-learn, optional xgboost/lightgbm, Vue 3, Element Plus, ECharts.
---
### Task 1: Backend Demo Service
**Files:**
- Create: `tests/test_demo_service.py`
- Create: `src/demo_service.py`
- Create: `data/demo_equipment_costs.csv`
- [ ] Write failing tests for data loading, algorithm availability, and training payload shape.
- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it fails because `src.demo_service` is missing.
- [ ] Implement `DemoModelService` with local CSV loading, selected algorithm training, metric calculation, top feature importance, and fallback algorithms when optional libraries are unavailable.
- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it passes.
### Task 2: Demo API
**Files:**
- Modify: `src/routes.py`
- Test: `tests/test_demo_routes.py`
- [ ] Write Flask route tests for `GET /api/demo/algorithms`, `GET /api/demo/dataset`, and `POST /api/demo/run`.
- [ ] Run `python -m pytest tests/test_demo_routes.py -q` and verify missing routes fail.
- [ ] Add demo routes that call `DemoModelService` and do not access MySQL.
- [ ] Run the route tests and demo service tests.
### Task 3: Vue Demo Page
**Files:**
- Create: `frontend/src/views/AlgorithmDemoPage.vue`
- Modify: `frontend/src/router/index.js`
- Modify: `frontend/src/App.vue`
- Modify: `frontend/src/api/index.js`
- Modify: `frontend/src/views/HomePage.vue`
- [ ] Add API helpers for demo algorithms, dataset, and run.
- [ ] Add `/algorithm-demo` route and navigation label `算法演示`.
- [ ] Build a modern dashboard-style page with algorithm toggles, metric cards, comparison chart, predicted-vs-actual chart, feature importance chart, sample prediction panel, and data preview table.
- [ ] Add a home page entry that links to the demo.
### Task 4: Verification
**Files:**
- No new files.
- [ ] Run `python -m pytest tests/test_demo_service.py tests/test_demo_routes.py -q`.
- [ ] Run `npm run build` in `frontend`.
- [ ] Start the app if feasible and confirm the new route is available.

73
docs/windows_setup.md Normal file
View File

@ -0,0 +1,73 @@
# Windows 开发环境设置
1. 安装必要软件
- Python 3.11.8: <https://www.python.org/downloads/>
- Git: <https://git-scm.com/download/win>
- MySQL 8.0+: <https://dev.mysql.com/downloads/mysql/>
- Visual Studio Build Tools: <https://visualstudio.microsoft.com/visual-cpp-build-tools/>
- Node.js 18+ LTS: <https://nodejs.org/download/>安装时Chocolatey不是必需的
- npm 9+: (随 Node.js 一起安装)
2. 克隆项目
```powershell
git clone [repository-url]
cd cost-prediction
```
3. 设置前端环境
```powershell
# 进入前端目录
cd frontend
# 安装依赖
npm install 22
nvm use 22
# 构建生产版本
npm run build
# 返回项目根目录
cd ..
```
4. 设置 Python 环境
```powershell
# 创建虚拟环境
python -m venv .venv
# 激活虚拟环境
.\.venv\Scripts\Activate.ps1
# 安装依赖
pip install -e .
```
5. 配置数据库
```powershell
# 确保 MySQL 服务已启动
# 初始化数据库和导入数据
```
6. 运行测试
```powershell
python src/test_api.py
```
7. 打包项目
```powershell
# 先下载所有依赖
.\scripts\download_deps.ps1
# 然后运行打包脚本
.\scripts\build_win.ps1
```
## 注意:如果需要制作发布包,请参考 docs/release_guide.md

View File

@ -1,5 +0,0 @@
module.exports = {
presets: [
'@vue/cli-plugin-babel/preset'
]
}

17
frontend/index.html Normal file
View File

@ -0,0 +1,17 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="/favicon.ico">
<title>装备成本估算系统</title>
</head>
<body>
<noscript>
<strong>装备成本估算系统需要启用 JavaScript 才能运行。</strong>
</noscript>
<div id="app"></div>
<script type="module" src="/src/main.js"></script>
</body>
</html>

11740
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -7,33 +7,24 @@
"npm": ">=8"
},
"scripts": {
"serve": "vue-cli-service serve",
"build": "vue-cli-service build",
"lint": "vue-cli-service lint"
"serve": "vite",
"build": "vite build",
"lint": "eslint --ext .js,.vue src/"
},
"dependencies": {
"axios": "^1.6.0",
"core-js": "^3.8.3",
"echarts": "^5.4.3",
"element-plus": "^2.4.2",
"vue": "^3.2.13",
"vue-router": "^4.0.3",
"vuex": "^4.0.0"
"vue-router": "^4.0.3"
},
"devDependencies": {
"@babel/core": "^7.12.16",
"@babel/eslint-parser": "^7.12.16",
"@element-plus/icons-vue": "^2.3.1",
"@vue/cli-plugin-babel": "~5.0.0",
"@vue/cli-plugin-eslint": "~5.0.0",
"@vue/cli-plugin-router": "~5.0.0",
"@vue/cli-plugin-vuex": "~5.0.0",
"@vue/cli-service": "~5.0.0",
"@vue/compiler-sfc": "^3.2.13",
"@vitejs/plugin-vue": "^5.0.0",
"eslint": "^7.32.0",
"eslint-plugin-vue": "^8.0.3",
"sass": "^1.32.7",
"sass-loader": "^12.0.0"
"sass-embedded": "^1.99.0",
"vite": "^5.0.0"
},
"eslintConfig": {
"root": true,
@ -45,17 +36,12 @@
"eslint:recommended"
],
"parserOptions": {
"parser": "@babel/eslint-parser"
"ecmaVersion": "latest",
"sourceType": "module"
},
"rules": {
"vue/multi-word-component-names": "off",
"no-unused-vars": "warn"
}
},
"browserslist": [
"> 1%",
"last 2 versions",
"not dead",
"not ie 11"
]
}
}

View File

@ -1,17 +0,0 @@
<!DOCTYPE html>
<html lang="">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
<title><%= htmlWebpackPlugin.options.title %></title>
</head>
<body>
<noscript>
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
</noscript>
<div id="app"></div>
<!-- built files will be auto injected -->
</body>
</html>

View File

@ -9,6 +9,7 @@
<el-menu-item index="/">首页</el-menu-item>
<el-menu-item index="/predict">成本预测</el-menu-item>
<el-menu-item index="/analysis">特征分析</el-menu-item>
<el-menu-item index="/algorithm-demo">算法演示</el-menu-item>
<el-menu-item index="/training">模型训练</el-menu-item>
<el-menu-item index="/models">模型管理</el-menu-item>
<el-menu-item index="/datasets">数据集管理</el-menu-item>

View File

@ -40,4 +40,16 @@ export const updateEquipment = (id, data) => {
export const deleteEquipment = (id) => {
return api.delete(`/data/${id}`)
}
}
export const getDemoAlgorithms = () => {
return api.get('/demo/algorithms')
}
export const getDemoDataset = () => {
return api.get('/demo/dataset')
}
export const runAlgorithmDemo = (data) => {
return api.post('/demo/run', data)
}

View File

@ -1,8 +1,12 @@
export const API_BASE_URL = 'http://localhost:5001/api';
const isLocalDevServer = window.location.port === '8080'
export const API_BASE_URL = isLocalDevServer
? 'http://localhost:5001/api'
: `${window.location.origin}/api`;
export const DB_CONFIG = {
host: 'localhost',
user: 'root',
password: '123456',
database: 'equipment_cost_db'
};
};

View File

@ -1,7 +1,6 @@
import { createApp } from 'vue'
import App from './App.vue'
import router from './router'
import store from './store'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import './assets/styles/global.css'
@ -15,7 +14,6 @@ app.use(ElementPlus, {
size: 'default'
})
app.use(router)
app.use(store)
// 注册图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {

View File

@ -5,6 +5,7 @@ import DatasetPage from '@/views/DatasetPage.vue'
import PredictPage from '@/views/PredictPage.vue'
import AnalysisPage from '@/views/AnalysisPage.vue'
import TrainingPage from '@/views/TrainingPage.vue'
import AlgorithmDemoPage from '@/views/AlgorithmDemoPage.vue'
const routes = [
{
@ -37,6 +38,11 @@ const routes = [
name: 'Training',
component: TrainingPage
},
{
path: '/algorithm-demo',
name: 'AlgorithmDemo',
component: AlgorithmDemoPage
},
{
path: '/models',
name: 'Models',
@ -49,4 +55,4 @@ const router = createRouter({
routes
})
export default router
export default router

View File

@ -1,14 +0,0 @@
import { createStore } from 'vuex'
export default createStore({
state: {
},
getters: {
},
mutations: {
},
actions: {
},
modules: {
}
})

View File

@ -0,0 +1,644 @@
<template>
<div class="algorithm-demo-page">
<section class="demo-hero">
<div>
<p class="eyebrow">本地文件算法演示</p>
<h1>机器学习算法演示</h1>
<p class="hero-copy">
使用本地数据文件快速训练和比较常用回归算法适合客户演示部署
</p>
</div>
<div class="hero-actions">
<el-button type="primary" :loading="loading" @click="runDemo">
<el-icon><VideoPlay /></el-icon>
运行演示
</el-button>
<el-tag effect="plain" type="success">无需 MySQL</el-tag>
</div>
</section>
<section class="control-band">
<div class="panel algorithm-panel">
<div class="panel-header">
<div>
<p class="eyebrow">算法选择</p>
<h2>选择算法</h2>
</div>
<el-button text type="primary" @click="selectRecommended">推荐组合</el-button>
</div>
<el-checkbox-group v-model="selectedAlgorithms" class="algorithm-grid">
<el-checkbox
v-for="item in algorithms"
:key="item.key"
:value="item.key"
border
>
<span class="algorithm-name">{{ item.name }}</span>
<small>{{ item.english_name }} · {{ item.family }}</small>
</el-checkbox>
</el-checkbox-group>
<el-alert
v-if="warnings.length"
class="warning-strip"
type="warning"
:closable="false"
show-icon
>
<template #title>{{ warnings.join(' ') }}</template>
</el-alert>
</div>
<div class="panel dataset-panel">
<div class="panel-header">
<div>
<p class="eyebrow">数据来源</p>
<h2>本地演示数据</h2>
</div>
<el-tag>{{ dataset.row_count || 0 }} </el-tag>
</div>
<div class="dataset-stats">
<div>
<strong>{{ dataset.features?.length || 0 }}</strong>
<span>特征数</span>
</div>
<div>
<strong>{{ dataset.equipment_types?.length || 0 }}</strong>
<span>装备类型</span>
</div>
<div>
<strong>{{ dataset.target_label || '-' }}</strong>
<span>预测目标</span>
</div>
</div>
</div>
</section>
<section v-if="result" class="metrics-grid">
<article
v-for="row in metricRows"
:key="row.key"
class="metric-card"
:class="{ active: row.key === result.best_model }"
@click="activeAlgorithm = row.key"
>
<div class="metric-title">
<span>{{ row.name }}</span>
<el-tag v-if="row.key === result.best_model" size="small" type="success">最佳</el-tag>
</div>
<strong>{{ formatScore(row.r2) }}</strong>
<div class="metric-values">
<span>平均绝对误差 {{ formatMoney(row.mae) }}</span>
<span>均方根误差 {{ formatMoney(row.rmse) }}</span>
</div>
</article>
</section>
<section v-if="result" class="visual-grid">
<div class="panel chart-panel wide">
<div class="panel-header">
<div>
<p class="eyebrow">效果对比</p>
<h2>模型指标对比</h2>
</div>
</div>
<div ref="metricsChartRef" class="chart"></div>
</div>
<div class="panel chart-panel wide">
<div class="panel-header">
<div>
<p class="eyebrow">预测结果</p>
<h2>预测值与真实值</h2>
</div>
<el-select v-model="activeAlgorithm" size="small" class="algorithm-select">
<el-option
v-for="row in metricRows"
:key="row.key"
:label="row.name"
:value="row.key"
/>
</el-select>
</div>
<div ref="predictionChartRef" class="chart"></div>
</div>
<div class="panel chart-panel">
<div class="panel-header">
<div>
<p class="eyebrow">模型解释</p>
<h2>特征重要性</h2>
</div>
</div>
<div ref="importanceChartRef" class="chart compact"></div>
</div>
<div class="panel sample-panel">
<div class="panel-header">
<div>
<p class="eyebrow">样例场景</p>
<h2>样例装备预测</h2>
</div>
</div>
<dl>
<dt>装备名称</dt>
<dd>{{ result.sample_prediction.input.name }}</dd>
<dt>真实成本</dt>
<dd>{{ formatMoney(result.sample_prediction.actual) }}</dd>
<dt>当前算法预测</dt>
<dd>{{ formatMoney(result.sample_prediction.predictions[activeAlgorithm]) }}</dd>
</dl>
</div>
</section>
<section class="panel data-preview">
<div class="panel-header">
<div>
<p class="eyebrow">数据预览</p>
<h2>数据文件预览</h2>
</div>
</div>
<el-table :data="dataset.preview || []" height="320" stripe>
<el-table-column prop="name" label="名称" min-width="130" fixed />
<el-table-column prop="type" label="类型" min-width="150" />
<el-table-column prop="weight_kg" label="重量(kg)" min-width="100" />
<el-table-column prop="max_range_km" label="射程(km)" min-width="100" />
<el-table-column prop="tech_level" label="技术水平" min-width="100" />
<el-table-column prop="complexity_score" label="复杂度" min-width="100" />
<el-table-column prop="actual_cost" label="实际成本" min-width="130">
<template #default="scope">{{ formatMoney(scope.row.actual_cost) }}</template>
</el-table-column>
</el-table>
</section>
</div>
</template>
<script setup>
import { computed, nextTick, onMounted, onUnmounted, ref, watch } from 'vue'
import { ElMessage } from 'element-plus'
import { VideoPlay } from '@element-plus/icons-vue'
import * as echarts from 'echarts'
import { getDemoAlgorithms, getDemoDataset, runAlgorithmDemo } from '@/api'
const algorithms = ref([])
const dataset = ref({})
const selectedAlgorithms = ref(['linear', 'ridge', 'random_forest', 'gradient_boosting'])
const activeAlgorithm = ref('random_forest')
const result = ref(null)
const loading = ref(false)
const warnings = ref([])
const metricsChartRef = ref(null)
const predictionChartRef = ref(null)
const importanceChartRef = ref(null)
const charts = []
const metricRows = computed(() => {
if (!result.value?.metrics) return []
return Object.entries(result.value.metrics).map(([key, value]) => ({
key,
...value
}))
})
const activeMetric = computed(() => {
return metricRows.value.find((row) => row.key === activeAlgorithm.value) || metricRows.value[0]
})
const selectRecommended = () => {
selectedAlgorithms.value = ['linear', 'ridge', 'random_forest', 'gradient_boosting']
}
const loadInitialData = async () => {
try {
const [algorithmResponse, datasetResponse] = await Promise.all([
getDemoAlgorithms(),
getDemoDataset()
])
algorithms.value = algorithmResponse.data.algorithms
dataset.value = datasetResponse.data
await runDemo()
} catch (error) {
ElMessage.error('加载演示数据失败')
console.error(error)
}
}
const runDemo = async () => {
if (!selectedAlgorithms.value.length) {
ElMessage.warning('请至少选择一个算法')
return
}
loading.value = true
try {
const response = await runAlgorithmDemo({ algorithms: selectedAlgorithms.value })
result.value = response.data
dataset.value = response.data.dataset
warnings.value = response.data.warnings || []
activeAlgorithm.value = response.data.best_model
await nextTick()
renderCharts()
} catch (error) {
ElMessage.error(error.response?.data?.error || '运行演示失败')
console.error(error)
} finally {
loading.value = false
}
}
const disposeCharts = () => {
while (charts.length) {
const chart = charts.pop()
if (chart && !chart.isDisposed()) chart.dispose()
}
}
const renderCharts = () => {
if (!result.value) return
disposeCharts()
renderMetricsChart()
renderPredictionChart()
renderImportanceChart()
}
const renderMetricsChart = () => {
if (!metricsChartRef.value) return
const chart = echarts.init(metricsChartRef.value)
charts.push(chart)
chart.setOption({
tooltip: { trigger: 'axis' },
legend: { top: 0 },
grid: { top: 48, left: 56, right: 24, bottom: 36 },
xAxis: { type: 'category', data: metricRows.value.map((row) => row.name) },
yAxis: [
{ type: 'value', name: '决定系数', min: 0 },
{ type: 'value', name: '误差' }
],
series: [
{
name: '决定系数',
type: 'bar',
data: metricRows.value.map((row) => Number(row.r2.toFixed(4))),
itemStyle: { color: '#2f6fdd' }
},
{
name: '平均绝对误差',
type: 'line',
yAxisIndex: 1,
data: metricRows.value.map((row) => Math.round(row.mae)),
itemStyle: { color: '#16a085' }
},
{
name: '均方根误差',
type: 'line',
yAxisIndex: 1,
data: metricRows.value.map((row) => Math.round(row.rmse)),
itemStyle: { color: '#d98b18' }
}
]
})
}
const renderPredictionChart = () => {
if (!predictionChartRef.value || !activeMetric.value) return
const chart = echarts.init(predictionChartRef.value)
charts.push(chart)
const points = result.value.prediction_points
chart.setOption({
tooltip: { trigger: 'axis' },
legend: { top: 0 },
grid: { top: 48, left: 68, right: 24, bottom: 46 },
xAxis: {
type: 'category',
data: points.map((point) => point.name),
axisLabel: { rotate: 25 }
},
yAxis: { type: 'value', name: '成本' },
series: [
{
name: '真实值',
type: 'line',
smooth: true,
data: points.map((point) => point.actual),
itemStyle: { color: '#202938' }
},
{
name: activeMetric.value.name,
type: 'bar',
data: points.map((point) => point[activeAlgorithm.value]),
itemStyle: { color: '#2f6fdd' }
}
]
})
}
const renderImportanceChart = () => {
if (!importanceChartRef.value || !activeAlgorithm.value) return
const chart = echarts.init(importanceChartRef.value)
charts.push(chart)
const rows = [...(result.value.feature_importance[activeAlgorithm.value] || [])].reverse()
chart.setOption({
tooltip: { trigger: 'axis' },
grid: { top: 20, left: 108, right: 20, bottom: 24 },
xAxis: { type: 'value' },
yAxis: { type: 'category', data: rows.map((row) => featureName(row.feature)) },
series: [
{
type: 'bar',
data: rows.map((row) => Number(row.importance.toFixed(4))),
itemStyle: { color: '#16a085' }
}
]
})
}
const featureName = (key) => {
const names = {
length_m: '长度',
width_m: '宽度',
height_m: '高度',
weight_kg: '重量',
max_range_km: '最大射程',
payload_kg: '载荷',
max_speed_kmh: '最大速度',
endurance_min: '续航',
tech_level: '技术水平',
scale_level: '规模水平',
supply_chain_level: '供应链',
complexity_score: '复杂度'
}
return names[key] || key
}
const formatMoney = (value) => {
if (value === undefined || value === null) return '-'
return Number(value).toLocaleString('zh-CN', {
style: 'currency',
currency: 'CNY',
maximumFractionDigits: 0
})
}
const formatScore = (value) => {
if (value === undefined || value === null) return '-'
return Number(value).toFixed(3)
}
watch(activeAlgorithm, async () => {
await nextTick()
renderCharts()
})
window.addEventListener('resize', () => {
charts.forEach((chart) => {
if (chart && !chart.isDisposed()) chart.resize()
})
})
onMounted(loadInitialData)
onUnmounted(disposeCharts)
</script>
<style lang="scss" scoped>
.algorithm-demo-page {
min-height: calc(100vh - 60px);
padding: 24px;
color: #202938;
background:
linear-gradient(180deg, #eef3f8 0%, #f7f9fb 280px),
#f7f9fb;
}
.demo-hero,
.control-band,
.metrics-grid,
.visual-grid,
.data-preview {
max-width: 1440px;
margin: 0 auto 18px;
}
.demo-hero {
display: flex;
align-items: center;
justify-content: space-between;
gap: 20px;
min-height: 176px;
h1 {
margin: 6px 0 10px;
font-size: 36px;
line-height: 1.2;
letter-spacing: 0;
}
}
.hero-copy {
max-width: 680px;
margin: 0;
color: #536273;
font-size: 16px;
line-height: 1.7;
}
.hero-actions {
display: flex;
align-items: center;
gap: 12px;
}
.eyebrow {
margin: 0;
color: #2f6fdd;
font-size: 12px;
font-weight: 700;
letter-spacing: 0;
text-transform: uppercase;
}
.control-band,
.visual-grid {
display: grid;
grid-template-columns: minmax(0, 1.35fr) minmax(320px, 0.65fr);
gap: 16px;
}
.panel,
.metric-card {
border: 1px solid #dfe6ef;
border-radius: 8px;
background: #fff;
box-shadow: 0 10px 28px rgba(32, 41, 56, 0.06);
}
.panel {
padding: 18px;
}
.panel-header,
.metric-title {
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
h2 {
margin: 4px 0 0;
font-size: 18px;
line-height: 1.3;
letter-spacing: 0;
}
}
.algorithm-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(190px, 1fr));
gap: 10px;
margin-top: 16px;
:deep(.el-checkbox) {
width: 100%;
height: 64px;
margin: 0;
border-radius: 8px;
}
:deep(.el-checkbox__label) {
display: flex;
flex-direction: column;
gap: 4px;
line-height: 1.2;
}
}
.algorithm-name {
font-weight: 700;
}
.algorithm-grid small {
color: #6b7786;
}
.warning-strip {
margin-top: 14px;
}
.dataset-stats {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 10px;
margin-top: 18px;
div {
padding: 14px;
border-radius: 8px;
background: #f2f6fa;
}
strong {
display: block;
margin-bottom: 6px;
font-size: 20px;
}
span {
color: #667485;
font-size: 13px;
}
}
.metrics-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
gap: 12px;
}
.metric-card {
padding: 16px;
cursor: pointer;
transition: border-color 0.2s ease, transform 0.2s ease;
&.active,
&:hover {
border-color: #2f6fdd;
transform: translateY(-2px);
}
strong {
display: block;
margin: 14px 0;
font-size: 30px;
letter-spacing: 0;
}
}
.metric-values {
display: flex;
flex-wrap: wrap;
gap: 8px;
color: #617080;
font-size: 13px;
}
.chart-panel.wide {
grid-column: span 2;
}
.chart {
width: 100%;
height: 360px;
&.compact {
height: 330px;
}
}
.algorithm-select {
width: 220px;
}
.sample-panel dl {
display: grid;
grid-template-columns: 110px minmax(0, 1fr);
gap: 14px 10px;
margin: 20px 0 0;
}
.sample-panel dt {
color: #667485;
}
.sample-panel dd {
margin: 0;
font-weight: 700;
}
@media (max-width: 900px) {
.algorithm-demo-page {
padding: 16px;
}
.demo-hero,
.control-band,
.visual-grid {
grid-template-columns: 1fr;
}
.demo-hero {
flex-direction: column;
align-items: flex-start;
h1 {
font-size: 28px;
}
}
.chart-panel.wide {
grid-column: span 1;
}
.dataset-stats {
grid-template-columns: 1fr;
}
}
</style>

View File

@ -67,12 +67,18 @@
<h3>特征重要性</h3>
<div class="chart-container">
<div ref="importanceChartRef" style="width: 100%; height: 600px"></div>
<div class="chart-note">
<p>说明F-特征分数是基于F统计量的特征重要性度量用于评估各个特征与预测目标之间的相关程度F分数越高表示该特征与预测目标之间的相关性越强但不一定是线性关系F分数没有固定的上限其值取决于数据的特征分布和样本量</p>
</div>
</div>
<!-- 相关性分析 -->
<h3>相关性分析</h3>
<div class="chart-container">
<div ref="correlationChartRef" style="width: 100%; height: 800px"></div>
<div class="chart-note">
<p>说明热力图展示了各特征之间的相关系数范围从-1到1正值蓝色表示正相关负值红色表示负相关0白色表示无相关性相关系数的绝对值越接近1表示相关性越强越接近0表示相关性越弱</p>
</div>
</div>
<!-- 火箭炮特有的图表 -->
@ -103,7 +109,31 @@
<div class="chart-container">
<div ref="engineChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 制导性能分析 -->
<h3>制导性能分析</h3>
<div class="chart-container">
<div ref="guidanceChartRef" style="width: 100%; height: 600px"></div>
</div>
</template>
<!-- 生产商分析 -->
<h3>生产商分析</h3>
<div class="chart-container">
<div ref="manufacturerChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 生产商地区分布 -->
<h3>生产商地区分布</h3>
<div class="chart-container">
<div ref="regionChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 生产商综合评分 -->
<h3>生产商综合评分</h3>
<div class="chart-container">
<div ref="scoreChartRef" style="width: 100%; height: 600px"></div>
</div>
</div>
</el-card>
</div>
@ -131,6 +161,10 @@ const newFeatureChartRef = ref(null)
const engineChartRef = ref(null)
const fireChartRef = ref(null)
const mobilityChartRef = ref(null)
const manufacturerChartRef = ref(null)
const regionChartRef = ref(null)
const scoreChartRef = ref(null)
const guidanceChartRef = ref(null)
//
const importanceChart = ref(null)
@ -139,6 +173,10 @@ const newFeatureChart = ref(null)
const engineChart = ref(null)
const fireChart = ref(null)
const mobilityChart = ref(null)
const manufacturerChart = ref(null)
const regionChart = ref(null)
const scoreChart = ref(null)
const guidanceChart = ref(null)
//
watch(() => analysisResult.value, async (newResult) => {
@ -236,69 +274,28 @@ const startAnalysis = async () => {
analyzing.value = true
try {
//
console.log('Analysis request params:', {
dataset_id: analysisForm.value.dataset_id,
equipment_type: analysisForm.value.equipment_type
})
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
//
const featureResponse = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
//
console.log('Raw API response:', response)
console.log('Response data type:', typeof response.data)
console.log('Response data:', response.data)
//
if (!response.data) {
throw new Error('API返回的数据为空')
}
//
analysisResult.value = response.data
//
console.log('Analysis result after assignment:', {
value: analysisResult.value,
important_features: analysisResult.value?.important_features,
correlation_analysis: analysisResult.value?.correlation_analysis,
equipment_names: analysisResult.value?.equipment_names,
length_width_ratio: analysisResult.value?.length_width_ratio
//
const manufacturerResponse = await axios.post(`${API_BASE_URL}/analyze-manufacturers`, {
dataset_id: analysisForm.value.dataset_id
})
//
if (analysisForm.value.equipment_type === '巡飞弹') {
const missileData = {
equipment_names: analysisResult.value?.equipment_names || [],
length_width_ratio: analysisResult.value?.length_width_ratio || [],
engine_power_kw: analysisResult.value?.engine_power_kw || [],
guidance_system_score: analysisResult.value?.guidance_system_score || [],
warhead_power_score: analysisResult.value?.warhead_power_score || []
}
console.log('Missile specific data:', missileData)
//
const missingFields = Object.entries(missileData)
.filter(([key, value]) => !Array.isArray(value) || value.length === 0)
.map(([key]) => key)
if (missingFields.length > 0) {
console.warn('Missing or empty missile data fields:', missingFields)
ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
}
//
analysisResult.value = {
...featureResponse.data,
...manufacturerResponse.data
}
//
console.log('Combined analysis result:', analysisResult.value)
} catch (error) {
console.error('Analysis error:', error)
console.error('Error details:', {
message: error.message,
response: error.response?.data,
status: error.response?.status
})
ElMessage.error(error.message || '特征析失败')
ElMessage.error(error.message || '分析失败')
} finally {
analyzing.value = false
}
@ -329,6 +326,18 @@ const createResizeHandler = () => {
if (mobilityChart.value && !mobilityChart.value.isDisposed()) {
mobilityChart.value.resize()
}
if (manufacturerChart.value && !manufacturerChart.value.isDisposed()) {
manufacturerChart.value.resize()
}
if (regionChart.value && !regionChart.value.isDisposed()) {
regionChart.value.resize()
}
if (scoreChart.value && !scoreChart.value.isDisposed()) {
scoreChart.value.resize()
}
if (guidanceChart.value && !guidanceChart.value.isDisposed()) {
guidanceChart.value.resize()
}
} catch (error) {
console.error('Error in resize handler:', error)
}
@ -360,7 +369,7 @@ onUnmounted(() => {
//
[importanceChart, correlationChart, newFeatureChart, engineChart,
fireChart, mobilityChart].forEach(chart => {
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
try {
chart.value.dispose()
@ -384,7 +393,7 @@ const renderCharts = () => {
try {
//
[importanceChart, correlationChart, newFeatureChart, engineChart,
fireChart, mobilityChart].forEach(chart => {
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
chart.value.dispose()
chart.value = null
@ -401,20 +410,27 @@ const renderCharts = () => {
//
const importanceOption = {
title: { text: '特征重要性排序' },
title: {
text: '特征重要性排序',
left: 'center'
},
tooltip: {
trigger: 'axis',
trigger: 'item',
axisPointer: {
type: 'shadow'
},
formatter: function(params) {
const data = params[0]
return `${data.name}: ${data.value.toFixed(4)}`
return `${params.name}: ${params.value.toFixed(2)}`
}
},
xAxis: {
type: 'value',
name: '重要性得分'
name: 'F-特征分数',
axisLabel: {
formatter: function(value) {
return value.toFixed(1)
}
}
},
yAxis: {
type: 'category',
@ -899,6 +915,156 @@ const renderCharts = () => {
mobilityChart.value.setOption(mobilityOption, { notMerge: true })
}
//
if (manufacturerChartRef.value) {
manufacturerChart.value = echarts.init(manufacturerChartRef.value)
const manufacturerOption = {
title: { text: '生产商特征影响分析' },
tooltip: {
trigger: 'axis',
axisPointer: { type: 'shadow' }
},
legend: {
data: ['技术水平', '规模水平', '供应链水平', '综合得分']
},
xAxis: {
type: 'category',
data: analysisResult.value.manufacturer_names || []
},
yAxis: {
type: 'value',
name: '评分',
min: 0,
max: 10
},
series: [
{
name: '技术水平',
type: 'bar',
data: analysisResult.value.manufacturer_tech_levels || []
},
{
name: '规模水平',
type: 'bar',
data: analysisResult.value.manufacturer_scale_levels || []
},
{
name: '供应链水平',
type: 'bar',
data: analysisResult.value.manufacturer_supply_chain_levels || []
},
{
name: '综合得分',
type: 'line',
data: analysisResult.value.manufacturer_composite_scores || []
}
]
}
manufacturerChart.value.setOption(manufacturerOption)
}
//
if (regionChartRef.value) {
regionChart.value = echarts.init(regionChartRef.value)
const regionOption = {
title: { text: '生产商地区分布' },
tooltip: {
trigger: 'item',
formatter: '{b}: {c} ({d}%)'
},
series: [
{
type: 'pie',
radius: '65%',
data: analysisResult.value.region_distribution || [],
emphasis: {
itemStyle: {
shadowBlur: 10,
shadowOffsetX: 0,
shadowColor: 'rgba(0, 0, 0, 0.5)'
}
}
}
]
}
regionChart.value.setOption(regionOption)
}
//
if (scoreChartRef.value) {
scoreChart.value = echarts.init(scoreChartRef.value)
const scoreOption = {
title: { text: '生产商综合评分雷达图' },
tooltip: {},
radar: {
indicator: [
{ name: '技术水平', max: 10 },
{ name: '规模水平', max: 10 },
{ name: '供应链水平', max: 10 },
{ name: '区域系数', max: 1.5 },
{ name: '综合得分', max: 10 }
]
},
series: [
{
type: 'radar',
data: analysisResult.value.manufacturer_scores || []
}
]
}
scoreChart.value.setOption(scoreOption)
}
//
if (guidanceChartRef.value && analysisForm.value.equipment_type === '巡飞弹') {
guidanceChart.value = echarts.init(guidanceChartRef.value)
const guidanceOption = {
title: { text: '制导性能分析' },
tooltip: {
trigger: 'axis',
axisPointer: { type: 'cross' }
},
legend: {
data: ['制导精度(m)', '数据链距离(km)', '制导系统评分']
},
xAxis: {
type: 'category',
data: analysisResult.value.equipment_names || []
},
yAxis: [
{
type: 'value',
name: '制导精度(m)',
position: 'left'
},
{
type: 'value',
name: '距离(km)',
position: 'right'
}
],
series: [
{
name: '制导精度(m)',
type: 'bar',
data: analysisResult.value.guidance_accuracy_m || []
},
{
name: '数据链距离(km)',
type: 'line',
yAxisIndex: 1,
data: analysisResult.value.datalink_range_km || []
},
{
name: '制导系统评分',
type: 'line',
data: analysisResult.value.guidance_system_score || []
}
]
}
guidanceChart.value.setOption(guidanceOption)
}
console.log('Charts rendered successfully')
} catch (error) {
console.error('Error in chart rendering:', error)
@ -954,4 +1120,12 @@ function debounce(fn, delay) {
}
}
}
.chart-note {
margin-top: 10px;
padding: 10px;
color: #666;
font-size: 14px;
line-height: 1.5;
}
</style>

View File

@ -795,38 +795,6 @@ onMounted(() => {
loadData()
})
//
const isNumberInput = (key) => {
const numberFields = [
'length_m',
'width_m',
'height_m',
'weight_kg',
'max_range_km',
'firing_angle_horizontal',
'firing_angle_vertical',
'rocket_length_m',
'rocket_diameter_mm',
'rocket_weight_kg',
'rate_of_fire',
'combat_weight_kg',
'speed_kmh',
'min_range_km',
'power_hp',
'travel_range_km',
'max_speed_ms',
'cruise_speed_kmh',
'flight_time_min',
'folded_length_mm',
'folded_width_mm',
'folded_height_mm',
'actual_cost',
'predicted_cost'
]
return numberFields.includes(key)
}
//
const getSelectOptions = (field) => {
switch (field) {

View File

@ -91,6 +91,8 @@
<el-select v-model="datasetForm.purpose">
<el-option label="训练" value="训练"></el-option>
<el-option label="验证" value="验证"></el-option>
<el-option label="分析" value="分析"></el-option>
<el-option label="测试" value="测试"></el-option>
</el-select>
</el-form-item>
<el-form-item label="描述">
@ -100,6 +102,7 @@
<!-- 选择装备数据 -->
<el-form-item label="选择装备" required>
<el-table
v-if="datasetForm.equipment_type"
ref="equipmentTable"
:data="availableEquipment"
border
@ -115,6 +118,9 @@
</template>
</el-table-column>
</el-table>
<div v-else class="empty-tip">
请先选择装备类型
</div>
</el-form-item>
</el-form>
<template #footer>
@ -139,7 +145,8 @@ const selectedDataset = ref(null) // 选中的数据集
const detailsVisible = ref(false) //
const editVisible = ref(false) //
const availableEquipment = ref([]) //
const selectedEquipment = ref([]) //
const selectedEquipment = ref([]) //
const currentSelection = ref([]) //
//
const equipmentTable = ref(null)
@ -182,6 +189,7 @@ const editDataset = async (dataset) => {
try {
//
const response = await axios.get(`${API_BASE_URL}/datasets/${dataset.id}`)
console.log('Dataset details:', response.data) //
//
datasetForm.value = {
@ -192,21 +200,34 @@ const editDataset = async (dataset) => {
description: response.data.description
}
//
// - 使
selectedEquipment.value = response.data.equipment
availableEquipment.value = response.data.equipment
console.log('Selected equipment:', selectedEquipment.value) //
//
editVisible.value = true
//
await nextTick()
//
await loadAvailableEquipment()
//
await nextTick()
//
nextTick(() => {
if (equipmentTable.value) {
equipmentTable.value.clearSelection()
availableEquipment.value.forEach(item => {
if (equipmentTable.value) {
console.log('Setting table selections')
equipmentTable.value.clearSelection()
availableEquipment.value.forEach(item => {
if (selectedEquipment.value.find(e => e.equipment_id === item.equipment_id)) {
equipmentTable.value.toggleRowSelection(item, true)
})
}
})
editVisible.value = true
}
})
} else {
console.warn('Equipment table not ready')
}
} catch (error) {
console.error('Error getting dataset details:', error)
@ -245,34 +266,56 @@ const deleteDataset = async (dataset) => {
//
const loadAvailableEquipment = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/data`)
availableEquipment.value = datasetForm.value.equipment_type === '火箭炮'
? response.data.rocket_artillery
: response.data.loitering_munition
const response = await axios.get(`${API_BASE_URL}/data`)
//
availableEquipment.value = response.data.filter(item =>
item.type === datasetForm.value.equipment_type
)
} catch (error) {
console.error('Error loading equipment:', error) //
ElMessage.error('获取装备列表失败')
}
}
//
const handleEquipmentTypeChange = () => {
console.log('Equipment type changed:', datasetForm.value.equipment_type) //
selectedEquipment.value = [] //
loadAvailableEquipment() //
}
//
const handleSelectionChange = (selection) => {
selectedEquipment.value = selection
//
currentSelection.value = selection
}
//
const saveDataset = async () => {
try {
//
if (!datasetForm.value.name || !datasetForm.value.equipment_type || !datasetForm.value.purpose) {
ElMessage.warning('请填写必要信息')
return
}
// 使
selectedEquipment.value = currentSelection.value
//
if (!selectedEquipment.value.length) {
ElMessage.warning('请选择装备')
return
}
//
const data = {
...datasetForm.value,
equipment_ids: selectedEquipment.value.map(item => item.id)
equipment_ids: selectedEquipment.value.map(item => item.equipment_id) // 使 equipment_id
}
console.log('Saving dataset:', data) //
if (data.id) {
await axios.put(`${API_BASE_URL}/datasets/${data.id}`, data)
} else {
@ -283,6 +326,7 @@ const saveDataset = async () => {
editVisible.value = false
loadDatasets()
} catch (error) {
console.error('Error saving dataset:', error) //
ElMessage.error('保存失败')
}
}
@ -315,7 +359,8 @@ const formatDateTime = (value) => {
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
hour12: false,
timeZone: 'Asia/Shanghai'
})
}

View File

@ -26,6 +26,13 @@
<p>训练和优化预测模型</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/algorithm-demo')">
<el-icon><TrendCharts /></el-icon>
<h3>算法演示</h3>
<p>切换常用机器学习算法并对比预测效果</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/models')">
<el-icon><Management /></el-icon>
@ -53,7 +60,7 @@
</template>
<script setup>
import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-plus/icons-vue'
import { Money, DataAnalysis, Monitor, Management, Collection, TrendCharts } from '@element-plus/icons-vue'
</script>
<style lang="scss" scoped>
@ -98,4 +105,4 @@ import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-p
}
}
}
</style>
</style>

View File

@ -31,7 +31,7 @@
{{ scope.row.rmse.toFixed(2) }}
</template>
</el-table-column>
<el-table-column prop="training_date" label="训练时间">
<el-table-column prop="training_date" label="训练时间" width="180">
<template #default="scope">
{{ formatDateTime(scope.row.training_date) }}
</template>
@ -211,10 +211,12 @@ const renderImportanceChart = () => {
//
const formatModelType = (type) => {
const typeMap = {
'pytorch': 'PyTorch',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbdt': 'GBDT',
'rf': 'Random Forest'
'gbm': 'GBM',
'rf': 'Random Forest',
'pls': 'PLS回归'
}
return typeMap[type] || type
}
@ -230,7 +232,8 @@ const formatDateTime = (value) => {
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
hour12: false,
timeZone: 'Asia/Shanghai'
})
}

View File

@ -223,72 +223,46 @@ import { API_BASE_URL } from '@/config'
const formData = reactive({
type: '',
length_m: null,
width_m: null,
height_m: null,
weight_kg: null,
max_range_km: null,
//
firing_angle_horizontal: null,
firing_angle_vertical: null,
rocket_length_m: null,
rocket_diameter_mm: null,
rocket_weight_kg: null,
rate_of_fire: null,
combat_weight_kg: null,
speed_kmh: null,
min_range_km: null,
mobility_type: '',
structure_layout: '',
engine_model: '',
engine_params: '',
power_hp: null,
travel_range_km: null,
// -
max_payload_kg: null, //
ceiling_altitude_m: null, //
combat_radius_km: null, //
engine_power_kw: null, //
engine_thrust_n: null, //
datalink_range_km: null, //
guidance_accuracy_m: null, //
min_altitude_m: null, //
max_altitude_m: null, //
//
length_width_ratio: null, //
weight_range_ratio: null, // /
speed_weight_ratio: null, // /
guidance_system_score: null, //
warhead_power_score: null //
length_m: 7.35,
width_m: 2.4,
height_m: 3.1,
weight_kg: 13700,
max_range_km: 20.4,
firing_angle_horizontal: 102,
firing_angle_vertical: 55,
rocket_length_m: 2.87,
rocket_diameter_mm: 122,
rocket_weight_kg: 66.6,
rate_of_fire: 40,
combat_weight_kg: 15000,
speed_kmh: 60,
min_range_km: 5,
mobility_type: '轮式',
structure_layout: '6x6轮式底盘',
engine_model: 'WD615',
engine_params: '6缸直列柴油机',
power_hp: 280,
travel_range_km: 600,
wingspan_m: 2.5,
warhead_weight_kg: 20,
max_speed_ms: 200,
cruise_speed_kmh: 720,
endurance_min: 30,
warhead_type: '破片杀伤战斗部',
launch_mode: '箱式发射',
power_system: '电动机',
guidance_system: 'GPS/INS/光电',
max_payload_kg: 25,
ceiling_altitude_m: 5000,
combat_radius_km: 100,
datalink_range_km: 150,
guidance_accuracy_m: 3
})
const predictionResults = ref(null)
const mlPrediction = ref(null)
const plsPrediction = ref(null)
const handleTypeChange = () => {
//
if (formData.type === '火箭炮') {
formData.firing_angle_horizontal = null
formData.firing_angle_vertical = null
formData.rocket_length_m = null
formData.rocket_diameter_mm = null
formData.rocket_weight_kg = null
formData.rate_of_fire = null
} else if (formData.type === '巡飞弹') {
formData.wingspan_m = null
formData.warhead_weight_kg = null
formData.max_speed_ms = null
formData.cruise_speed_kmh = null
formData.endurance_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.power_system = ''
formData.guidance_system = ''
}
}
const submitForm = async () => {
try {
//
@ -327,7 +301,7 @@ const submitForm = async () => {
}
}
//
//
const [mlResponse, plsResponse] = await Promise.all([
axios.post(`${API_BASE_URL}/predict`, formData),
axios.post(`${API_BASE_URL}/pls/predict`, formData)
@ -396,9 +370,96 @@ const getModelName = (modelType) => {
}
return modelNames[modelType] || modelType
}
const handleTypeChange = () => {
//
predictionResults.value = false
mlPrediction.value = null
plsPrediction.value = null
//
if (formData.type === '火箭炮') {
//
formData.length_m = 7.35
formData.width_m = 2.4
formData.height_m = 3.1
formData.weight_kg = 13700
formData.max_range_km = 20.4
formData.firing_angle_horizontal = 102
formData.firing_angle_vertical = 55
formData.rocket_length_m = 2.87
formData.rocket_diameter_mm = 122
formData.rocket_weight_kg = 66.6
formData.rate_of_fire = 40
formData.combat_weight_kg = 15000
formData.speed_kmh = 60
formData.min_range_km = 5
formData.mobility_type = '轮式'
formData.structure_layout = '6x6轮式底盘'
formData.engine_model = 'WD615'
formData.engine_params = '6缸直列柴油机'
formData.power_hp = 280
formData.travel_range_km = 600
//
formData.wingspan_m = null
formData.warhead_weight_kg = null
formData.max_speed_ms = null
formData.cruise_speed_kmh = null
formData.endurance_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.power_system = ''
formData.guidance_system = ''
formData.max_payload_kg = null
formData.ceiling_altitude_m = null
formData.combat_radius_km = null
formData.datalink_range_km = null
formData.guidance_accuracy_m = null
} else if (formData.type === '巡飞弹') {
//
formData.length_m = 2.5
formData.width_m = 0.4
formData.height_m = 0.4
formData.weight_kg = 120
formData.max_range_km = 100
formData.wingspan_m = 2.5
formData.warhead_weight_kg = 20
formData.max_speed_ms = 200
formData.cruise_speed_kmh = 720
formData.endurance_min = 30
formData.warhead_type = '破片杀伤战斗部'
formData.launch_mode = '箱式发射'
formData.power_system = '电动机'
formData.guidance_system = 'GPS/INS/光电'
formData.max_payload_kg = 25
formData.ceiling_altitude_m = 5000
formData.combat_radius_km = 100
formData.datalink_range_km = 150
formData.guidance_accuracy_m = 3
//
formData.firing_angle_horizontal = null
formData.firing_angle_vertical = null
formData.rocket_length_m = null
formData.rocket_diameter_mm = null
formData.rocket_weight_kg = null
formData.rate_of_fire = null
formData.combat_weight_kg = null
formData.speed_kmh = null
formData.min_range_km = null
formData.mobility_type = ''
formData.structure_layout = ''
formData.engine_model = ''
formData.engine_params = ''
formData.power_hp = null
formData.travel_range_km = null
}
}
</script>
<style scoped>
<style lang="scss" scoped>
.predict-page {
padding: 20px;
}

View File

@ -1,5 +1,6 @@
<template>
<div class="training-page">
<!-- 上部分模型训练区域 -->
<el-card class="training-card">
<template #header>
<h2>模型训练</h2>
@ -38,11 +39,12 @@
<el-form-item label="选择模型">
<el-checkbox-group v-model="trainingConfig.models">
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
<el-checkbox value="pytorch" checked>PyTorch</el-checkbox>
<el-checkbox value="xgboost" checked>XGBoost</el-checkbox>
<el-checkbox value="lightgbm" checked>LightGBM</el-checkbox>
<el-checkbox value="gbm" checked>GBM</el-checkbox>
<el-checkbox value="rf" checked>Random Forest</el-checkbox>
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
</el-checkbox-group>
</el-form-item>
@ -61,8 +63,8 @@
<div class="best-model-info" v-if="trainingResult.best_model">
<h4>最佳模型: {{ getModelName(trainingResult.best_model.type) }}</h4>
<p>R²分数: {{ formatNumber(trainingResult.best_model.r2) }}</p>
<p>MAE: {{ formatNumber(trainingResult.best_model.mae) }} </p>
<p>RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} </p>
<p>MAE: {{ formatNumber(trainingResult.best_model.mae) }} </p>
<p>RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} </p>
</div>
<!-- 所有模型评估结果 -->
@ -80,12 +82,12 @@
{{ formatNumber(scope.row.train.r2) }}
</template>
</el-table-column>
<el-table-column prop="train.mae" label="MAE (元)" width="150">
<el-table-column prop="train.mae" label="MAE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.train.mae) }}
</template>
</el-table-column>
<el-table-column prop="train.rmse" label="RMSE (元)" width="150">
<el-table-column prop="train.rmse" label="RMSE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.train.rmse) }}
</template>
@ -99,12 +101,12 @@
{{ formatNumber(scope.row.validation.r2) }}
</template>
</el-table-column>
<el-table-column prop="validation.mae" label="MAE (元)" width="150">
<el-table-column prop="validation.mae" label="MAE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.validation.mae) }}
</template>
</el-table-column>
<el-table-column prop="validation.rmse" label="RMSE (元)" width="150">
<el-table-column prop="validation.rmse" label="RMSE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.validation.rmse) }}
</template>
@ -130,6 +132,159 @@
</div>
</div>
</el-card>
<!-- 下部分模型简介区域 -->
<el-card class="model-intro-card">
<template #header>
<h2>模型简介</h2>
</template>
<el-collapse>
<el-collapse-item name="pytorch">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">PyTorch</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>深度学习框架可以构建复杂的神经网络结构</li>
<li>分别处理装备特征和生产商特征然后合并进行预测</li>
<li>使用批量归一化和Dropout防止过拟合</li>
<li>适合处理非线性关系和复杂特征交互</li>
</ul>
<h4>优势</h4>
<ul>
<li>强大的特征学习能力</li>
<li>可以自动学习特征之间的复杂关系</li>
<li>灵活的网络结构设计</li>
<li>支持GPU加速训练</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="xgboost">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">XGBoost</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>基于梯度提升树的集成学习算法</li>
<li>使用二阶导数进行优化</li>
<li>内置正则化机制防止过拟合</li>
<li>支持特征重要性评估</li>
</ul>
<h4>优势</h4>
<ul>
<li>优秀的预测性能</li>
<li>处理缺失值的能力强</li>
<li>训练速度快</li>
<li>可解释性好</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="lightgbm">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">LightGBM</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>微软开发的轻量级梯度提升框架</li>
<li>使用直方图算法优化训练速度</li>
<li>支持类别特征的高效处理</li>
<li>叶子优先的生长策略</li>
</ul>
<h4>优势</h4>
<ul>
<li>训练速度非常快</li>
<li>内存占用低</li>
<li>支持大规模数据训练</li>
<li>准确率高</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="gbm">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">Gradient Boosting (GBM)</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>经典的梯度提升算法</li>
<li>逐步减少残差的思想</li>
<li>可以使用不同的损失函数</li>
<li>支持特征重要性分析</li>
</ul>
<h4>优势</h4>
<ul>
<li>稳定的性能表现</li>
<li>较好的可解释性</li>
<li>对异常值不敏感</li>
<li>适合各种回归问题</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="rf">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">Random Forest</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>基于决策树的集成学习方法</li>
<li>使用随机采样和特征选择</li>
<li>多个决策树投票或平均</li>
<li>自带特征重要性评估</li>
</ul>
<h4>优势</h4>
<ul>
<li>不易过拟合</li>
<li>训练过程可并行化</li>
<li>对噪声数据鲁棒</li>
<li>较少的参数调整</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="pls">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">PLS回归</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>偏最小二乘回归</li>
<li>同时考虑自变量和因变量的变异</li>
<li>处理多重共线性问题</li>
<li>降维和回归的结合</li>
</ul>
<h4>优势</h4>
<ul>
<li>适合小样本数据</li>
<li>处理变量间相关性强的数据</li>
<li>计算效率高</li>
<li>结果稳定可靠</li>
</ul>
</div>
</el-collapse-item>
</el-collapse>
</el-card>
</div>
</template>
@ -144,7 +299,7 @@ const trainingConfig = ref({
type: '',
train_dataset_id: null,
validation_dataset_id: null,
models: ['xgboost', 'lightgbm', 'gbm', 'rf']
models: ['pytorch', 'xgboost', 'lightgbm', 'gbm', 'rf']
})
//
@ -241,10 +396,12 @@ const formatNumber = (value) => {
//
const getModelName = (modelType) => {
const modelNames = {
'pytorch': 'PyTorch',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
'rf': 'Random Forest',
'pls': 'PLS回归'
}
return modelNames[modelType] || modelType
}
@ -334,8 +491,13 @@ onMounted(() => {
<style lang="scss" scoped>
.training-page {
padding: 20px;
display: flex;
flex-direction: column;
gap: 20px;
.training-card {
width: 100%;
.training-result {
margin-top: 20px;
@ -366,5 +528,62 @@ onMounted(() => {
}
}
}
.model-intro-card {
width: 100%;
.model-title {
.el-link {
font-size: 16px;
font-weight: 500;
&:hover {
opacity: 0.8;
}
&:active {
opacity: 0.6;
}
}
}
.model-intro {
padding: 10px;
h4 {
margin: 10px 0;
color: #409EFF;
font-size: 15px;
}
ul {
padding-left: 20px;
margin: 5px 0;
li {
line-height: 1.8;
color: #606266;
font-size: 14px;
&:hover {
color: #409EFF;
}
}
}
}
:deep(.el-collapse-item__header) {
padding: 12px 0;
font-size: 16px;
&:hover {
background-color: #f5f7fa;
}
}
:deep(.el-collapse-item__content) {
padding: 10px 20px;
}
}
}
</style>

View File

@ -19,7 +19,7 @@ export default defineConfig({
port: 3000,
proxy: {
'/api': {
target: 'http://localhost:5000',
target: 'http://localhost:5001',
changeOrigin: true
}
}

View File

@ -0,0 +1,11 @@
智能成本预测系统 - HTML5离线版
运行方式:
1. 解压 zip 文件。
2. 双击 index.html。
说明:
- 不需要 Python。
- 不需要数据库。
- 不需要联网。
- 页面内置样例数据和模型效果,用于客户现场展示不同模型的预测差异。

File diff suppressed because one or more lines are too long

57
pyproject.toml Normal file
View File

@ -0,0 +1,57 @@
[project]
name = "cost-prediction"
version = "0.1.0"
description = "装备成本预测系统"
requires-python = ">=3.9,<3.12"
readme = "README.md"
license = {file = "LICENSE"}
dependencies = [
# Web框架
"flask>=3.1.0",
"flask-cors>=5.0.0",
# 数据处理
"numpy>=1.26.0,<2.0.0",
"pandas>=2.2.0",
# 机器学习
"scikit-learn>=1.5.2",
"xgboost>=2.1.0",
"lightgbm>=4.5.0",
# 工具
"openpyxl>=3.1.5",
"python-dotenv>=1.0.0",
"requests>=2.31.0",
]
[project.optional-dependencies]
# PyTorch 为可选依赖(安装约 800MB仅训练神经网络时需要
torch = [
"torch==2.5.1",
]
dev = [
# 测试工具
"pytest>=7.0",
"black>=22.0", # 代码格式化
"mypy>=1.0", # 类型检查
]
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
[tool.black]
line-length = 88
target-version = ["py39", "py310", "py311"]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true

View File

@ -1,12 +1,12 @@
flask==2.0.1
flask-cors==3.0.10
sqlalchemy==1.4.23
pymysql==1.0.2
cryptography==3.4.7 # MySQL 8.0+ 认证需要
numpy==1.21.2
pandas==1.3.3
scikit-learn==0.24.2
tensorflow==2.6.0
urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
openpyxl==3.1.2 # 用于读取 .xlsx 文件
xlrd==2.0.1 # 用于读取 .xls 文件
flask>=3.1.0
flask-cors>=5.0.0
numpy>=1.26.0,<2.0.0
pandas>=2.2.0
xgboost>=2.1.0
lightgbm>=4.5.0
scikit-learn>=1.5.2
openpyxl>=3.1.5
python-dotenv>=1.0.0

View File

@ -1,29 +0,0 @@
# 火箭炮系统技术参数示例
## 伊朗“胜利”-2 240mm 12管火箭炮系统
产品类别: 多管火箭炮
型号: “胜利”-2 240mm   多管火箭炮
尺寸与重量
总长: 10m(393.7in)
宽(行军状态): 2.5m(98.4in)
高(行军状态) 3.34m131.5in
标准重: 15000kg(33069 lb)(15.0t)
战斗重: 19900kg(43871 lb)(19.9t)
机动性
行走装置: 轮式
布局: 6×6
两栖: 无
火力
方向射界: 100º(1778mils)(左/90°右
高低射界(武器前方): 57°(1013mils)
型号: “胜利”2火箭弹
尺寸与重量
总长: 3.550m(11ft)
弹体直径: 512mm(20.16in)(尾翼展开)
发射(重量): 275kg(606 lb)
性能
速度(最大速度): 1302kt(2412km/h;1499mph;670m/s)
最大射程 12.4n miles(23km;14.3miles)
武器组成
战斗部: 85kg(187 lb)

33
run.py
View File

@ -1,13 +1,26 @@
from src.app import create_app
import logging
from src import create_app
from src.logger import setup_logger
from config import config
import os
# 创建应用实例
app = create_app()
logger = setup_logger(__name__)
def main():
try:
# 创建必要的目录
os.makedirs(config.MODEL_DIR, exist_ok=True)
os.makedirs(config.LOG_DIR, exist_ok=True)
os.makedirs(config.DATA_DIR, exist_ok=True)
app = create_app()
app.run(
host=config.FLASK_HOST,
port=config.FLASK_PORT,
debug=config.FLASK_DEBUG
)
except Exception as e:
logger.error(f"Error starting application: {str(e)}")
raise
if __name__ == '__main__':
# 设置日志
logging.basicConfig(level=logging.INFO)
logging.info('=== Server Starting ===')
logging.info('Initializing directories...')
app.run(host='0.0.0.0', port=5001, debug=True)
main()

View File

@ -0,0 +1,57 @@
param(
[string]$OutputPath = "release\algorithm-demo-standalone.zip"
)
$ErrorActionPreference = "Stop"
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
$releaseRoot = Join-Path $repoRoot "release"
$stageDir = Join-Path $releaseRoot "algorithm-demo-standalone"
$zipPath = Join-Path $repoRoot $OutputPath
function Assert-InRepo([string]$PathToCheck) {
$resolved = [System.IO.Path]::GetFullPath($PathToCheck)
$root = [System.IO.Path]::GetFullPath($repoRoot)
if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) {
throw "Refusing to operate outside repository: $resolved"
}
}
Assert-InRepo $stageDir
Assert-InRepo $zipPath
Push-Location (Join-Path $repoRoot "frontend")
try {
npm run build
}
finally {
Pop-Location
}
New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null
if (Test-Path $stageDir) {
Remove-Item -LiteralPath $stageDir -Recurse -Force
}
if (Test-Path $zipPath) {
Remove-Item -LiteralPath $zipPath -Force
}
New-Item -ItemType Directory -Force -Path $stageDir | Out-Null
New-Item -ItemType Directory -Force -Path (Join-Path $stageDir "data") | Out-Null
Copy-Item -Recurse -Path (Join-Path $repoRoot "frontend\dist") -Destination (Join-Path $stageDir "frontend")
Copy-Item -Path (Join-Path $repoRoot "src\demo_service.py") -Destination (Join-Path $stageDir "demo_service.py")
Copy-Item -Path (Join-Path $repoRoot "data\demo_equipment_costs.csv") -Destination (Join-Path $stageDir "data\demo_equipment_costs.csv")
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\server.py") -Destination (Join-Path $stageDir "server.py")
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\requirements.txt") -Destination (Join-Path $stageDir "requirements.txt")
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\start_demo.bat") -Destination (Join-Path $stageDir "start_demo.bat")
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\README.md") -Destination (Join-Path $stageDir "README.md")
Get-ChildItem -Path $stageDir -Recurse -Include "*.map", "__pycache__" | ForEach-Object {
Remove-Item -LiteralPath $_.FullName -Recurse -Force
}
Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force
Write-Host "Demo zip created: $zipPath"

View File

@ -0,0 +1,36 @@
param(
[string]$OutputPath = "release\intelligent-cost-prediction-html5.zip"
)
$ErrorActionPreference = "Stop"
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
$sourceDir = Join-Path $repoRoot "html5_cost_prediction"
$releaseRoot = Join-Path $repoRoot "release"
$stageDir = Join-Path $releaseRoot "intelligent-cost-prediction-html5"
$zipPath = Join-Path $repoRoot $OutputPath
function Assert-InRepo([string]$PathToCheck) {
$resolved = [System.IO.Path]::GetFullPath($PathToCheck)
$root = [System.IO.Path]::GetFullPath($repoRoot)
if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) {
throw "Refusing to operate outside repository: $resolved"
}
}
Assert-InRepo $stageDir
Assert-InRepo $zipPath
New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null
if (Test-Path $stageDir) {
Remove-Item -LiteralPath $stageDir -Recurse -Force
}
if (Test-Path $zipPath) {
Remove-Item -LiteralPath $zipPath -Force
}
Copy-Item -Recurse -Path $sourceDir -Destination $stageDir
Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force
Write-Host "HTML5 zip created: $zipPath"

87
scripts/build_linux.sh Normal file
View File

@ -0,0 +1,87 @@
#!/bin/bash
# 确保脚本在错误时退出
set -e
echo "Starting packaging for Linux..."
# 创建虚拟环境
python3 -m venv .venv
source .venv/bin/activate
# 安装依赖
echo "Installing dependencies..."
pip install -e .
# 构建前端
echo "Building frontend..."
cd frontend
npm install
npm run build
# 把构建好的文件直接复制到 frontend 目录
cp -r dist/* .
rm -rf dist
cd ..
# 创建必要的目录
mkdir -p logs data models
# 使用 PyInstaller 打包
echo "Packaging with PyInstaller..."
pyinstaller --clean \
--add-data "src/loitering_munition_data.sql:data" \
--add-data "src/rocket_artillery_data.sql:data" \
--add-data "src/manufacturer_data.sql:data" \
--add-data "src/schema.sql:data" \
--add-data "config.py:." \
--add-data "src:src" \
--add-data "frontend:frontend" \
--add-data "logs:logs" \
--add-data "data:data" \
--add-data "models:models" \
--collect-all "xgboost" \
--collect-all "lightgbm" \
--collect-all "torch" \
--collect-all "sklearn" \
--collect-all "numpy" \
--collect-all "pandas" \
--collect-all "sqlalchemy" \
--collect-all "pymysql" \
--collect-all "cryptography" \
--collect-all "flask" \
--collect-all "flask_cors" \
--hidden-import "xgboost.testing" \
--hidden-import "torch.utils.tensorboard" \
--hidden-import "pytest" \
run.py
# 创建启动脚本
echo "Creating start script..."
cat > src/start.sh << 'EOF'
#!/bin/bash
export FLASK_DEBUG=false
export MYSQL_HOST=localhost
export MYSQL_USER=root
export MYSQL_PASSWORD=123456
export MYSQL_DB=equipment_cost_db
echo "Starting Cost Prediction System..."
./run
xdg-open http://localhost:5001
EOF
# 复制启动脚本
cp src/start.sh dist/run/
chmod +x dist/run/start.sh
# 创建发布包
echo "Creating release package..."
version=$(grep "version" pyproject.toml | cut -d'"' -f2)
mkdir -p dist/release
cp -r dist/run/* dist/release/
# 创建 tar.gz 包
cd dist
tar czf "cost-prediction-${version}-linux.tar.gz" release/
cd ..
echo "Package completed: dist/cost-prediction-${version}-linux.tar.gz"

70
scripts/build_win.ps1 Normal file
View File

@ -0,0 +1,70 @@
# Set console encoding to UTF-8
[Console]::OutputEncoding = [System.Text.Encoding]::UTF8
$OutputEncoding = [System.Text.Encoding]::UTF8
# Ensure PowerShell stops on error
$ErrorActionPreference = "Stop"
# Create virtual environment
Write-Host "Creating virtual environment..."
python -m venv .venv
.\.venv\Scripts\Activate.ps1
# Install dependencies
Write-Host "Installing dependencies..."
pip install -e .
# Build frontend
Write-Host "Building frontend..."
Push-Location frontend
npm install
npm run build
Copy-Item dist/* . -Recurse -Force
Remove-Item dist -Recurse -Force
Pop-Location
# Package with PyInstaller
Write-Host "Starting packaging..."
# Create necessary directories
New-Item -ItemType Directory -Force -Path "logs"
New-Item -ItemType Directory -Force -Path "data"
New-Item -ItemType Directory -Force -Path "models"
pyinstaller --clean `
--add-data "src/loitering_munition_data.sql;data" `
--add-data "src/rocket_artillery_data.sql;data" `
--add-data "src/manufacturer_data.sql;data" `
--add-data "config.py;." `
--add-data "src;src" `
--add-data "frontend;frontend" `
--add-data "logs;logs" `
--add-data "data;data" `
--add-data "models;models" `
--collect-all "xgboost" `
--collect-all "lightgbm" `
--collect-all "sklearn" `
--collect-all "numpy" `
--collect-all "pandas" `
--collect-all "flask" `
--collect-all "flask_cors" `
run.py
# Copy necessary files
Write-Host "Copying configuration files..."
Copy-Item "src/start.bat" -Destination "dist/run"
# Create release package
Write-Host "Creating release package..."
$version = (Get-Content pyproject.toml | Select-String 'version = "(.*?)"').Matches.Groups[1].Value
# Create complete offline installation package directory
$RELEASE_DIR = "dist/release"
New-Item -ItemType Directory -Force -Path $RELEASE_DIR
# Copy application files
Copy-Item "dist/run/*" -Destination $RELEASE_DIR -Recurse
# Create final zip package
Compress-Archive -Path "$RELEASE_DIR/*" -DestinationPath "cost-prediction-$version-win64.zip" -Force
Write-Host "Package completed: cost-prediction-$version-win64.zip"

121
scripts/setup_env.ps1 Normal file
View File

@ -0,0 +1,121 @@
# 设置错误操作首选项
$ErrorActionPreference = "Stop"
# 检查管理员权限
$isAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
if (-not $isAdmin) {
Write-Warning "建议使用管理员权限运行此脚本"
Start-Sleep -Seconds 3
}
# 检查 pyenv-win 是否安装
if (!(Get-Command pyenv -ErrorAction SilentlyContinue)) {
Write-Host "pyenv not found. Installing..."
try {
# 下载并安装 pyenv-win
Invoke-WebRequest -UseBasicParsing -Uri "https://raw.githubusercontent.com/pyenv-win/pyenv-win/master/pyenv-win/install-pyenv-win.ps1" -OutFile "./install-pyenv-win.ps1"
& ./install-pyenv-win.ps1
# 添加环境变量
$env:PYENV = "$env:USERPROFILE\.pyenv\pyenv-win"
$env:Path = "$env:PYENV\bin;$env:PYENV\shims;$env:Path"
# 刷新环境变量
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
}
catch {
Write-Error "Failed to install pyenv: $_"
exit 1
}
}
try {
# 安装指定版本的 Python
Write-Host "Installing Python 3.11.8..."
pyenv install 3.11.8
if ($LASTEXITCODE -ne 0) {
throw "Failed to install Python 3.11.8"
}
# 设置本地 Python 版本
Write-Host "Setting local Python version..."
pyenv local 3.11.8
if ($LASTEXITCODE -ne 0) {
throw "Failed to set local Python version"
}
# 验证 Python 版本
$pythonVersion = python -V
if (-not $pythonVersion.Contains("3.11.8")) {
throw "Wrong Python version: $pythonVersion"
}
Write-Host "Using Python version: $pythonVersion"
# 创建虚拟环境
Write-Host "Creating virtual environment..."
python -m venv .venv
# 激活虚拟环境
Write-Host "Activating virtual environment..."
.\.venv\Scripts\Activate.ps1
# 升级 pip 和构建工具
Write-Host "Upgrading pip and build tools..."
python -m pip install --upgrade pip setuptools wheel
# 分步安装依赖以确保正确的顺序和版本
Write-Host "Installing database dependencies..."
pip install mysql-connector-python==8.0.33
Write-Host "Installing PyTorch and related packages..."
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
Write-Host "Installing basic dependencies..."
pip install numpy==1.26.4 pandas==2.2.1
Write-Host "Installing machine learning packages..."
pip install scikit-learn==1.5.2
# 安装开发依赖
Write-Host "Installing development dependencies..."
pip install -e ".[dev]"
if ($LASTEXITCODE -ne 0) {
Write-Warning "Failed to install development dependencies. Installing core package..."
pip install -e .
}
# 验证安装
Write-Host "Verifying installations..."
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"
python -c "import pandas; print(f'Pandas version: {pandas.__version__}')"
python -c "import sklearn; print(f'Scikit-learn version: {sklearn.__version__}')"
Write-Host "Environment setup complete!" -ForegroundColor Green
}
catch {
Write-Error "An error occurred: $_"
exit 1
}
finally {
# 清理临时文件
if (Test-Path "./install-pyenv-win.ps1") {
Remove-Item "./install-pyenv-win.ps1"
}
}
# 显示使用说明
Write-Host @"
环境设置完成使用说明
1. 虚拟环境已激活命令提示符前应该显示 (.venv)
2. 要退出虚拟环境运行: deactivate
3. 要重新激活虚拟环境运行: .\.venv\Scripts\Activate.ps1
4. 项目依赖已安装可以开始开发了
如果遇到问题请检查
- Python 版本: python -V
- PyTorch 安装: python -c "import torch; print(torch.__version__)"
- 虚拟环境状态: 确保看到 (.venv) 前缀
"@ -ForegroundColor Cyan

73
scripts/setup_env.sh Executable file
View File

@ -0,0 +1,73 @@
#!/bin/bash
# 此脚本用于设置 Python 开发环境
# 主要用于:
# 1. 开发环境初始化
# 2. 确保正确的 Python 版本
# 3. 安装项目依赖
# 注意:在运行此脚本前,请先运行 setup_linux.sh 安装系统依赖
# 检查 pyenv 是否安装
if ! command -v pyenv &> /dev/null; then
echo "pyenv not found. Installing..."
if [[ "$OSTYPE" == "darwin"* ]]; then
brew install pyenv
else
curl https://pyenv.run | bash
fi
fi
# 安装指定版本的 Python
pyenv install 3.11.8 || true
# 设置本地 Python 版本
pyenv local 3.11.8
# 确保使用正确的 Python 版本
eval "$(pyenv init -)"
pyenv shell 3.11.8
# 验证 Python 版本
python_version=$(python -V 2>&1)
if [[ $python_version != *"3.11.8"* ]]; then
echo "Error: Wrong Python version: $python_version"
echo "Please ensure pyenv is properly configured in your shell"
exit 1
fi
# 创建虚拟环境
python -m venv .venv
# 激活虚拟环境
source .venv/bin/activate
# 升级 pip 和构建工具
python -m pip install --upgrade pip setuptools wheel
# 分步安装依赖以确保正确的顺序和版本
echo "Installing database dependencies..."
pip install mysql-connector-python==8.0.33
echo "Installing PyTorch and related packages..."
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
echo "Installing basic dependencies..."
pip install numpy==1.26.4 pandas==2.2.1
echo "Installing machine learning packages..."
pip install scikit-learn==1.5.2
# 安装开发依赖
if ! pip install -e ".[dev]"; then
echo "Warning: Failed to install development dependencies. Installing core package..."
pip install -e .
fi
# 验证安装
echo "Verifying Python version..."
python --version
echo "Verifying PyTorch installation..."
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
echo "Environment setup complete!"

36
scripts/setup_linux.sh Normal file
View File

@ -0,0 +1,36 @@
#!/bin/bash
# 此脚本用于安装 Linux 系统级依赖
# 主要用于:
# 1. 打包环境准备
# 2. 新系统的初始化
# 3. CI/CD 环境设置
# 更新包列表
sudo apt update
# 安装系统依赖
echo "Installing system dependencies..."
sudo apt install -y \
python3.11 \ # Python 3.11 解释器
python3.11-venv \ # Python 虚拟环境支持
python3-pip \ # Python 包管理器
build-essential \ # 编译工具
python3.11-dev \ # Python 开发库
nodejs \ # Node.js用于前端构建
npm \ # Node.js 包管理器
binutils \ # 二进制工具(用于 PyInstaller
tar \ # 打包工具
gzip \ # 压缩工具
mysql-client \ # MySQL 客户端
libmysqlclient-dev \ # MySQL 开发库
gcc \ # C 编译器
g++ \ # C++ 编译器
libssl-dev \ # SSL 支持
xdg-utils # 用于打开浏览器
# 验证安装
echo "Verifying installations..."
python3.11 --version
node --version
npm --version

View File

@ -1 +1,3 @@
# 这个文件可以为空,但必须存在
from .app import create_app
__all__ = ['create_app']

Binary file not shown.

Binary file not shown.

View File

@ -1,50 +1,57 @@
from flask import Flask
from flask_cors import CORS
from flask import send_from_directory
from .routes import api_bp
from .logger import setup_logger
from config import config
import os
import sys
# 获取logger
logger = setup_logger(__name__)
def create_app():
"""
创建并配置Flask应用
"""
"""创建并配置 Flask 应用"""
try:
# 创建必要的目录
os.makedirs('logs', exist_ok=True)
os.makedirs('data', exist_ok=True)
os.makedirs('models', exist_ok=True)
logger.info("=== Server Starting ===")
logger.info("Initializing directories...")
# 创建Flask应用
app = Flask(__name__)
# 配置CORS
CORS(app)
logger.info("CORS enabled")
# 注册API蓝图
# 注册路由
app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered")
# 配置数据库连接
app.config['MYSQL_HOST'] = 'localhost'
app.config['MYSQL_USER'] = 'root'
app.config['MYSQL_PASSWORD'] = '123456'
app.config['MYSQL_DB'] = 'equipment_cost_db'
# 获取前端文件路径
if getattr(sys, 'frozen', False):
# PyInstaller 打包后的路径
frontend_path = os.path.join(sys._MEIPASS, 'frontend')
logger.info(f"Running in frozen mode, frontend path: {frontend_path}")
logger.info(f"MEIPASS path: {sys._MEIPASS}")
logger.info(f"Files in frontend dir: {os.listdir(frontend_path)}")
else:
# 开发环境路径
frontend_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'frontend', 'dist')
logger.info("Starting server...")
# 服务前端文件
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>')
def serve_frontend(path):
logger.info(f"Serving path: {path}")
logger.info(f"Frontend path: {frontend_path}")
logger.info(f"Full file path: {os.path.join(frontend_path, path)}")
logger.info(f"File exists: {os.path.exists(os.path.join(frontend_path, path))}")
try:
if path == "":
return send_from_directory(frontend_path, 'index.html')
file_path = os.path.join(frontend_path, path)
if os.path.exists(file_path):
return send_from_directory(frontend_path, path)
return send_from_directory(frontend_path, 'index.html')
except Exception as e:
logger.error(f"Error serving file {path}: {str(e)}")
return str(e), 500
return app
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
raise
if __name__ == '__main__':
app = create_app()
app.run(host='localhost', port=5001)
logger.error(f"Error creating application: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise

View File

@ -1,17 +1,23 @@
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from scipy import stats
import joblib
import os
import pandas as pd
from .feature_analysis import FeatureAnalysis
import logging
from src.model_trainer import ModelTrainer
from src.database import get_db_connection
from src.feature_analysis import FeatureAnalysis
from .logger import setup_logger
# PyTorch 为可选依赖
try:
import torch
import torch.nn as nn
_HAS_TORCH = True
except ImportError:
torch = None
nn = None
_HAS_TORCH = False
logger = setup_logger(__name__)
class CostPredictor:
@ -21,161 +27,120 @@ class CostPredictor:
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
# 添加 TensorFlow 配置
tf.config.run_functions_eagerly(False) # 启用图执行模式
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
if _HAS_TORCH:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = None
self.load_model()
def load_model(self):
"""
加载预训练型和标准化器
加载预训练模型和标准化器
"""
try:
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 创建默认模型
self._create_default_model()
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
if _HAS_TORCH:
try:
self._create_default_model()
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
def _create_default_model(self):
"""
创建默认模型并进行初始化训练
"""
# 创建输入层
inputs = tf.keras.Input(shape=(11,))
# 创建隐藏层
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
# 创建输出层
outputs = tf.keras.layers.Dense(1)(x)
# 创建模型
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 编译模型
self.model.compile(
optimizer='adam',
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_absolute_error]
)
# 创建示例数据
example_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed.")
class DefaultModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
return self.layers(x)
example_features = {
'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2], 'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 训练标准化器
self.scaler_X.fit(example_data)
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
# 设置默认装备类型
'firing_angle_horizontal': [102, 110], 'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1], 'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60]
}
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
self.scaler_X.fit(X.numpy())
self.scaler_y.fit(y.numpy())
self.model = DefaultModel(X.shape[1]).to(self.device)
self.equipment_type = '火箭炮'
def _create_example_data(self):
"""
创建示例数据来训练标准化器
"""
# 火箭炮示例数据
rocket_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 巡飞弹示例数据
missile_data = pd.DataFrame({
'length_m': [1.3, 2.5],
'width_m': [0.23, 0.6],
'height_m': [0.23, 0.6],
'weight_kg': [12.5, 135],
'max_range_km': [40, 250],
'max_speed_kmh': [180, 185],
'cruise_speed_kmh': [100, 110],
'flight_time_min': [60, 120],
'folded_length_mm': [1300, 2500],
'folded_width_mm': [230, 600],
'folded_height_mm': [230, 600]
})
# 训练标准化器
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
# 设置默认装备类型
self.equipment_type = '火箭炮'
def predict(self, data):
"""
使用训练好的最优模型进行预测
"""
def predict(self, data, model_record):
"""使用训练好的模型进行预测"""
try:
logger.info(f"Starting prediction for {data.get('type')}")
logger.info(f"Starting prediction for {data.get('type')} using {model_record['model_type']}")
equipment_type = data.get('type')
# 加载已训练的最优模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type):
raise ValueError(f"No trained model found for {equipment_type}")
# 使用ModelTrainer加载模型
model_trainer = ModelTrainer()
success = model_trainer.load_model(equipment_type, model_record['model_type'])
if not success:
raise ValueError(f"Failed to load model for {equipment_type}")
# 从ModelTrainer获取模型和标准化器
model = model_trainer.model
feature_scaler = model_trainer.feature_scaler
target_scaler = model_trainer.target_scaler
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
feature_analyzer = FeatureAnalysis()
features = feature_analyzer.get_equipment_specific_features(equipment_type)
X = []
for feature in features:
value = data.get(feature, 0.0)
X.append(float(value))
# 预测
y_pred = trainer.predict(X)
# 转换为numpy数组并标准化
X = np.array([X])
X_scaled = feature_scaler.transform(X)
# 根据模型类型进行预测
if isinstance(model, torch.nn.Module):
# PyTorch模型预测
model.eval()
with torch.no_grad():
X_tensor = torch.FloatTensor(X_scaled).to(self.device)
y_pred = model(X_tensor)
y_pred = y_pred.cpu().numpy()
elif model_record['model_type'] == 'pls':
# PLS模型预测
y_pred = model.predict(X_scaled).reshape(-1, 1)
else:
# 其他sklearn模型预测
y_pred = model.predict(X_scaled).reshape(-1, 1)
# 转换回原始尺度并确保为正数
y_pred_original = target_scaler.inverse_transform(y_pred)
predicted_cost = abs(float(y_pred_original[0][0])) # 确保预测值为正数
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 获取模型类型
model_type = trainer.get_model_type()
std = predicted_cost * 0.2 # 使用预测值的20%作为标准差
confidence_interval = {
'lower': max(predicted_cost - std, predicted_cost * 0.5), # 至少是预测值的50%
'upper': predicted_cost + std
}
return {
'predicted_cost': float(y_pred[0]),
'model_type': model_type, # 返回使用的模型类型
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
'predicted_cost': predicted_cost,
'confidence_interval': confidence_interval
}
except Exception as e:
@ -187,11 +152,10 @@ class CostPredictor:
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
# 使用预测值的20%作为标准差
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
@ -213,130 +177,15 @@ class CostPredictor:
"""
模型评估
"""
# 确保输入是 numpy 数组
if torch.is_tensor(y_true):
y_true = y_true.cpu().numpy()
if torch.is_tensor(y_pred):
y_pred = y_pred.cpu().numpy()
return {
'mae': float(mean_absolute_error(y_true, y_pred)),
'mse': float(mean_squared_error(y_true, y_pred)),
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
'r2': float(r2_score(y_true, y_pred))
}
def predict_pls(self, data):
"""
使用 PLS 型预测成本
"""
try:
logger.info(f"Starting PLS prediction for {data.get('type')}")
equipment_type = data.get('type')
# 加载 PLS 模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
raise ValueError(f"No trained PLS model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
return {
'predicted_cost': float(y_pred[0]),
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"PLS prediction error: {str(e)}")
raise
def predict_all(self, data):
"""
使用所有可用模型进行预测
"""
try:
logger.info(f"Starting multi-model prediction for {data.get('type')}")
equipment_type = data.get('type')
results = {}
# 1. 获取所有激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT id, model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
""", (equipment_type,))
active_models = cursor.fetchall()
if not active_models:
raise ValueError(f"No active models found for {equipment_type}")
# 2. 使用每个模型进行预测
trainer = ModelTrainer()
for model_info in active_models:
try:
# 加载特定模型
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
logger.warning(f"Failed to load model: {model_info['model_name']}")
continue
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 保存结果
results[model_info['model_type']] = {
'predicted_cost': float(y_pred[0]),
'model_info': {
'name': model_info['model_name'],
'type': model_info['model_type'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
continue
if not results:
raise ValueError("No successful predictions from any model")
# 3. 计算综合预测结果
all_predictions = [result['predicted_cost'] for result in results.values()]
ensemble_prediction = float(np.mean(all_predictions))
prediction_std = float(np.std(all_predictions))
# 4. 返回所有结果
return {
'individual_predictions': results,
'ensemble_prediction': {
'predicted_cost': ensemble_prediction,
'standard_deviation': prediction_std,
'confidence_interval': {
'lower': float(ensemble_prediction - 1.96 * prediction_std),
'upper': float(ensemble_prediction + 1.96 * prediction_std)
}
}
}
except Exception as e:
logger.error(f"Error in multi-model prediction: {str(e)}")
raise
}

View File

@ -1,83 +1,89 @@
from sklearn.preprocessing import StandardScaler
from datetime import datetime
import os
import joblib
import pandas as pd
import numpy as np
from src.feature_analysis import FeatureAnalysis
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.model_selection import cross_val_score, LeaveOneOut
import json
import logging
from src.database.db_connection import get_db_connection
from sklearn.metrics import mean_absolute_error, mean_squared_error
from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
from .logger import setup_logger
logger = setup_logger(__name__)
# PyTorch 为可选依赖
try:
import torch
from torch.utils.data import Dataset, DataLoader
_HAS_TORCH = True
except ImportError:
torch = None
Dataset = object
DataLoader = None
_HAS_TORCH = False
class EquipmentDataset(Dataset if _HAS_TORCH else object):
"""装备数据集类"""
def __init__(self, features, targets=None):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
self.features = torch.FloatTensor(features)
self.targets = torch.FloatTensor(targets) if targets is not None else None
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
if self.targets is not None:
return self.features[idx], self.targets[idx]
return self.features[idx]
class DataPreparation:
def __init__(self):
self.feature_analyzer = FeatureAnalysis()
self.feature_scaler = StandardScaler()
self.target_scaler = StandardScaler() # 添加目标值标准化器
self.target_scaler = StandardScaler()
def prepare_training_data(self, equipment_data, equipment_type):
"""
准备训练数据
"""
def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
"""准备训练数据"""
try:
logger.info(f"Preparing training data for {equipment_type}")
logger.info(f"Raw data size: {len(equipment_data)}")
# 如果输入已经是 numpy 数组,直接返回
if isinstance(equipment_data, np.ndarray):
X = equipment_data
logger.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
return {
'X': X,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
# 从原始数据中提取特征和目标值
# 获取特征名称(包含生产商特征)
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
for item in equipment_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 获取数据库连接
with get_db_connection() as conn:
cursor = conn.cursor()
# 提取目标值(成本)
try:
cost = float(item['actual_cost'])
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logger.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError, KeyError):
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 获取所有生产商数据,用于计算特征
cursor.execute("""
SELECT * FROM manufacturers
""")
manufacturers = {row['id']: row for row in cursor.fetchall()}
for item in equipment_data:
# 获取生产商数据
manufacturer = manufacturers.get(item['manufacturer_id'], {})
# 计算生产商特征
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer)
# 合并装备特征和生产商特征
feature_values = []
for name in feature_names:
if name.startswith('manufacturer_'):
value = manufacturer_features.get(name, 0.0)
else:
value = item.get(name)
feature_values.append(float(value) if value is not None else 0.0)
features.append(feature_values)
targets.append(float(item['actual_cost']))
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录原始数据范围
# 记录数据范围
logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
@ -85,25 +91,18 @@ class DataPreparation:
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
# 记录标准化后的数据范围
logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
# 记录标准化器参数
logger.info("Feature scaler params:")
logger.info(f"Mean: {self.feature_scaler.mean_}")
logger.info(f"Scale: {self.feature_scaler.scale_}")
logger.info("Target scaler params:")
logger.info(f"Mean: {self.target_scaler.mean_}")
logger.info(f"Scale: {self.target_scaler.scale_}")
# 创建数据集和数据加载器
dataset = EquipmentDataset(X_scaled, y_scaled)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return {
'X': X_scaled,
'y': y_scaled,
'dataloader': dataloader,
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
'target_scaler': self.target_scaler,
'raw_shape': X.shape,
'X': X_scaled,
'y': y_scaled
}
except Exception as e:
@ -162,7 +161,6 @@ class DataPreparation:
# 提取目标值(成本)并验证范围
try:
cost = float(item['actual_cost'])
logger.info(f"Raw cost value: {cost}")
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:

View File

@ -1,37 +1,227 @@
import mysql.connector
from mysql.connector import Error
import sqlite3
from contextlib import contextmanager
import os
from dotenv import load_dotenv
from ..logger import setup_logger
# 获取logger
logger = setup_logger(__name__)
# 加载环境变量
load_dotenv()
# SQLite 数据库文件路径(相对于项目根目录)
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data')
DB_PATH = os.path.join(DB_DIR, 'equipment_cost.db')
# 建表 SQL
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS equipments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
type TEXT,
manufacturer TEXT,
manufacturer_id INTEGER,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE TABLE IF NOT EXISTS common_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
length_m REAL,
width_m REAL,
height_m REAL,
weight_kg REAL,
max_range_km REAL,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS rocket_artillery_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
firing_angle_horizontal REAL,
firing_angle_vertical REAL,
rocket_length_m REAL,
rocket_diameter_mm REAL,
rocket_weight_kg REAL,
rate_of_fire REAL,
combat_weight_kg REAL,
speed_kmh REAL,
min_range_km REAL,
max_range_km REAL,
mobility_type TEXT,
structure_layout TEXT,
engine_model TEXT,
engine_params TEXT,
power_hp REAL,
travel_range_km REAL,
fire_density REAL,
range_ratio REAL,
mobility_score INTEGER,
combat_readiness_score INTEGER,
deployment_score INTEGER,
terrain_adaptability_score INTEGER,
rocket_power_ratio REAL,
platform_efficiency REAL,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS loitering_munition_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
wingspan_m REAL,
warhead_weight_kg REAL,
max_speed_ms REAL,
cruise_speed_kmh REAL,
endurance_min REAL,
flight_time_min REAL,
max_range_km REAL,
max_payload_kg REAL,
ceiling_altitude_m REAL,
combat_radius_km REAL,
folded_length_mm REAL,
folded_width_mm REAL,
folded_height_mm REAL,
warhead_type TEXT,
launch_mode TEXT,
power_system TEXT,
guidance_system TEXT,
engine_power_kw REAL,
engine_thrust_n REAL,
datalink_range_km REAL,
guidance_accuracy_m REAL,
min_altitude_m REAL,
max_altitude_m REAL,
length_width_ratio REAL,
weight_range_ratio REAL,
speed_weight_ratio REAL,
guidance_system_score INTEGER,
warhead_power_score INTEGER,
warhead_type_code INTEGER,
launch_mode_code INTEGER,
power_system_code INTEGER,
guidance_system_code INTEGER,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS feature_encoding (
id INTEGER PRIMARY KEY AUTOINCREMENT,
feature_type TEXT,
feature_value TEXT,
code INTEGER,
UNIQUE(feature_type, feature_value)
);
CREATE TABLE IF NOT EXISTS cost_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
actual_cost REAL,
predicted_cost REAL,
prediction_date TEXT,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS custom_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
param_name TEXT,
param_value TEXT,
param_unit TEXT,
description TEXT,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS datasets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT,
equipment_type TEXT NOT NULL,
purpose TEXT NOT NULL,
created_at TEXT DEFAULT (datetime('now','localtime')),
updated_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE TABLE IF NOT EXISTS dataset_equipments (
dataset_id INTEGER NOT NULL,
equipment_id INTEGER NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS trained_models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
model_type TEXT NOT NULL,
equipment_type TEXT NOT NULL,
model_path TEXT NOT NULL,
scaler_path TEXT NOT NULL,
r2_score REAL,
mae REAL,
rmse REAL,
feature_importance TEXT,
training_data_size INTEGER,
training_date TEXT DEFAULT (datetime('now','localtime')),
is_active INTEGER DEFAULT 0,
created_by TEXT
);
CREATE TABLE IF NOT EXISTS manufacturers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
country TEXT NOT NULL,
tech_level INTEGER NOT NULL,
scale_level INTEGER NOT NULL,
supply_chain_level INTEGER NOT NULL,
created_at TEXT DEFAULT (datetime('now','localtime')),
updated_at TEXT DEFAULT (datetime('now','localtime')),
UNIQUE(name)
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_equipment_type ON equipments(type);
CREATE INDEX IF NOT EXISTS idx_equipment_name ON equipments(name);
CREATE INDEX IF NOT EXISTS idx_cost_data_equipment ON cost_data(equipment_id);
CREATE INDEX IF NOT EXISTS idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX IF NOT EXISTS idx_model_active ON trained_models(is_active);
CREATE INDEX IF NOT EXISTS idx_manufacturer_country ON manufacturers(country);
CREATE INDEX IF NOT EXISTS idx_manufacturer_tech_level ON manufacturers(tech_level);
CREATE INDEX IF NOT EXISTS idx_manufacturer_scale_level ON manufacturers(scale_level);
CREATE INDEX IF NOT EXISTS idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
CREATE INDEX IF NOT EXISTS idx_equipment_manufacturer ON equipments(manufacturer_id);
"""
def init_db():
"""初始化数据库:确保数据库文件和表存在"""
os.makedirs(DB_DIR, exist_ok=True)
conn = sqlite3.connect(DB_PATH)
conn.executescript(SCHEMA_SQL)
conn.commit()
conn.close()
logger.info(f"Database initialized at {DB_PATH}")
@contextmanager
def get_db_connection():
"""
数据库连接上下文管理器
返回的 connection 已设置 dict row_factory
以便按列名访问
"""
connection = None
conn = None
try:
connection = mysql.connector.connect(
host=os.getenv('MYSQL_HOST', 'localhost'),
user=os.getenv('MYSQL_USER', 'root'),
password=os.getenv('MYSQL_PASSWORD', '123456'),
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
)
logger.info("Database connection established")
yield connection
except Error as e:
logger.error(f"Error connecting to MySQL: {str(e)}")
# 确保数据库已初始化
if not os.path.exists(DB_PATH):
logger.info("Database file not found, initializing...")
init_db()
conn = sqlite3.connect(DB_PATH)
conn.row_factory = lambda c, r: {col[0]: r[idx] for idx, col in enumerate(c.description)}
conn.execute("PRAGMA foreign_keys = ON")
logger.debug("Database connection established")
yield conn
except sqlite3.Error as e:
logger.error(f"Database error: {str(e)}")
raise
finally:
if connection and connection.is_connected():
connection.close()
logger.info("Database connection closed")
if conn:
conn.close()
logger.debug("Database connection closed")

290
src/demo_service.py Normal file
View File

@ -0,0 +1,290 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
@dataclass(frozen=True)
class AlgorithmDefinition:
key: str
name: str
english_name: str
family: str
description: str
estimator: Any
class DemoModelService:
target_column = "actual_cost"
ignored_columns = {"name", "type", target_column}
def __init__(self, dataset_path: Path | str | None = None):
root = Path(__file__).resolve().parent.parent
self.dataset_path = Path(dataset_path) if dataset_path else root / "data" / "demo_equipment_costs.csv"
def get_algorithms(self) -> list[dict[str, str]]:
algorithms, _ = self._available_algorithms()
return [
{
"key": item.key,
"name": item.name,
"english_name": item.english_name,
"family": item.family,
"description": item.description,
}
for item in algorithms.values()
]
def get_dataset_summary(self) -> dict[str, Any]:
frame = self._load_dataset()
feature_columns = self._feature_columns(frame)
return {
"source": "local-file",
"path": str(self.dataset_path),
"row_count": int(len(frame)),
"columns": list(frame.columns),
"features": feature_columns,
"target": self.target_column,
"target_label": "实际成本",
"equipment_types": sorted(frame["type"].dropna().unique().tolist()),
"preview": frame.head(8).to_dict(orient="records"),
}
def run_demo(self, selected_algorithms: list[str] | None = None) -> dict[str, Any]:
frame = self._load_dataset()
feature_columns = self._feature_columns(frame)
algorithms, availability_warnings = self._available_algorithms()
requested = selected_algorithms or list(algorithms.keys())
warnings = list(availability_warnings)
selected = []
for key in requested:
if key in algorithms:
selected.append(key)
else:
warnings.append(f"算法 '{key}' 不可用,已自动跳过。")
if not selected:
selected = ["linear"]
warnings.append("所选算法均不可用,已自动使用线性回归。")
X = frame[feature_columns]
y = frame[self.target_column]
train_x, test_x, train_y, test_y = train_test_split(
X,
y,
test_size=0.3,
random_state=42,
)
metrics: dict[str, dict[str, float | str]] = {}
predictions: dict[str, list[float]] = {}
feature_importance: dict[str, list[dict[str, float | str]]] = {}
for key in selected:
definition = algorithms[key]
model = definition.estimator
model.fit(train_x, train_y)
predicted = model.predict(test_x)
predictions[key] = [float(value) for value in predicted]
metrics[key] = {
"name": definition.name,
"r2": float(r2_score(test_y, predicted)),
"mae": float(mean_absolute_error(test_y, predicted)),
"rmse": float(np.sqrt(mean_squared_error(test_y, predicted))),
}
feature_importance[key] = self._feature_importance(model, feature_columns)
best_model = min(metrics, key=lambda key: float(metrics[key]["rmse"]))
ordered_test = test_x.copy()
ordered_test["actual"] = test_y
ordered_test["name"] = frame.loc[test_x.index, "name"]
prediction_points = []
for position, (index, row) in enumerate(ordered_test.sort_values("actual").iterrows()):
point = {
"name": row["name"],
"actual": float(row["actual"]),
}
for key in selected:
original_position = list(test_x.index).index(index)
point[key] = predictions[key][original_position]
prediction_points.append(point)
sample = frame.sort_values(self.target_column).iloc[len(frame) // 2]
sample_x = pd.DataFrame([sample[feature_columns].to_dict()])
sample_predictions = {
key: float(algorithms[key].estimator.fit(X, y).predict(sample_x)[0])
for key in selected
}
return {
"source": "local-file",
"dataset": self.get_dataset_summary(),
"algorithms": self.get_algorithms(),
"selected_algorithms": selected,
"best_model": best_model,
"metrics": metrics,
"feature_importance": feature_importance,
"prediction_points": prediction_points,
"sample_prediction": {
"input": sample.drop(labels=[self.target_column]).to_dict(),
"actual": float(sample[self.target_column]),
"predictions": sample_predictions,
},
"warnings": warnings,
}
def _load_dataset(self) -> pd.DataFrame:
if not self.dataset_path.exists():
raise FileNotFoundError(f"Demo dataset not found: {self.dataset_path}")
frame = pd.read_csv(self.dataset_path)
if self.target_column not in frame.columns:
raise ValueError(f"Demo dataset must include '{self.target_column}'.")
return frame
def _feature_columns(self, frame: pd.DataFrame) -> list[str]:
columns = [
column
for column in frame.columns
if column not in self.ignored_columns and pd.api.types.is_numeric_dtype(frame[column])
]
if not columns:
raise ValueError("Demo dataset has no numeric feature columns.")
return columns
def _available_algorithms(self) -> tuple[dict[str, AlgorithmDefinition], list[str]]:
algorithms = {
"linear": AlgorithmDefinition(
"linear",
"线性回归",
"Linear Regression",
"线性模型",
"快速建立基准模型,用于展示参数与成本之间的线性关系。",
Pipeline([("scaler", StandardScaler()), ("model", LinearRegression())]),
),
"ridge": AlgorithmDefinition(
"ridge",
"岭回归",
"Ridge Regression",
"线性模型",
"带正则化的线性模型,适合特征存在相关性的场景。",
Pipeline([("scaler", StandardScaler()), ("model", Ridge(alpha=1.0))]),
),
"random_forest": AlgorithmDefinition(
"random_forest",
"随机森林",
"Random Forest",
"树模型集成",
"通过多棵决策树集成预测,能够捕捉非线性特征影响。",
RandomForestRegressor(n_estimators=160, max_depth=6, random_state=42),
),
"gradient_boosting": AlgorithmDefinition(
"gradient_boosting",
"梯度提升树",
"Gradient Boosting",
"树模型集成",
"逐步修正误差的提升模型,常用于表格数据回归任务。",
GradientBoostingRegressor(n_estimators=120, learning_rate=0.06, max_depth=3, random_state=42),
),
"svr": AlgorithmDefinition(
"svr",
"支持向量回归",
"Support Vector Regression",
"核方法",
"使用核函数拟合平滑回归关系,适合展示不同算法偏好。",
Pipeline([("scaler", StandardScaler()), ("model", SVR(C=500000, epsilon=50000))]),
),
"knn": AlgorithmDefinition(
"knn",
"近邻回归",
"KNN Regression",
"实例学习",
"基于相似样本进行预测,便于解释局部相似性。",
Pipeline([("scaler", StandardScaler()), ("model", KNeighborsRegressor(n_neighbors=4))]),
),
}
warnings = []
try:
from xgboost import XGBRegressor
algorithms["xgboost"] = AlgorithmDefinition(
"xgboost",
"XGBoost",
"XGBoost",
"提升模型",
"面向表格数据的高性能梯度提升实现。",
XGBRegressor(
n_estimators=120,
max_depth=3,
learning_rate=0.05,
subsample=0.9,
colsample_bytree=0.9,
random_state=42,
objective="reg:squarederror",
),
)
except Exception:
warnings.append("当前环境未安装 XGBoost页面已自动隐藏该算法。")
try:
from lightgbm import LGBMRegressor
algorithms["lightgbm"] = AlgorithmDefinition(
"lightgbm",
"LightGBM",
"LightGBM",
"提升模型",
"基于直方图优化的快速梯度提升模型。",
LGBMRegressor(
n_estimators=120,
learning_rate=0.05,
max_depth=4,
random_state=42,
verbose=-1,
),
)
except Exception:
warnings.append("当前环境未安装 LightGBM页面已自动隐藏该算法。")
return algorithms, warnings
def _feature_importance(self, model: Any, feature_columns: list[str]) -> list[dict[str, float | str]]:
estimator = model
if isinstance(model, Pipeline):
estimator = model.named_steps["model"]
if hasattr(estimator, "feature_importances_"):
values = estimator.feature_importances_
elif hasattr(estimator, "coef_"):
values = np.abs(np.ravel(estimator.coef_))
else:
values = np.zeros(len(feature_columns))
total = float(np.sum(values))
if total > 0:
values = values / total
ranked = sorted(
[
{"feature": feature, "importance": float(value)}
for feature, value in zip(feature_columns, values)
],
key=lambda item: item["importance"],
reverse=True,
)
return ranked[:8]

View File

@ -15,9 +15,9 @@ class FeatureAnalysis:
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
# 火箭炮特有参数
'max_range_km': '最大射程(km)',
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
@ -39,6 +39,7 @@ class FeatureAnalysis:
'terrain_adaptability_score': '地形适应性评分',
# 巡飞弹特有参数
'max_range_km': '最大射程(km)',
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
@ -57,7 +58,14 @@ class FeatureAnalysis:
'weight_range_ratio': '重量射程比',
'speed_weight_ratio': '速度重量比',
'guidance_system_score': '制导系统评分',
'warhead_power_score': '战斗部威力评分'
'warhead_power_score': '战斗部威力评分',
# 添加生产商特征映射
'manufacturer_tech_level': '生产商技术水平',
'manufacturer_scale_level': '生产商规模水平',
'manufacturer_supply_chain_level': '生产商供应链水平',
'manufacturer_composite_score': '生产商综合得分',
'manufacturer_region_factor': '生产商区域系数'
}
def get_equipment_specific_features(self, equipment_type):
@ -121,6 +129,17 @@ class FeatureAnalysis:
'guidance_system_score',
'warhead_power_score'
])
# 添加生产商特征
manufacturer_features = [
'manufacturer_tech_level',
'manufacturer_scale_level',
'manufacturer_supply_chain_level',
'manufacturer_composite_score',
'manufacturer_region_factor'
]
numeric_features.extend(manufacturer_features)
return numeric_features
def analyze_features(self, features, target, feature_names):
@ -193,11 +212,20 @@ class FeatureAnalysis:
# 创建特征重要性列表(使用中文名称)
important_features = []
# 过滤掉无效的分数
valid_scores = importance_scores[~np.isnan(importance_scores)]
# 记录一些统计信息
logger.info(f"F-score statistics:")
logger.info(f"min={np.min(valid_scores):.2f}, max={np.max(valid_scores):.2f}, "
f"mean={np.mean(valid_scores):.2f}, median={np.median(valid_scores):.2f}")
for idx, (name, score) in enumerate(zip(feature_names, importance_scores)):
if not np.isnan(score):
important_features.append({
'name': self.feature_name_map.get(name, name), # 使用中文名称
'importance': float(score)
'name': self.feature_name_map.get(name, name),
'importance': float(score) # 保持原始F-score
})
# 按重要性排序
@ -234,4 +262,67 @@ class FeatureAnalysis:
except Exception as e:
logger.error(f"Error in analyze_features: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
raise
def calculate_manufacturer_features(self, manufacturer_data):
"""计算生产商相关的特征"""
try:
# 处理 None 值(数据库 NULL使用默认值
raw_tech = manufacturer_data.get('tech_level')
raw_scale = manufacturer_data.get('scale_level')
raw_supply = manufacturer_data.get('supply_chain_level')
tech_level = float(raw_tech) if raw_tech is not None else 0
scale_level = float(raw_scale) if raw_scale is not None else 0
supply_chain_level = float(raw_supply) if raw_supply is not None else 0
country = manufacturer_data.get('country', '未知') or '未知'
# 计算综合得分
composite_score = (
tech_level * 0.4 + # 技术水平权重最高
scale_level * 0.3 + # 规模水平次之
supply_chain_level * 0.3 # 供应链水平
)
# 计算区域系数(基于不同地区的成本差异)
region_factors = {
'美国': 1.2,
'英国': 1.15,
'德国': 1.15,
'法国': 1.15,
'以色列': 1.1,
'中国': 0.8,
'俄罗斯': 0.85,
'韩国': 0.9,
'日本': 1.1
}
region_factor = region_factors.get(country, 1.0)
# 记录计算过程
logger.info(f"Manufacturer features calculation:")
logger.info(f"Tech level: {tech_level}")
logger.info(f"Scale level: {scale_level}")
logger.info(f"Supply chain level: {supply_chain_level}")
logger.info(f"Country: {country}")
logger.info(f"Composite score: {composite_score}")
logger.info(f"Region factor: {region_factor}")
return {
'manufacturer_tech_level': tech_level,
'manufacturer_scale_level': scale_level,
'manufacturer_supply_chain_level': supply_chain_level,
'manufacturer_composite_score': composite_score,
'manufacturer_region_factor': region_factor
}
except Exception as e:
logger.error(f"Error calculating manufacturer features: {str(e)}")
# 返回默认值而不是抛出异常,确保分析过程可以继续
return {
'manufacturer_tech_level': 0,
'manufacturer_scale_level': 0,
'manufacturer_supply_chain_level': 0,
'manufacturer_composite_score': 0,
'manufacturer_region_factor': 1.0
}

View File

@ -26,8 +26,8 @@ def import_training_data(excel_file):
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '火箭炮'
SELECT id FROM equipments
WHERE name = ? AND type = '火箭炮'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
@ -37,8 +37,8 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
INSERT INTO equipments (name, type, manufacturer)
VALUES (?, ?, ?)
""", (row['名称'], '火箭炮', row['制造商']))
equipment_id = cursor.lastrowid
@ -47,7 +47,7 @@ def import_training_data(excel_file):
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?)
""", (
equipment_id,
row['总长_m'] if pd.notna(row['总长_m']) else None,
@ -65,7 +65,7 @@ def import_training_data(excel_file):
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
structure_layout, engine_model, engine_params, power_hp,
travel_range_km)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
equipment_id,
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
@ -89,7 +89,7 @@ def import_training_data(excel_file):
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
VALUES (?, ?)
""", (equipment_id, row['成本_元']))
logger.info("火箭炮数据导入完成")
@ -105,8 +105,8 @@ def import_training_data(excel_file):
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '巡飞弹'
SELECT id FROM equipments
WHERE name = ? AND type = '巡飞弹'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
@ -116,8 +116,8 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
INSERT INTO equipments (name, type, manufacturer)
VALUES (?, ?, ?)
""", (
row['名称'],
'巡飞弹',
@ -130,7 +130,7 @@ def import_training_data(excel_file):
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?)
""", (
equipment_id,
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
@ -147,7 +147,7 @@ def import_training_data(excel_file):
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
folded_length_mm, folded_width_mm, folded_height_mm,
power_system, guidance_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
equipment_id,
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
@ -168,7 +168,7 @@ def import_training_data(excel_file):
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
VALUES (?, ?)
""", (equipment_id, float(row['成本_元'])))
logger.info("巡飞弹数据导入完成")
@ -190,48 +190,48 @@ def import_training_data(excel_file):
# 获取装备ID - 使用新的游标
logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
SELECT id FROM equipment WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()
id_cursor = conn.cursor()
id_cursor.execute("""
SELECT id FROM equipments WHERE name = ?
""", (equipment_name,))
result = id_cursor.fetchone()
if not result:
logger.warning(f"未找到装备: {equipment_name}")
continue
equipment_id = result[0]
equipment_id = result['id']
logger.debug(f"找到装备ID: {equipment_id}")
# 检查参数是否存在 - 使用新的游标
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
with conn.cursor() as check_cursor:
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = %s AND param_name = %s
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
check_cursor = conn.cursor()
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = ? AND param_name = ?
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
if exists:
logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
continue
# 插入新的参数 - 使用新的游标
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
with conn.cursor() as insert_cursor:
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (%s, %s, %s, %s, %s)
""", (
equipment_id,
param_name,
param_value,
param_unit,
insert_cursor = conn.cursor()
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (?, ?, ?, ?, ?)
""", (
equipment_id,
param_name,
param_value,
param_unit,
param_desc
))
logger.debug(f"成功插入参数记录")

237
src/import_sql_data.py Normal file
View File

@ -0,0 +1,237 @@
"""
导入 SQL 数据文件到 SQLite 数据库
处理 src/ 目录下的 MySQL 格式 SQL 数据文件:
- rocket_artillery_data.sql (96 条火箭炮数据)
- loitering_munition_data.sql (100 条巡飞弹数据)
- manufacturer_data.sql (生产商数据)
"""
import re
import os
import sys
from src.database.db_connection import get_db_connection, DB_PATH
from src.logger import setup_logger
logger = setup_logger(__name__)
def parse_mysql_insert(sql_text):
"""
解析 MySQL INSERT 语句返回 (table_name, columns, values_list)
columns 是列名列表, values_list 是每行的值列表
"""
# 去掉块注释 /* ... */
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# 去掉行注释 --
sql_text = re.sub(r'--.*', '', sql_text)
results = []
# 匹配 INSERT INTO table (columns) VALUES (values), (values), ...;
pattern = re.compile(
r"INSERT\s+INTO\s+(\w+)\s*\(([^)]+)\)\s*VALUES\s*(.+?);",
re.DOTALL | re.IGNORECASE
)
for match in pattern.finditer(sql_text):
table_name = match.group(1).strip()
columns_str = match.group(2).strip()
values_str = match.group(3).strip()
# 提取列名(去掉注释部分)
columns = []
for col in columns_str.split(','):
col = col.strip()
# 去掉行内注释
col = re.sub(r'\s*--.*', '', col).strip()
columns.append(col)
# 解析 values - 处理多个用逗号分隔的 ( ... ) 元组
values_list = []
current_tuple = []
depth = 0
current_val = ''
in_string = False
string_char = None
for ch in values_str:
if in_string:
if ch == string_char:
in_string = False
current_val += ch
continue
if ch in ("'", '"'):
in_string = True
string_char = ch
current_val += ch
continue
if ch == '(':
if depth > 0:
current_val += ch
depth += 1
if depth == 1:
current_val = ''
continue
if ch == ')':
depth -= 1
if depth == 0:
current_tuple.append(current_val.strip())
values_list.append(current_tuple)
current_tuple = []
else:
current_val += ch
continue
if ch == ',' and depth == 1:
current_tuple.append(current_val.strip())
current_val = ''
continue
if depth >= 1:
current_val += ch
results.append((table_name, columns, values_list))
return results
def convert_value(val):
"""将 MySQL 值字符串转为 Python 对象"""
val = val.strip()
if val.upper() == 'NULL' or val == '':
return None
if val.upper() == 'TRUE':
return 1
if val.upper() == 'FALSE':
return 0
# 字符串
if (val.startswith("'") and val.endswith("'")) or \
(val.startswith('"') and val.endswith('"')):
return val[1:-1]
# 数字
try:
if '.' in val:
return float(val)
return int(val)
except ValueError:
return val
def import_sql_file(filepath):
"""导入单个 SQL 文件"""
logger.info(f"Reading {filepath}...")
with open(filepath, 'r', encoding='utf-8') as f:
sql_text = f.read()
parsed = parse_mysql_insert(sql_text)
total_rows = 0
with get_db_connection() as conn:
cursor = conn.cursor()
for table_name, columns, values_list in parsed:
if not values_list:
continue
# 构建参数化 INSERT
placeholders = ','.join(['?'] * len(columns))
col_names = ','.join(columns)
sql = f"INSERT OR IGNORE INTO {table_name} ({col_names}) VALUES ({placeholders})"
row_count = 0
for values in values_list:
if len(values) != len(columns):
logger.warning(f"Column mismatch in {table_name}: "
f"expected {len(columns)} values, got {len(values)}: {values}")
continue
converted = [convert_value(v) for v in values]
try:
cursor.execute(sql, converted)
if cursor.rowcount > 0:
row_count += 1
except Exception as e:
logger.warning(f"Error inserting into {table_name}: {e}")
logger.warning(f" Values: {converted}")
if row_count > 0:
logger.info(f" {table_name}: inserted {row_count} rows")
total_rows += row_count
conn.commit()
return total_rows
def run_manufacturer_update():
"""执行 manufacturer_id 更新(把 equipments.manufacturer 名称映射为 manufacturers.id"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE equipments
SET manufacturer_id = (
SELECT id FROM manufacturers WHERE name = equipments.manufacturer
)
WHERE manufacturer_id IS NULL
AND manufacturer IS NOT NULL
AND EXISTS (SELECT 1 FROM manufacturers WHERE name = equipments.manufacturer)
""")
conn.commit()
updated = cursor.rowcount
if updated > 0:
logger.info(f"Updated manufacturer_id for {updated} equipment(s)")
else:
logger.info("No manufacturer_id updates needed")
def import_all():
"""导入所有 SQL 数据文件"""
base_dir = os.path.dirname(os.path.dirname(__file__))
# 清除现有数据库,重新开始
if os.path.exists(DB_PATH):
os.remove(DB_PATH)
logger.info(f"Removed existing database: {DB_PATH}")
files = [
(os.path.join(base_dir, 'src', 'manufacturer_data.sql'), '生产商数据'),
(os.path.join(base_dir, 'src', 'rocket_artillery_data.sql'), '火箭炮数据'),
(os.path.join(base_dir, 'src', 'loitering_munition_data.sql'), '巡飞弹数据'),
]
total = 0
for filepath, label in files:
if not os.path.exists(filepath):
logger.warning(f"File not found: {filepath}")
continue
logger.info(f"正在导入 {label}...")
rows = import_sql_file(filepath)
logger.info(f" {label}: 共导入 {rows}")
total += rows
# 更新 manufacturer_id
run_manufacturer_update()
# 统计结果
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as cnt FROM equipments")
eq_count = cursor.fetchone()['cnt']
cursor.execute("SELECT type, COUNT(*) as cnt FROM equipments GROUP BY type")
by_type = {r['type']: r['cnt'] for r in cursor.fetchall()}
cursor.execute("SELECT COUNT(*) as cnt FROM manufacturers")
mf_count = cursor.fetchone()['cnt']
logger.info("=" * 50)
logger.info(f"导入完成!")
logger.info(f" 装备总计: {eq_count}")
for t, c in by_type.items():
logger.info(f" {t}: {c}")
logger.info(f" 生产商: {mf_count}")
logger.info("=" * 50)
return eq_count
if __name__ == '__main__':
import_all()

View File

@ -1,319 +0,0 @@
/*
使
1.
2.
3.
*/
-- 插入装备基本信息
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-- 插入巡飞弹技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES (
1, -- 终结者巡飞弹
0.56,
0.15,
0.20,
2.72,
160.93,
96.56,
24,
15,
'破片杀伤战斗部',
'凭自身动力起飞',
560,
150,
200
);
-- 插入火箭炮技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES (
2, -- 胜利-2火箭炮
10,
2.5,
3.34,
15000,
23
);
-- 插入成本数据(示例数据)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 1000000), -- 终结者巡飞弹成本
(2, 5000000); -- 胜利-2火箭炮成本
-- 插入更多巡飞弹变体数据用于训练
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-- 插入变体技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 终结者-A稍大型号
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
-- 终结者-B稍小型号
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
-- 终结者-C标准型号的改进版
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-- 插入变体成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(3, 1100000), -- 终结者-A成本较高
(4, 900000), -- 终结者-B成本较低
(5, 1050000); -- 终结者-C成本中等
-- 添加更多巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-- 添加它们的技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 哈比
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
-- 游荡者
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
-- 凤凰
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
-- 弹簧刀
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
-- 彩虹-4
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-- 添加成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(6, 800000), -- 哈比
(7, 500000), -- 游荡者
(8, 450000), -- 凤凰
(9, 480000), -- 弹簧刀
(10, 1500000); -- 彩虹-4
-- 火箭炮数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('BM-21', '火箭炮', '俄罗斯'),
('SR5', '火箭炮', '中国'),
('HIMARS', '火箭炮', '美国'),
('LAR-160', '火箭炮', '以色列'),
('T-122', '火箭炮', '土耳其'),
('RM-70', '火箭炮', '捷克'),
('ASTROS II', '火箭炮', '巴西');
-- 火箭炮通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- BM-21
(1, 7.35, 2.4, 3.1, 13700, 20.4),
-- SR5
(2, 10.2, 2.8, 3.2, 28500, 70),
-- HIMARS
(3, 7.0, 2.4, 3.2, 16250, 70),
-- LAR-160
(4, 6.7, 2.5, 2.8, 15000, 45),
-- T-122
(5, 7.2, 2.5, 2.9, 18000, 40),
-- RM-70
(6, 7.5, 2.5, 3.0, 17200, 20.3),
-- ASTROS II
(7, 8.0, 2.7, 3.1, 24500, 90);
-- 火箭炮特有参数
INSERT INTO rocket_artillery_params (
equipment_id,
firing_angle_horizontal,
firing_angle_vertical,
rocket_length_m,
rocket_diameter_mm,
rocket_weight_kg,
rate_of_fire,
combat_weight_kg,
speed_kmh,
min_range_km,
mobility_type,
structure_layout,
engine_model,
engine_params,
power_hp,
travel_range_km
) VALUES
-- BM-21
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
-- SR5
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
-- HIMARS
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
-- LAR-160
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
-- T-122
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
-- RM-70
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
-- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-- 巡飞弹数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('Hero-120', '巡飞弹', '以色列'),
('Switchblade 600', '巡飞弹', '美国'),
('Warmate', '巡飞弹', '波兰'),
('CH-901', '巡飞弹', '中国'),
('HAROP', '巡飞弹', '以色列'),
('Coyote', '巡飞弹', '美国'),
('WS-43', '巡飞弹', '中国');
-- 巡飞弹通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- Hero-120
(8, 1.3, 0.23, 0.23, 12.5, 40),
-- Switchblade 600
(9, 1.3, 0.22, 0.22, 15.0, 40),
-- Warmate
(10, 1.1, 0.15, 0.15, 5.7, 15),
-- CH-901
(11, 1.2, 0.18, 0.18, 9.0, 20),
-- HAROP
(12, 2.5, 0.43, 0.43, 135, 1000),
-- Coyote
(13, 0.9, 0.12, 0.12, 5.9, 20),
-- WS-43
(14, 1.8, 0.35, 0.35, 20, 60);
-- 巡飞弹特有参数
INSERT INTO loitering_munition_params (
equipment_id,
wingspan_m,
warhead_weight_kg,
max_speed_ms,
cruise_speed_kmh,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm,
power_system,
guidance_system
) VALUES
-- Hero-120
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
-- Switchblade 600
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
-- Warmate
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
-- CH-901
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
-- HAROP
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
-- Coyote
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
-- WS-43
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-- 火箭炮
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000), -- ASTROS II
-- 巡飞弹
(8, 150000), -- Hero-120
(9, 180000), -- Switchblade 600
(10, 80000), -- Warmate
(11, 100000), -- CH-901
(12, 850000), -- HAROP
(13, 75000), -- Coyote
(14, 120000); -- WS-43
-- 创建初始数据集
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-- 关联装备到数据集
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 火箭炮训练集
(1, 1), (1, 2), (1, 3), (1, 4),
-- 巡飞弹训练集
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
-- 火箭炮验证集
(3, 5), (3, 6), (3, 7),
-- 巡飞弹验证集
(4, 13), (4, 14);

View File

@ -26,15 +26,15 @@
*/
-- 插入装备基本信息
INSERT INTO equipment (
INSERT INTO equipments (
id, -- 装备ID
name, -- 装备名称
type, -- 装备类型
manufacturer -- 制造商
) VALUES
(1, 'IAI Harop', '巡飞弹', '以色列'),
(2, 'IAI Harpy', '巡飞弹', '以色列'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列'),
(1, 'IAI Harop', '巡飞弹', '以色列 IAI'),
(2, 'IAI Harpy', '巡飞弹', '以色列 IAI'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列 IAI'),
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
@ -65,11 +65,11 @@ INSERT INTO equipment (
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'),
(34, 'Shahed-131', '巡飞弹', '伊朗'),
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
(36, 'Shahed-136', '巡飞弹', '伊朗'),
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
(38, 'Shahed-136C', '巡飞弹', '伊朗'),
(34, 'Shahed-131', '巡飞弹', '伊朗国防工业'),
(35, 'Shahed-131B', '巡飞弹', '伊朗国防工业'),
(36, 'Shahed-136', '巡飞弹', '伊朗国防工业'),
(37, 'Shahed-136B', '巡飞弹', '伊朗国防工业'),
(38, 'Shahed-136C', '巡飞弹', '伊朗国防工业'),
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
@ -285,7 +285,7 @@ INSERT INTO loitering_munition_params (
(24, 2.8, 8.0, 70, 180, 240, 50, 10.0, 4000, 25, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(25, 3.0, 9.0, 75, 190, 270, 60, 11.0, 4500, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(26, 3.2, 10.0, 80, 200, 300, 70, 12.0, 5000, 35, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
(28, 3.6, 16.0, 90, 230, 400, 120, 20.0, 6500, 60, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
(29, 1.2, 1.0, 40, 90, 30, 5, 1.5, 1500, 3, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
(30, 1.3, 1.2, 45, 100, 40, 8, 2.0, 2000, 4, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
@ -469,7 +469,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(2, '巡飞弹验证集 2024', '包含20个巡飞弹型号用于验证模型性能', '巡飞弹', '验证');
-- 训练集80个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 以色列系列8/10
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
(1, 4), (1, 5), (1, 6), (1, 7), (1, 8), -- Hero系列
@ -520,7 +520,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
(1, 96), (1, 97), (1, 98), (1, 99); -- Shadow/Argus系列
-- 验证集20个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 以色列系列2/10
(2, 9), -- Hero-900
(2, 48), -- Rotem L
@ -574,8 +574,8 @@ SET description = '包含20个巡飞弹型号覆盖所有主要制造国
WHERE id = 2;
-- 更新巡飞弹特征工程字段
UPDATE loitering_munition_params l
JOIN common_params c ON l.equipment_id = c.equipment_id
-- 第一步更新基于common_params的特征
UPDATE loitering_munition_params l, common_params c, equipments e
SET
-- 长宽比(反映气动布局特点)
l.length_width_ratio = c.length_m / NULLIF(c.width_m, 0),
@ -586,61 +586,6 @@ SET
-- 速度重量比反映动力性能m/s/kg
l.speed_weight_ratio = l.max_speed_ms / NULLIF(c.weight_kg, 0),
-- 制导系统评分(1-10)
l.guidance_system_score =
CASE
WHEN l.guidance_system LIKE '%卫通%' THEN 10
WHEN l.guidance_system LIKE '%AI辅助%' AND l.guidance_system LIKE '%红外%' THEN 9
WHEN l.guidance_system LIKE '%AI辅助%' THEN 8
WHEN l.guidance_system LIKE '%数据链%' AND l.guidance_system LIKE '%光电%' THEN 7
WHEN l.guidance_system LIKE '%数据链%' THEN 6
WHEN l.guidance_system LIKE '%光电%' THEN 5
WHEN l.guidance_system LIKE '%GPS/INS%' THEN 4
ELSE 3
END,
-- 战斗部威力评分(1-10)
l.warhead_power_score =
CASE
-- 大型战斗部(>30kg
WHEN l.warhead_weight_kg > 30 AND l.warhead_type LIKE '%模块化%' THEN 10
WHEN l.warhead_weight_kg > 30 AND l.warhead_type LIKE '%破甲%' THEN 9
WHEN l.warhead_weight_kg > 30 AND l.warhead_type LIKE '%破片%' THEN 8
-- 中型战斗部10-30kg
WHEN l.warhead_weight_kg > 10 AND l.warhead_type LIKE '%模块化%' THEN 8
WHEN l.warhead_weight_kg > 10 AND l.warhead_type LIKE '%破甲%' THEN 7
WHEN l.warhead_weight_kg > 10 AND l.warhead_type LIKE '%破片%' THEN 6
-- 小型战斗部3-10kg
WHEN l.warhead_weight_kg > 3 AND l.warhead_type LIKE '%模块化%' THEN 6
WHEN l.warhead_weight_kg > 3 AND l.warhead_type LIKE '%破甲%' THEN 5
WHEN l.warhead_weight_kg > 3 AND l.warhead_type LIKE '%破片%' THEN 4
-- 微型战斗部(<3kg
WHEN l.warhead_type LIKE '%破甲%' THEN 3
WHEN l.warhead_type LIKE '%破片%' THEN 2
ELSE 1
END,
-- 发动机功率kW根据重量估算
l.engine_power_kw =
CASE
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
WHEN l.power_system = '涡轮喷气' THEN c.weight_kg * 0.35
ELSE c.weight_kg * 0.2
END,
-- 发动机推力N根据重量估算
l.engine_thrust_n =
CASE
WHEN l.power_system = '电动机' THEN c.weight_kg * 9.8 * 0.3
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 9.8 * 0.4
WHEN l.power_system = '涡轮喷气' THEN c.weight_kg * 9.8 * 0.5
ELSE c.weight_kg * 9.8 * 0.35
END,
-- 最小作战高度m根据体型和任务类型估算
l.min_altitude_m =
CASE
@ -660,10 +605,133 @@ SET
ELSE 30
END,
-- 最大作战高度m根据航程估算)
l.max_altitude_m =
-- 发动机功率kW根据重量估算)
l.engine_power_kw =
CASE
WHEN l.max_range_km > 500 THEN 5000
WHEN l.max_range_km > 100 THEN 3000
WHEN power_system = '电动机' THEN c.weight_kg * 0.15
WHEN power_system = '活塞发动机' THEN c.weight_kg * 0.25
WHEN power_system = '涡轮喷气' THEN c.weight_kg * 0.35
ELSE c.weight_kg * 0.2
END,
-- 发动机推力N根据重量估算
l.engine_thrust_n =
CASE
WHEN power_system = '电动机' THEN c.weight_kg * 9.8 * 0.3
WHEN power_system = '活塞发动机' THEN c.weight_kg * 9.8 * 0.4
WHEN power_system = '涡轮喷气' THEN c.weight_kg * 9.8 * 0.5
ELSE c.weight_kg * 9.8 * 0.35
END
WHERE
l.equipment_id = c.equipment_id
AND l.equipment_id = e.id
AND e.type = '巡飞弹';
-- 第二步:更新基于自身参数的特征
UPDATE loitering_munition_params
SET
-- 制导系统评分(1-10)
guidance_system_score =
CASE
WHEN guidance_system LIKE '%卫通%' THEN 10
WHEN guidance_system LIKE '%AI辅助%' AND guidance_system LIKE '%红外%' THEN 9
WHEN guidance_system LIKE '%AI辅助%' THEN 8
WHEN guidance_system LIKE '%数据链%' AND guidance_system LIKE '%光电%' THEN 7
WHEN guidance_system LIKE '%数据链%' THEN 6
WHEN guidance_system LIKE '%光电%' THEN 5
WHEN guidance_system LIKE '%GPS/INS%' THEN 4
ELSE 3
END,
-- 战斗部威力评分(1-10)
warhead_power_score =
CASE
-- 大型战斗部(>30kg
WHEN warhead_weight_kg > 30 AND warhead_type LIKE '%模块化%' THEN 10
WHEN warhead_weight_kg > 30 AND warhead_type LIKE '%破甲%' THEN 9
WHEN warhead_weight_kg > 30 AND warhead_type LIKE '%破片%' THEN 8
-- 中型战斗部10-30kg
WHEN warhead_weight_kg > 10 AND warhead_type LIKE '%模块化%' THEN 8
WHEN warhead_weight_kg > 10 AND warhead_type LIKE '%破甲%' THEN 7
WHEN warhead_weight_kg > 10 AND warhead_type LIKE '%破片%' THEN 6
-- 小型战斗部3-10kg
WHEN warhead_weight_kg > 3 AND warhead_type LIKE '%模块化%' THEN 6
WHEN warhead_weight_kg > 3 AND warhead_type LIKE '%破甲%' THEN 5
WHEN warhead_weight_kg > 3 AND warhead_type LIKE '%破片%' THEN 4
-- 微型战斗部(<3kg
WHEN warhead_type LIKE '%破甲%' THEN 3
WHEN warhead_type LIKE '%破片%' THEN 2
ELSE 1
END,
-- 数据链范围km
datalink_range_km =
CASE
-- 大型巡飞弹(通常具有卫星通信能力)
WHEN guidance_system LIKE '%卫通%' THEN max_range_km
-- 中大型巡飞弹(具有较强数据链能力)
WHEN guidance_system LIKE '%数据链%' AND max_range_km > 100 THEN LEAST(max_range_km, 200)
-- 中型巡飞弹
WHEN guidance_system LIKE '%数据链%' AND max_range_km > 50 THEN LEAST(max_range_km, 100)
-- 小型巡飞弹
WHEN guidance_system LIKE '%数据链%' THEN LEAST(max_range_km, 50)
-- 无数据链的情况(使用光电或其他制导方式)
ELSE LEAST(max_range_km * 0.5, 30)
END,
-- 最大作战高度m根据航程估算
max_altitude_m =
CASE
WHEN max_range_km > 500 THEN 5000
WHEN max_range_km > 100 THEN 3000
ELSE 1500
END;
END
WHERE equipment_id IN (
SELECT id FROM equipments WHERE type = '巡飞弹'
);
-- 更新巡飞弹的制导精度
UPDATE loitering_munition_params
SET guidance_accuracy_m =
CASE
-- 基础精度(根据制导系统类型)
WHEN guidance_system LIKE '%GPS/INS%' AND guidance_system LIKE '%AI辅助%' THEN 2.0
WHEN guidance_system LIKE '%GPS/INS%' THEN 3.0
WHEN guidance_system LIKE '%激光制导%' THEN 1.0
WHEN guidance_system LIKE '%红外制导%' THEN 2.0
WHEN guidance_system LIKE '%卫星制导%' THEN 2.5
ELSE 5.0
END *
-- 速度影响因子(速度越快,精度略微降低)
CASE
WHEN max_speed_ms > 200 THEN 1.2
WHEN max_speed_ms > 150 THEN 1.1
WHEN max_speed_ms > 100 THEN 1.0
ELSE 0.9
END *
-- 重量影响因子(重量越大,精度略微降低)
CASE
WHEN warhead_weight_kg > 100 THEN 1.2
WHEN warhead_weight_kg > 50 THEN 1.1
WHEN warhead_weight_kg > 20 THEN 1.0
ELSE 0.9
END *
-- 飞行高度影响因子(高度越高,精度略微降低)
CASE
WHEN ceiling_altitude_m > 5000 THEN 1.2
WHEN ceiling_altitude_m > 3000 THEN 1.1
WHEN ceiling_altitude_m > 1000 THEN 1.0
ELSE 0.9
END
WHERE equipment_id IN (
SELECT id FROM equipments WHERE type = '巡飞弹'
);

80
src/manufacturer_data.sql Normal file
View File

@ -0,0 +1,80 @@
-- 插入供应商数据
INSERT INTO manufacturers (
name, -- 供应商名称
country, -- 所属国家
tech_level, -- 技术水平评分(1-10)
scale_level, -- 规模评分(1-10)
supply_chain_level -- 供应链成熟度评分(1-10)
) VALUES
-- 美国供应商
('美国洛克希德·马丁', '美国', 10, 10, 10), -- 全球最大军工企业
('美国 AeroVironment', '美国', 9, 8, 9), -- 无人机和导弹领域领先
('美国 Raytheon', '美国', 9, 9, 9), -- 导弹技术领先
('美国 AEVEX', '美国', 8, 7, 8), -- 新兴军工企业
('美国 AREA-I', '美国', 8, 7, 8), -- 专注无人机系统
('美国 Northrop Grumman', '美国', 9, 9, 9), -- 大型军工企业
-- 欧洲供应商
('英国 BAE Systems', '英国', 8, 9, 9), -- 欧洲最大军工企业
('英国 MBDA', '英国', 8, 8, 8), -- 导弹系统专家
('德国 KMW', '德国', 9, 8, 9), -- 陆军装备主要供应商
('德国 MBDA', '德国', 8, 8, 8), -- 导弹系统制造商
('德国 Rheinmetall', '德国', 8, 8, 8), -- 综合军工企业
('法国 Nexter', '法国', 8, 8, 8), -- 陆军装备制造商
('法国 MBDA', '法国', 8, 8, 8), -- 导弹系统制造商
('法国 Safran', '法国', 8, 8, 8), -- 航空航天企业
('意大利 Leonardo', '意大利', 7, 7, 7), -- 综合军工企业
('意大利 OTO Melara', '意大利', 7, 7, 7), -- 火炮系统制造商
-- 以色列供应商
('以色列军事工业', '以色列', 9, 7, 7), -- 技术先进
('以色列 IAI', '以色列', 9, 7, 7), -- 航空航天领先
('以色列 UVision', '以色列', 8, 6, 7), -- 无人机专家
-- 中国供应商
('中国兵器工业集团', '中国', 8, 9, 8), -- 陆军装备制造商
('中国航天科工', '中国', 8, 9, 8), -- 导弹制造商
-- 亚洲供应商
('韩国韩华防务', '韩国', 7, 7, 7), -- 韩国主要军工企业
('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业
-- 俄罗斯供应商
('俄罗斯 Rostec', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商
('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商
-- 其他欧洲供应商
('波兰 WB Electronics', '波兰', 6, 6, 6), -- 电子系统制造商
('波兰 WB Group', '波兰', 6, 6, 6), -- 军工集团
('波兰胡塔斯塔洛瓦', '波兰', 6, 6, 6), -- 装备制造商
('瑞典 UMS Skeldar', '瑞典', 7, 6, 7), -- 无人机系统
('瑞典 Saab', '瑞典', 7, 7, 7), -- 综合军工企业
('捷克 RETIA', '捷克', 6, 5, 6), -- 电子系统制造商
('斯洛伐克 ZTS', '斯洛伐克', 5, 5, 5), -- 装备制造商
('捷克 Excalibur Army', '捷克', 6, 5, 6), -- 陆军装备制造商
('克罗地亚 RH ALAN', '克罗地亚', 5, 4, 5), -- 军工企业
('塞尔维亚 Yugoimport', '塞尔维亚', 5, 4, 5), -- 军工出口企业
('芬兰 Patria', '芬兰', 7, 6, 7), -- 装甲车辆制造商
('奥地利 Hirtenberger', '奥地利', 7, 6, 7), -- 火炮系统制造商
-- 其他供应商
('土耳其洛克特桑', '土耳其', 6, 6, 6), -- 新兴军工企业
('土耳其 STM', '土耳其', 6, 6, 6), -- 防务技术公司
('巴西航空工业', '巴西', 6, 6, 5), -- 南美最大军工企业
('印度DRDO', '印度', 5, 5, 5), -- 国防研究机构
('伊朗国防工业', '伊朗', 4, 4, 4), -- 受制裁影响
('埃及 AOI', '埃及', 4, 4, 4), -- 军工企业
('罗马尼亚 ROMARM', '罗马尼亚', 5, 4, 5), -- 国营军工企业
('乌克兰尤日马什', '乌克兰', 6, 5, 5), -- 航天企业
('白俄罗斯国家军工委员会', '白俄罗斯', 5, 5, 5), -- 国家军工管理机构
('阿联酋国际金龙', '阿联酋', 6, 6, 6), -- 新兴军工企业
('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业
-- 更新装备表中的供应商ID
UPDATE equipments e
SET manufacturer_id = (
SELECT id
FROM manufacturers m
WHERE m.name = e.manufacturer
);

File diff suppressed because it is too large Load Diff

View File

@ -1,485 +0,0 @@
-- 清空现有数据
SET FOREIGN_KEY_CHECKS=0;
TRUNCATE TABLE dataset_equipment;
TRUNCATE TABLE datasets;
TRUNCATE TABLE cost_data;
TRUNCATE TABLE loitering_munition_params;
TRUNCATE TABLE common_params;
TRUNCATE TABLE equipment;
SET FOREIGN_KEY_CHECKS=1;
-- 按系列插入装备数据确保ID连续
-- 1. HAROP/Harpy 系列 (ID: 1-3)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1, 'IAI Harop', '巡飞弹', '以色列'),
(2, 'IAI Harpy', '巡飞弹', '以色列'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
-- 2. Hero 系列 (ID: 4-9)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
(9, 'Hero-900', '巡飞弹', '以色列 UVision');
-- 3. Switchblade 系列 (ID: 10-13)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
-- 4. Warmate 系列 (ID: 14-18)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
-- 5. CH-901/902 系列 (ID: 19-23)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(19, 'CH-901', '巡飞弹', '中国航天科工'),
(20, 'CH-901A', '巡飞弹', '中国航天科工'),
(21, 'CH-901H', '巡飞弹', '中国航天科工'),
(22, 'CH-902', '巡飞弹', '中国航天科工'),
(23, 'CH-902A', '巡飞弹', '中国航天科工');
-- 6. WS-43/61 系列 (ID: 24-28)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(24, 'WS-43', '巡飞弹', '中国航天科工'),
(25, 'WS-43A', '巡飞弹', '中国航天科工'),
(26, 'WS-43B', '巡飞弹', '中国航天科工'),
(27, 'WS-61', '巡飞弹', '中国航天科工'),
(28, 'WS-61A', '巡飞弹', '中国航天科工');
-- 7. Kargu/Alpagu 系列 (ID: 29-33)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(29, 'Kargu', '巡飞弹', '土耳其 STM'),
(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
-- 8. Shahed 系列 (ID: 34-38)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(34, 'Shahed-131', '巡飞弹', '伊朗'),
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
(36, 'Shahed-136', '巡飞弹', '伊朗'),
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
(38, 'Shahed-136C', '巡飞弹', '伊朗');
-- 9. Green Dragon 系列 (ID: 39-43)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
-- 10. Phoenix Ghost 系列 (ID: 44-48)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
-- 11. ZALA Lancet 系列 (ID: 49-52)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
-- 12. Rotem L 系列 (ID: 53-56)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
-- 13. KUB-BLA 系列 (ID: 57-60)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
-- 插入通用参数
INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
(10, 0.58, 0.12, 0.12, 2.5, 10),
(11, 1.30, 0.22, 0.22, 15.0, 40),
(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
(14, 0.68, 0.12, 0.12, 2.5, 10),
(15, 1.30, 0.22, 0.22, 15.0, 40),
(16, 0.68, 0.12, 0.12, 2.5, 10),
(17, 1.30, 0.22, 0.22, 15.0, 40),
(18, 0.68, 0.12, 0.12, 2.5, 10),
(19, 1.2, 0.18, 0.18, 9.0, 20),
(20, 1.2, 0.18, 0.18, 9.3, 25),
(21, 1.2, 0.18, 0.18, 9.5, 20),
(22, 1.4, 0.22, 0.22, 15.0, 30),
(23, 1.4, 0.22, 0.22, 15.5, 35),
(24, 1.8, 0.35, 0.35, 20, 60),
(25, 1.8, 0.35, 0.35, 21, 70),
(26, 1.9, 0.35, 0.35, 22, 80),
(27, 2.2, 0.40, 0.40, 35, 100),
(28, 2.2, 0.40, 0.40, 37, 120),
(29, 0.6, 0.35, 0.35, 7.0, 10),
(30, 0.6, 0.35, 0.35, 7.2, 15),
(31, 1.0, 0.23, 0.23, 3.7, 5),
(32, 1.0, 0.23, 0.23, 3.9, 8),
(33, 0.6, 0.35, 0.35, 7.5, 15),
(34, 2.6, 0.34, 0.34, 135, 900),
(35, 2.6, 0.34, 0.34, 140, 1000),
(36, 3.5, 0.42, 0.42, 200, 2000),
(37, 3.5, 0.42, 0.42, 210, 2200),
(38, 3.5, 0.42, 0.42, 215, 2500),
(39, 1.5, 0.20, 0.20, 15, 40),
(40, 1.6, 0.20, 0.20, 16, 50),
(41, 1.5, 0.20, 0.20, 15.5, 45),
(42, 1.5, 0.20, 0.20, 15.8, 40),
(43, 1.2, 0.18, 0.18, 12, 30),
(44, 1.5, 0.25, 0.25, 14.0, 30),
(45, 1.5, 0.25, 0.25, 14.5, 35),
(46, 1.6, 0.26, 0.26, 15.0, 40),
(47, 1.5, 0.25, 0.25, 14.8, 30),
(48, 1.7, 0.27, 0.27, 16.0, 50),
(49, 1.0, 0.20, 0.20, 5.0, 40),
(50, 1.65, 0.35, 0.35, 12.0, 70),
(51, 1.65, 0.35, 0.35, 12.5, 80),
(52, 1.80, 0.40, 0.40, 15.0, 100),
(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
-- 插入特有参数
INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
endurance_min,
warhead_type,
launch_mode,
power_system,
guidance_system
) VALUES
-- HAROP/Harpy系列
(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
-- Hero系列
(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
-- Switchblade系列
(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-- Warmate系列
(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
-- CH-901/902系列
(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
-- 插入成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 800000), -- IAI Harop
(2, 700000), -- IAI Harpy
(3, 350000), -- IAI Mini Harpy
(4, 70000), -- Hero-30
(5, 120000), -- Hero-70
(6, 150000), -- Hero-120
(7, 300000), -- Hero-250
(8, 400000), -- Hero-400EC
(9, 650000), -- Hero-900
(10, 60000), -- Switchblade 300
(11, 180000), -- Switchblade 600
(12, 75000), -- Switchblade 300 Block 10
(13, 200000), -- Switchblade 600 Extended Range
(14, 60000), -- Warmate 1.0
(15, 180000), -- Warmate 2.0
(16, 60000), -- Warmate-V
(17, 180000), -- Warmate-L
(18, 60000), -- Warmate 3.0
(19, 100000), -- CH-901
(20, 120000), -- CH-901A
(21, 130000), -- CH-901H
(22, 180000), -- CH-902
(23, 200000), -- CH-902A
(24, 120000), -- WS-43
(25, 150000), -- WS-43A
(26, 180000), -- WS-43B
(27, 300000), -- WS-61
(28, 350000), -- WS-61A
(29, 70000), -- Kargu
(30, 85000), -- Kargu-2
(31, 45000), -- Alpagu
(32, 55000), -- Alpagu Block-II
(33, 95000), -- Kargu Autonomous
(34, 20000), -- Shahed-131
(35, 25000), -- Shahed-131B
(36, 40000), -- Shahed-136
(37, 45000), -- Shahed-136B
(38, 50000), -- Shahed-136C
(39, 160000), -- Green Dragon
(40, 200000), -- Green Dragon Extended Range
(41, 180000), -- Green Dragon Block 2
(42, 190000), -- Green Dragon Maritime
(43, 140000), -- Green Dragon-S
(44, 150000), -- Phoenix Ghost
(45, 180000), -- Phoenix Ghost Block I
(46, 220000), -- Phoenix Ghost Block II
(47, 190000), -- Phoenix Ghost Maritime
(48, 250000), -- Phoenix Ghost-ER
(49, 80000), -- Lancet-1
(50, 150000), -- Lancet-3
(51, 180000), -- Lancet-3M
(52, 250000), -- Lancet-4
(53, 65000), -- Rotem L
(54, 85000), -- Rotem L-X
(55, 75000), -- Rotem L-M
(56, 95000), -- Rotem L-ER
(57, 95000), -- KUB-BLA
(58, 120000), -- KUB-BLA-E
(59, 110000), -- KUB-BLA-M
(60, 150000); -- KUB-BLA-ER
-- 创建数据集
INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
-- 关联装备到数据集(按照制造商和型号分配)
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 训练集约80%的数据48个型号
-- 以色列系列
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
-- 美国系列
(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
-- 中国系列
(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
-- 波兰和土耳其系列
(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
-- 俄罗斯系列
(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
-- 验证集约20%的数据12个型号
-- 混合系列
(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
-- 添加分类特征编码
INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
-- 战斗部类型编码
('warhead_type', '破片杀伤战斗部', 1),
('warhead_type', '破甲战斗部', 2),
('warhead_type', '高爆战斗部', 3),
('warhead_type', '破片杀伤/破甲双用战斗部', 4),
('warhead_type', '模块化战斗部', 5),
-- 发射方式编码
('launch_mode', '箱式发射', 1),
('launch_mode', '弹射式发射', 2),
('launch_mode', '垂直起降', 3),
('launch_mode', '单兵发射管', 4),
('launch_mode', '箱式发射/弹射式', 5),
('launch_mode', '箱式发射/空中发射', 6),
-- 动力装置编码(按复杂度递增)
('power_system', '电动机', 1),
('power_system', '活塞发动机', 2),
-- 制导系统编码(按复杂度递增)
('guidance_system', 'GPS/INS', 1),
('guidance_system', 'GPS/INS/光电', 2),
('guidance_system', 'GPS/INS/光电/数据链', 3),
('guidance_system', 'GPS/INS/光电/AI识别', 4),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
UPDATE loitering_munition_params l
JOIN common_params c ON l.equipment_id = c.equipment_id
SET
-- 新增关键参数
l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
l.engine_power_kw =
CASE
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
END,
l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
l.guidance_accuracy_m =
CASE
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
ELSE 3.0
END,
l.min_altitude_m = -- 最小作战高度
CASE
-- 大型巡飞弹(体型大、重量大)
WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
-- 中型巡飞弹
WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
-- 中小型巡飞弹
WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
-- 小型巡飞弹
WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
-- 超小型巡飞弹
WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
-- 其他型号使用默认值
ELSE 50
END,
l.max_altitude_m =
CASE
WHEN c.max_range_km > 500 THEN 5000
WHEN c.max_range_km > 100 THEN 3000
ELSE 1500
END,
-- 特征工程字段
l.length_width_ratio = c.length_m / c.width_m,
l.weight_range_ratio = c.weight_kg / c.max_range_km,
l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
l.guidance_system_score =
CASE
WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
ELSE 2
END,
l.warhead_power_score =
CASE
WHEN l.warhead_type = '模块化战斗部' THEN 10
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
WHEN l.warhead_type = '高爆战斗部' THEN 7
WHEN l.warhead_type = '破甲战斗部' THEN 6
WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
ELSE 4
END,
-- 分类特征编码
l.warhead_type_code =
CASE
WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
WHEN l.warhead_type = '破甲战斗部' THEN 2
WHEN l.warhead_type = '高爆战斗部' THEN 3
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
WHEN l.warhead_type = '模块化战斗部' THEN 5
ELSE 0
END,
l.launch_mode_code =
CASE
WHEN l.launch_mode = '箱式发射' THEN 1
WHEN l.launch_mode = '弹射式发射' THEN 2
WHEN l.launch_mode = '垂直起降' THEN 3
WHEN l.launch_mode = '单兵发射管' THEN 4
WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
ELSE 0
END,
l.power_system_code =
CASE
WHEN l.power_system = '电动机' THEN 1
WHEN l.power_system = '活塞发动机' THEN 2
ELSE 0
END,
l.guidance_system_code =
CASE
WHEN l.guidance_system = 'GPS/INS' THEN 1
WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
ELSE 0
END;

View File

@ -29,7 +29,7 @@
*/
-- 中国系列火箭炮数据
INSERT INTO equipment (id, name, type, manufacturer) VALUES
INSERT INTO equipments (id, name, type, manufacturer) VALUES
(1001, 'PCL-191', '火箭炮', '中国兵器工业集团'),
(1002, 'PHL-03', '火箭炮', '中国兵器工业集团'),
(1003, 'AR-3', '火箭炮', '中国航天科工'),
@ -39,11 +39,11 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1007, 'WS-2', '火箭炮', '中国航天科工'),
(1008, 'WS-3', '火箭炮', '中国航天科工'),
(1009, 'Type 63', '火箭炮', '中国兵器工业集团'),
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯'),
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯'),
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯'),
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯'),
(1014, 'TOS-1A', '火箭炮', '俄罗斯'),
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯 Rostec'),
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯 Rostec'),
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯 Rostec'),
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯 Rostec'),
(1014, 'TOS-1A', '火箭炮', '俄罗斯 Rostec'),
(1015, 'M142 HIMARS', '火箭炮', '美国洛克希德·马丁'),
(1016, 'M270 MLRS', '火箭炮', '美国洛克希德·马丁'),
(1017, 'M270A1', '火箭炮', '美国洛克希德·马丁'),
@ -62,10 +62,10 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1030, 'ASTROS 2020', '火箭炮', '巴西航空工业'),
(1031, 'ASTROS II Mk3', '火箭炮', '巴西航空工业'),
(1032, 'ASTROS II Mk6', '火箭炮', '巴西航空工业'),
(1033, 'Pinaka', '火箭炮', '印度DRDO'),
(1034, 'Pinaka Mk-II', '火箭炮', '印度DRDO'),
(1035, 'Pinaka Mk-III', '火箭炮', '印度DRDO'),
(1036, 'Pinaka-ER', '火箭炮', '印度DRDO'),
(1033, 'Pinaka', '火箭炮', '印度 DRDO'),
(1034, 'Pinaka Mk-II', '火箭炮', '印度 DRDO'),
(1035, 'Pinaka Mk-III', '火箭炮', '印度 DRDO'),
(1036, 'Pinaka-ER', '火箭炮', '印度 DRDO'),
(1037, 'WR-40 Langusta', '火箭炮', '波兰胡塔斯塔洛瓦'),
(1038, 'RM-70', '火箭炮', '波兰胡塔斯塔洛瓦'),
(1039, 'BM-21M', '火箭炮', '波兰胡塔斯塔洛瓦'),
@ -485,7 +485,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(4, '火箭炮验证集 2024', '包含19个火箭炮型号用于验证模型性能', '火箭炮', '验证');
-- 训练集约80%的数据77个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 中国系列7/9
(3, 1001), (3, 1002), (3, 1003), (3, 1004), (3, 1005), (3, 1006), (3, 1007),
@ -565,7 +565,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
(3, 1094), (3, 1095);
-- 验证集约20%的数据19个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 中国系列2/9
(4, 1008), (4, 1009),

File diff suppressed because it is too large Load Diff

View File

@ -10,11 +10,12 @@ COLLATE utf8mb4_unicode_ci;
USE equipment_cost_db;
-- 装备基本信息表
CREATE TABLE equipment (
CREATE TABLE equipments (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100), -- 名称
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
manufacturer VARCHAR(100), -- 制造商
manufacturer_id INT, -- 制造商ID
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
@ -26,8 +27,7 @@ CREATE TABLE common_params (
width_m FLOAT, -- 宽度(m)
height_m FLOAT, -- 高度(m)
weight_kg FLOAT, -- 重量(kg)
max_range_km FLOAT, -- 最大射程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 火箭炮特有参数表
@ -61,7 +61,7 @@ CREATE TABLE rocket_artillery_params (
deployment_score INT, -- 部署评分(1-10)
terrain_adaptability_score INT, -- 地形适应性评分(1-10)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 巡飞弹特有参数表
@ -103,7 +103,7 @@ CREATE TABLE loitering_munition_params (
power_system_code INT, -- 动力装置编码
guidance_system_code INT, -- 制导系统编码
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 分类特征编码表
@ -122,7 +122,7 @@ CREATE TABLE cost_data (
actual_cost DECIMAL(15,2), -- 实际成本(元)
predicted_cost DECIMAL(15,2), -- 预测成本(元)
prediction_date TIMESTAMP, -- 预测日期
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 特殊参数表
@ -133,12 +133,12 @@ CREATE TABLE custom_params (
param_value VARCHAR(255), -- 参数值
param_unit VARCHAR(50), -- 参数单位
description TEXT, -- 参数说明
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_equipment_type ON equipment(type);
CREATE INDEX idx_equipment_name ON equipment(name);
CREATE INDEX idx_equipment_type ON equipments(type);
CREATE INDEX idx_equipment_name ON equipments(name);
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-- 数据集表
@ -153,12 +153,12 @@ CREATE TABLE datasets (
);
-- 数据集-装备关联表
CREATE TABLE dataset_equipment (
CREATE TABLE dataset_equipments (
dataset_id INT NOT NULL,
equipment_id INT NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
-- 训练模型表
@ -175,10 +175,34 @@ CREATE TABLE trained_models (
feature_importance JSON, -- 特征重要性
training_data_size INT, -- 训练数据量
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
created_by VARCHAR(50) -- 创建者
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX idx_model_active ON trained_models(is_active);
CREATE INDEX idx_model_active ON trained_models(is_active);
-- 生产商表
CREATE TABLE manufacturers (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL, -- 生产商名称
country VARCHAR(50) NOT NULL, -- 所属国家
tech_level INT NOT NULL, -- 技术水平评分(1-10)
scale_level INT NOT NULL, -- 规模评分(1-10)
supply_chain_level INT NOT NULL, -- 供应链成熟度评分(1-10)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_name (name)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加生产商外键
ALTER TABLE equipments ADD FOREIGN KEY (manufacturer_id) REFERENCES manufacturers(id);
-- 添加索引
CREATE INDEX idx_manufacturer_country ON manufacturers(country);
CREATE INDEX idx_manufacturer_tech_level ON manufacturers(tech_level);
CREATE INDEX idx_manufacturer_scale_level ON manufacturers(scale_level);
CREATE INDEX idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
CREATE INDEX idx_equipment_manufacturer ON equipments(manufacturer_id);

5
src/start.bat Normal file
View File

@ -0,0 +1,5 @@
@echo off
set FLASK_DEBUG=false
echo Starting Cost Prediction System...
start /B run.exe
start http://localhost:5001

View File

@ -147,14 +147,124 @@ def test_api_endpoints():
response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型")
# 8. 测试多模型预测接口
logger.info("\n8. 测试多模型预测接口")
# 8. 测试预测接口
logger.info("\n8. 测试预测接口")
# 8.1 测试普通预测接口
logger.info("8.1 测试普通预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post(
f'{base_url}/predict/all',
f'{base_url}/predict',
json=predict_data
)
print_response(response, "多模型预测")
print_response(response, "普通预测")
# 8.2 测试 PLS 预测接口
logger.info("8.2 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print_response(response, "PLS 预测")
# 9. 测试生产商分析接口
logger.info("\n9. 测试生产商分析接口")
manufacturer_data = {
"dataset_id": 1 # 使用已存在的数据集ID
}
response = requests.post(
f'{base_url}/analyze-manufacturers',
json=manufacturer_data
)
print_response(response, "生产商分析")
# 10. 测试模型激活接口
logger.info("\n10. 测试模型激活接口")
# 假设存在ID为1的模型
response = requests.post(f'{base_url}/models/1/activate')
print_response(response, "模型激活")
# 11. 测试获取最新模型接口
logger.info("\n11. 测试获取最新模型接口")
response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型")
# 12. 测试数据集详情接口
logger.info("\n12. 测试数据集详情接口")
response = requests.get(f'{base_url}/datasets/1') # 假设存在ID为1的数据集
print_response(response, "数据集详情")
# 13. 测试更新数据集接口
logger.info("\n13. 测试更新数据集接口")
if available_equipment_ids:
update_dataset_data = {
"name": "更新后的测试数据集",
"description": "用于测试的更新数据集",
"equipment_type": "巡飞弹",
"purpose": "测试",
"equipment_ids": available_equipment_ids[:2] # 使用前两个可用的装备ID
}
response = requests.put(
f'{base_url}/datasets/1', # 假设更新ID为1的数据集
json=update_dataset_data
)
print_response(response, "更新数据集")
else:
logger.warning("没有可用的装备ID跳过数据集更新测试")
# 14. 测试装备详情接口
logger.info("\n14. 测试装备详情接口")
if available_equipment_ids:
response = requests.get(f'{base_url}/data/details/{available_equipment_ids[0]}')
print_response(response, "装备详情")
# 15. 测试更新装备接口
logger.info("\n15. 测试更新装备接口")
if available_equipment_ids:
equipment_update_data = {
"equipment_id": available_equipment_ids[0],
"name": "更新后的装备名称",
"type": "巡飞弹",
"manufacturer": "测试厂商",
"length_m": 1.5,
"width_m": 0.3,
"height_m": 0.3,
"weight_kg": 15.0,
"wingspan_m": 0.8,
"warhead_weight_kg": 5.0,
"max_speed_ms": 60,
"cruise_speed_kmh": 120,
"endurance_min": 45,
"max_range_km": 50,
"warhead_type": "高爆战斗部",
"launch_mode": "弹射起飞",
"power_system": "涡轮发动机",
"guidance_system": "GPS/INS组合导航"
}
response = requests.put(
f'{base_url}/data/{available_equipment_ids[0]}',
json=equipment_update_data
)
print_response(response, "更新装备")
logger.info("所有测试完成")
except requests.exceptions.RequestException as e:

40
tests/test_demo_routes.py Normal file
View File

@ -0,0 +1,40 @@
from src import create_app
def test_demo_algorithms_route_returns_available_models():
app = create_app()
client = app.test_client()
response = client.get("/api/demo/algorithms")
assert response.status_code == 200
payload = response.get_json()
assert any(item["key"] == "random_forest" for item in payload["algorithms"])
def test_demo_dataset_route_returns_local_file_summary():
app = create_app()
client = app.test_client()
response = client.get("/api/demo/dataset")
assert response.status_code == 200
payload = response.get_json()
assert payload["source"] == "local-file"
assert payload["row_count"] >= 20
def test_demo_run_route_returns_metrics_without_mysql():
app = create_app()
client = app.test_client()
response = client.post(
"/api/demo/run",
json={"algorithms": ["linear", "random_forest"]},
)
assert response.status_code == 200
payload = response.get_json()
assert payload["source"] == "local-file"
assert set(payload["metrics"]) == {"linear", "random_forest"}
assert payload["prediction_points"]

View File

@ -0,0 +1,49 @@
from pathlib import Path
from src.demo_service import DemoModelService
def test_demo_service_loads_local_dataset():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
summary = service.get_dataset_summary()
assert summary["row_count"] >= 20
assert "actual_cost" in summary["columns"]
assert summary["target"] == "actual_cost"
assert summary["preview"][0]["name"]
assert summary["preview"][0]["type"] in {"巡飞弹", "火箭炮"}
def test_demo_service_returns_chinese_algorithm_names_with_english_notes():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
algorithms = service.get_algorithms()
linear = next(item for item in algorithms if item["key"] == "linear")
assert linear["name"] == "线性回归"
assert linear["english_name"] == "Linear Regression"
assert linear["family"] == "线性模型"
def test_demo_service_runs_multiple_algorithms():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
result = service.run_demo(["linear", "random_forest", "gradient_boosting"])
assert result["source"] == "local-file"
assert result["best_model"] in result["metrics"]
assert len(result["metrics"]) == 3
assert len(result["prediction_points"]) > 0
assert len(result["sample_prediction"]["predictions"]) == 3
for metrics in result["metrics"].values():
assert {"r2", "mae", "rmse"}.issubset(metrics)
def test_demo_service_ignores_unavailable_algorithms():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
result = service.run_demo(["linear", "does_not_exist"])
assert list(result["metrics"].keys()) == ["linear"]
assert result["warnings"]