Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 48ba547c36 | |||
| 137451ba7a | |||
| 485b4e497a | |||
| 1c91ba42ab | |||
| 04e43760ae | |||
| 092398c6df | |||
| e0249d65d8 | |||
| ed91df6607 | |||
| 047cafa7c7 | |||
| 8cd0bb5c06 | |||
| e67da8eaed | |||
| dba9f2fcc9 | |||
| 9421512677 | |||
| 38c543d599 | |||
| 06aac27f20 |
29
.claude/settings.local.json
Normal file
29
.claude/settings.local.json
Normal file
File diff suppressed because one or more lines are too long
31
.cursorrules
31
.cursorrules
@ -1,4 +1,3 @@
|
||||
|
||||
# 代码修改最佳实践
|
||||
|
||||
1. 修改前的准备
|
||||
@ -56,3 +55,33 @@
|
||||
- 处理异常情况
|
||||
- 保护敏感信息
|
||||
- 添加访问控制
|
||||
|
||||
9. 中文处理规则
|
||||
|
||||
- 不修改任何包含中文的注释行
|
||||
- 使用 `// ... existing code ...` 跳过包含中文的代码块
|
||||
- 如需修改中文附近的代码,应完整保留原有中文内容
|
||||
|
||||
示例:
|
||||
```
|
||||
// ... existing code ...
|
||||
// 这是中文注释,保持不变
|
||||
newCode = value;
|
||||
// ... existing code ...
|
||||
```
|
||||
|
||||
10. 编码规则
|
||||
|
||||
- 所有文件统一使用 UTF-8 编码
|
||||
- 不使用 BOM 头
|
||||
- 换行符统一使用 LF (\n)
|
||||
- 文件末尾保留一个换行符
|
||||
- 代码注释中的中文必须使用 UTF-8 编码
|
||||
|
||||
配置示例:
|
||||
```json
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"end_of_line": "lf",
|
||||
"insert_final_newline": true
|
||||
}
|
||||
25
.env.example
25
.env.example
@ -1,25 +0,0 @@
|
||||
# 数据库配置
|
||||
MYSQL_HOST=localhost
|
||||
MYSQL_USER=root
|
||||
MYSQL_PASSWORD=your_password_here
|
||||
MYSQL_DATABASE=equipment_cost_db
|
||||
|
||||
# 服务配置
|
||||
PORT=5001
|
||||
DEBUG=False
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL=INFO
|
||||
LOG_DIR=logs
|
||||
|
||||
# 模型配置
|
||||
MODEL_DIR=models
|
||||
DATA_DIR=data
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY=your_secret_key_here
|
||||
ALLOWED_HOSTS=localhost,127.0.0.1
|
||||
|
||||
# 其他配置
|
||||
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
|
||||
ALLOWED_FILE_TYPES=.xlsx,.xls
|
||||
49
.gitignore
vendored
49
.gitignore
vendored
@ -1,14 +1,54 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
node_modules
|
||||
/dist
|
||||
/models
|
||||
/logs
|
||||
/uploads
|
||||
/data
|
||||
/data/*
|
||||
!/data/demo_equipment_costs.csv
|
||||
|
||||
# local env files
|
||||
.env.local
|
||||
.env.*.local
|
||||
.venv
|
||||
|
||||
# Log files
|
||||
npm-debug.log*
|
||||
@ -25,3 +65,10 @@ pnpm-debug.log*
|
||||
*.sln
|
||||
*.sw?
|
||||
/frontend/node_modules
|
||||
/frontend/dist
|
||||
/release/
|
||||
|
||||
*.zip
|
||||
*.tar
|
||||
*.gz
|
||||
*.whl
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.11.8
|
||||
108
CLAUDE.md
Normal file
108
CLAUDE.md
Normal file
@ -0,0 +1,108 @@
|
||||
# CLAUDE.md
|
||||
|
||||
本文件为 Claude Code(claude.ai/code)在此仓库中工作提供指引。
|
||||
|
||||
## 常用命令
|
||||
|
||||
```bash
|
||||
# 安装后端依赖(核心:无需 MySQL、无需 PyTorch)
|
||||
pip install -e .
|
||||
pip install -e ".[dev]" # 含开发工具(pytest, black, mypy)
|
||||
pip install -e ".[torch]" # 安装可选 PyTorch(神经网络训练需要)
|
||||
|
||||
# 启动后端(Flask,0.0.0.0:5001)
|
||||
python run.py
|
||||
|
||||
# 启动前端开发服务器(另开终端,localhost:3000)
|
||||
cd frontend && npm install && npm run serve
|
||||
|
||||
# 构建前端生产版本
|
||||
cd frontend && npm run build
|
||||
|
||||
# 运行测试
|
||||
python -m pytest tests/
|
||||
python -m pytest tests/test_demo_service.py -q # 单个测试文件
|
||||
python -m pytest tests/test_demo_routes.py -q
|
||||
|
||||
# 代码格式化 / 类型检查
|
||||
black src/ tests/ # 自动格式化(line-length=88)
|
||||
mypy src/ # 类型检查
|
||||
|
||||
# 运行独立演示模式(仅需 5 个依赖)
|
||||
cd demo_standalone && pip install -r requirements.txt && python server.py
|
||||
```
|
||||
|
||||
## 项目概述
|
||||
|
||||
基于机器学习的装备成本预测系统。支持两种装备类型:**火箭炮**和**巡飞弹**。
|
||||
|
||||
### 技术栈
|
||||
|
||||
- **后端**:Python 3.9-3.11, Flask 3.1+
|
||||
- **数据库**:SQLite(Python 内置,零外部依赖),首次启动自动建表,无需手动安装配置
|
||||
- **机器学习**:scikit-learn(核心)、XGBoost、LightGBM
|
||||
- **可选依赖**:PyTorch(仅神经网络训练需要,约 800MB)
|
||||
- **前端**:Vue 3 (Composition API) + Element Plus + ECharts, 使用 Vite 构建
|
||||
|
||||
### 关键文件
|
||||
|
||||
| 文件 | 用途 |
|
||||
|---|---|
|
||||
| `run.py` | 入口点,启动 Flask 服务器 |
|
||||
| `src/app.py` | Flask 应用工厂 |
|
||||
| `src/routes.py` | 所有 API 路由(约 1300 行) |
|
||||
| `src/model_trainer.py` | ModelTrainer 类 + CostPredictionModel(PyTorch NN) |
|
||||
| `src/data_preparation.py` | DataPreparation + EquipmentDataset |
|
||||
| `src/cost_prediction.py` | CostPredictor(预测编排) |
|
||||
| `src/demo_service.py` | DemoModelService(基于 CSV,无需数据库) |
|
||||
| `src/database/db_connection.py` | SQLite 数据库连接 + 建表 DDL(内置) |
|
||||
| `config.py` | 运行时配置 |
|
||||
| `frontend/src/router/index.js` | Vue Router,共 8 个路由 |
|
||||
| `frontend/src/api/index.js` | Axios API 客户端 |
|
||||
|
||||
## 架构
|
||||
|
||||
### 与旧架构的核心变化
|
||||
|
||||
| 项目 | 旧架构(MySQL) | 新架构(SQLite) |
|
||||
|------|----------------|-----------------|
|
||||
| 数据库 | MySQL 8.0+,需单独安装运行 | SQLite,Python 内置,零依赖 |
|
||||
| 数据库依赖 | sqlalchemy, pymysql, cryptography, mysql-connector-python | 无(仅用 Python 标准库 sqlite3) |
|
||||
| PyTorch | 硬依赖(顶层 import,无则崩溃) | 可选依赖(try/except 保护,无 PyTorch 可启动) |
|
||||
| 前端构建 | Vue CLI + Babel + SCSS | Vite(无需 Babel/SCSS) |
|
||||
| Vuex | 存在但完全空 | 已移除 |
|
||||
| 依赖总数(核心)| 15+ | 5(flask, numpy, pandas, scikit-learn, openpyxl) |
|
||||
|
||||
### 训练数据流
|
||||
```
|
||||
用户界面 → POST /api/train → 查询 SQLite → DataPreparation(特征提取 + 标准化)
|
||||
→ ModelTrainer.fit_model()(训练 XGBoost, LightGBM, RF, GBM, PyTorch NN, PLS)
|
||||
→ 保存最优模型 → 写入 SQLite trained_models 表
|
||||
```
|
||||
|
||||
### 预测数据流
|
||||
```
|
||||
用户界面 → POST /api/predict → 从 SQLite 加载最优模型 + 标准化器 → 提取特征
|
||||
→ 特征标准化 → 预测 → ±20% 置信区间 → JSON 返回
|
||||
```
|
||||
|
||||
### 机器学习模型
|
||||
|
||||
| 模型 | 标识 | 说明 |
|
||||
|---|---|---|
|
||||
| XGBoost | `xgboost` | 针对小数据量使用保守参数 |
|
||||
| LightGBM | `lightgbm` | 针对小数据量使用保守参数 |
|
||||
| Random Forest | `rf` | sklearn.ensemble |
|
||||
| Gradient Boosting | `gbm` | sklearn.ensemble |
|
||||
| PyTorch NN | `pytorch` | 可选(需安装 torch),针对不同装备类型定制网络结构 |
|
||||
| PLS 回归 | `pls` | 不参与最优模型评选 |
|
||||
| Linear/Ridge/SVR/KNN | 仅演示 | 用于算法对比演示 |
|
||||
|
||||
## 重要注意事项
|
||||
|
||||
- **小数据量机器学习**:所有模型超参数均为保守设置(强正则化、浅树、低学习率、早停)
|
||||
- **两种装备类型**:火箭炮(27+ 特征)和巡飞弹(24+ 特征),各自有独立数据库表和神经网络结构
|
||||
- **生产商议价能力特征**:技术等级、规模、供应链、地区等信息被纳入特征,地区成本乘数(如美国 1.2 倍、中国 0.8 倍)
|
||||
- **SQLite 首次使用**:`data/equipment_cost.db` 文件在首次数据库操作时自动创建,无需手动初始化
|
||||
- **编码规范**:UTF-8 无 BOM,LF 换行符,文件末尾保留换行符
|
||||
- **中文内容**:绝不修改中文注释或文本。编辑中文附近代码时,完整保留原有中文内容
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
# MIT License
|
||||
|
||||
Copyright (c) 2024 Your Name or Your Organization
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
111
README.md
111
README.md
@ -1,23 +1,102 @@
|
||||
# 数据库配置说明
|
||||
# 装备成本预测系统
|
||||
|
||||
本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要:
|
||||
基于机器学习的装备成本预测系统,支持多种预测模型和数据分析功能。
|
||||
|
||||
1. 创建数据库用户
|
||||
## 功能特性
|
||||
|
||||
```sql
|
||||
CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password';
|
||||
GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost';
|
||||
FLUSH PRIVILEGES;
|
||||
- 多模型成本预测
|
||||
- 机器学习模型 (XGBoost, LightGBM, RandomForest)
|
||||
- PLS 回归模型
|
||||
- 特征分析与数据可视化
|
||||
- 生产商分析
|
||||
- 数据集管理
|
||||
- 模型训练与评估
|
||||
|
||||
## 系统要求
|
||||
|
||||
- Python >= 3.9, < 3.12
|
||||
- MySQL >= 8.0
|
||||
- 其他依赖见 pyproject.toml
|
||||
|
||||
## 快速开始
|
||||
|
||||
1. 克隆项目
|
||||
|
||||
```bash
|
||||
git clone [repository-url]
|
||||
cd cost-prediction
|
||||
```
|
||||
|
||||
2. 配置数据库字符集
|
||||
确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置:
|
||||
2. 安装依赖
|
||||
|
||||
```ini
|
||||
[mysqld]
|
||||
character-set-server=utf8mb4
|
||||
collation-server=utf8mb4_unicode_ci
|
||||
|
||||
[client]
|
||||
default-character-set=utf8mb4
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
3. 配置数据库
|
||||
|
||||
```bash
|
||||
[Windows]
|
||||
scripts/setup_env.ps1
|
||||
|
||||
[Linux/macOS]
|
||||
scripts/setup_env.sh
|
||||
```
|
||||
|
||||
4. 运行系统
|
||||
|
||||
```bash
|
||||
python run.py
|
||||
```
|
||||
|
||||
## API 文档
|
||||
|
||||
### 预测接口
|
||||
|
||||
- POST `/api/predict` - 使用最优机器学习模型预测
|
||||
- POST `/api/pls/predict` - 使用 PLS 模型预测
|
||||
|
||||
### 数据管理
|
||||
|
||||
- GET `/api/data` - 获取装备数据列表
|
||||
- GET `/api/data/details/<id>` - 获取装备详情
|
||||
- PUT `/api/data/<id>` - 更新装备数据
|
||||
|
||||
### 数据集管理
|
||||
|
||||
- GET `/api/datasets` - 获取数据集列表
|
||||
- POST `/api/datasets` - 创建数据集
|
||||
- GET `/api/datasets/<id>` - 获取数据集详情
|
||||
- PUT `/api/datasets/<id>` - 更新数据集
|
||||
- DELETE `/api/datasets/<id>` - 删除数据集
|
||||
|
||||
### 模型管理
|
||||
|
||||
- GET `/api/models` - 获取模型列表
|
||||
- POST `/api/train` - 训练模型
|
||||
- POST `/api/models/<id>/activate` - 激活模型
|
||||
- DELETE `/api/models/<id>` - 删除模型
|
||||
|
||||
### 分析功能
|
||||
|
||||
- POST `/api/analyze-features` - 特征分析
|
||||
- POST `/api/analyze-manufacturers` - 生产商分析
|
||||
|
||||
## 开发指南
|
||||
|
||||
详细的开发文档请参考 `docs/dev/` 目录:
|
||||
|
||||
- requirements.md - 项目需求文档
|
||||
- debug.md - 调试指南
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试:
|
||||
|
||||
```bash
|
||||
python src/test_api.py
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 [LICENSE](LICENSE) 许可证。
|
||||
|
||||
126
config.py
126
config.py
@ -1,32 +1,100 @@
|
||||
import os
|
||||
import secrets
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
|
||||
class Config:
|
||||
"""配置类"""
|
||||
# 数据库配置(使用 SQLite)
|
||||
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
|
||||
|
||||
# Flask配置
|
||||
FLASK_HOST = '0.0.0.0'
|
||||
FLASK_PORT = 5001
|
||||
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
|
||||
|
||||
# 目录配置
|
||||
MODEL_DIR = 'models'
|
||||
DATA_DIR = 'data'
|
||||
LOG_DIR = 'logs'
|
||||
UPLOAD_DIR = 'uploads'
|
||||
TEMPLATE_DIR = 'templates'
|
||||
|
||||
# 文件上传配置
|
||||
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
||||
|
||||
# API配置
|
||||
API_VERSION = 'v1'
|
||||
API_PREFIX = f'/api/{API_VERSION}'
|
||||
|
||||
# 日志配置
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOG_LEVEL = 'INFO'
|
||||
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
|
||||
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
LOG_BACKUP_COUNT = 5
|
||||
|
||||
# PyTorch配置
|
||||
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 0.001
|
||||
NUM_EPOCHS = 100
|
||||
|
||||
# 模型训练配置
|
||||
TRAIN_TEST_SPLIT = 0.2
|
||||
RANDOM_SEED = 42
|
||||
EARLY_STOPPING_PATIENCE = 10
|
||||
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
|
||||
|
||||
# 缓存配置
|
||||
CACHE_TYPE = 'simple'
|
||||
CACHE_DEFAULT_TIMEOUT = 300
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here')
|
||||
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here')
|
||||
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
|
||||
|
||||
# 跨域配置
|
||||
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
|
||||
|
||||
# 数据验证配置
|
||||
MAX_EQUIPMENT_NAME_LENGTH = 100
|
||||
MAX_MANUFACTURER_NAME_LENGTH = 100
|
||||
|
||||
@classmethod
|
||||
def init_app(cls, app):
|
||||
"""初始化应用配置"""
|
||||
# 创建必要的目录
|
||||
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
|
||||
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
# 配置日志
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
formatter = logging.Formatter(cls.LOG_FORMAT)
|
||||
file_handler = RotatingFileHandler(
|
||||
cls.LOG_FILE,
|
||||
maxBytes=cls.LOG_MAX_SIZE,
|
||||
backupCount=cls.LOG_BACKUP_COUNT
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(cls.LOG_LEVEL)
|
||||
|
||||
app.logger.addHandler(file_handler)
|
||||
app.logger.setLevel(cls.LOG_LEVEL)
|
||||
|
||||
# 配置上传目录
|
||||
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
|
||||
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
|
||||
|
||||
# 配置跨域
|
||||
from flask_cors import CORS
|
||||
CORS(app, resources={
|
||||
r"/api/*": {"origins": cls.CORS_ORIGINS}
|
||||
})
|
||||
|
||||
return app
|
||||
|
||||
# 安全密钥配置(自动生成随机密钥)
|
||||
SECRET_KEY = secrets.token_hex(16)
|
||||
|
||||
# 环境配置
|
||||
DEBUG = False
|
||||
ENV = 'production'
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
||||
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
|
||||
|
||||
# API配置
|
||||
API_VERSION = 'v1'
|
||||
API_PREFIX = f'/api/{API_VERSION}'
|
||||
|
||||
# 跨域配置
|
||||
CORS_ORIGINS = [
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
]
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL = 'DEBUG'
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')
|
||||
# 创建配置实例
|
||||
config = Config()
|
||||
27
data/demo_equipment_costs.csv
Normal file
27
data/demo_equipment_costs.csv
Normal file
@ -0,0 +1,27 @@
|
||||
name,type,length_m,width_m,height_m,weight_kg,max_range_km,payload_kg,max_speed_kmh,endurance_min,tech_level,scale_level,supply_chain_level,complexity_score,actual_cost
|
||||
隼击-A,巡飞弹,1.2,1.8,0.32,18,35,4,145,55,6.4,5.8,6.2,5.9,420000
|
||||
隼击-B,巡飞弹,1.5,2.1,0.36,26,48,6,160,70,6.8,6.2,6.4,6.6,610000
|
||||
隼击-C,巡飞弹,1.8,2.5,0.42,34,65,8,175,85,7.2,6.4,6.8,7.1,830000
|
||||
侦察-100,巡飞弹,0.9,1.4,0.25,9,18,2,110,35,5.4,5.1,5.5,4.8,190000
|
||||
侦察-200,巡飞弹,1.1,1.7,0.29,14,28,3,125,48,5.9,5.4,5.7,5.3,310000
|
||||
侦察-300,巡飞弹,1.4,2.0,0.34,22,42,5,150,62,6.3,5.9,6.0,6.1,520000
|
||||
锐蛇-S,巡飞弹,1.7,2.4,0.38,30,58,7,185,76,7.5,6.7,6.9,7.4,940000
|
||||
锐蛇-M,巡飞弹,2.0,2.8,0.46,44,82,10,205,94,8.0,7.1,7.3,8.0,1360000
|
||||
锐蛇-L,巡飞弹,2.4,3.2,0.55,62,120,15,230,125,8.7,7.5,7.8,8.8,2100000
|
||||
鹰眼-1,巡飞弹,1.3,1.9,0.31,20,40,4,155,58,6.6,5.7,6.3,6.0,470000
|
||||
鹰眼-2,巡飞弹,1.6,2.2,0.37,29,57,7,172,78,7.1,6.1,6.6,6.9,760000
|
||||
鹰眼-3,巡飞弹,2.1,2.9,0.49,51,95,12,215,105,8.2,7.0,7.2,8.1,1580000
|
||||
雷霆-122,火箭炮,6.9,2.4,2.8,13500,22,480,72,0,5.8,6.6,6.0,5.5,980000
|
||||
雷霆-160,火箭炮,7.6,2.6,3.0,16800,40,760,68,0,6.4,6.9,6.3,6.1,1450000
|
||||
雷霆-220,火箭炮,8.3,2.8,3.2,21500,70,1200,65,0,7.0,7.1,6.8,7.0,2380000
|
||||
雷霆-300,火箭炮,9.8,3.0,3.4,28500,120,1850,62,0,7.8,7.4,7.2,8.0,4200000
|
||||
山猫-95,火箭炮,6.2,2.3,2.7,11800,18,360,78,0,5.4,6.0,5.7,5.0,740000
|
||||
山猫-120,火箭炮,6.7,2.4,2.8,13000,30,520,75,0,5.9,6.2,6.0,5.6,1050000
|
||||
山猫-200,火箭炮,7.9,2.7,3.1,19800,60,980,70,0,6.8,6.8,6.5,6.7,1980000
|
||||
山猫-300,火箭炮,9.3,2.9,3.3,26000,105,1600,66,0,7.6,7.2,7.0,7.8,3560000
|
||||
弓兵-L,火箭炮,8.8,2.9,3.2,23500,85,1350,69,0,7.2,7.0,6.9,7.3,2860000
|
||||
弓兵-X,火箭炮,10.2,3.1,3.6,31000,150,2100,60,0,8.4,7.8,7.6,8.7,5400000
|
||||
长矛-1,火箭炮,7.1,2.5,2.9,14200,28,560,73,0,6.1,6.4,6.1,5.8,1180000
|
||||
长矛-2,火箭炮,8.1,2.7,3.1,20500,75,1120,68,0,7.1,6.9,6.7,7.1,2420000
|
||||
长矛-3,火箭炮,9.6,3.0,3.5,29200,130,1900,63,0,8.1,7.5,7.4,8.3,4650000
|
||||
擎天-M,火箭炮,10.8,3.2,3.8,34800,180,2450,58,0,8.9,8.0,7.9,9.2,6900000
|
||||
|
13
demo_standalone/README.md
Normal file
13
demo_standalone/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# 机器学习算法演示
|
||||
|
||||
## 运行方式
|
||||
|
||||
1. 解压 zip 文件。
|
||||
2. 双击 `start_demo.bat`。
|
||||
3. 浏览器会自动打开 `http://127.0.0.1:5001/algorithm-demo`。
|
||||
|
||||
## 说明
|
||||
|
||||
- 演示使用 `data/demo_equipment_costs.csv`,不需要 MySQL。
|
||||
- 首次运行会创建 `.venv` 并安装最小 Python 依赖。
|
||||
- 需要本机已安装 Python 3.9 至 3.11。
|
||||
5
demo_standalone/requirements.txt
Normal file
5
demo_standalone/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
flask>=3.1.0
|
||||
flask-cors>=5.0.0
|
||||
numpy>=1.26.0,<2.0.0
|
||||
pandas>=2.2.0
|
||||
scikit-learn>=1.5.2
|
||||
48
demo_standalone/server.py
Normal file
48
demo_standalone/server.py
Normal file
@ -0,0 +1,48 @@
|
||||
from pathlib import Path
|
||||
|
||||
from flask import Flask, jsonify, request, send_from_directory
|
||||
from flask_cors import CORS
|
||||
|
||||
from demo_service import DemoModelService
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
STATIC_DIR = BASE_DIR / "frontend"
|
||||
DATASET_PATH = BASE_DIR / "data" / "demo_equipment_costs.csv"
|
||||
|
||||
|
||||
def create_app():
|
||||
app = Flask(__name__, static_folder=None)
|
||||
CORS(app)
|
||||
|
||||
@app.get("/api/demo/algorithms")
|
||||
def demo_algorithms():
|
||||
service = DemoModelService(DATASET_PATH)
|
||||
return jsonify({"algorithms": service.get_algorithms()})
|
||||
|
||||
@app.get("/api/demo/dataset")
|
||||
def demo_dataset():
|
||||
service = DemoModelService(DATASET_PATH)
|
||||
return jsonify(service.get_dataset_summary())
|
||||
|
||||
@app.post("/api/demo/run")
|
||||
def demo_run():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
service = DemoModelService(DATASET_PATH)
|
||||
return jsonify(service.run_demo(payload.get("algorithms")))
|
||||
|
||||
@app.get("/")
|
||||
@app.get("/<path:path>")
|
||||
def frontend(path=""):
|
||||
file_path = STATIC_DIR / path
|
||||
if path and file_path.exists() and file_path.is_file():
|
||||
return send_from_directory(STATIC_DIR, path)
|
||||
return send_from_directory(STATIC_DIR, "index.html")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
print("算法演示服务已启动:http://127.0.0.1:5001/algorithm-demo")
|
||||
app.run(host="127.0.0.1", port=5001, debug=False)
|
||||
32
demo_standalone/start_demo.bat
Normal file
32
demo_standalone/start_demo.bat
Normal file
@ -0,0 +1,32 @@
|
||||
@echo off
|
||||
setlocal
|
||||
cd /d "%~dp0"
|
||||
|
||||
where python >nul 2>nul
|
||||
if errorlevel 1 (
|
||||
echo 未找到 Python。请先安装 Python 3.9 至 3.11,然后重新运行本脚本。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if not exist ".venv\Scripts\python.exe" (
|
||||
echo 正在创建演示环境...
|
||||
python -m venv .venv
|
||||
if errorlevel 1 (
|
||||
echo 创建环境失败。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
echo 正在安装或检查依赖...
|
||||
".venv\Scripts\python.exe" -m pip install -r requirements.txt
|
||||
if errorlevel 1 (
|
||||
echo 依赖安装失败,请检查网络或 Python 环境。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
start "" http://127.0.0.1:5001/algorithm-demo
|
||||
".venv\Scripts\python.exe" server.py
|
||||
pause
|
||||
@ -4,8 +4,8 @@
|
||||
|
||||
### 1. 基础软件
|
||||
|
||||
- Linux 操作系统 (推荐 Ubuntu 20.04+)
|
||||
- Python 3.8+ 及相关组件
|
||||
- Linux 操作系统 (推荐 Ubuntu 22.04+)
|
||||
- Python 3.12 及相关组件
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
@ -23,6 +23,19 @@
|
||||
nvm use 14
|
||||
```
|
||||
|
||||
- Windows 操作系统 (推荐 Windows 10+)
|
||||
|
||||
- Python 3.12 及相关组件
|
||||
参考:<https://www.python.org/downloads/>
|
||||
|
||||
- Node.js 14+ 及 npm
|
||||
参考:<https://learn.microsoft.com/en-us/windows/dev-environment/javascript/nodejs-on-windows>
|
||||
|
||||
```bash
|
||||
# 设置执行策略
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
### 2. 数据库
|
||||
|
||||
- MySQL 8.0+
|
||||
@ -32,16 +45,58 @@
|
||||
sudo apt install libmysqlclient-dev
|
||||
```
|
||||
|
||||
Windows 参考:<https://dev.mysql.com/downloads/installer/>
|
||||
|
||||
### 3. Python包依赖
|
||||
|
||||
```bash
|
||||
# 科学计算相关
|
||||
# Windows系统下安装依赖
|
||||
# 1. 创建并激活虚拟环境
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
# 2. 设置pip源为国内镜像(可选,但推荐)
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 3. 更新pip
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# 4. 安装依赖包(使用UTF-8编码)
|
||||
|
||||
# PowerShell命令行
|
||||
$env:PYTHONUTF8=1
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Linux系统下安装依赖
|
||||
# 1. 创建并激活虚拟环境
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# 2. 安装依赖包
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 解析Excel文件需要安装以下依赖
|
||||
pip install pandas
|
||||
pip install openpyxl
|
||||
pip install xlrd
|
||||
# 常见问题解决:
|
||||
# 1. 如果遇到编码错误,请确保使用UTF-8编码
|
||||
# 2. 如果安装过程中出现权限问题,请使用管理员权限运行命令行
|
||||
# 3. 如果下载速度慢,建议使用国内镜像源
|
||||
# 4. 如果出现SSL证书错误,可以尝试添加--trusted-host参数:
|
||||
pip install -r requirements.txt --trusted-host pypi.tuna.tsinghua.edu.cn
|
||||
```
|
||||
|
||||
### 4. 科学计算相关
|
||||
|
||||
sudo apt install libatlas-base-dev # numpy依赖
|
||||
sudo apt install libopenblas-dev # 线性代数库
|
||||
sudo apt install liblapack-dev # 线性代数包
|
||||
sudo apt install gfortran # Fortran编译器(scipy依赖)
|
||||
|
||||
# XML处理相关(用于Excel文件处理)
|
||||
## XML处理相关(用于Excel文件处理)
|
||||
|
||||
```bash
|
||||
sudo apt install libxml2-dev
|
||||
sudo apt install libxslt1-dev
|
||||
```
|
||||
|
||||
224
docs/dev/hyper_parameters.md
Normal file
224
docs/dev/hyper_parameters.md
Normal file
@ -0,0 +1,224 @@
|
||||
# 模型超参数设置
|
||||
|
||||
## 1. PyTorch神经网络
|
||||
|
||||
### 火箭炮配置
|
||||
|
||||
1. 输入层 -> 隐藏层1:Linear(input_size -> 32) + ReLU + BatchNorm
|
||||
2. 隐藏层1 -> 隐藏层2:Linear(32 -> 16) + ReLU + BatchNorm
|
||||
3. 隐藏层2 -> 隐藏层3:Linear(16 -> 8) + ReLU + BatchNorm
|
||||
4. 隐藏层3 -> 输出层:Linear(8 -> 1)
|
||||
|
||||
```python
|
||||
learning_rate = 0.0003
|
||||
weight_decay = 0.001
|
||||
optimizer = AdamW(
|
||||
betas=(0.8, 0.9),
|
||||
eps=1e-8
|
||||
)
|
||||
loss_function = SmoothL1Loss(beta=0.1)
|
||||
scheduler = 带预热的余弦退火
|
||||
gradient_clip = max_norm=0.1
|
||||
```
|
||||
|
||||
### 巡飞弹配置
|
||||
|
||||
生产商特征网络(2层):
|
||||
|
||||
1. Linear(5 -> 4) + ReLU + BatchNorm + Dropout(0.2)
|
||||
|
||||
装备特征网络(4层):
|
||||
|
||||
1. Linear(input_size-5 -> 64) + LeakyReLU + BatchNorm + Dropout
|
||||
2. Linear(64 -> 32) + LeakyReLU + BatchNorm + Dropout
|
||||
3. Linear(32 -> 16) + LeakyReLU + BatchNorm + Dropout
|
||||
|
||||
合并网络(4层):
|
||||
|
||||
1. Linear(20 -> 32) + LeakyReLU + BatchNorm + Dropout
|
||||
2. Linear(32 -> 16) + LeakyReLU + BatchNorm + Dropout
|
||||
3. Linear(16 -> 8) + LeakyReLU + BatchNorm
|
||||
4. Linear(8 -> 1)
|
||||
|
||||
```python
|
||||
learning_rate = 0.001
|
||||
weight_decay = 0.001
|
||||
optimizer = Adam(betas=(0.9, 0.999))
|
||||
loss_function = MSELoss()
|
||||
scheduler = 余弦退火
|
||||
```
|
||||
|
||||
## 2. XGBoost
|
||||
|
||||
```python
|
||||
n_estimators = 50
|
||||
learning_rate = 0.03
|
||||
max_depth = 3
|
||||
min_child_weight = 5
|
||||
subsample = 0.6
|
||||
colsample_bytree = 0.6
|
||||
reg_alpha = 0.5
|
||||
reg_lambda = 2.0
|
||||
gamma = 1
|
||||
random_state = 42
|
||||
```
|
||||
|
||||
## 3. LightGBM
|
||||
|
||||
```python
|
||||
n_estimators = 50
|
||||
learning_rate = 0.03
|
||||
max_depth = 3
|
||||
num_leaves = 8
|
||||
subsample = 0.6
|
||||
colsample_bytree = 0.6
|
||||
reg_alpha = 0.5
|
||||
reg_lambda = 2.0
|
||||
min_child_samples = 10
|
||||
min_split_gain = 1.0
|
||||
random_state = 42
|
||||
```
|
||||
|
||||
## 4. GBM(梯度提升机)
|
||||
|
||||
```python
|
||||
n_estimators = 50
|
||||
learning_rate = 0.03
|
||||
max_depth = 3
|
||||
min_samples_split = 10
|
||||
min_samples_leaf = 5
|
||||
subsample = 0.6
|
||||
min_impurity_decrease = 0.01
|
||||
random_state = 42
|
||||
```
|
||||
|
||||
## 5. Random Forest(随机森林)
|
||||
|
||||
```python
|
||||
n_estimators = 100
|
||||
max_depth = 4
|
||||
min_samples_split = 5
|
||||
min_samples_leaf = 3
|
||||
max_features = 'sqrt'
|
||||
bootstrap = True
|
||||
random_state = 42
|
||||
```
|
||||
|
||||
## 6. PLS回归
|
||||
|
||||
```python
|
||||
n_components = min(3, 特征数量//5)
|
||||
scale = True
|
||||
max_iter = 500
|
||||
tol = 1e-6
|
||||
```
|
||||
|
||||
## 超参数调优策略
|
||||
|
||||
### 1. 样本量增加时的调整策略
|
||||
|
||||
#### PyTorch神经网络
|
||||
|
||||
- 增加网络深度和宽度
|
||||
- 可以在现有层之间添加更多隐藏层
|
||||
- 适当增加每层神经元数量
|
||||
- 调整学习率和优化器
|
||||
- 可以使用更大的学习率(如0.001-0.005)
|
||||
- 减小weight_decay(如0.0005)
|
||||
- 减少正则化强度
|
||||
- 降低Dropout率(如0.1)
|
||||
- 可以移除部分BatchNorm层
|
||||
|
||||
#### 树模型(XGBoost/LightGBM/GBM)
|
||||
|
||||
- 增加树的数量(n_estimators:100-500)
|
||||
- 增加树的深度(max_depth:4-6)
|
||||
- 减小正则化参数
|
||||
- reg_alpha:0.3
|
||||
- reg_lambda:1.0
|
||||
- 增大子采样比例(subsample:0.8-0.9)
|
||||
|
||||
#### Random Forest
|
||||
|
||||
- 增加树的数量(n_estimators:200-500)
|
||||
- 增加树的深度(max_depth:6-8)
|
||||
- 减小最小分裂样本数
|
||||
- min_samples_split:3
|
||||
- min_samples_leaf:2
|
||||
|
||||
#### PLS回归
|
||||
|
||||
- 增加组件数量(n_components)
|
||||
- 可以考虑使用非线性核函数
|
||||
|
||||
### 2. 特征数量变化的调整策略
|
||||
|
||||
#### 特征数量增加时
|
||||
|
||||
- 增强特征选择和降维
|
||||
- 增加正则化强度
|
||||
- 考虑使用特征筛选方法
|
||||
- 可以使用自动特征选择算法
|
||||
|
||||
#### 特征数量减少时
|
||||
|
||||
- 简化模型结构
|
||||
- 减少正则化强度
|
||||
- 增加每个特征的权重
|
||||
|
||||
### 3. 自动化调优建议
|
||||
|
||||
1. 使用网格搜索(Grid Search)
|
||||
- 适用于参数空间较小时
|
||||
- 可以详尽搜索最优参数
|
||||
|
||||
2. 使用随机搜索(Random Search)
|
||||
- 适用于参数空间较大时
|
||||
- 比网格搜索更高效
|
||||
|
||||
3. 使用贝叶斯优化
|
||||
- 适用于计算资源有限时
|
||||
- 能更智能地搜索参数空间
|
||||
|
||||
4. 交叉验证策略
|
||||
- 样本量大时:使用K折交叉验证(K=5或10)
|
||||
- 样本量小时:使用留一法交叉验证
|
||||
|
||||
### 4. 性能监控指标
|
||||
|
||||
在调参过程中需要监控:
|
||||
|
||||
1. 训练集和验证集的损失曲线
|
||||
2. 模型复杂度vs性能提升
|
||||
3. 训练时间vs性能提升
|
||||
4. 过拟合风险
|
||||
|
||||
### 5. 调优注意事项
|
||||
|
||||
1. 保持可解释性
|
||||
- 模型复杂度增加时,确保结果仍可解释
|
||||
- 记录参数调整的原因和效果
|
||||
|
||||
2. 计算资源平衡
|
||||
- 在性能提升和计算成本间找到平衡点
|
||||
- 考虑模型部署的实际环境限制
|
||||
|
||||
3. 稳定性要求
|
||||
- 确保模型在不同数据分布下仍能稳定工作
|
||||
- 定期使用新数据验证模型性能
|
||||
|
||||
## 参数说明
|
||||
|
||||
所有模型都设置了 `random_state=42` 以确保结果可重现。这些参数是经过调优的,针对小样本量的特点,采用了较为保守的设置:
|
||||
|
||||
- 较小的学习率:避免过拟合,提高模型稳定性
|
||||
- 较浅的树深度:防止模型过于复杂
|
||||
- 较强的正则化:增强模型泛化能力
|
||||
- 适当的子采样比例:提高模型鲁棒性
|
||||
|
||||
这些参数设置主要考虑了以下因素:
|
||||
|
||||
1. 样本量较小
|
||||
2. 特征维度适中
|
||||
3. 需要较强的泛化能力
|
||||
4. 预测稳定性要求高
|
||||
30
docs/release_guide.md
Normal file
30
docs/release_guide.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Windows 发布包制作指南
|
||||
|
||||
1. 准备工作
|
||||
|
||||
1.1 安装必要软件
|
||||
|
||||
- Python 3.11.8: <https://www.python.org/downloads/>
|
||||
- Visual Studio Build Tools: <https://visualstudio.microsoft.com/visual-cpp-build-tools/>
|
||||
|
||||
1.2 下载安装程序
|
||||
|
||||
- Python 3.11.8: <https://www.python.org/ftp/python/3.11.8/python-3.11.8-amd64.exe>
|
||||
- Visual C++ Redistributable: <https://aka.ms/vs/17/release/vc_redist.x64.exe(可选,如果系统已安装则不需要)>
|
||||
|
||||
|
||||
1. 克隆项目
|
||||
|
||||
```powershell
|
||||
git clone [repository-url]
|
||||
cd cost-prediction
|
||||
```
|
||||
|
||||
2. 打包步骤
|
||||
|
||||
```powershell
|
||||
# 运行打包脚本
|
||||
.\scripts\build_win.ps1
|
||||
```
|
||||
|
||||
打包完成后会在项目根目录生成 `cost-prediction-[version]-win64.zip`。
|
||||
57
docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md
Normal file
57
docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md
Normal file
@ -0,0 +1,57 @@
|
||||
# ML Algorithm Demo Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add a modern demo page that compares common machine learning algorithms using a local data file instead of MySQL.
|
||||
|
||||
**Architecture:** Add an isolated backend demo service that reads `data/demo_equipment_costs.csv`, trains selected regressors in memory, and returns metrics, prediction points, feature importance, and a sample prediction. Add a Vue route that calls the demo API and renders algorithm switching, charts, metrics, and data preview. Existing database-backed pages remain unchanged.
|
||||
|
||||
**Tech Stack:** Flask, pandas, scikit-learn, optional xgboost/lightgbm, Vue 3, Element Plus, ECharts.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Backend Demo Service
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_demo_service.py`
|
||||
- Create: `src/demo_service.py`
|
||||
- Create: `data/demo_equipment_costs.csv`
|
||||
|
||||
- [ ] Write failing tests for data loading, algorithm availability, and training payload shape.
|
||||
- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it fails because `src.demo_service` is missing.
|
||||
- [ ] Implement `DemoModelService` with local CSV loading, selected algorithm training, metric calculation, top feature importance, and fallback algorithms when optional libraries are unavailable.
|
||||
- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it passes.
|
||||
|
||||
### Task 2: Demo API
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/routes.py`
|
||||
- Test: `tests/test_demo_routes.py`
|
||||
|
||||
- [ ] Write Flask route tests for `GET /api/demo/algorithms`, `GET /api/demo/dataset`, and `POST /api/demo/run`.
|
||||
- [ ] Run `python -m pytest tests/test_demo_routes.py -q` and verify missing routes fail.
|
||||
- [ ] Add demo routes that call `DemoModelService` and do not access MySQL.
|
||||
- [ ] Run the route tests and demo service tests.
|
||||
|
||||
### Task 3: Vue Demo Page
|
||||
|
||||
**Files:**
|
||||
- Create: `frontend/src/views/AlgorithmDemoPage.vue`
|
||||
- Modify: `frontend/src/router/index.js`
|
||||
- Modify: `frontend/src/App.vue`
|
||||
- Modify: `frontend/src/api/index.js`
|
||||
- Modify: `frontend/src/views/HomePage.vue`
|
||||
|
||||
- [ ] Add API helpers for demo algorithms, dataset, and run.
|
||||
- [ ] Add `/algorithm-demo` route and navigation label `算法演示`.
|
||||
- [ ] Build a modern dashboard-style page with algorithm toggles, metric cards, comparison chart, predicted-vs-actual chart, feature importance chart, sample prediction panel, and data preview table.
|
||||
- [ ] Add a home page entry that links to the demo.
|
||||
|
||||
### Task 4: Verification
|
||||
|
||||
**Files:**
|
||||
- No new files.
|
||||
|
||||
- [ ] Run `python -m pytest tests/test_demo_service.py tests/test_demo_routes.py -q`.
|
||||
- [ ] Run `npm run build` in `frontend`.
|
||||
- [ ] Start the app if feasible and confirm the new route is available.
|
||||
73
docs/windows_setup.md
Normal file
73
docs/windows_setup.md
Normal file
@ -0,0 +1,73 @@
|
||||
# Windows 开发环境设置
|
||||
|
||||
1. 安装必要软件
|
||||
- Python 3.11.8: <https://www.python.org/downloads/>
|
||||
- Git: <https://git-scm.com/download/win>
|
||||
- MySQL 8.0+: <https://dev.mysql.com/downloads/mysql/>
|
||||
- Visual Studio Build Tools: <https://visualstudio.microsoft.com/visual-cpp-build-tools/>
|
||||
- Node.js 18+ LTS: <https://nodejs.org/download/>,安装时,Chocolatey不是必需的
|
||||
- npm 9+: (随 Node.js 一起安装)
|
||||
|
||||
2. 克隆项目
|
||||
|
||||
```powershell
|
||||
git clone [repository-url]
|
||||
cd cost-prediction
|
||||
```
|
||||
|
||||
3. 设置前端环境
|
||||
|
||||
```powershell
|
||||
# 进入前端目录
|
||||
cd frontend
|
||||
|
||||
# 安装依赖
|
||||
npm install 22
|
||||
nvm use 22
|
||||
|
||||
# 构建生产版本
|
||||
npm run build
|
||||
|
||||
# 返回项目根目录
|
||||
cd ..
|
||||
```
|
||||
|
||||
4. 设置 Python 环境
|
||||
|
||||
```powershell
|
||||
# 创建虚拟环境
|
||||
python -m venv .venv
|
||||
|
||||
# 激活虚拟环境
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
|
||||
# 安装依赖
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
5. 配置数据库
|
||||
|
||||
```powershell
|
||||
# 确保 MySQL 服务已启动
|
||||
|
||||
# 初始化数据库和导入数据
|
||||
|
||||
```
|
||||
|
||||
6. 运行测试
|
||||
|
||||
```powershell
|
||||
python src/test_api.py
|
||||
```
|
||||
|
||||
7. 打包项目
|
||||
|
||||
```powershell
|
||||
# 先下载所有依赖
|
||||
.\scripts\download_deps.ps1
|
||||
|
||||
# 然后运行打包脚本
|
||||
.\scripts\build_win.ps1
|
||||
```
|
||||
|
||||
## 注意:如果需要制作发布包,请参考 docs/release_guide.md
|
||||
@ -1,5 +0,0 @@
|
||||
module.exports = {
|
||||
presets: [
|
||||
'@vue/cli-plugin-babel/preset'
|
||||
]
|
||||
}
|
||||
17
frontend/index.html
Normal file
17
frontend/index.html
Normal file
@ -0,0 +1,17 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||
<link rel="icon" href="/favicon.ico">
|
||||
<title>装备成本估算系统</title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>
|
||||
<strong>装备成本估算系统需要启用 JavaScript 才能运行。</strong>
|
||||
</noscript>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
11740
frontend/package-lock.json
generated
11740
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -7,33 +7,24 @@
|
||||
"npm": ">=8"
|
||||
},
|
||||
"scripts": {
|
||||
"serve": "vue-cli-service serve",
|
||||
"build": "vue-cli-service build",
|
||||
"lint": "vue-cli-service lint"
|
||||
"serve": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint --ext .js,.vue src/"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.6.0",
|
||||
"core-js": "^3.8.3",
|
||||
"echarts": "^5.4.3",
|
||||
"element-plus": "^2.4.2",
|
||||
"vue": "^3.2.13",
|
||||
"vue-router": "^4.0.3",
|
||||
"vuex": "^4.0.0"
|
||||
"vue-router": "^4.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.12.16",
|
||||
"@babel/eslint-parser": "^7.12.16",
|
||||
"@element-plus/icons-vue": "^2.3.1",
|
||||
"@vue/cli-plugin-babel": "~5.0.0",
|
||||
"@vue/cli-plugin-eslint": "~5.0.0",
|
||||
"@vue/cli-plugin-router": "~5.0.0",
|
||||
"@vue/cli-plugin-vuex": "~5.0.0",
|
||||
"@vue/cli-service": "~5.0.0",
|
||||
"@vue/compiler-sfc": "^3.2.13",
|
||||
"@vitejs/plugin-vue": "^5.0.0",
|
||||
"eslint": "^7.32.0",
|
||||
"eslint-plugin-vue": "^8.0.3",
|
||||
"sass": "^1.32.7",
|
||||
"sass-loader": "^12.0.0"
|
||||
"sass-embedded": "^1.99.0",
|
||||
"vite": "^5.0.0"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"root": true,
|
||||
@ -45,17 +36,12 @@
|
||||
"eslint:recommended"
|
||||
],
|
||||
"parserOptions": {
|
||||
"parser": "@babel/eslint-parser"
|
||||
"ecmaVersion": "latest",
|
||||
"sourceType": "module"
|
||||
},
|
||||
"rules": {
|
||||
"vue/multi-word-component-names": "off",
|
||||
"no-unused-vars": "warn"
|
||||
}
|
||||
},
|
||||
"browserslist": [
|
||||
"> 1%",
|
||||
"last 2 versions",
|
||||
"not dead",
|
||||
"not ie 11"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
|
||||
<title><%= htmlWebpackPlugin.options.title %></title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>
|
||||
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
|
||||
</noscript>
|
||||
<div id="app"></div>
|
||||
<!-- built files will be auto injected -->
|
||||
</body>
|
||||
</html>
|
||||
@ -9,6 +9,7 @@
|
||||
<el-menu-item index="/">首页</el-menu-item>
|
||||
<el-menu-item index="/predict">成本预测</el-menu-item>
|
||||
<el-menu-item index="/analysis">特征分析</el-menu-item>
|
||||
<el-menu-item index="/algorithm-demo">算法演示</el-menu-item>
|
||||
<el-menu-item index="/training">模型训练</el-menu-item>
|
||||
<el-menu-item index="/models">模型管理</el-menu-item>
|
||||
<el-menu-item index="/datasets">数据集管理</el-menu-item>
|
||||
|
||||
@ -40,4 +40,16 @@ export const updateEquipment = (id, data) => {
|
||||
|
||||
export const deleteEquipment = (id) => {
|
||||
return api.delete(`/data/${id}`)
|
||||
}
|
||||
}
|
||||
|
||||
export const getDemoAlgorithms = () => {
|
||||
return api.get('/demo/algorithms')
|
||||
}
|
||||
|
||||
export const getDemoDataset = () => {
|
||||
return api.get('/demo/dataset')
|
||||
}
|
||||
|
||||
export const runAlgorithmDemo = (data) => {
|
||||
return api.post('/demo/run', data)
|
||||
}
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
export const API_BASE_URL = 'http://localhost:5001/api';
|
||||
const isLocalDevServer = window.location.port === '8080'
|
||||
|
||||
export const API_BASE_URL = isLocalDevServer
|
||||
? 'http://localhost:5001/api'
|
||||
: `${window.location.origin}/api`;
|
||||
|
||||
export const DB_CONFIG = {
|
||||
host: 'localhost',
|
||||
user: 'root',
|
||||
password: '123456',
|
||||
database: 'equipment_cost_db'
|
||||
};
|
||||
};
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import { createApp } from 'vue'
|
||||
import App from './App.vue'
|
||||
import router from './router'
|
||||
import store from './store'
|
||||
import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import './assets/styles/global.css'
|
||||
@ -15,7 +14,6 @@ app.use(ElementPlus, {
|
||||
size: 'default'
|
||||
})
|
||||
app.use(router)
|
||||
app.use(store)
|
||||
|
||||
// 注册图标
|
||||
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
|
||||
|
||||
@ -5,6 +5,7 @@ import DatasetPage from '@/views/DatasetPage.vue'
|
||||
import PredictPage from '@/views/PredictPage.vue'
|
||||
import AnalysisPage from '@/views/AnalysisPage.vue'
|
||||
import TrainingPage from '@/views/TrainingPage.vue'
|
||||
import AlgorithmDemoPage from '@/views/AlgorithmDemoPage.vue'
|
||||
|
||||
const routes = [
|
||||
{
|
||||
@ -37,6 +38,11 @@ const routes = [
|
||||
name: 'Training',
|
||||
component: TrainingPage
|
||||
},
|
||||
{
|
||||
path: '/algorithm-demo',
|
||||
name: 'AlgorithmDemo',
|
||||
component: AlgorithmDemoPage
|
||||
},
|
||||
{
|
||||
path: '/models',
|
||||
name: 'Models',
|
||||
@ -49,4 +55,4 @@ const router = createRouter({
|
||||
routes
|
||||
})
|
||||
|
||||
export default router
|
||||
export default router
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
import { createStore } from 'vuex'
|
||||
|
||||
export default createStore({
|
||||
state: {
|
||||
},
|
||||
getters: {
|
||||
},
|
||||
mutations: {
|
||||
},
|
||||
actions: {
|
||||
},
|
||||
modules: {
|
||||
}
|
||||
})
|
||||
644
frontend/src/views/AlgorithmDemoPage.vue
Normal file
644
frontend/src/views/AlgorithmDemoPage.vue
Normal file
@ -0,0 +1,644 @@
|
||||
<template>
|
||||
<div class="algorithm-demo-page">
|
||||
<section class="demo-hero">
|
||||
<div>
|
||||
<p class="eyebrow">本地文件算法演示</p>
|
||||
<h1>机器学习算法演示</h1>
|
||||
<p class="hero-copy">
|
||||
使用本地数据文件快速训练和比较常用回归算法,适合客户演示部署。
|
||||
</p>
|
||||
</div>
|
||||
<div class="hero-actions">
|
||||
<el-button type="primary" :loading="loading" @click="runDemo">
|
||||
<el-icon><VideoPlay /></el-icon>
|
||||
运行演示
|
||||
</el-button>
|
||||
<el-tag effect="plain" type="success">无需 MySQL</el-tag>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="control-band">
|
||||
<div class="panel algorithm-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">算法选择</p>
|
||||
<h2>选择算法</h2>
|
||||
</div>
|
||||
<el-button text type="primary" @click="selectRecommended">推荐组合</el-button>
|
||||
</div>
|
||||
<el-checkbox-group v-model="selectedAlgorithms" class="algorithm-grid">
|
||||
<el-checkbox
|
||||
v-for="item in algorithms"
|
||||
:key="item.key"
|
||||
:value="item.key"
|
||||
border
|
||||
>
|
||||
<span class="algorithm-name">{{ item.name }}</span>
|
||||
<small>{{ item.english_name }} · {{ item.family }}</small>
|
||||
</el-checkbox>
|
||||
</el-checkbox-group>
|
||||
<el-alert
|
||||
v-if="warnings.length"
|
||||
class="warning-strip"
|
||||
type="warning"
|
||||
:closable="false"
|
||||
show-icon
|
||||
>
|
||||
<template #title>{{ warnings.join(' ') }}</template>
|
||||
</el-alert>
|
||||
</div>
|
||||
|
||||
<div class="panel dataset-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">数据来源</p>
|
||||
<h2>本地演示数据</h2>
|
||||
</div>
|
||||
<el-tag>{{ dataset.row_count || 0 }} 条</el-tag>
|
||||
</div>
|
||||
<div class="dataset-stats">
|
||||
<div>
|
||||
<strong>{{ dataset.features?.length || 0 }}</strong>
|
||||
<span>特征数</span>
|
||||
</div>
|
||||
<div>
|
||||
<strong>{{ dataset.equipment_types?.length || 0 }}</strong>
|
||||
<span>装备类型</span>
|
||||
</div>
|
||||
<div>
|
||||
<strong>{{ dataset.target_label || '-' }}</strong>
|
||||
<span>预测目标</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section v-if="result" class="metrics-grid">
|
||||
<article
|
||||
v-for="row in metricRows"
|
||||
:key="row.key"
|
||||
class="metric-card"
|
||||
:class="{ active: row.key === result.best_model }"
|
||||
@click="activeAlgorithm = row.key"
|
||||
>
|
||||
<div class="metric-title">
|
||||
<span>{{ row.name }}</span>
|
||||
<el-tag v-if="row.key === result.best_model" size="small" type="success">最佳</el-tag>
|
||||
</div>
|
||||
<strong>{{ formatScore(row.r2) }}</strong>
|
||||
<div class="metric-values">
|
||||
<span>平均绝对误差 {{ formatMoney(row.mae) }}</span>
|
||||
<span>均方根误差 {{ formatMoney(row.rmse) }}</span>
|
||||
</div>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<section v-if="result" class="visual-grid">
|
||||
<div class="panel chart-panel wide">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">效果对比</p>
|
||||
<h2>模型指标对比</h2>
|
||||
</div>
|
||||
</div>
|
||||
<div ref="metricsChartRef" class="chart"></div>
|
||||
</div>
|
||||
|
||||
<div class="panel chart-panel wide">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">预测结果</p>
|
||||
<h2>预测值与真实值</h2>
|
||||
</div>
|
||||
<el-select v-model="activeAlgorithm" size="small" class="algorithm-select">
|
||||
<el-option
|
||||
v-for="row in metricRows"
|
||||
:key="row.key"
|
||||
:label="row.name"
|
||||
:value="row.key"
|
||||
/>
|
||||
</el-select>
|
||||
</div>
|
||||
<div ref="predictionChartRef" class="chart"></div>
|
||||
</div>
|
||||
|
||||
<div class="panel chart-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">模型解释</p>
|
||||
<h2>特征重要性</h2>
|
||||
</div>
|
||||
</div>
|
||||
<div ref="importanceChartRef" class="chart compact"></div>
|
||||
</div>
|
||||
|
||||
<div class="panel sample-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">样例场景</p>
|
||||
<h2>样例装备预测</h2>
|
||||
</div>
|
||||
</div>
|
||||
<dl>
|
||||
<dt>装备名称</dt>
|
||||
<dd>{{ result.sample_prediction.input.name }}</dd>
|
||||
<dt>真实成本</dt>
|
||||
<dd>{{ formatMoney(result.sample_prediction.actual) }}</dd>
|
||||
<dt>当前算法预测</dt>
|
||||
<dd>{{ formatMoney(result.sample_prediction.predictions[activeAlgorithm]) }}</dd>
|
||||
</dl>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="panel data-preview">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">数据预览</p>
|
||||
<h2>数据文件预览</h2>
|
||||
</div>
|
||||
</div>
|
||||
<el-table :data="dataset.preview || []" height="320" stripe>
|
||||
<el-table-column prop="name" label="名称" min-width="130" fixed />
|
||||
<el-table-column prop="type" label="类型" min-width="150" />
|
||||
<el-table-column prop="weight_kg" label="重量(kg)" min-width="100" />
|
||||
<el-table-column prop="max_range_km" label="射程(km)" min-width="100" />
|
||||
<el-table-column prop="tech_level" label="技术水平" min-width="100" />
|
||||
<el-table-column prop="complexity_score" label="复杂度" min-width="100" />
|
||||
<el-table-column prop="actual_cost" label="实际成本" min-width="130">
|
||||
<template #default="scope">{{ formatMoney(scope.row.actual_cost) }}</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { computed, nextTick, onMounted, onUnmounted, ref, watch } from 'vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { VideoPlay } from '@element-plus/icons-vue'
|
||||
import * as echarts from 'echarts'
|
||||
import { getDemoAlgorithms, getDemoDataset, runAlgorithmDemo } from '@/api'
|
||||
|
||||
const algorithms = ref([])
|
||||
const dataset = ref({})
|
||||
const selectedAlgorithms = ref(['linear', 'ridge', 'random_forest', 'gradient_boosting'])
|
||||
const activeAlgorithm = ref('random_forest')
|
||||
const result = ref(null)
|
||||
const loading = ref(false)
|
||||
const warnings = ref([])
|
||||
|
||||
const metricsChartRef = ref(null)
|
||||
const predictionChartRef = ref(null)
|
||||
const importanceChartRef = ref(null)
|
||||
const charts = []
|
||||
|
||||
const metricRows = computed(() => {
|
||||
if (!result.value?.metrics) return []
|
||||
return Object.entries(result.value.metrics).map(([key, value]) => ({
|
||||
key,
|
||||
...value
|
||||
}))
|
||||
})
|
||||
|
||||
const activeMetric = computed(() => {
|
||||
return metricRows.value.find((row) => row.key === activeAlgorithm.value) || metricRows.value[0]
|
||||
})
|
||||
|
||||
const selectRecommended = () => {
|
||||
selectedAlgorithms.value = ['linear', 'ridge', 'random_forest', 'gradient_boosting']
|
||||
}
|
||||
|
||||
const loadInitialData = async () => {
|
||||
try {
|
||||
const [algorithmResponse, datasetResponse] = await Promise.all([
|
||||
getDemoAlgorithms(),
|
||||
getDemoDataset()
|
||||
])
|
||||
algorithms.value = algorithmResponse.data.algorithms
|
||||
dataset.value = datasetResponse.data
|
||||
await runDemo()
|
||||
} catch (error) {
|
||||
ElMessage.error('加载演示数据失败')
|
||||
console.error(error)
|
||||
}
|
||||
}
|
||||
|
||||
const runDemo = async () => {
|
||||
if (!selectedAlgorithms.value.length) {
|
||||
ElMessage.warning('请至少选择一个算法')
|
||||
return
|
||||
}
|
||||
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await runAlgorithmDemo({ algorithms: selectedAlgorithms.value })
|
||||
result.value = response.data
|
||||
dataset.value = response.data.dataset
|
||||
warnings.value = response.data.warnings || []
|
||||
activeAlgorithm.value = response.data.best_model
|
||||
await nextTick()
|
||||
renderCharts()
|
||||
} catch (error) {
|
||||
ElMessage.error(error.response?.data?.error || '运行演示失败')
|
||||
console.error(error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const disposeCharts = () => {
|
||||
while (charts.length) {
|
||||
const chart = charts.pop()
|
||||
if (chart && !chart.isDisposed()) chart.dispose()
|
||||
}
|
||||
}
|
||||
|
||||
const renderCharts = () => {
|
||||
if (!result.value) return
|
||||
disposeCharts()
|
||||
renderMetricsChart()
|
||||
renderPredictionChart()
|
||||
renderImportanceChart()
|
||||
}
|
||||
|
||||
const renderMetricsChart = () => {
|
||||
if (!metricsChartRef.value) return
|
||||
const chart = echarts.init(metricsChartRef.value)
|
||||
charts.push(chart)
|
||||
chart.setOption({
|
||||
tooltip: { trigger: 'axis' },
|
||||
legend: { top: 0 },
|
||||
grid: { top: 48, left: 56, right: 24, bottom: 36 },
|
||||
xAxis: { type: 'category', data: metricRows.value.map((row) => row.name) },
|
||||
yAxis: [
|
||||
{ type: 'value', name: '决定系数', min: 0 },
|
||||
{ type: 'value', name: '误差' }
|
||||
],
|
||||
series: [
|
||||
{
|
||||
name: '决定系数',
|
||||
type: 'bar',
|
||||
data: metricRows.value.map((row) => Number(row.r2.toFixed(4))),
|
||||
itemStyle: { color: '#2f6fdd' }
|
||||
},
|
||||
{
|
||||
name: '平均绝对误差',
|
||||
type: 'line',
|
||||
yAxisIndex: 1,
|
||||
data: metricRows.value.map((row) => Math.round(row.mae)),
|
||||
itemStyle: { color: '#16a085' }
|
||||
},
|
||||
{
|
||||
name: '均方根误差',
|
||||
type: 'line',
|
||||
yAxisIndex: 1,
|
||||
data: metricRows.value.map((row) => Math.round(row.rmse)),
|
||||
itemStyle: { color: '#d98b18' }
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
const renderPredictionChart = () => {
|
||||
if (!predictionChartRef.value || !activeMetric.value) return
|
||||
const chart = echarts.init(predictionChartRef.value)
|
||||
charts.push(chart)
|
||||
const points = result.value.prediction_points
|
||||
chart.setOption({
|
||||
tooltip: { trigger: 'axis' },
|
||||
legend: { top: 0 },
|
||||
grid: { top: 48, left: 68, right: 24, bottom: 46 },
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: points.map((point) => point.name),
|
||||
axisLabel: { rotate: 25 }
|
||||
},
|
||||
yAxis: { type: 'value', name: '成本' },
|
||||
series: [
|
||||
{
|
||||
name: '真实值',
|
||||
type: 'line',
|
||||
smooth: true,
|
||||
data: points.map((point) => point.actual),
|
||||
itemStyle: { color: '#202938' }
|
||||
},
|
||||
{
|
||||
name: activeMetric.value.name,
|
||||
type: 'bar',
|
||||
data: points.map((point) => point[activeAlgorithm.value]),
|
||||
itemStyle: { color: '#2f6fdd' }
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
const renderImportanceChart = () => {
|
||||
if (!importanceChartRef.value || !activeAlgorithm.value) return
|
||||
const chart = echarts.init(importanceChartRef.value)
|
||||
charts.push(chart)
|
||||
const rows = [...(result.value.feature_importance[activeAlgorithm.value] || [])].reverse()
|
||||
chart.setOption({
|
||||
tooltip: { trigger: 'axis' },
|
||||
grid: { top: 20, left: 108, right: 20, bottom: 24 },
|
||||
xAxis: { type: 'value' },
|
||||
yAxis: { type: 'category', data: rows.map((row) => featureName(row.feature)) },
|
||||
series: [
|
||||
{
|
||||
type: 'bar',
|
||||
data: rows.map((row) => Number(row.importance.toFixed(4))),
|
||||
itemStyle: { color: '#16a085' }
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
const featureName = (key) => {
|
||||
const names = {
|
||||
length_m: '长度',
|
||||
width_m: '宽度',
|
||||
height_m: '高度',
|
||||
weight_kg: '重量',
|
||||
max_range_km: '最大射程',
|
||||
payload_kg: '载荷',
|
||||
max_speed_kmh: '最大速度',
|
||||
endurance_min: '续航',
|
||||
tech_level: '技术水平',
|
||||
scale_level: '规模水平',
|
||||
supply_chain_level: '供应链',
|
||||
complexity_score: '复杂度'
|
||||
}
|
||||
return names[key] || key
|
||||
}
|
||||
|
||||
const formatMoney = (value) => {
|
||||
if (value === undefined || value === null) return '-'
|
||||
return Number(value).toLocaleString('zh-CN', {
|
||||
style: 'currency',
|
||||
currency: 'CNY',
|
||||
maximumFractionDigits: 0
|
||||
})
|
||||
}
|
||||
|
||||
const formatScore = (value) => {
|
||||
if (value === undefined || value === null) return '-'
|
||||
return Number(value).toFixed(3)
|
||||
}
|
||||
|
||||
watch(activeAlgorithm, async () => {
|
||||
await nextTick()
|
||||
renderCharts()
|
||||
})
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
charts.forEach((chart) => {
|
||||
if (chart && !chart.isDisposed()) chart.resize()
|
||||
})
|
||||
})
|
||||
|
||||
onMounted(loadInitialData)
|
||||
onUnmounted(disposeCharts)
|
||||
</script>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
.algorithm-demo-page {
|
||||
min-height: calc(100vh - 60px);
|
||||
padding: 24px;
|
||||
color: #202938;
|
||||
background:
|
||||
linear-gradient(180deg, #eef3f8 0%, #f7f9fb 280px),
|
||||
#f7f9fb;
|
||||
}
|
||||
|
||||
.demo-hero,
|
||||
.control-band,
|
||||
.metrics-grid,
|
||||
.visual-grid,
|
||||
.data-preview {
|
||||
max-width: 1440px;
|
||||
margin: 0 auto 18px;
|
||||
}
|
||||
|
||||
.demo-hero {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 20px;
|
||||
min-height: 176px;
|
||||
|
||||
h1 {
|
||||
margin: 6px 0 10px;
|
||||
font-size: 36px;
|
||||
line-height: 1.2;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.hero-copy {
|
||||
max-width: 680px;
|
||||
margin: 0;
|
||||
color: #536273;
|
||||
font-size: 16px;
|
||||
line-height: 1.7;
|
||||
}
|
||||
|
||||
.hero-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.eyebrow {
|
||||
margin: 0;
|
||||
color: #2f6fdd;
|
||||
font-size: 12px;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.control-band,
|
||||
.visual-grid {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(0, 1.35fr) minmax(320px, 0.65fr);
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.panel,
|
||||
.metric-card {
|
||||
border: 1px solid #dfe6ef;
|
||||
border-radius: 8px;
|
||||
background: #fff;
|
||||
box-shadow: 0 10px 28px rgba(32, 41, 56, 0.06);
|
||||
}
|
||||
|
||||
.panel {
|
||||
padding: 18px;
|
||||
}
|
||||
|
||||
.panel-header,
|
||||
.metric-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 12px;
|
||||
|
||||
h2 {
|
||||
margin: 4px 0 0;
|
||||
font-size: 18px;
|
||||
line-height: 1.3;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.algorithm-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(190px, 1fr));
|
||||
gap: 10px;
|
||||
margin-top: 16px;
|
||||
|
||||
:deep(.el-checkbox) {
|
||||
width: 100%;
|
||||
height: 64px;
|
||||
margin: 0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
:deep(.el-checkbox__label) {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
line-height: 1.2;
|
||||
}
|
||||
}
|
||||
|
||||
.algorithm-name {
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.algorithm-grid small {
|
||||
color: #6b7786;
|
||||
}
|
||||
|
||||
.warning-strip {
|
||||
margin-top: 14px;
|
||||
}
|
||||
|
||||
.dataset-stats {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 10px;
|
||||
margin-top: 18px;
|
||||
|
||||
div {
|
||||
padding: 14px;
|
||||
border-radius: 8px;
|
||||
background: #f2f6fa;
|
||||
}
|
||||
|
||||
strong {
|
||||
display: block;
|
||||
margin-bottom: 6px;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
span {
|
||||
color: #667485;
|
||||
font-size: 13px;
|
||||
}
|
||||
}
|
||||
|
||||
.metrics-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
padding: 16px;
|
||||
cursor: pointer;
|
||||
transition: border-color 0.2s ease, transform 0.2s ease;
|
||||
|
||||
&.active,
|
||||
&:hover {
|
||||
border-color: #2f6fdd;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
strong {
|
||||
display: block;
|
||||
margin: 14px 0;
|
||||
font-size: 30px;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.metric-values {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
color: #617080;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.chart-panel.wide {
|
||||
grid-column: span 2;
|
||||
}
|
||||
|
||||
.chart {
|
||||
width: 100%;
|
||||
height: 360px;
|
||||
|
||||
&.compact {
|
||||
height: 330px;
|
||||
}
|
||||
}
|
||||
|
||||
.algorithm-select {
|
||||
width: 220px;
|
||||
}
|
||||
|
||||
.sample-panel dl {
|
||||
display: grid;
|
||||
grid-template-columns: 110px minmax(0, 1fr);
|
||||
gap: 14px 10px;
|
||||
margin: 20px 0 0;
|
||||
}
|
||||
|
||||
.sample-panel dt {
|
||||
color: #667485;
|
||||
}
|
||||
|
||||
.sample-panel dd {
|
||||
margin: 0;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
@media (max-width: 900px) {
|
||||
.algorithm-demo-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.demo-hero,
|
||||
.control-band,
|
||||
.visual-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.demo-hero {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
|
||||
h1 {
|
||||
font-size: 28px;
|
||||
}
|
||||
}
|
||||
|
||||
.chart-panel.wide {
|
||||
grid-column: span 1;
|
||||
}
|
||||
|
||||
.dataset-stats {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@ -67,12 +67,18 @@
|
||||
<h3>特征重要性</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="importanceChartRef" style="width: 100%; height: 600px"></div>
|
||||
<div class="chart-note">
|
||||
<p>说明:F-特征分数是基于F统计量的特征重要性度量,用于评估各个特征与预测目标之间的相关程度。F分数越高,表示该特征与预测目标之间的相关性越强,但不一定是线性关系。F分数没有固定的上限,其值取决于数据的特征分布和样本量。</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 相关性分析 -->
|
||||
<h3>相关性分析</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="correlationChartRef" style="width: 100%; height: 800px"></div>
|
||||
<div class="chart-note">
|
||||
<p>说明:热力图展示了各特征之间的相关系数,范围从-1到1。正值(蓝色)表示正相关,负值(红色)表示负相关,0(白色)表示无相关性。相关系数的绝对值越接近1,表示相关性越强;越接近0,表示相关性越弱。</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 火箭炮特有的图表 -->
|
||||
@ -103,7 +109,31 @@
|
||||
<div class="chart-container">
|
||||
<div ref="engineChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
|
||||
<!-- 制导性能分析 -->
|
||||
<h3>制导性能分析</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="guidanceChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- 生产商分析 -->
|
||||
<h3>生产商分析</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="manufacturerChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
|
||||
<!-- 生产商地区分布 -->
|
||||
<h3>生产商地区分布</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="regionChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
|
||||
<!-- 生产商综合评分 -->
|
||||
<h3>生产商综合评分</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="scoreChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</div>
|
||||
@ -131,6 +161,10 @@ const newFeatureChartRef = ref(null)
|
||||
const engineChartRef = ref(null)
|
||||
const fireChartRef = ref(null)
|
||||
const mobilityChartRef = ref(null)
|
||||
const manufacturerChartRef = ref(null)
|
||||
const regionChartRef = ref(null)
|
||||
const scoreChartRef = ref(null)
|
||||
const guidanceChartRef = ref(null)
|
||||
|
||||
// 图表实例引用
|
||||
const importanceChart = ref(null)
|
||||
@ -139,6 +173,10 @@ const newFeatureChart = ref(null)
|
||||
const engineChart = ref(null)
|
||||
const fireChart = ref(null)
|
||||
const mobilityChart = ref(null)
|
||||
const manufacturerChart = ref(null)
|
||||
const regionChart = ref(null)
|
||||
const scoreChart = ref(null)
|
||||
const guidanceChart = ref(null)
|
||||
|
||||
// 监听分析结果变化
|
||||
watch(() => analysisResult.value, async (newResult) => {
|
||||
@ -236,69 +274,28 @@ const startAnalysis = async () => {
|
||||
|
||||
analyzing.value = true
|
||||
try {
|
||||
// 打印请求参数
|
||||
console.log('Analysis request params:', {
|
||||
dataset_id: analysisForm.value.dataset_id,
|
||||
equipment_type: analysisForm.value.equipment_type
|
||||
})
|
||||
|
||||
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
|
||||
// 调用特征分析接口
|
||||
const featureResponse = await axios.post(`${API_BASE_URL}/analyze-features`, {
|
||||
dataset_id: analysisForm.value.dataset_id
|
||||
})
|
||||
|
||||
// 打印原始响应数据
|
||||
console.log('Raw API response:', response)
|
||||
console.log('Response data type:', typeof response.data)
|
||||
console.log('Response data:', response.data)
|
||||
|
||||
// 检查响应数据的结构
|
||||
if (!response.data) {
|
||||
throw new Error('API返回的数据为空')
|
||||
}
|
||||
|
||||
// 确保数据正确赋值
|
||||
analysisResult.value = response.data
|
||||
|
||||
// 验证数赋值是否成功
|
||||
console.log('Analysis result after assignment:', {
|
||||
value: analysisResult.value,
|
||||
important_features: analysisResult.value?.important_features,
|
||||
correlation_analysis: analysisResult.value?.correlation_analysis,
|
||||
equipment_names: analysisResult.value?.equipment_names,
|
||||
length_width_ratio: analysisResult.value?.length_width_ratio
|
||||
// 调用生产商分析接口
|
||||
const manufacturerResponse = await axios.post(`${API_BASE_URL}/analyze-manufacturers`, {
|
||||
dataset_id: analysisForm.value.dataset_id
|
||||
})
|
||||
|
||||
// 如果是巡飞弹类型,检查特定数据
|
||||
if (analysisForm.value.equipment_type === '巡飞弹') {
|
||||
const missileData = {
|
||||
equipment_names: analysisResult.value?.equipment_names || [],
|
||||
length_width_ratio: analysisResult.value?.length_width_ratio || [],
|
||||
engine_power_kw: analysisResult.value?.engine_power_kw || [],
|
||||
guidance_system_score: analysisResult.value?.guidance_system_score || [],
|
||||
warhead_power_score: analysisResult.value?.warhead_power_score || []
|
||||
}
|
||||
|
||||
console.log('Missile specific data:', missileData)
|
||||
|
||||
// 验证数据完整性
|
||||
const missingFields = Object.entries(missileData)
|
||||
.filter(([key, value]) => !Array.isArray(value) || value.length === 0)
|
||||
.map(([key]) => key)
|
||||
|
||||
if (missingFields.length > 0) {
|
||||
console.warn('Missing or empty missile data fields:', missingFields)
|
||||
ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
|
||||
}
|
||||
|
||||
// 合并两个接口的结果
|
||||
analysisResult.value = {
|
||||
...featureResponse.data,
|
||||
...manufacturerResponse.data
|
||||
}
|
||||
|
||||
|
||||
// 验证数据
|
||||
console.log('Combined analysis result:', analysisResult.value)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Analysis error:', error)
|
||||
console.error('Error details:', {
|
||||
message: error.message,
|
||||
response: error.response?.data,
|
||||
status: error.response?.status
|
||||
})
|
||||
ElMessage.error(error.message || '特征析失败')
|
||||
ElMessage.error(error.message || '分析失败')
|
||||
} finally {
|
||||
analyzing.value = false
|
||||
}
|
||||
@ -329,6 +326,18 @@ const createResizeHandler = () => {
|
||||
if (mobilityChart.value && !mobilityChart.value.isDisposed()) {
|
||||
mobilityChart.value.resize()
|
||||
}
|
||||
if (manufacturerChart.value && !manufacturerChart.value.isDisposed()) {
|
||||
manufacturerChart.value.resize()
|
||||
}
|
||||
if (regionChart.value && !regionChart.value.isDisposed()) {
|
||||
regionChart.value.resize()
|
||||
}
|
||||
if (scoreChart.value && !scoreChart.value.isDisposed()) {
|
||||
scoreChart.value.resize()
|
||||
}
|
||||
if (guidanceChart.value && !guidanceChart.value.isDisposed()) {
|
||||
guidanceChart.value.resize()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in resize handler:', error)
|
||||
}
|
||||
@ -360,7 +369,7 @@ onUnmounted(() => {
|
||||
|
||||
// 销毁所有图表实例
|
||||
[importanceChart, correlationChart, newFeatureChart, engineChart,
|
||||
fireChart, mobilityChart].forEach(chart => {
|
||||
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
|
||||
if (chart.value && !chart.value.isDisposed()) {
|
||||
try {
|
||||
chart.value.dispose()
|
||||
@ -384,7 +393,7 @@ const renderCharts = () => {
|
||||
try {
|
||||
// 先销毁所有现有的图表实例
|
||||
[importanceChart, correlationChart, newFeatureChart, engineChart,
|
||||
fireChart, mobilityChart].forEach(chart => {
|
||||
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
|
||||
if (chart.value && !chart.value.isDisposed()) {
|
||||
chart.value.dispose()
|
||||
chart.value = null
|
||||
@ -401,20 +410,27 @@ const renderCharts = () => {
|
||||
|
||||
// 设置基本图表的选项
|
||||
const importanceOption = {
|
||||
title: { text: '特征重要性排序' },
|
||||
title: {
|
||||
text: '特征重要性排序',
|
||||
left: 'center'
|
||||
},
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
trigger: 'item',
|
||||
axisPointer: {
|
||||
type: 'shadow'
|
||||
},
|
||||
formatter: function(params) {
|
||||
const data = params[0]
|
||||
return `${data.name}: ${data.value.toFixed(4)}`
|
||||
return `${params.name}: ${params.value.toFixed(2)}`
|
||||
}
|
||||
},
|
||||
xAxis: {
|
||||
type: 'value',
|
||||
name: '重要性得分'
|
||||
name: 'F-特征分数',
|
||||
axisLabel: {
|
||||
formatter: function(value) {
|
||||
return value.toFixed(1)
|
||||
}
|
||||
}
|
||||
},
|
||||
yAxis: {
|
||||
type: 'category',
|
||||
@ -899,6 +915,156 @@ const renderCharts = () => {
|
||||
mobilityChart.value.setOption(mobilityOption, { notMerge: true })
|
||||
}
|
||||
|
||||
// 渲染生产商分析图表
|
||||
if (manufacturerChartRef.value) {
|
||||
manufacturerChart.value = echarts.init(manufacturerChartRef.value)
|
||||
const manufacturerOption = {
|
||||
title: { text: '生产商特征影响分析' },
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: { type: 'shadow' }
|
||||
},
|
||||
legend: {
|
||||
data: ['技术水平', '规模水平', '供应链水平', '综合得分']
|
||||
},
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: analysisResult.value.manufacturer_names || []
|
||||
},
|
||||
yAxis: {
|
||||
type: 'value',
|
||||
name: '评分',
|
||||
min: 0,
|
||||
max: 10
|
||||
},
|
||||
series: [
|
||||
{
|
||||
name: '技术水平',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.manufacturer_tech_levels || []
|
||||
},
|
||||
{
|
||||
name: '规模水平',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.manufacturer_scale_levels || []
|
||||
},
|
||||
{
|
||||
name: '供应链水平',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.manufacturer_supply_chain_levels || []
|
||||
},
|
||||
{
|
||||
name: '综合得分',
|
||||
type: 'line',
|
||||
data: analysisResult.value.manufacturer_composite_scores || []
|
||||
}
|
||||
]
|
||||
}
|
||||
manufacturerChart.value.setOption(manufacturerOption)
|
||||
}
|
||||
|
||||
// 渲染地区分布图表
|
||||
if (regionChartRef.value) {
|
||||
regionChart.value = echarts.init(regionChartRef.value)
|
||||
const regionOption = {
|
||||
title: { text: '生产商地区分布' },
|
||||
tooltip: {
|
||||
trigger: 'item',
|
||||
formatter: '{b}: {c} ({d}%)'
|
||||
},
|
||||
series: [
|
||||
{
|
||||
type: 'pie',
|
||||
radius: '65%',
|
||||
data: analysisResult.value.region_distribution || [],
|
||||
emphasis: {
|
||||
itemStyle: {
|
||||
shadowBlur: 10,
|
||||
shadowOffsetX: 0,
|
||||
shadowColor: 'rgba(0, 0, 0, 0.5)'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
regionChart.value.setOption(regionOption)
|
||||
}
|
||||
|
||||
// 渲染综合评分图表
|
||||
if (scoreChartRef.value) {
|
||||
scoreChart.value = echarts.init(scoreChartRef.value)
|
||||
const scoreOption = {
|
||||
title: { text: '生产商综合评分雷达图' },
|
||||
tooltip: {},
|
||||
radar: {
|
||||
indicator: [
|
||||
{ name: '技术水平', max: 10 },
|
||||
{ name: '规模水平', max: 10 },
|
||||
{ name: '供应链水平', max: 10 },
|
||||
{ name: '区域系数', max: 1.5 },
|
||||
{ name: '综合得分', max: 10 }
|
||||
]
|
||||
},
|
||||
series: [
|
||||
{
|
||||
type: 'radar',
|
||||
data: analysisResult.value.manufacturer_scores || []
|
||||
}
|
||||
]
|
||||
}
|
||||
scoreChart.value.setOption(scoreOption)
|
||||
}
|
||||
|
||||
// 渲染制导性能分析图表
|
||||
if (guidanceChartRef.value && analysisForm.value.equipment_type === '巡飞弹') {
|
||||
guidanceChart.value = echarts.init(guidanceChartRef.value)
|
||||
const guidanceOption = {
|
||||
title: { text: '制导性能分析' },
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: { type: 'cross' }
|
||||
},
|
||||
legend: {
|
||||
data: ['制导精度(m)', '数据链距离(km)', '制导系统评分']
|
||||
},
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: analysisResult.value.equipment_names || []
|
||||
},
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value',
|
||||
name: '制导精度(m)',
|
||||
position: 'left'
|
||||
},
|
||||
{
|
||||
type: 'value',
|
||||
name: '距离(km)',
|
||||
position: 'right'
|
||||
}
|
||||
],
|
||||
series: [
|
||||
{
|
||||
name: '制导精度(m)',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.guidance_accuracy_m || []
|
||||
},
|
||||
{
|
||||
name: '数据链距离(km)',
|
||||
type: 'line',
|
||||
yAxisIndex: 1,
|
||||
data: analysisResult.value.datalink_range_km || []
|
||||
},
|
||||
{
|
||||
name: '制导系统评分',
|
||||
type: 'line',
|
||||
data: analysisResult.value.guidance_system_score || []
|
||||
}
|
||||
]
|
||||
}
|
||||
guidanceChart.value.setOption(guidanceOption)
|
||||
}
|
||||
|
||||
console.log('Charts rendered successfully')
|
||||
} catch (error) {
|
||||
console.error('Error in chart rendering:', error)
|
||||
@ -954,4 +1120,12 @@ function debounce(fn, delay) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.chart-note {
|
||||
margin-top: 10px;
|
||||
padding: 10px;
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
@ -795,38 +795,6 @@ onMounted(() => {
|
||||
loadData()
|
||||
})
|
||||
|
||||
// 判断是否为数值类型输入字段
|
||||
const isNumberInput = (key) => {
|
||||
const numberFields = [
|
||||
'length_m',
|
||||
'width_m',
|
||||
'height_m',
|
||||
'weight_kg',
|
||||
'max_range_km',
|
||||
'firing_angle_horizontal',
|
||||
'firing_angle_vertical',
|
||||
'rocket_length_m',
|
||||
'rocket_diameter_mm',
|
||||
'rocket_weight_kg',
|
||||
'rate_of_fire',
|
||||
'combat_weight_kg',
|
||||
'speed_kmh',
|
||||
'min_range_km',
|
||||
'power_hp',
|
||||
'travel_range_km',
|
||||
'max_speed_ms',
|
||||
'cruise_speed_kmh',
|
||||
'flight_time_min',
|
||||
'folded_length_mm',
|
||||
'folded_width_mm',
|
||||
'folded_height_mm',
|
||||
'actual_cost',
|
||||
'predicted_cost'
|
||||
]
|
||||
|
||||
return numberFields.includes(key)
|
||||
}
|
||||
|
||||
// 获取下拉选项
|
||||
const getSelectOptions = (field) => {
|
||||
switch (field) {
|
||||
|
||||
@ -91,6 +91,8 @@
|
||||
<el-select v-model="datasetForm.purpose">
|
||||
<el-option label="训练" value="训练"></el-option>
|
||||
<el-option label="验证" value="验证"></el-option>
|
||||
<el-option label="分析" value="分析"></el-option>
|
||||
<el-option label="测试" value="测试"></el-option>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="描述">
|
||||
@ -100,6 +102,7 @@
|
||||
<!-- 选择装备数据 -->
|
||||
<el-form-item label="选择装备" required>
|
||||
<el-table
|
||||
v-if="datasetForm.equipment_type"
|
||||
ref="equipmentTable"
|
||||
:data="availableEquipment"
|
||||
border
|
||||
@ -115,6 +118,9 @@
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<div v-else class="empty-tip">
|
||||
请先选择装备类型
|
||||
</div>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
<template #footer>
|
||||
@ -139,7 +145,8 @@ const selectedDataset = ref(null) // 选中的数据集
|
||||
const detailsVisible = ref(false) // 详情对话框显示状态
|
||||
const editVisible = ref(false) // 编辑对话框显示状态
|
||||
const availableEquipment = ref([]) // 可选装备列表
|
||||
const selectedEquipment = ref([]) // 已选装备列表
|
||||
const selectedEquipment = ref([]) // 用于保存最终选中的设备
|
||||
const currentSelection = ref([]) // 用于处理表格的选中状态
|
||||
|
||||
// 表格引用
|
||||
const equipmentTable = ref(null)
|
||||
@ -182,6 +189,7 @@ const editDataset = async (dataset) => {
|
||||
try {
|
||||
// 获取数据集详情
|
||||
const response = await axios.get(`${API_BASE_URL}/datasets/${dataset.id}`)
|
||||
console.log('Dataset details:', response.data) // 添加日志
|
||||
|
||||
// 设置表单数据
|
||||
datasetForm.value = {
|
||||
@ -192,21 +200,34 @@ const editDataset = async (dataset) => {
|
||||
description: response.data.description
|
||||
}
|
||||
|
||||
// 设置已选装备
|
||||
// 设置已选装备 - 直接使用后端返回的设备列表
|
||||
selectedEquipment.value = response.data.equipment
|
||||
availableEquipment.value = response.data.equipment
|
||||
console.log('Selected equipment:', selectedEquipment.value) // 添加日志
|
||||
|
||||
// 先显示对话框
|
||||
editVisible.value = true
|
||||
|
||||
// 等待对话框显示并且表格组件挂载完成
|
||||
await nextTick()
|
||||
|
||||
// 加载可选装备
|
||||
await loadAvailableEquipment()
|
||||
|
||||
// 再次等待表格数据更新完成
|
||||
await nextTick()
|
||||
|
||||
// 设置表格选中状态
|
||||
nextTick(() => {
|
||||
if (equipmentTable.value) {
|
||||
equipmentTable.value.clearSelection()
|
||||
availableEquipment.value.forEach(item => {
|
||||
if (equipmentTable.value) {
|
||||
console.log('Setting table selections')
|
||||
equipmentTable.value.clearSelection()
|
||||
availableEquipment.value.forEach(item => {
|
||||
if (selectedEquipment.value.find(e => e.equipment_id === item.equipment_id)) {
|
||||
equipmentTable.value.toggleRowSelection(item, true)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
editVisible.value = true
|
||||
}
|
||||
})
|
||||
} else {
|
||||
console.warn('Equipment table not ready')
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error getting dataset details:', error)
|
||||
@ -245,34 +266,56 @@ const deleteDataset = async (dataset) => {
|
||||
// 加载可选装备
|
||||
const loadAvailableEquipment = async () => {
|
||||
try {
|
||||
const response = await axios.get(`${API_BASE_URL}/data`)
|
||||
availableEquipment.value = datasetForm.value.equipment_type === '火箭炮'
|
||||
? response.data.rocket_artillery
|
||||
: response.data.loitering_munition
|
||||
const response = await axios.get(`${API_BASE_URL}/data`)
|
||||
// 根据装备类型过滤数据
|
||||
availableEquipment.value = response.data.filter(item =>
|
||||
item.type === datasetForm.value.equipment_type
|
||||
)
|
||||
} catch (error) {
|
||||
console.error('Error loading equipment:', error) // 添加错误日志
|
||||
ElMessage.error('获取装备列表失败')
|
||||
}
|
||||
}
|
||||
|
||||
// 处理装备类型变化
|
||||
const handleEquipmentTypeChange = () => {
|
||||
console.log('Equipment type changed:', datasetForm.value.equipment_type) // 添加日志
|
||||
selectedEquipment.value = [] // 清空已选装备
|
||||
loadAvailableEquipment() // 重新加载可选装备
|
||||
}
|
||||
|
||||
// 处理装备选择变化
|
||||
const handleSelectionChange = (selection) => {
|
||||
selectedEquipment.value = selection
|
||||
// 更新当前表格选中状态
|
||||
currentSelection.value = selection
|
||||
}
|
||||
|
||||
// 保存数据集
|
||||
const saveDataset = async () => {
|
||||
try {
|
||||
// 验证必填字段
|
||||
if (!datasetForm.value.name || !datasetForm.value.equipment_type || !datasetForm.value.purpose) {
|
||||
ElMessage.warning('请填写必要信息')
|
||||
return
|
||||
}
|
||||
|
||||
// 使用当前表格选中状态更新最终选中的设备
|
||||
selectedEquipment.value = currentSelection.value
|
||||
|
||||
// 验证是否选择了装备
|
||||
if (!selectedEquipment.value.length) {
|
||||
ElMessage.warning('请选择装备')
|
||||
return
|
||||
}
|
||||
|
||||
// 准备要保存的数据
|
||||
const data = {
|
||||
...datasetForm.value,
|
||||
equipment_ids: selectedEquipment.value.map(item => item.id)
|
||||
equipment_ids: selectedEquipment.value.map(item => item.equipment_id) // 使用 equipment_id
|
||||
}
|
||||
|
||||
console.log('Saving dataset:', data) // 添加日志
|
||||
|
||||
if (data.id) {
|
||||
await axios.put(`${API_BASE_URL}/datasets/${data.id}`, data)
|
||||
} else {
|
||||
@ -283,6 +326,7 @@ const saveDataset = async () => {
|
||||
editVisible.value = false
|
||||
loadDatasets()
|
||||
} catch (error) {
|
||||
console.error('Error saving dataset:', error) // 添加错误日志
|
||||
ElMessage.error('保存失败')
|
||||
}
|
||||
}
|
||||
@ -315,7 +359,8 @@ const formatDateTime = (value) => {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
hour12: false,
|
||||
timeZone: 'Asia/Shanghai'
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,13 @@
|
||||
<p>训练和优化预测模型</p>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-card @click="$router.push('/algorithm-demo')">
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
<h3>算法演示</h3>
|
||||
<p>切换常用机器学习算法并对比预测效果</p>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-card @click="$router.push('/models')">
|
||||
<el-icon><Management /></el-icon>
|
||||
@ -53,7 +60,7 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-plus/icons-vue'
|
||||
import { Money, DataAnalysis, Monitor, Management, Collection, TrendCharts } from '@element-plus/icons-vue'
|
||||
</script>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
@ -98,4 +105,4 @@ import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-p
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</style>
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
{{ scope.row.rmse.toFixed(2) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="training_date" label="训练时间">
|
||||
<el-table-column prop="training_date" label="训练时间" width="180">
|
||||
<template #default="scope">
|
||||
{{ formatDateTime(scope.row.training_date) }}
|
||||
</template>
|
||||
@ -211,10 +211,12 @@ const renderImportanceChart = () => {
|
||||
// 格式化模型类型
|
||||
const formatModelType = (type) => {
|
||||
const typeMap = {
|
||||
'pytorch': 'PyTorch',
|
||||
'xgboost': 'XGBoost',
|
||||
'lightgbm': 'LightGBM',
|
||||
'gbdt': 'GBDT',
|
||||
'rf': 'Random Forest'
|
||||
'gbm': 'GBM',
|
||||
'rf': 'Random Forest',
|
||||
'pls': 'PLS回归'
|
||||
}
|
||||
return typeMap[type] || type
|
||||
}
|
||||
@ -230,7 +232,8 @@ const formatDateTime = (value) => {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
hour12: false,
|
||||
timeZone: 'Asia/Shanghai'
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -223,72 +223,46 @@ import { API_BASE_URL } from '@/config'
|
||||
|
||||
const formData = reactive({
|
||||
type: '',
|
||||
length_m: null,
|
||||
width_m: null,
|
||||
height_m: null,
|
||||
weight_kg: null,
|
||||
max_range_km: null,
|
||||
// 火箭炮特有参数
|
||||
firing_angle_horizontal: null,
|
||||
firing_angle_vertical: null,
|
||||
rocket_length_m: null,
|
||||
rocket_diameter_mm: null,
|
||||
rocket_weight_kg: null,
|
||||
rate_of_fire: null,
|
||||
combat_weight_kg: null,
|
||||
speed_kmh: null,
|
||||
min_range_km: null,
|
||||
mobility_type: '',
|
||||
structure_layout: '',
|
||||
engine_model: '',
|
||||
engine_params: '',
|
||||
power_hp: null,
|
||||
travel_range_km: null,
|
||||
// 巡飞弹特有参数 - 补充
|
||||
max_payload_kg: null, // 最大载荷
|
||||
ceiling_altitude_m: null, // 升限
|
||||
combat_radius_km: null, // 作战半径
|
||||
engine_power_kw: null, // 发动机功率
|
||||
engine_thrust_n: null, // 发动机推力
|
||||
datalink_range_km: null, // 通信链路距离
|
||||
guidance_accuracy_m: null, // 制导精度
|
||||
min_altitude_m: null, // 最小作战高度
|
||||
max_altitude_m: null, // 最大作战高度
|
||||
|
||||
// 特征工程参数
|
||||
length_width_ratio: null, // 长宽比
|
||||
weight_range_ratio: null, // 重量/射程比
|
||||
speed_weight_ratio: null, // 速度/重量比
|
||||
guidance_system_score: null, // 制导系统复杂度评分
|
||||
warhead_power_score: null // 战斗部威力评分
|
||||
length_m: 7.35,
|
||||
width_m: 2.4,
|
||||
height_m: 3.1,
|
||||
weight_kg: 13700,
|
||||
max_range_km: 20.4,
|
||||
firing_angle_horizontal: 102,
|
||||
firing_angle_vertical: 55,
|
||||
rocket_length_m: 2.87,
|
||||
rocket_diameter_mm: 122,
|
||||
rocket_weight_kg: 66.6,
|
||||
rate_of_fire: 40,
|
||||
combat_weight_kg: 15000,
|
||||
speed_kmh: 60,
|
||||
min_range_km: 5,
|
||||
mobility_type: '轮式',
|
||||
structure_layout: '6x6轮式底盘',
|
||||
engine_model: 'WD615',
|
||||
engine_params: '6缸直列柴油机',
|
||||
power_hp: 280,
|
||||
travel_range_km: 600,
|
||||
wingspan_m: 2.5,
|
||||
warhead_weight_kg: 20,
|
||||
max_speed_ms: 200,
|
||||
cruise_speed_kmh: 720,
|
||||
endurance_min: 30,
|
||||
warhead_type: '破片杀伤战斗部',
|
||||
launch_mode: '箱式发射',
|
||||
power_system: '电动机',
|
||||
guidance_system: 'GPS/INS/光电',
|
||||
max_payload_kg: 25,
|
||||
ceiling_altitude_m: 5000,
|
||||
combat_radius_km: 100,
|
||||
datalink_range_km: 150,
|
||||
guidance_accuracy_m: 3
|
||||
})
|
||||
|
||||
const predictionResults = ref(null)
|
||||
const mlPrediction = ref(null)
|
||||
const plsPrediction = ref(null)
|
||||
|
||||
const handleTypeChange = () => {
|
||||
// 重置特有参数
|
||||
if (formData.type === '火箭炮') {
|
||||
formData.firing_angle_horizontal = null
|
||||
formData.firing_angle_vertical = null
|
||||
formData.rocket_length_m = null
|
||||
formData.rocket_diameter_mm = null
|
||||
formData.rocket_weight_kg = null
|
||||
formData.rate_of_fire = null
|
||||
} else if (formData.type === '巡飞弹') {
|
||||
formData.wingspan_m = null
|
||||
formData.warhead_weight_kg = null
|
||||
formData.max_speed_ms = null
|
||||
formData.cruise_speed_kmh = null
|
||||
formData.endurance_min = null
|
||||
formData.warhead_type = ''
|
||||
formData.launch_mode = ''
|
||||
formData.power_system = ''
|
||||
formData.guidance_system = ''
|
||||
}
|
||||
}
|
||||
|
||||
const submitForm = async () => {
|
||||
try {
|
||||
// 验证必填字段
|
||||
@ -327,7 +301,7 @@ const submitForm = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取预测结果
|
||||
// 同时调用两个预测接口
|
||||
const [mlResponse, plsResponse] = await Promise.all([
|
||||
axios.post(`${API_BASE_URL}/predict`, formData),
|
||||
axios.post(`${API_BASE_URL}/pls/predict`, formData)
|
||||
@ -396,9 +370,96 @@ const getModelName = (modelType) => {
|
||||
}
|
||||
return modelNames[modelType] || modelType
|
||||
}
|
||||
|
||||
const handleTypeChange = () => {
|
||||
// 清空预测结果
|
||||
predictionResults.value = false
|
||||
mlPrediction.value = null
|
||||
plsPrediction.value = null
|
||||
|
||||
// 重置特有参数
|
||||
if (formData.type === '火箭炮') {
|
||||
// 设置火箭炮的默认值
|
||||
formData.length_m = 7.35
|
||||
formData.width_m = 2.4
|
||||
formData.height_m = 3.1
|
||||
formData.weight_kg = 13700
|
||||
formData.max_range_km = 20.4
|
||||
formData.firing_angle_horizontal = 102
|
||||
formData.firing_angle_vertical = 55
|
||||
formData.rocket_length_m = 2.87
|
||||
formData.rocket_diameter_mm = 122
|
||||
formData.rocket_weight_kg = 66.6
|
||||
formData.rate_of_fire = 40
|
||||
formData.combat_weight_kg = 15000
|
||||
formData.speed_kmh = 60
|
||||
formData.min_range_km = 5
|
||||
formData.mobility_type = '轮式'
|
||||
formData.structure_layout = '6x6轮式底盘'
|
||||
formData.engine_model = 'WD615'
|
||||
formData.engine_params = '6缸直列柴油机'
|
||||
formData.power_hp = 280
|
||||
formData.travel_range_km = 600
|
||||
|
||||
// 清空巡飞弹参数
|
||||
formData.wingspan_m = null
|
||||
formData.warhead_weight_kg = null
|
||||
formData.max_speed_ms = null
|
||||
formData.cruise_speed_kmh = null
|
||||
formData.endurance_min = null
|
||||
formData.warhead_type = ''
|
||||
formData.launch_mode = ''
|
||||
formData.power_system = ''
|
||||
formData.guidance_system = ''
|
||||
formData.max_payload_kg = null
|
||||
formData.ceiling_altitude_m = null
|
||||
formData.combat_radius_km = null
|
||||
formData.datalink_range_km = null
|
||||
formData.guidance_accuracy_m = null
|
||||
|
||||
} else if (formData.type === '巡飞弹') {
|
||||
// 设置巡飞弹的默认值
|
||||
formData.length_m = 2.5
|
||||
formData.width_m = 0.4
|
||||
formData.height_m = 0.4
|
||||
formData.weight_kg = 120
|
||||
formData.max_range_km = 100
|
||||
formData.wingspan_m = 2.5
|
||||
formData.warhead_weight_kg = 20
|
||||
formData.max_speed_ms = 200
|
||||
formData.cruise_speed_kmh = 720
|
||||
formData.endurance_min = 30
|
||||
formData.warhead_type = '破片杀伤战斗部'
|
||||
formData.launch_mode = '箱式发射'
|
||||
formData.power_system = '电动机'
|
||||
formData.guidance_system = 'GPS/INS/光电'
|
||||
formData.max_payload_kg = 25
|
||||
formData.ceiling_altitude_m = 5000
|
||||
formData.combat_radius_km = 100
|
||||
formData.datalink_range_km = 150
|
||||
formData.guidance_accuracy_m = 3
|
||||
|
||||
// 清空火箭炮参数
|
||||
formData.firing_angle_horizontal = null
|
||||
formData.firing_angle_vertical = null
|
||||
formData.rocket_length_m = null
|
||||
formData.rocket_diameter_mm = null
|
||||
formData.rocket_weight_kg = null
|
||||
formData.rate_of_fire = null
|
||||
formData.combat_weight_kg = null
|
||||
formData.speed_kmh = null
|
||||
formData.min_range_km = null
|
||||
formData.mobility_type = ''
|
||||
formData.structure_layout = ''
|
||||
formData.engine_model = ''
|
||||
formData.engine_params = ''
|
||||
formData.power_hp = null
|
||||
formData.travel_range_km = null
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
<style lang="scss" scoped>
|
||||
.predict-page {
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
<template>
|
||||
<div class="training-page">
|
||||
<!-- 上部分:模型训练区域 -->
|
||||
<el-card class="training-card">
|
||||
<template #header>
|
||||
<h2>模型训练</h2>
|
||||
@ -38,11 +39,12 @@
|
||||
|
||||
<el-form-item label="选择模型">
|
||||
<el-checkbox-group v-model="trainingConfig.models">
|
||||
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
|
||||
<el-checkbox value="pytorch" checked>PyTorch</el-checkbox>
|
||||
<el-checkbox value="xgboost" checked>XGBoost</el-checkbox>
|
||||
<el-checkbox value="lightgbm" checked>LightGBM</el-checkbox>
|
||||
<el-checkbox value="gbm" checked>GBM</el-checkbox>
|
||||
<el-checkbox value="rf" checked>Random Forest</el-checkbox>
|
||||
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
|
||||
</el-checkbox-group>
|
||||
</el-form-item>
|
||||
|
||||
@ -61,8 +63,8 @@
|
||||
<div class="best-model-info" v-if="trainingResult.best_model">
|
||||
<h4>最佳模型: {{ getModelName(trainingResult.best_model.type) }}</h4>
|
||||
<p>R²分数: {{ formatNumber(trainingResult.best_model.r2) }}</p>
|
||||
<p>MAE: {{ formatNumber(trainingResult.best_model.mae) }} 元</p>
|
||||
<p>RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} 元</p>
|
||||
<p>MAE: {{ formatNumber(trainingResult.best_model.mae) }} 美元</p>
|
||||
<p>RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} 美元</p>
|
||||
</div>
|
||||
|
||||
<!-- 所有模型评估结果 -->
|
||||
@ -80,12 +82,12 @@
|
||||
{{ formatNumber(scope.row.train.r2) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="train.mae" label="MAE (元)" width="150">
|
||||
<el-table-column prop="train.mae" label="MAE (美元)" width="150">
|
||||
<template #default="scope">
|
||||
{{ formatNumber(scope.row.train.mae) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="train.rmse" label="RMSE (元)" width="150">
|
||||
<el-table-column prop="train.rmse" label="RMSE (美元)" width="150">
|
||||
<template #default="scope">
|
||||
{{ formatNumber(scope.row.train.rmse) }}
|
||||
</template>
|
||||
@ -99,12 +101,12 @@
|
||||
{{ formatNumber(scope.row.validation.r2) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="validation.mae" label="MAE (元)" width="150">
|
||||
<el-table-column prop="validation.mae" label="MAE (美元)" width="150">
|
||||
<template #default="scope">
|
||||
{{ formatNumber(scope.row.validation.mae) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="validation.rmse" label="RMSE (元)" width="150">
|
||||
<el-table-column prop="validation.rmse" label="RMSE (美元)" width="150">
|
||||
<template #default="scope">
|
||||
{{ formatNumber(scope.row.validation.rmse) }}
|
||||
</template>
|
||||
@ -130,6 +132,159 @@
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<!-- 下部分:模型简介区域 -->
|
||||
<el-card class="model-intro-card">
|
||||
<template #header>
|
||||
<h2>模型简介</h2>
|
||||
</template>
|
||||
|
||||
<el-collapse>
|
||||
<el-collapse-item name="pytorch">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">PyTorch</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>深度学习框架,可以构建复杂的神经网络结构</li>
|
||||
<li>分别处理装备特征和生产商特征,然后合并进行预测</li>
|
||||
<li>使用批量归一化和Dropout防止过拟合</li>
|
||||
<li>适合处理非线性关系和复杂特征交互</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>强大的特征学习能力</li>
|
||||
<li>可以自动学习特征之间的复杂关系</li>
|
||||
<li>灵活的网络结构设计</li>
|
||||
<li>支持GPU加速训练</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="xgboost">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">XGBoost</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>基于梯度提升树的集成学习算法</li>
|
||||
<li>使用二阶导数进行优化</li>
|
||||
<li>内置正则化机制防止过拟合</li>
|
||||
<li>支持特征重要性评估</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>优秀的预测性能</li>
|
||||
<li>处理缺失值的能力强</li>
|
||||
<li>训练速度快</li>
|
||||
<li>可解释性好</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="lightgbm">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">LightGBM</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>微软开发的轻量级梯度提升框架</li>
|
||||
<li>使用直方图算法优化训练速度</li>
|
||||
<li>支持类别特征的高效处理</li>
|
||||
<li>叶子优先的生长策略</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>训练速度非常快</li>
|
||||
<li>内存占用低</li>
|
||||
<li>支持大规模数据训练</li>
|
||||
<li>准确率高</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="gbm">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">Gradient Boosting (GBM)</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>经典的梯度提升算法</li>
|
||||
<li>逐步减少残差的思想</li>
|
||||
<li>可以使用不同的损失函数</li>
|
||||
<li>支持特征重要性分析</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>稳定的性能表现</li>
|
||||
<li>较好的可解释性</li>
|
||||
<li>对异常值不敏感</li>
|
||||
<li>适合各种回归问题</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="rf">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">Random Forest</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>基于决策树的集成学习方法</li>
|
||||
<li>使用随机采样和特征选择</li>
|
||||
<li>多个决策树投票或平均</li>
|
||||
<li>自带特征重要性评估</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>不易过拟合</li>
|
||||
<li>训练过程可并行化</li>
|
||||
<li>对噪声数据鲁棒</li>
|
||||
<li>较少的参数调整</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="pls">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">PLS回归</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>偏最小二乘回归</li>
|
||||
<li>同时考虑自变量和因变量的变异</li>
|
||||
<li>处理多重共线性问题</li>
|
||||
<li>降维和回归的结合</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>适合小样本数据</li>
|
||||
<li>处理变量间相关性强的数据</li>
|
||||
<li>计算效率高</li>
|
||||
<li>结果稳定可靠</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
</el-collapse>
|
||||
</el-card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@ -144,7 +299,7 @@ const trainingConfig = ref({
|
||||
type: '',
|
||||
train_dataset_id: null,
|
||||
validation_dataset_id: null,
|
||||
models: ['xgboost', 'lightgbm', 'gbm', 'rf']
|
||||
models: ['pytorch', 'xgboost', 'lightgbm', 'gbm', 'rf']
|
||||
})
|
||||
|
||||
// 数据集列表
|
||||
@ -241,10 +396,12 @@ const formatNumber = (value) => {
|
||||
// 获取模型中文名称
|
||||
const getModelName = (modelType) => {
|
||||
const modelNames = {
|
||||
'pytorch': 'PyTorch',
|
||||
'xgboost': 'XGBoost',
|
||||
'lightgbm': 'LightGBM',
|
||||
'gbm': 'GBM',
|
||||
'rf': 'Random Forest'
|
||||
'rf': 'Random Forest',
|
||||
'pls': 'PLS回归'
|
||||
}
|
||||
return modelNames[modelType] || modelType
|
||||
}
|
||||
@ -334,8 +491,13 @@ onMounted(() => {
|
||||
<style lang="scss" scoped>
|
||||
.training-page {
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
|
||||
.training-card {
|
||||
width: 100%;
|
||||
|
||||
.training-result {
|
||||
margin-top: 20px;
|
||||
|
||||
@ -366,5 +528,62 @@ onMounted(() => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.model-intro-card {
|
||||
width: 100%;
|
||||
|
||||
.model-title {
|
||||
.el-link {
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
|
||||
&:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
&:active {
|
||||
opacity: 0.6;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.model-intro {
|
||||
padding: 10px;
|
||||
|
||||
h4 {
|
||||
margin: 10px 0;
|
||||
color: #409EFF;
|
||||
font-size: 15px;
|
||||
}
|
||||
|
||||
ul {
|
||||
padding-left: 20px;
|
||||
margin: 5px 0;
|
||||
|
||||
li {
|
||||
line-height: 1.8;
|
||||
color: #606266;
|
||||
font-size: 14px;
|
||||
|
||||
&:hover {
|
||||
color: #409EFF;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.el-collapse-item__header) {
|
||||
padding: 12px 0;
|
||||
font-size: 16px;
|
||||
|
||||
&:hover {
|
||||
background-color: #f5f7fa;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.el-collapse-item__content) {
|
||||
padding: 10px 20px;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@ -19,7 +19,7 @@ export default defineConfig({
|
||||
port: 3000,
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:5000',
|
||||
target: 'http://localhost:5001',
|
||||
changeOrigin: true
|
||||
}
|
||||
}
|
||||
|
||||
11
html5_cost_prediction/README.txt
Normal file
11
html5_cost_prediction/README.txt
Normal file
@ -0,0 +1,11 @@
|
||||
智能成本预测系统 - HTML5离线版
|
||||
|
||||
运行方式:
|
||||
1. 解压 zip 文件。
|
||||
2. 双击 index.html。
|
||||
|
||||
说明:
|
||||
- 不需要 Python。
|
||||
- 不需要数据库。
|
||||
- 不需要联网。
|
||||
- 页面内置样例数据和模型效果,用于客户现场展示不同模型的预测差异。
|
||||
1324
html5_cost_prediction/index.html
Normal file
1324
html5_cost_prediction/index.html
Normal file
File diff suppressed because one or more lines are too long
57
pyproject.toml
Normal file
57
pyproject.toml
Normal file
@ -0,0 +1,57 @@
|
||||
[project]
|
||||
name = "cost-prediction"
|
||||
version = "0.1.0"
|
||||
description = "装备成本预测系统"
|
||||
requires-python = ">=3.9,<3.12"
|
||||
readme = "README.md"
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
dependencies = [
|
||||
# Web框架
|
||||
"flask>=3.1.0",
|
||||
"flask-cors>=5.0.0",
|
||||
|
||||
# 数据处理
|
||||
"numpy>=1.26.0,<2.0.0",
|
||||
"pandas>=2.2.0",
|
||||
|
||||
# 机器学习
|
||||
"scikit-learn>=1.5.2",
|
||||
"xgboost>=2.1.0",
|
||||
"lightgbm>=4.5.0",
|
||||
|
||||
# 工具
|
||||
"openpyxl>=3.1.5",
|
||||
"python-dotenv>=1.0.0",
|
||||
"requests>=2.31.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# PyTorch 为可选依赖(安装约 800MB,仅训练神经网络时需要)
|
||||
torch = [
|
||||
"torch==2.5.1",
|
||||
]
|
||||
dev = [
|
||||
# 测试工具
|
||||
"pytest>=7.0",
|
||||
"black>=22.0", # 代码格式化
|
||||
"mypy>=1.0", # 类型检查
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py39", "py310", "py311"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
flask==2.0.1
|
||||
flask-cors==3.0.10
|
||||
sqlalchemy==1.4.23
|
||||
pymysql==1.0.2
|
||||
cryptography==3.4.7 # MySQL 8.0+ 认证需要
|
||||
numpy==1.21.2
|
||||
pandas==1.3.3
|
||||
scikit-learn==0.24.2
|
||||
tensorflow==2.6.0
|
||||
urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
|
||||
openpyxl==3.1.2 # 用于读取 .xlsx 文件
|
||||
xlrd==2.0.1 # 用于读取 .xls 文件
|
||||
flask>=3.1.0
|
||||
flask-cors>=5.0.0
|
||||
|
||||
numpy>=1.26.0,<2.0.0
|
||||
pandas>=2.2.0
|
||||
|
||||
xgboost>=2.1.0
|
||||
lightgbm>=4.5.0
|
||||
scikit-learn>=1.5.2
|
||||
|
||||
openpyxl>=3.1.5
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
# 火箭炮系统技术参数示例
|
||||
|
||||
## 伊朗“胜利”-2 240mm (12管)火箭炮系统
|
||||
|
||||
产品类别: 多管火箭炮
|
||||
型号: “胜利”-2 240mm 多管火箭炮
|
||||
尺寸与重量
|
||||
总长: 10m(393.7in)
|
||||
宽(行军状态): 2.5m(98.4in)
|
||||
高(行军状态) 3.34m(131.5in)
|
||||
标准重: 15000kg(33069 lb)(15.0t)
|
||||
战斗重: 19900kg(43871 lb)(19.9t)
|
||||
机动性
|
||||
行走装置: 轮式
|
||||
布局: 6×6
|
||||
两栖: 无
|
||||
火力
|
||||
方向射界: 100º(1778mils)(左/90°右)
|
||||
高低射界(武器前方): 57°(1013mils)
|
||||
型号: “胜利”2火箭弹
|
||||
尺寸与重量
|
||||
总长: 3.550m(11ft)
|
||||
弹体直径: 512mm(20.16in)(尾翼展开)
|
||||
发射(重量): 275kg(606 lb)
|
||||
性能
|
||||
速度(最大速度): 1302kt(2412km/h;1499mph;670m/s)
|
||||
最大射程 12.4n miles(23km;14.3miles)
|
||||
武器组成
|
||||
战斗部: 85kg(187 lb)
|
||||
33
run.py
33
run.py
@ -1,13 +1,26 @@
|
||||
from src.app import create_app
|
||||
import logging
|
||||
from src import create_app
|
||||
from src.logger import setup_logger
|
||||
from config import config
|
||||
import os
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
def main():
|
||||
try:
|
||||
# 创建必要的目录
|
||||
os.makedirs(config.MODEL_DIR, exist_ok=True)
|
||||
os.makedirs(config.LOG_DIR, exist_ok=True)
|
||||
os.makedirs(config.DATA_DIR, exist_ok=True)
|
||||
|
||||
app = create_app()
|
||||
app.run(
|
||||
host=config.FLASK_HOST,
|
||||
port=config.FLASK_PORT,
|
||||
debug=config.FLASK_DEBUG
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting application: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('=== Server Starting ===')
|
||||
logging.info('Initializing directories...')
|
||||
|
||||
app.run(host='0.0.0.0', port=5001, debug=True)
|
||||
main()
|
||||
57
scripts/build_demo_zip.ps1
Normal file
57
scripts/build_demo_zip.ps1
Normal file
@ -0,0 +1,57 @@
|
||||
param(
|
||||
[string]$OutputPath = "release\algorithm-demo-standalone.zip"
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
|
||||
$releaseRoot = Join-Path $repoRoot "release"
|
||||
$stageDir = Join-Path $releaseRoot "algorithm-demo-standalone"
|
||||
$zipPath = Join-Path $repoRoot $OutputPath
|
||||
|
||||
function Assert-InRepo([string]$PathToCheck) {
|
||||
$resolved = [System.IO.Path]::GetFullPath($PathToCheck)
|
||||
$root = [System.IO.Path]::GetFullPath($repoRoot)
|
||||
if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) {
|
||||
throw "Refusing to operate outside repository: $resolved"
|
||||
}
|
||||
}
|
||||
|
||||
Assert-InRepo $stageDir
|
||||
Assert-InRepo $zipPath
|
||||
|
||||
Push-Location (Join-Path $repoRoot "frontend")
|
||||
try {
|
||||
npm run build
|
||||
}
|
||||
finally {
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null
|
||||
|
||||
if (Test-Path $stageDir) {
|
||||
Remove-Item -LiteralPath $stageDir -Recurse -Force
|
||||
}
|
||||
if (Test-Path $zipPath) {
|
||||
Remove-Item -LiteralPath $zipPath -Force
|
||||
}
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $stageDir | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path (Join-Path $stageDir "data") | Out-Null
|
||||
|
||||
Copy-Item -Recurse -Path (Join-Path $repoRoot "frontend\dist") -Destination (Join-Path $stageDir "frontend")
|
||||
Copy-Item -Path (Join-Path $repoRoot "src\demo_service.py") -Destination (Join-Path $stageDir "demo_service.py")
|
||||
Copy-Item -Path (Join-Path $repoRoot "data\demo_equipment_costs.csv") -Destination (Join-Path $stageDir "data\demo_equipment_costs.csv")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\server.py") -Destination (Join-Path $stageDir "server.py")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\requirements.txt") -Destination (Join-Path $stageDir "requirements.txt")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\start_demo.bat") -Destination (Join-Path $stageDir "start_demo.bat")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\README.md") -Destination (Join-Path $stageDir "README.md")
|
||||
|
||||
Get-ChildItem -Path $stageDir -Recurse -Include "*.map", "__pycache__" | ForEach-Object {
|
||||
Remove-Item -LiteralPath $_.FullName -Recurse -Force
|
||||
}
|
||||
|
||||
Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force
|
||||
|
||||
Write-Host "Demo zip created: $zipPath"
|
||||
36
scripts/build_html5_zip.ps1
Normal file
36
scripts/build_html5_zip.ps1
Normal file
@ -0,0 +1,36 @@
|
||||
param(
|
||||
[string]$OutputPath = "release\intelligent-cost-prediction-html5.zip"
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
|
||||
$sourceDir = Join-Path $repoRoot "html5_cost_prediction"
|
||||
$releaseRoot = Join-Path $repoRoot "release"
|
||||
$stageDir = Join-Path $releaseRoot "intelligent-cost-prediction-html5"
|
||||
$zipPath = Join-Path $repoRoot $OutputPath
|
||||
|
||||
function Assert-InRepo([string]$PathToCheck) {
|
||||
$resolved = [System.IO.Path]::GetFullPath($PathToCheck)
|
||||
$root = [System.IO.Path]::GetFullPath($repoRoot)
|
||||
if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) {
|
||||
throw "Refusing to operate outside repository: $resolved"
|
||||
}
|
||||
}
|
||||
|
||||
Assert-InRepo $stageDir
|
||||
Assert-InRepo $zipPath
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null
|
||||
|
||||
if (Test-Path $stageDir) {
|
||||
Remove-Item -LiteralPath $stageDir -Recurse -Force
|
||||
}
|
||||
if (Test-Path $zipPath) {
|
||||
Remove-Item -LiteralPath $zipPath -Force
|
||||
}
|
||||
|
||||
Copy-Item -Recurse -Path $sourceDir -Destination $stageDir
|
||||
Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force
|
||||
|
||||
Write-Host "HTML5 zip created: $zipPath"
|
||||
87
scripts/build_linux.sh
Normal file
87
scripts/build_linux.sh
Normal file
@ -0,0 +1,87 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 确保脚本在错误时退出
|
||||
set -e
|
||||
|
||||
echo "Starting packaging for Linux..."
|
||||
|
||||
# 创建虚拟环境
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# 安装依赖
|
||||
echo "Installing dependencies..."
|
||||
pip install -e .
|
||||
|
||||
# 构建前端
|
||||
echo "Building frontend..."
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
# 把构建好的文件直接复制到 frontend 目录
|
||||
cp -r dist/* .
|
||||
rm -rf dist
|
||||
cd ..
|
||||
|
||||
# 创建必要的目录
|
||||
mkdir -p logs data models
|
||||
|
||||
# 使用 PyInstaller 打包
|
||||
echo "Packaging with PyInstaller..."
|
||||
pyinstaller --clean \
|
||||
--add-data "src/loitering_munition_data.sql:data" \
|
||||
--add-data "src/rocket_artillery_data.sql:data" \
|
||||
--add-data "src/manufacturer_data.sql:data" \
|
||||
--add-data "src/schema.sql:data" \
|
||||
--add-data "config.py:." \
|
||||
--add-data "src:src" \
|
||||
--add-data "frontend:frontend" \
|
||||
--add-data "logs:logs" \
|
||||
--add-data "data:data" \
|
||||
--add-data "models:models" \
|
||||
--collect-all "xgboost" \
|
||||
--collect-all "lightgbm" \
|
||||
--collect-all "torch" \
|
||||
--collect-all "sklearn" \
|
||||
--collect-all "numpy" \
|
||||
--collect-all "pandas" \
|
||||
--collect-all "sqlalchemy" \
|
||||
--collect-all "pymysql" \
|
||||
--collect-all "cryptography" \
|
||||
--collect-all "flask" \
|
||||
--collect-all "flask_cors" \
|
||||
--hidden-import "xgboost.testing" \
|
||||
--hidden-import "torch.utils.tensorboard" \
|
||||
--hidden-import "pytest" \
|
||||
run.py
|
||||
|
||||
# 创建启动脚本
|
||||
echo "Creating start script..."
|
||||
cat > src/start.sh << 'EOF'
|
||||
#!/bin/bash
|
||||
export FLASK_DEBUG=false
|
||||
export MYSQL_HOST=localhost
|
||||
export MYSQL_USER=root
|
||||
export MYSQL_PASSWORD=123456
|
||||
export MYSQL_DB=equipment_cost_db
|
||||
echo "Starting Cost Prediction System..."
|
||||
./run
|
||||
xdg-open http://localhost:5001
|
||||
EOF
|
||||
|
||||
# 复制启动脚本
|
||||
cp src/start.sh dist/run/
|
||||
chmod +x dist/run/start.sh
|
||||
|
||||
# 创建发布包
|
||||
echo "Creating release package..."
|
||||
version=$(grep "version" pyproject.toml | cut -d'"' -f2)
|
||||
mkdir -p dist/release
|
||||
cp -r dist/run/* dist/release/
|
||||
|
||||
# 创建 tar.gz 包
|
||||
cd dist
|
||||
tar czf "cost-prediction-${version}-linux.tar.gz" release/
|
||||
cd ..
|
||||
|
||||
echo "Package completed: dist/cost-prediction-${version}-linux.tar.gz"
|
||||
70
scripts/build_win.ps1
Normal file
70
scripts/build_win.ps1
Normal file
@ -0,0 +1,70 @@
|
||||
# Set console encoding to UTF-8
|
||||
[Console]::OutputEncoding = [System.Text.Encoding]::UTF8
|
||||
$OutputEncoding = [System.Text.Encoding]::UTF8
|
||||
|
||||
# Ensure PowerShell stops on error
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# Create virtual environment
|
||||
Write-Host "Creating virtual environment..."
|
||||
python -m venv .venv
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
|
||||
# Install dependencies
|
||||
Write-Host "Installing dependencies..."
|
||||
pip install -e .
|
||||
|
||||
# Build frontend
|
||||
Write-Host "Building frontend..."
|
||||
Push-Location frontend
|
||||
npm install
|
||||
npm run build
|
||||
Copy-Item dist/* . -Recurse -Force
|
||||
Remove-Item dist -Recurse -Force
|
||||
Pop-Location
|
||||
|
||||
# Package with PyInstaller
|
||||
Write-Host "Starting packaging..."
|
||||
# Create necessary directories
|
||||
New-Item -ItemType Directory -Force -Path "logs"
|
||||
New-Item -ItemType Directory -Force -Path "data"
|
||||
New-Item -ItemType Directory -Force -Path "models"
|
||||
|
||||
pyinstaller --clean `
|
||||
--add-data "src/loitering_munition_data.sql;data" `
|
||||
--add-data "src/rocket_artillery_data.sql;data" `
|
||||
--add-data "src/manufacturer_data.sql;data" `
|
||||
--add-data "config.py;." `
|
||||
--add-data "src;src" `
|
||||
--add-data "frontend;frontend" `
|
||||
--add-data "logs;logs" `
|
||||
--add-data "data;data" `
|
||||
--add-data "models;models" `
|
||||
--collect-all "xgboost" `
|
||||
--collect-all "lightgbm" `
|
||||
--collect-all "sklearn" `
|
||||
--collect-all "numpy" `
|
||||
--collect-all "pandas" `
|
||||
--collect-all "flask" `
|
||||
--collect-all "flask_cors" `
|
||||
run.py
|
||||
|
||||
# Copy necessary files
|
||||
Write-Host "Copying configuration files..."
|
||||
Copy-Item "src/start.bat" -Destination "dist/run"
|
||||
|
||||
# Create release package
|
||||
Write-Host "Creating release package..."
|
||||
$version = (Get-Content pyproject.toml | Select-String 'version = "(.*?)"').Matches.Groups[1].Value
|
||||
|
||||
# Create complete offline installation package directory
|
||||
$RELEASE_DIR = "dist/release"
|
||||
New-Item -ItemType Directory -Force -Path $RELEASE_DIR
|
||||
|
||||
# Copy application files
|
||||
Copy-Item "dist/run/*" -Destination $RELEASE_DIR -Recurse
|
||||
|
||||
# Create final zip package
|
||||
Compress-Archive -Path "$RELEASE_DIR/*" -DestinationPath "cost-prediction-$version-win64.zip" -Force
|
||||
|
||||
Write-Host "Package completed: cost-prediction-$version-win64.zip"
|
||||
121
scripts/setup_env.ps1
Normal file
121
scripts/setup_env.ps1
Normal file
@ -0,0 +1,121 @@
|
||||
# 设置错误操作首选项
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# 检查管理员权限
|
||||
$isAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
|
||||
if (-not $isAdmin) {
|
||||
Write-Warning "建议使用管理员权限运行此脚本"
|
||||
Start-Sleep -Seconds 3
|
||||
}
|
||||
|
||||
# 检查 pyenv-win 是否安装
|
||||
if (!(Get-Command pyenv -ErrorAction SilentlyContinue)) {
|
||||
Write-Host "pyenv not found. Installing..."
|
||||
try {
|
||||
# 下载并安装 pyenv-win
|
||||
Invoke-WebRequest -UseBasicParsing -Uri "https://raw.githubusercontent.com/pyenv-win/pyenv-win/master/pyenv-win/install-pyenv-win.ps1" -OutFile "./install-pyenv-win.ps1"
|
||||
& ./install-pyenv-win.ps1
|
||||
|
||||
# 添加环境变量
|
||||
$env:PYENV = "$env:USERPROFILE\.pyenv\pyenv-win"
|
||||
$env:Path = "$env:PYENV\bin;$env:PYENV\shims;$env:Path"
|
||||
|
||||
# 刷新环境变量
|
||||
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
}
|
||||
catch {
|
||||
Write-Error "Failed to install pyenv: $_"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
# 安装指定版本的 Python
|
||||
Write-Host "Installing Python 3.11.8..."
|
||||
pyenv install 3.11.8
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "Failed to install Python 3.11.8"
|
||||
}
|
||||
|
||||
# 设置本地 Python 版本
|
||||
Write-Host "Setting local Python version..."
|
||||
pyenv local 3.11.8
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "Failed to set local Python version"
|
||||
}
|
||||
|
||||
# 验证 Python 版本
|
||||
$pythonVersion = python -V
|
||||
if (-not $pythonVersion.Contains("3.11.8")) {
|
||||
throw "Wrong Python version: $pythonVersion"
|
||||
}
|
||||
Write-Host "Using Python version: $pythonVersion"
|
||||
|
||||
# 创建虚拟环境
|
||||
Write-Host "Creating virtual environment..."
|
||||
python -m venv .venv
|
||||
|
||||
# 激活虚拟环境
|
||||
Write-Host "Activating virtual environment..."
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
|
||||
# 升级 pip 和构建工具
|
||||
Write-Host "Upgrading pip and build tools..."
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
# 分步安装依赖以确保正确的顺序和版本
|
||||
Write-Host "Installing database dependencies..."
|
||||
pip install mysql-connector-python==8.0.33
|
||||
|
||||
Write-Host "Installing PyTorch and related packages..."
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
Write-Host "Installing basic dependencies..."
|
||||
pip install numpy==1.26.4 pandas==2.2.1
|
||||
|
||||
Write-Host "Installing machine learning packages..."
|
||||
pip install scikit-learn==1.5.2
|
||||
|
||||
# 安装开发依赖
|
||||
Write-Host "Installing development dependencies..."
|
||||
pip install -e ".[dev]"
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Warning "Failed to install development dependencies. Installing core package..."
|
||||
pip install -e .
|
||||
}
|
||||
|
||||
# 验证安装
|
||||
Write-Host "Verifying installations..."
|
||||
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
|
||||
python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"
|
||||
python -c "import pandas; print(f'Pandas version: {pandas.__version__}')"
|
||||
python -c "import sklearn; print(f'Scikit-learn version: {sklearn.__version__}')"
|
||||
|
||||
Write-Host "Environment setup complete!" -ForegroundColor Green
|
||||
}
|
||||
catch {
|
||||
Write-Error "An error occurred: $_"
|
||||
exit 1
|
||||
}
|
||||
finally {
|
||||
# 清理临时文件
|
||||
if (Test-Path "./install-pyenv-win.ps1") {
|
||||
Remove-Item "./install-pyenv-win.ps1"
|
||||
}
|
||||
}
|
||||
|
||||
# 显示使用说明
|
||||
Write-Host @"
|
||||
|
||||
环境设置完成!使用说明:
|
||||
1. 虚拟环境已激活,命令提示符前应该显示 (.venv)
|
||||
2. 要退出虚拟环境,运行: deactivate
|
||||
3. 要重新激活虚拟环境,运行: .\.venv\Scripts\Activate.ps1
|
||||
4. 项目依赖已安装,可以开始开发了
|
||||
|
||||
如果遇到问题,请检查:
|
||||
- Python 版本: python -V
|
||||
- PyTorch 安装: python -c "import torch; print(torch.__version__)"
|
||||
- 虚拟环境状态: 确保看到 (.venv) 前缀
|
||||
|
||||
"@ -ForegroundColor Cyan
|
||||
73
scripts/setup_env.sh
Executable file
73
scripts/setup_env.sh
Executable file
@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 此脚本用于设置 Python 开发环境
|
||||
# 主要用于:
|
||||
# 1. 开发环境初始化
|
||||
# 2. 确保正确的 Python 版本
|
||||
# 3. 安装项目依赖
|
||||
# 注意:在运行此脚本前,请先运行 setup_linux.sh 安装系统依赖
|
||||
|
||||
# 检查 pyenv 是否安装
|
||||
if ! command -v pyenv &> /dev/null; then
|
||||
echo "pyenv not found. Installing..."
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
brew install pyenv
|
||||
else
|
||||
curl https://pyenv.run | bash
|
||||
fi
|
||||
fi
|
||||
|
||||
# 安装指定版本的 Python
|
||||
pyenv install 3.11.8 || true
|
||||
|
||||
# 设置本地 Python 版本
|
||||
pyenv local 3.11.8
|
||||
|
||||
# 确保使用正确的 Python 版本
|
||||
eval "$(pyenv init -)"
|
||||
pyenv shell 3.11.8
|
||||
|
||||
# 验证 Python 版本
|
||||
python_version=$(python -V 2>&1)
|
||||
if [[ $python_version != *"3.11.8"* ]]; then
|
||||
echo "Error: Wrong Python version: $python_version"
|
||||
echo "Please ensure pyenv is properly configured in your shell"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 创建虚拟环境
|
||||
python -m venv .venv
|
||||
|
||||
# 激活虚拟环境
|
||||
source .venv/bin/activate
|
||||
|
||||
# 升级 pip 和构建工具
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
# 分步安装依赖以确保正确的顺序和版本
|
||||
echo "Installing database dependencies..."
|
||||
pip install mysql-connector-python==8.0.33
|
||||
|
||||
echo "Installing PyTorch and related packages..."
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
echo "Installing basic dependencies..."
|
||||
pip install numpy==1.26.4 pandas==2.2.1
|
||||
|
||||
echo "Installing machine learning packages..."
|
||||
pip install scikit-learn==1.5.2
|
||||
|
||||
# 安装开发依赖
|
||||
if ! pip install -e ".[dev]"; then
|
||||
echo "Warning: Failed to install development dependencies. Installing core package..."
|
||||
pip install -e .
|
||||
fi
|
||||
|
||||
# 验证安装
|
||||
echo "Verifying Python version..."
|
||||
python --version
|
||||
|
||||
echo "Verifying PyTorch installation..."
|
||||
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
|
||||
|
||||
echo "Environment setup complete!"
|
||||
36
scripts/setup_linux.sh
Normal file
36
scripts/setup_linux.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 此脚本用于安装 Linux 系统级依赖
|
||||
# 主要用于:
|
||||
# 1. 打包环境准备
|
||||
# 2. 新系统的初始化
|
||||
# 3. CI/CD 环境设置
|
||||
|
||||
# 更新包列表
|
||||
sudo apt update
|
||||
|
||||
# 安装系统依赖
|
||||
echo "Installing system dependencies..."
|
||||
sudo apt install -y \
|
||||
python3.11 \ # Python 3.11 解释器
|
||||
python3.11-venv \ # Python 虚拟环境支持
|
||||
python3-pip \ # Python 包管理器
|
||||
build-essential \ # 编译工具
|
||||
python3.11-dev \ # Python 开发库
|
||||
nodejs \ # Node.js(用于前端构建)
|
||||
npm \ # Node.js 包管理器
|
||||
binutils \ # 二进制工具(用于 PyInstaller)
|
||||
tar \ # 打包工具
|
||||
gzip \ # 压缩工具
|
||||
mysql-client \ # MySQL 客户端
|
||||
libmysqlclient-dev \ # MySQL 开发库
|
||||
gcc \ # C 编译器
|
||||
g++ \ # C++ 编译器
|
||||
libssl-dev \ # SSL 支持
|
||||
xdg-utils # 用于打开浏览器
|
||||
|
||||
# 验证安装
|
||||
echo "Verifying installations..."
|
||||
python3.11 --version
|
||||
node --version
|
||||
npm --version
|
||||
@ -1 +1,3 @@
|
||||
# 这个文件可以为空,但必须存在
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ['create_app']
|
||||
|
||||
BIN
src/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
src/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/app.cpython-313.pyc
Normal file
BIN
src/__pycache__/app.cpython-313.pyc
Normal file
Binary file not shown.
65
src/app.py
65
src/app.py
@ -1,50 +1,57 @@
|
||||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
from flask import send_from_directory
|
||||
from .routes import api_bp
|
||||
from .logger import setup_logger
|
||||
from config import config
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 获取logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
def create_app():
|
||||
"""
|
||||
创建并配置Flask应用
|
||||
"""
|
||||
"""创建并配置 Flask 应用"""
|
||||
try:
|
||||
# 创建必要的目录
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
os.makedirs('data', exist_ok=True)
|
||||
os.makedirs('models', exist_ok=True)
|
||||
|
||||
logger.info("=== Server Starting ===")
|
||||
logger.info("Initializing directories...")
|
||||
|
||||
# 创建Flask应用
|
||||
app = Flask(__name__)
|
||||
|
||||
# 配置CORS
|
||||
CORS(app)
|
||||
logger.info("CORS enabled")
|
||||
|
||||
# 注册API蓝图
|
||||
# 注册路由
|
||||
app.register_blueprint(api_bp, url_prefix='/api')
|
||||
logger.info("API blueprint registered")
|
||||
|
||||
# 配置数据库连接
|
||||
app.config['MYSQL_HOST'] = 'localhost'
|
||||
app.config['MYSQL_USER'] = 'root'
|
||||
app.config['MYSQL_PASSWORD'] = '123456'
|
||||
app.config['MYSQL_DB'] = 'equipment_cost_db'
|
||||
# 获取前端文件路径
|
||||
if getattr(sys, 'frozen', False):
|
||||
# PyInstaller 打包后的路径
|
||||
frontend_path = os.path.join(sys._MEIPASS, 'frontend')
|
||||
logger.info(f"Running in frozen mode, frontend path: {frontend_path}")
|
||||
logger.info(f"MEIPASS path: {sys._MEIPASS}")
|
||||
logger.info(f"Files in frontend dir: {os.listdir(frontend_path)}")
|
||||
else:
|
||||
# 开发环境路径
|
||||
frontend_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'frontend', 'dist')
|
||||
|
||||
logger.info("Starting server...")
|
||||
# 服务前端文件
|
||||
@app.route('/', defaults={'path': ''})
|
||||
@app.route('/<path:path>')
|
||||
def serve_frontend(path):
|
||||
logger.info(f"Serving path: {path}")
|
||||
logger.info(f"Frontend path: {frontend_path}")
|
||||
logger.info(f"Full file path: {os.path.join(frontend_path, path)}")
|
||||
logger.info(f"File exists: {os.path.exists(os.path.join(frontend_path, path))}")
|
||||
try:
|
||||
if path == "":
|
||||
return send_from_directory(frontend_path, 'index.html')
|
||||
file_path = os.path.join(frontend_path, path)
|
||||
if os.path.exists(file_path):
|
||||
return send_from_directory(frontend_path, path)
|
||||
return send_from_directory(frontend_path, 'index.html')
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving file {path}: {str(e)}")
|
||||
return str(e), 500
|
||||
|
||||
return app
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_app()
|
||||
app.run(host='localhost', port=5001)
|
||||
logger.error(f"Error creating application: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
@ -1,17 +1,23 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import tensorflow as tf
|
||||
from scipy import stats
|
||||
import joblib
|
||||
import os
|
||||
import pandas as pd
|
||||
from .feature_analysis import FeatureAnalysis
|
||||
import logging
|
||||
from src.model_trainer import ModelTrainer
|
||||
from src.database import get_db_connection
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from .logger import setup_logger
|
||||
|
||||
# PyTorch 为可选依赖
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
_HAS_TORCH = False
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class CostPredictor:
|
||||
@ -21,161 +27,120 @@ class CostPredictor:
|
||||
self.model = None
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.equipment_type = None
|
||||
|
||||
# 添加 TensorFlow 配置
|
||||
tf.config.run_functions_eagerly(False) # 启用图执行模式
|
||||
|
||||
# 创建预测函数
|
||||
@tf.function(reduce_retracing=True, jit_compile=True)
|
||||
def predict_fn(x):
|
||||
return self.model(x, training=False)
|
||||
|
||||
self._predict_fn = predict_fn
|
||||
|
||||
|
||||
if _HAS_TORCH:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = None
|
||||
|
||||
self.load_model()
|
||||
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
加载预训练型和标准化器
|
||||
加载预训练模型和标准化器
|
||||
"""
|
||||
try:
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 创建默认模型
|
||||
self._create_default_model()
|
||||
|
||||
# 创建预测函数
|
||||
@tf.function(reduce_retracing=True, jit_compile=True)
|
||||
def predict_fn(x):
|
||||
return self.model(x, training=False)
|
||||
|
||||
self._predict_fn = predict_fn
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
self._create_default_model()
|
||||
|
||||
if _HAS_TORCH:
|
||||
try:
|
||||
self._create_default_model()
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
self._create_default_model()
|
||||
|
||||
def _create_default_model(self):
|
||||
"""
|
||||
创建默认模型并进行初始化训练
|
||||
"""
|
||||
# 创建输入层
|
||||
inputs = tf.keras.Input(shape=(11,))
|
||||
|
||||
# 创建隐藏层
|
||||
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
|
||||
x = tf.keras.layers.Dense(32, activation='relu')(x)
|
||||
|
||||
# 创建输出层
|
||||
outputs = tf.keras.layers.Dense(1)(x)
|
||||
|
||||
# 创建模型
|
||||
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
# 编译模型
|
||||
self.model.compile(
|
||||
optimizer='adam',
|
||||
loss=tf.keras.losses.mean_squared_error,
|
||||
metrics=[tf.keras.metrics.mean_absolute_error]
|
||||
)
|
||||
|
||||
# 创建示例数据
|
||||
example_data = pd.DataFrame({
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
'weight_kg': [13700, 28500],
|
||||
if not _HAS_TORCH:
|
||||
raise ImportError("PyTorch is not installed.")
|
||||
|
||||
class DefaultModel(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(input_size, 64), nn.ReLU(),
|
||||
nn.Linear(64, 32), nn.ReLU(),
|
||||
nn.Linear(32, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
example_features = {
|
||||
'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2], 'weight_kg': [13700, 28500],
|
||||
'max_range_km': [20.4, 70],
|
||||
'firing_angle_horizontal': [102, 110],
|
||||
'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1],
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
})
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(example_data)
|
||||
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
|
||||
|
||||
# 设置默认装备类型
|
||||
'firing_angle_horizontal': [102, 110], 'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1], 'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60]
|
||||
}
|
||||
|
||||
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
|
||||
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
|
||||
|
||||
self.scaler_X.fit(X.numpy())
|
||||
self.scaler_y.fit(y.numpy())
|
||||
|
||||
self.model = DefaultModel(X.shape[1]).to(self.device)
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def _create_example_data(self):
|
||||
"""
|
||||
创建示例数据来训练标准化器
|
||||
"""
|
||||
# 火箭炮示例数据
|
||||
rocket_data = pd.DataFrame({
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
'weight_kg': [13700, 28500],
|
||||
'max_range_km': [20.4, 70],
|
||||
'firing_angle_horizontal': [102, 110],
|
||||
'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1],
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
})
|
||||
|
||||
# 巡飞弹示例数据
|
||||
missile_data = pd.DataFrame({
|
||||
'length_m': [1.3, 2.5],
|
||||
'width_m': [0.23, 0.6],
|
||||
'height_m': [0.23, 0.6],
|
||||
'weight_kg': [12.5, 135],
|
||||
'max_range_km': [40, 250],
|
||||
'max_speed_kmh': [180, 185],
|
||||
'cruise_speed_kmh': [100, 110],
|
||||
'flight_time_min': [60, 120],
|
||||
'folded_length_mm': [1300, 2500],
|
||||
'folded_width_mm': [230, 600],
|
||||
'folded_height_mm': [230, 600]
|
||||
})
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
|
||||
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
|
||||
|
||||
# 设置默认装备类型
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def predict(self, data):
|
||||
"""
|
||||
使用训练好的最优模型进行预测
|
||||
"""
|
||||
def predict(self, data, model_record):
|
||||
"""使用训练好的模型进行预测"""
|
||||
try:
|
||||
logger.info(f"Starting prediction for {data.get('type')}")
|
||||
logger.info(f"Starting prediction for {data.get('type')} using {model_record['model_type']}")
|
||||
equipment_type = data.get('type')
|
||||
|
||||
# 加载已训练的最优模型
|
||||
trainer = ModelTrainer()
|
||||
if not trainer.load_model(equipment_type):
|
||||
raise ValueError(f"No trained model found for {equipment_type}")
|
||||
# 使用ModelTrainer加载模型
|
||||
model_trainer = ModelTrainer()
|
||||
success = model_trainer.load_model(equipment_type, model_record['model_type'])
|
||||
if not success:
|
||||
raise ValueError(f"Failed to load model for {equipment_type}")
|
||||
|
||||
# 从ModelTrainer获取模型和标准化器
|
||||
model = model_trainer.model
|
||||
feature_scaler = model_trainer.feature_scaler
|
||||
target_scaler = model_trainer.target_scaler
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
feature_analyzer = FeatureAnalysis()
|
||||
features = feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = []
|
||||
for feature in features:
|
||||
value = data.get(feature, 0.0)
|
||||
X.append(float(value))
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
# 转换为numpy数组并标准化
|
||||
X = np.array([X])
|
||||
X_scaled = feature_scaler.transform(X)
|
||||
|
||||
# 根据模型类型进行预测
|
||||
if isinstance(model, torch.nn.Module):
|
||||
# PyTorch模型预测
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
X_tensor = torch.FloatTensor(X_scaled).to(self.device)
|
||||
y_pred = model(X_tensor)
|
||||
y_pred = y_pred.cpu().numpy()
|
||||
elif model_record['model_type'] == 'pls':
|
||||
# PLS模型预测
|
||||
y_pred = model.predict(X_scaled).reshape(-1, 1)
|
||||
else:
|
||||
# 其他sklearn模型预测
|
||||
y_pred = model.predict(X_scaled).reshape(-1, 1)
|
||||
|
||||
# 转换回原始尺度并确保为正数
|
||||
y_pred_original = target_scaler.inverse_transform(y_pred)
|
||||
predicted_cost = abs(float(y_pred_original[0][0])) # 确保预测值为正数
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
# 获取模型类型
|
||||
model_type = trainer.get_model_type()
|
||||
std = predicted_cost * 0.2 # 使用预测值的20%作为标准差
|
||||
confidence_interval = {
|
||||
'lower': max(predicted_cost - std, predicted_cost * 0.5), # 至少是预测值的50%
|
||||
'upper': predicted_cost + std
|
||||
}
|
||||
|
||||
return {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'model_type': model_type, # 返回使用的模型类型
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
'predicted_cost': predicted_cost,
|
||||
'confidence_interval': confidence_interval
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -187,11 +152,10 @@ class CostPredictor:
|
||||
计算预测值的置信区间
|
||||
"""
|
||||
try:
|
||||
# 使用预测值的20%作为标准差(增加不确定性)
|
||||
# 使用预测值的20%作为标准差
|
||||
std = abs(prediction) * 0.2
|
||||
|
||||
# 计算置信区间
|
||||
from scipy import stats
|
||||
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
|
||||
|
||||
# 确保区间值为正数且合理
|
||||
@ -213,130 +177,15 @@ class CostPredictor:
|
||||
"""
|
||||
模型评估
|
||||
"""
|
||||
# 确保输入是 numpy 数组
|
||||
if torch.is_tensor(y_true):
|
||||
y_true = y_true.cpu().numpy()
|
||||
if torch.is_tensor(y_pred):
|
||||
y_pred = y_pred.cpu().numpy()
|
||||
|
||||
return {
|
||||
'mae': float(mean_absolute_error(y_true, y_pred)),
|
||||
'mse': float(mean_squared_error(y_true, y_pred)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||||
'r2': float(r2_score(y_true, y_pred))
|
||||
}
|
||||
|
||||
def predict_pls(self, data):
|
||||
"""
|
||||
使用 PLS 型预测成本
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting PLS prediction for {data.get('type')}")
|
||||
equipment_type = data.get('type')
|
||||
|
||||
# 加载 PLS 模型
|
||||
trainer = ModelTrainer()
|
||||
if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
|
||||
raise ValueError(f"No trained PLS model found for {equipment_type}")
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
return {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PLS prediction error: {str(e)}")
|
||||
raise
|
||||
|
||||
def predict_all(self, data):
|
||||
"""
|
||||
使用所有可用模型进行预测
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting multi-model prediction for {data.get('type')}")
|
||||
equipment_type = data.get('type')
|
||||
results = {}
|
||||
|
||||
# 1. 获取所有激活的模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("""
|
||||
SELECT id, model_type, model_name, r2_score, mae, rmse
|
||||
FROM trained_models
|
||||
WHERE equipment_type = %s AND is_active = TRUE
|
||||
""", (equipment_type,))
|
||||
active_models = cursor.fetchall()
|
||||
|
||||
if not active_models:
|
||||
raise ValueError(f"No active models found for {equipment_type}")
|
||||
|
||||
# 2. 使用每个模型进行预测
|
||||
trainer = ModelTrainer()
|
||||
for model_info in active_models:
|
||||
try:
|
||||
# 加载特定模型
|
||||
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
|
||||
logger.warning(f"Failed to load model: {model_info['model_name']}")
|
||||
continue
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
# 保存结果
|
||||
results[model_info['model_type']] = {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'model_info': {
|
||||
'name': model_info['model_name'],
|
||||
'type': model_info['model_type'],
|
||||
'r2_score': float(model_info['r2_score']),
|
||||
'mae': float(model_info['mae']),
|
||||
'rmse': float(model_info['rmse'])
|
||||
},
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not results:
|
||||
raise ValueError("No successful predictions from any model")
|
||||
|
||||
# 3. 计算综合预测结果
|
||||
all_predictions = [result['predicted_cost'] for result in results.values()]
|
||||
ensemble_prediction = float(np.mean(all_predictions))
|
||||
prediction_std = float(np.std(all_predictions))
|
||||
|
||||
# 4. 返回所有结果
|
||||
return {
|
||||
'individual_predictions': results,
|
||||
'ensemble_prediction': {
|
||||
'predicted_cost': ensemble_prediction,
|
||||
'standard_deviation': prediction_std,
|
||||
'confidence_interval': {
|
||||
'lower': float(ensemble_prediction - 1.96 * prediction_std),
|
||||
'upper': float(ensemble_prediction + 1.96 * prediction_std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in multi-model prediction: {str(e)}")
|
||||
raise
|
||||
}
|
||||
@ -1,83 +1,89 @@
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from datetime import datetime
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
|
||||
from xgboost import XGBRegressor
|
||||
from lightgbm import LGBMRegressor
|
||||
from sklearn.model_selection import cross_val_score, LeaveOneOut
|
||||
import json
|
||||
import logging
|
||||
from src.database.db_connection import get_db_connection
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from src.database import get_db_connection
|
||||
from .logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
# PyTorch 为可选依赖
|
||||
try:
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
Dataset = object
|
||||
DataLoader = None
|
||||
_HAS_TORCH = False
|
||||
|
||||
class EquipmentDataset(Dataset if _HAS_TORCH else object):
|
||||
"""装备数据集类"""
|
||||
def __init__(self, features, targets=None):
|
||||
if not _HAS_TORCH:
|
||||
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||
self.features = torch.FloatTensor(features)
|
||||
self.targets = torch.FloatTensor(targets) if targets is not None else None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.targets is not None:
|
||||
return self.features[idx], self.targets[idx]
|
||||
return self.features[idx]
|
||||
|
||||
class DataPreparation:
|
||||
def __init__(self):
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.feature_scaler = StandardScaler()
|
||||
self.target_scaler = StandardScaler() # 添加目标值标准化器
|
||||
self.target_scaler = StandardScaler()
|
||||
|
||||
def prepare_training_data(self, equipment_data, equipment_type):
|
||||
"""
|
||||
准备训练数据
|
||||
"""
|
||||
def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
|
||||
"""准备训练数据"""
|
||||
try:
|
||||
logger.info(f"Preparing training data for {equipment_type}")
|
||||
logger.info(f"Raw data size: {len(equipment_data)}")
|
||||
|
||||
# 如果输入已经是 numpy 数组,直接返回
|
||||
if isinstance(equipment_data, np.ndarray):
|
||||
X = equipment_data
|
||||
logger.info(f"Input is already numpy array with shape: {X.shape}")
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
return {
|
||||
'X': X,
|
||||
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}
|
||||
|
||||
# 从原始数据中提取特征和目标值
|
||||
# 获取特征名称(包含生产商特征)
|
||||
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
features = []
|
||||
targets = []
|
||||
|
||||
for item in equipment_data:
|
||||
# 提取特征值
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
value = item.get(name)
|
||||
try:
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
except (ValueError, TypeError):
|
||||
feature_values.append(0.0)
|
||||
features.append(feature_values)
|
||||
# 获取数据库连接
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 提取目标值(成本)
|
||||
try:
|
||||
cost = float(item['actual_cost'])
|
||||
if cost > 0: # 只使用正数成本值
|
||||
targets.append(cost)
|
||||
else:
|
||||
logger.warning(f"Skipping non-positive cost value: {cost}")
|
||||
except (ValueError, TypeError, KeyError):
|
||||
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
|
||||
continue
|
||||
# 获取所有生产商数据,用于计算特征
|
||||
cursor.execute("""
|
||||
SELECT * FROM manufacturers
|
||||
""")
|
||||
manufacturers = {row['id']: row for row in cursor.fetchall()}
|
||||
|
||||
for item in equipment_data:
|
||||
# 获取生产商数据
|
||||
manufacturer = manufacturers.get(item['manufacturer_id'], {})
|
||||
|
||||
# 计算生产商特征
|
||||
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer)
|
||||
|
||||
# 合并装备特征和生产商特征
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
if name.startswith('manufacturer_'):
|
||||
value = manufacturer_features.get(name, 0.0)
|
||||
else:
|
||||
value = item.get(name)
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
|
||||
features.append(feature_values)
|
||||
targets.append(float(item['actual_cost']))
|
||||
|
||||
# 转换为numpy数组
|
||||
X = np.array(features, dtype=float)
|
||||
y = np.array(targets, dtype=float)
|
||||
|
||||
# 记录原始数据范围
|
||||
# 记录数据范围
|
||||
logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
|
||||
logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
|
||||
|
||||
@ -85,25 +91,18 @@ class DataPreparation:
|
||||
X_scaled = self.feature_scaler.fit_transform(X)
|
||||
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
|
||||
|
||||
# 记录标准化后的数据范围
|
||||
logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
|
||||
logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
|
||||
|
||||
# 记录标准化器参数
|
||||
logger.info("Feature scaler params:")
|
||||
logger.info(f"Mean: {self.feature_scaler.mean_}")
|
||||
logger.info(f"Scale: {self.feature_scaler.scale_}")
|
||||
|
||||
logger.info("Target scaler params:")
|
||||
logger.info(f"Mean: {self.target_scaler.mean_}")
|
||||
logger.info(f"Scale: {self.target_scaler.scale_}")
|
||||
# 创建数据集和数据加载器
|
||||
dataset = EquipmentDataset(X_scaled, y_scaled)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
return {
|
||||
'X': X_scaled,
|
||||
'y': y_scaled,
|
||||
'dataloader': dataloader,
|
||||
'feature_names': feature_names,
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
'target_scaler': self.target_scaler,
|
||||
'raw_shape': X.shape,
|
||||
'X': X_scaled,
|
||||
'y': y_scaled
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -162,7 +161,6 @@ class DataPreparation:
|
||||
# 提取目标值(成本)并验证范围
|
||||
try:
|
||||
cost = float(item['actual_cost'])
|
||||
logger.info(f"Raw cost value: {cost}")
|
||||
if cost > 0: # 只使用正数成本值
|
||||
targets.append(cost)
|
||||
else:
|
||||
|
||||
@ -1,37 +1,227 @@
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from ..logger import setup_logger
|
||||
|
||||
# 获取logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# SQLite 数据库文件路径(相对于项目根目录)
|
||||
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data')
|
||||
DB_PATH = os.path.join(DB_DIR, 'equipment_cost.db')
|
||||
|
||||
# 建表 SQL
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS equipments (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
type TEXT,
|
||||
manufacturer TEXT,
|
||||
manufacturer_id INTEGER,
|
||||
created_at TEXT DEFAULT (datetime('now','localtime'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS common_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
length_m REAL,
|
||||
width_m REAL,
|
||||
height_m REAL,
|
||||
weight_kg REAL,
|
||||
max_range_km REAL,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS rocket_artillery_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
firing_angle_horizontal REAL,
|
||||
firing_angle_vertical REAL,
|
||||
rocket_length_m REAL,
|
||||
rocket_diameter_mm REAL,
|
||||
rocket_weight_kg REAL,
|
||||
rate_of_fire REAL,
|
||||
combat_weight_kg REAL,
|
||||
speed_kmh REAL,
|
||||
min_range_km REAL,
|
||||
max_range_km REAL,
|
||||
mobility_type TEXT,
|
||||
structure_layout TEXT,
|
||||
engine_model TEXT,
|
||||
engine_params TEXT,
|
||||
power_hp REAL,
|
||||
travel_range_km REAL,
|
||||
fire_density REAL,
|
||||
range_ratio REAL,
|
||||
mobility_score INTEGER,
|
||||
combat_readiness_score INTEGER,
|
||||
deployment_score INTEGER,
|
||||
terrain_adaptability_score INTEGER,
|
||||
rocket_power_ratio REAL,
|
||||
platform_efficiency REAL,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS loitering_munition_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
wingspan_m REAL,
|
||||
warhead_weight_kg REAL,
|
||||
max_speed_ms REAL,
|
||||
cruise_speed_kmh REAL,
|
||||
endurance_min REAL,
|
||||
flight_time_min REAL,
|
||||
max_range_km REAL,
|
||||
max_payload_kg REAL,
|
||||
ceiling_altitude_m REAL,
|
||||
combat_radius_km REAL,
|
||||
folded_length_mm REAL,
|
||||
folded_width_mm REAL,
|
||||
folded_height_mm REAL,
|
||||
warhead_type TEXT,
|
||||
launch_mode TEXT,
|
||||
power_system TEXT,
|
||||
guidance_system TEXT,
|
||||
engine_power_kw REAL,
|
||||
engine_thrust_n REAL,
|
||||
datalink_range_km REAL,
|
||||
guidance_accuracy_m REAL,
|
||||
min_altitude_m REAL,
|
||||
max_altitude_m REAL,
|
||||
length_width_ratio REAL,
|
||||
weight_range_ratio REAL,
|
||||
speed_weight_ratio REAL,
|
||||
guidance_system_score INTEGER,
|
||||
warhead_power_score INTEGER,
|
||||
warhead_type_code INTEGER,
|
||||
launch_mode_code INTEGER,
|
||||
power_system_code INTEGER,
|
||||
guidance_system_code INTEGER,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS feature_encoding (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
feature_type TEXT,
|
||||
feature_value TEXT,
|
||||
code INTEGER,
|
||||
UNIQUE(feature_type, feature_value)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS cost_data (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
actual_cost REAL,
|
||||
predicted_cost REAL,
|
||||
prediction_date TEXT,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS custom_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
param_name TEXT,
|
||||
param_value TEXT,
|
||||
param_unit TEXT,
|
||||
description TEXT,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS datasets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
equipment_type TEXT NOT NULL,
|
||||
purpose TEXT NOT NULL,
|
||||
created_at TEXT DEFAULT (datetime('now','localtime')),
|
||||
updated_at TEXT DEFAULT (datetime('now','localtime'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS dataset_equipments (
|
||||
dataset_id INTEGER NOT NULL,
|
||||
equipment_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (dataset_id, equipment_id),
|
||||
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS trained_models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
equipment_type TEXT NOT NULL,
|
||||
model_path TEXT NOT NULL,
|
||||
scaler_path TEXT NOT NULL,
|
||||
r2_score REAL,
|
||||
mae REAL,
|
||||
rmse REAL,
|
||||
feature_importance TEXT,
|
||||
training_data_size INTEGER,
|
||||
training_date TEXT DEFAULT (datetime('now','localtime')),
|
||||
is_active INTEGER DEFAULT 0,
|
||||
created_by TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS manufacturers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
country TEXT NOT NULL,
|
||||
tech_level INTEGER NOT NULL,
|
||||
scale_level INTEGER NOT NULL,
|
||||
supply_chain_level INTEGER NOT NULL,
|
||||
created_at TEXT DEFAULT (datetime('now','localtime')),
|
||||
updated_at TEXT DEFAULT (datetime('now','localtime')),
|
||||
UNIQUE(name)
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_equipment_type ON equipments(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_equipment_name ON equipments(name);
|
||||
CREATE INDEX IF NOT EXISTS idx_cost_data_equipment ON cost_data(equipment_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_equipment_type ON trained_models(equipment_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_active ON trained_models(is_active);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_country ON manufacturers(country);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_tech_level ON manufacturers(tech_level);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_scale_level ON manufacturers(scale_level);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
|
||||
CREATE INDEX IF NOT EXISTS idx_equipment_manufacturer ON equipments(manufacturer_id);
|
||||
"""
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库:确保数据库文件和表存在"""
|
||||
os.makedirs(DB_DIR, exist_ok=True)
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.executescript(SCHEMA_SQL)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Database initialized at {DB_PATH}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
"""
|
||||
数据库连接上下文管理器
|
||||
返回的 connection 已设置 dict row_factory,
|
||||
以便按列名访问。
|
||||
"""
|
||||
connection = None
|
||||
conn = None
|
||||
try:
|
||||
connection = mysql.connector.connect(
|
||||
host=os.getenv('MYSQL_HOST', 'localhost'),
|
||||
user=os.getenv('MYSQL_USER', 'root'),
|
||||
password=os.getenv('MYSQL_PASSWORD', '123456'),
|
||||
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
|
||||
)
|
||||
logger.info("Database connection established")
|
||||
yield connection
|
||||
|
||||
except Error as e:
|
||||
logger.error(f"Error connecting to MySQL: {str(e)}")
|
||||
# 确保数据库已初始化
|
||||
if not os.path.exists(DB_PATH):
|
||||
logger.info("Database file not found, initializing...")
|
||||
init_db()
|
||||
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = lambda c, r: {col[0]: r[idx] for idx, col in enumerate(c.description)}
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
logger.debug("Database connection established")
|
||||
yield conn
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
finally:
|
||||
if connection and connection.is_connected():
|
||||
connection.close()
|
||||
logger.info("Database connection closed")
|
||||
if conn:
|
||||
conn.close()
|
||||
logger.debug("Database connection closed")
|
||||
|
||||
290
src/demo_service.py
Normal file
290
src/demo_service.py
Normal file
@ -0,0 +1,290 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
|
||||
from sklearn.linear_model import LinearRegression, Ridge
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVR
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlgorithmDefinition:
|
||||
key: str
|
||||
name: str
|
||||
english_name: str
|
||||
family: str
|
||||
description: str
|
||||
estimator: Any
|
||||
|
||||
|
||||
class DemoModelService:
|
||||
target_column = "actual_cost"
|
||||
ignored_columns = {"name", "type", target_column}
|
||||
|
||||
def __init__(self, dataset_path: Path | str | None = None):
|
||||
root = Path(__file__).resolve().parent.parent
|
||||
self.dataset_path = Path(dataset_path) if dataset_path else root / "data" / "demo_equipment_costs.csv"
|
||||
|
||||
def get_algorithms(self) -> list[dict[str, str]]:
|
||||
algorithms, _ = self._available_algorithms()
|
||||
return [
|
||||
{
|
||||
"key": item.key,
|
||||
"name": item.name,
|
||||
"english_name": item.english_name,
|
||||
"family": item.family,
|
||||
"description": item.description,
|
||||
}
|
||||
for item in algorithms.values()
|
||||
]
|
||||
|
||||
def get_dataset_summary(self) -> dict[str, Any]:
|
||||
frame = self._load_dataset()
|
||||
feature_columns = self._feature_columns(frame)
|
||||
return {
|
||||
"source": "local-file",
|
||||
"path": str(self.dataset_path),
|
||||
"row_count": int(len(frame)),
|
||||
"columns": list(frame.columns),
|
||||
"features": feature_columns,
|
||||
"target": self.target_column,
|
||||
"target_label": "实际成本",
|
||||
"equipment_types": sorted(frame["type"].dropna().unique().tolist()),
|
||||
"preview": frame.head(8).to_dict(orient="records"),
|
||||
}
|
||||
|
||||
def run_demo(self, selected_algorithms: list[str] | None = None) -> dict[str, Any]:
|
||||
frame = self._load_dataset()
|
||||
feature_columns = self._feature_columns(frame)
|
||||
algorithms, availability_warnings = self._available_algorithms()
|
||||
|
||||
requested = selected_algorithms or list(algorithms.keys())
|
||||
warnings = list(availability_warnings)
|
||||
selected = []
|
||||
for key in requested:
|
||||
if key in algorithms:
|
||||
selected.append(key)
|
||||
else:
|
||||
warnings.append(f"算法 '{key}' 不可用,已自动跳过。")
|
||||
|
||||
if not selected:
|
||||
selected = ["linear"]
|
||||
warnings.append("所选算法均不可用,已自动使用线性回归。")
|
||||
|
||||
X = frame[feature_columns]
|
||||
y = frame[self.target_column]
|
||||
train_x, test_x, train_y, test_y = train_test_split(
|
||||
X,
|
||||
y,
|
||||
test_size=0.3,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
metrics: dict[str, dict[str, float | str]] = {}
|
||||
predictions: dict[str, list[float]] = {}
|
||||
feature_importance: dict[str, list[dict[str, float | str]]] = {}
|
||||
|
||||
for key in selected:
|
||||
definition = algorithms[key]
|
||||
model = definition.estimator
|
||||
model.fit(train_x, train_y)
|
||||
predicted = model.predict(test_x)
|
||||
predictions[key] = [float(value) for value in predicted]
|
||||
metrics[key] = {
|
||||
"name": definition.name,
|
||||
"r2": float(r2_score(test_y, predicted)),
|
||||
"mae": float(mean_absolute_error(test_y, predicted)),
|
||||
"rmse": float(np.sqrt(mean_squared_error(test_y, predicted))),
|
||||
}
|
||||
feature_importance[key] = self._feature_importance(model, feature_columns)
|
||||
|
||||
best_model = min(metrics, key=lambda key: float(metrics[key]["rmse"]))
|
||||
ordered_test = test_x.copy()
|
||||
ordered_test["actual"] = test_y
|
||||
ordered_test["name"] = frame.loc[test_x.index, "name"]
|
||||
|
||||
prediction_points = []
|
||||
for position, (index, row) in enumerate(ordered_test.sort_values("actual").iterrows()):
|
||||
point = {
|
||||
"name": row["name"],
|
||||
"actual": float(row["actual"]),
|
||||
}
|
||||
for key in selected:
|
||||
original_position = list(test_x.index).index(index)
|
||||
point[key] = predictions[key][original_position]
|
||||
prediction_points.append(point)
|
||||
|
||||
sample = frame.sort_values(self.target_column).iloc[len(frame) // 2]
|
||||
sample_x = pd.DataFrame([sample[feature_columns].to_dict()])
|
||||
sample_predictions = {
|
||||
key: float(algorithms[key].estimator.fit(X, y).predict(sample_x)[0])
|
||||
for key in selected
|
||||
}
|
||||
|
||||
return {
|
||||
"source": "local-file",
|
||||
"dataset": self.get_dataset_summary(),
|
||||
"algorithms": self.get_algorithms(),
|
||||
"selected_algorithms": selected,
|
||||
"best_model": best_model,
|
||||
"metrics": metrics,
|
||||
"feature_importance": feature_importance,
|
||||
"prediction_points": prediction_points,
|
||||
"sample_prediction": {
|
||||
"input": sample.drop(labels=[self.target_column]).to_dict(),
|
||||
"actual": float(sample[self.target_column]),
|
||||
"predictions": sample_predictions,
|
||||
},
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
def _load_dataset(self) -> pd.DataFrame:
|
||||
if not self.dataset_path.exists():
|
||||
raise FileNotFoundError(f"Demo dataset not found: {self.dataset_path}")
|
||||
|
||||
frame = pd.read_csv(self.dataset_path)
|
||||
if self.target_column not in frame.columns:
|
||||
raise ValueError(f"Demo dataset must include '{self.target_column}'.")
|
||||
return frame
|
||||
|
||||
def _feature_columns(self, frame: pd.DataFrame) -> list[str]:
|
||||
columns = [
|
||||
column
|
||||
for column in frame.columns
|
||||
if column not in self.ignored_columns and pd.api.types.is_numeric_dtype(frame[column])
|
||||
]
|
||||
if not columns:
|
||||
raise ValueError("Demo dataset has no numeric feature columns.")
|
||||
return columns
|
||||
|
||||
def _available_algorithms(self) -> tuple[dict[str, AlgorithmDefinition], list[str]]:
|
||||
algorithms = {
|
||||
"linear": AlgorithmDefinition(
|
||||
"linear",
|
||||
"线性回归",
|
||||
"Linear Regression",
|
||||
"线性模型",
|
||||
"快速建立基准模型,用于展示参数与成本之间的线性关系。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", LinearRegression())]),
|
||||
),
|
||||
"ridge": AlgorithmDefinition(
|
||||
"ridge",
|
||||
"岭回归",
|
||||
"Ridge Regression",
|
||||
"线性模型",
|
||||
"带正则化的线性模型,适合特征存在相关性的场景。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", Ridge(alpha=1.0))]),
|
||||
),
|
||||
"random_forest": AlgorithmDefinition(
|
||||
"random_forest",
|
||||
"随机森林",
|
||||
"Random Forest",
|
||||
"树模型集成",
|
||||
"通过多棵决策树集成预测,能够捕捉非线性特征影响。",
|
||||
RandomForestRegressor(n_estimators=160, max_depth=6, random_state=42),
|
||||
),
|
||||
"gradient_boosting": AlgorithmDefinition(
|
||||
"gradient_boosting",
|
||||
"梯度提升树",
|
||||
"Gradient Boosting",
|
||||
"树模型集成",
|
||||
"逐步修正误差的提升模型,常用于表格数据回归任务。",
|
||||
GradientBoostingRegressor(n_estimators=120, learning_rate=0.06, max_depth=3, random_state=42),
|
||||
),
|
||||
"svr": AlgorithmDefinition(
|
||||
"svr",
|
||||
"支持向量回归",
|
||||
"Support Vector Regression",
|
||||
"核方法",
|
||||
"使用核函数拟合平滑回归关系,适合展示不同算法偏好。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", SVR(C=500000, epsilon=50000))]),
|
||||
),
|
||||
"knn": AlgorithmDefinition(
|
||||
"knn",
|
||||
"近邻回归",
|
||||
"KNN Regression",
|
||||
"实例学习",
|
||||
"基于相似样本进行预测,便于解释局部相似性。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", KNeighborsRegressor(n_neighbors=4))]),
|
||||
),
|
||||
}
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
from xgboost import XGBRegressor
|
||||
|
||||
algorithms["xgboost"] = AlgorithmDefinition(
|
||||
"xgboost",
|
||||
"XGBoost",
|
||||
"XGBoost",
|
||||
"提升模型",
|
||||
"面向表格数据的高性能梯度提升实现。",
|
||||
XGBRegressor(
|
||||
n_estimators=120,
|
||||
max_depth=3,
|
||||
learning_rate=0.05,
|
||||
subsample=0.9,
|
||||
colsample_bytree=0.9,
|
||||
random_state=42,
|
||||
objective="reg:squarederror",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
warnings.append("当前环境未安装 XGBoost,页面已自动隐藏该算法。")
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMRegressor
|
||||
|
||||
algorithms["lightgbm"] = AlgorithmDefinition(
|
||||
"lightgbm",
|
||||
"LightGBM",
|
||||
"LightGBM",
|
||||
"提升模型",
|
||||
"基于直方图优化的快速梯度提升模型。",
|
||||
LGBMRegressor(
|
||||
n_estimators=120,
|
||||
learning_rate=0.05,
|
||||
max_depth=4,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
warnings.append("当前环境未安装 LightGBM,页面已自动隐藏该算法。")
|
||||
|
||||
return algorithms, warnings
|
||||
|
||||
def _feature_importance(self, model: Any, feature_columns: list[str]) -> list[dict[str, float | str]]:
|
||||
estimator = model
|
||||
if isinstance(model, Pipeline):
|
||||
estimator = model.named_steps["model"]
|
||||
|
||||
if hasattr(estimator, "feature_importances_"):
|
||||
values = estimator.feature_importances_
|
||||
elif hasattr(estimator, "coef_"):
|
||||
values = np.abs(np.ravel(estimator.coef_))
|
||||
else:
|
||||
values = np.zeros(len(feature_columns))
|
||||
|
||||
total = float(np.sum(values))
|
||||
if total > 0:
|
||||
values = values / total
|
||||
|
||||
ranked = sorted(
|
||||
[
|
||||
{"feature": feature, "importance": float(value)}
|
||||
for feature, value in zip(feature_columns, values)
|
||||
],
|
||||
key=lambda item: item["importance"],
|
||||
reverse=True,
|
||||
)
|
||||
return ranked[:8]
|
||||
@ -15,9 +15,9 @@ class FeatureAnalysis:
|
||||
'width_m': '宽度(m)',
|
||||
'height_m': '高度(m)',
|
||||
'weight_kg': '重量(kg)',
|
||||
'max_range_km': '最大射程(km)',
|
||||
|
||||
# 火箭炮特有参数
|
||||
'max_range_km': '最大射程(km)',
|
||||
'firing_angle_horizontal': '方向射界(度)',
|
||||
'firing_angle_vertical': '高低射界(度)',
|
||||
'rocket_length_m': '火箭弹长度(m)',
|
||||
@ -39,6 +39,7 @@ class FeatureAnalysis:
|
||||
'terrain_adaptability_score': '地形适应性评分',
|
||||
|
||||
# 巡飞弹特有参数
|
||||
'max_range_km': '最大射程(km)',
|
||||
'wingspan_m': '翼展(m)',
|
||||
'warhead_weight_kg': '战斗部重量(kg)',
|
||||
'max_speed_ms': '最大速度(m/s)',
|
||||
@ -57,7 +58,14 @@ class FeatureAnalysis:
|
||||
'weight_range_ratio': '重量射程比',
|
||||
'speed_weight_ratio': '速度重量比',
|
||||
'guidance_system_score': '制导系统评分',
|
||||
'warhead_power_score': '战斗部威力评分'
|
||||
'warhead_power_score': '战斗部威力评分',
|
||||
|
||||
# 添加生产商特征映射
|
||||
'manufacturer_tech_level': '生产商技术水平',
|
||||
'manufacturer_scale_level': '生产商规模水平',
|
||||
'manufacturer_supply_chain_level': '生产商供应链水平',
|
||||
'manufacturer_composite_score': '生产商综合得分',
|
||||
'manufacturer_region_factor': '生产商区域系数'
|
||||
}
|
||||
|
||||
def get_equipment_specific_features(self, equipment_type):
|
||||
@ -121,6 +129,17 @@ class FeatureAnalysis:
|
||||
'guidance_system_score',
|
||||
'warhead_power_score'
|
||||
])
|
||||
|
||||
# 添加生产商特征
|
||||
manufacturer_features = [
|
||||
'manufacturer_tech_level',
|
||||
'manufacturer_scale_level',
|
||||
'manufacturer_supply_chain_level',
|
||||
'manufacturer_composite_score',
|
||||
'manufacturer_region_factor'
|
||||
]
|
||||
|
||||
numeric_features.extend(manufacturer_features)
|
||||
return numeric_features
|
||||
|
||||
def analyze_features(self, features, target, feature_names):
|
||||
@ -193,11 +212,20 @@ class FeatureAnalysis:
|
||||
|
||||
# 创建特征重要性列表(使用中文名称)
|
||||
important_features = []
|
||||
|
||||
# 过滤掉无效的分数
|
||||
valid_scores = importance_scores[~np.isnan(importance_scores)]
|
||||
|
||||
# 记录一些统计信息
|
||||
logger.info(f"F-score statistics:")
|
||||
logger.info(f"min={np.min(valid_scores):.2f}, max={np.max(valid_scores):.2f}, "
|
||||
f"mean={np.mean(valid_scores):.2f}, median={np.median(valid_scores):.2f}")
|
||||
|
||||
for idx, (name, score) in enumerate(zip(feature_names, importance_scores)):
|
||||
if not np.isnan(score):
|
||||
important_features.append({
|
||||
'name': self.feature_name_map.get(name, name), # 使用中文名称
|
||||
'importance': float(score)
|
||||
'name': self.feature_name_map.get(name, name),
|
||||
'importance': float(score) # 保持原始F-score
|
||||
})
|
||||
|
||||
# 按重要性排序
|
||||
@ -234,4 +262,67 @@ class FeatureAnalysis:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analyze_features: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
raise
|
||||
|
||||
def calculate_manufacturer_features(self, manufacturer_data):
|
||||
"""计算生产商相关的特征"""
|
||||
try:
|
||||
# 处理 None 值(数据库 NULL),使用默认值
|
||||
raw_tech = manufacturer_data.get('tech_level')
|
||||
raw_scale = manufacturer_data.get('scale_level')
|
||||
raw_supply = manufacturer_data.get('supply_chain_level')
|
||||
|
||||
tech_level = float(raw_tech) if raw_tech is not None else 0
|
||||
scale_level = float(raw_scale) if raw_scale is not None else 0
|
||||
supply_chain_level = float(raw_supply) if raw_supply is not None else 0
|
||||
country = manufacturer_data.get('country', '未知') or '未知'
|
||||
|
||||
# 计算综合得分
|
||||
composite_score = (
|
||||
tech_level * 0.4 + # 技术水平权重最高
|
||||
scale_level * 0.3 + # 规模水平次之
|
||||
supply_chain_level * 0.3 # 供应链水平
|
||||
)
|
||||
|
||||
# 计算区域系数(基于不同地区的成本差异)
|
||||
region_factors = {
|
||||
'美国': 1.2,
|
||||
'英国': 1.15,
|
||||
'德国': 1.15,
|
||||
'法国': 1.15,
|
||||
'以色列': 1.1,
|
||||
'中国': 0.8,
|
||||
'俄罗斯': 0.85,
|
||||
'韩国': 0.9,
|
||||
'日本': 1.1
|
||||
}
|
||||
|
||||
region_factor = region_factors.get(country, 1.0)
|
||||
|
||||
# 记录计算过程
|
||||
logger.info(f"Manufacturer features calculation:")
|
||||
logger.info(f"Tech level: {tech_level}")
|
||||
logger.info(f"Scale level: {scale_level}")
|
||||
logger.info(f"Supply chain level: {supply_chain_level}")
|
||||
logger.info(f"Country: {country}")
|
||||
logger.info(f"Composite score: {composite_score}")
|
||||
logger.info(f"Region factor: {region_factor}")
|
||||
|
||||
return {
|
||||
'manufacturer_tech_level': tech_level,
|
||||
'manufacturer_scale_level': scale_level,
|
||||
'manufacturer_supply_chain_level': supply_chain_level,
|
||||
'manufacturer_composite_score': composite_score,
|
||||
'manufacturer_region_factor': region_factor
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating manufacturer features: {str(e)}")
|
||||
# 返回默认值而不是抛出异常,确保分析过程可以继续
|
||||
return {
|
||||
'manufacturer_tech_level': 0,
|
||||
'manufacturer_scale_level': 0,
|
||||
'manufacturer_supply_chain_level': 0,
|
||||
'manufacturer_composite_score': 0,
|
||||
'manufacturer_region_factor': 1.0
|
||||
}
|
||||
@ -26,8 +26,8 @@ def import_training_data(excel_file):
|
||||
equipment_names.add(row['名称'])
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipment
|
||||
WHERE name = %s AND type = '火箭炮'
|
||||
SELECT id FROM equipments
|
||||
WHERE name = ? AND type = '火箭炮'
|
||||
""", (row['名称'],))
|
||||
|
||||
existing_equipment = cursor.fetchone()
|
||||
@ -37,8 +37,8 @@ def import_training_data(excel_file):
|
||||
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipment (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
INSERT INTO equipments (name, type, manufacturer)
|
||||
VALUES (?, ?, ?)
|
||||
""", (row['名称'], '火箭炮', row['制造商']))
|
||||
|
||||
equipment_id = cursor.lastrowid
|
||||
@ -47,7 +47,7 @@ def import_training_data(excel_file):
|
||||
cursor.execute("""
|
||||
INSERT INTO common_params
|
||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
row['总长_m'] if pd.notna(row['总长_m']) else None,
|
||||
@ -65,7 +65,7 @@ def import_training_data(excel_file):
|
||||
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
|
||||
structure_layout, engine_model, engine_params, power_hp,
|
||||
travel_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
|
||||
@ -89,7 +89,7 @@ def import_training_data(excel_file):
|
||||
if pd.notna(row['成本_元']):
|
||||
cursor.execute("""
|
||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||
VALUES (%s, %s)
|
||||
VALUES (?, ?)
|
||||
""", (equipment_id, row['成本_元']))
|
||||
|
||||
logger.info("火箭炮数据导入完成")
|
||||
@ -105,8 +105,8 @@ def import_training_data(excel_file):
|
||||
equipment_names.add(row['名称'])
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipment
|
||||
WHERE name = %s AND type = '巡飞弹'
|
||||
SELECT id FROM equipments
|
||||
WHERE name = ? AND type = '巡飞弹'
|
||||
""", (row['名称'],))
|
||||
|
||||
existing_equipment = cursor.fetchone()
|
||||
@ -116,8 +116,8 @@ def import_training_data(excel_file):
|
||||
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipment (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
INSERT INTO equipments (name, type, manufacturer)
|
||||
VALUES (?, ?, ?)
|
||||
""", (
|
||||
row['名称'],
|
||||
'巡飞弹',
|
||||
@ -130,7 +130,7 @@ def import_training_data(excel_file):
|
||||
cursor.execute("""
|
||||
INSERT INTO common_params
|
||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
|
||||
@ -147,7 +147,7 @@ def import_training_data(excel_file):
|
||||
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
|
||||
folded_length_mm, folded_width_mm, folded_height_mm,
|
||||
power_system, guidance_system)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
|
||||
@ -168,7 +168,7 @@ def import_training_data(excel_file):
|
||||
if pd.notna(row['成本_元']):
|
||||
cursor.execute("""
|
||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||
VALUES (%s, %s)
|
||||
VALUES (?, ?)
|
||||
""", (equipment_id, float(row['成本_元'])))
|
||||
|
||||
logger.info("巡飞弹数据导入完成")
|
||||
@ -190,48 +190,48 @@ def import_training_data(excel_file):
|
||||
|
||||
# 获取装备ID - 使用新的游标
|
||||
logger.debug(f"查询装备ID: {equipment_name}")
|
||||
with conn.cursor() as id_cursor:
|
||||
id_cursor.execute("""
|
||||
SELECT id FROM equipment WHERE name = %s
|
||||
""", (equipment_name,))
|
||||
result = id_cursor.fetchone()
|
||||
|
||||
id_cursor = conn.cursor()
|
||||
id_cursor.execute("""
|
||||
SELECT id FROM equipments WHERE name = ?
|
||||
""", (equipment_name,))
|
||||
result = id_cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
logger.warning(f"未找到装备: {equipment_name}")
|
||||
continue
|
||||
|
||||
equipment_id = result[0]
|
||||
|
||||
equipment_id = result['id']
|
||||
logger.debug(f"找到装备ID: {equipment_id}")
|
||||
|
||||
|
||||
# 检查参数是否存在 - 使用新的游标
|
||||
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
|
||||
with conn.cursor() as check_cursor:
|
||||
check_cursor.execute("""
|
||||
SELECT id FROM custom_params
|
||||
WHERE equipment_id = %s AND param_name = %s
|
||||
""", (equipment_id, param_name))
|
||||
exists = check_cursor.fetchone()
|
||||
|
||||
check_cursor = conn.cursor()
|
||||
check_cursor.execute("""
|
||||
SELECT id FROM custom_params
|
||||
WHERE equipment_id = ? AND param_name = ?
|
||||
""", (equipment_id, param_name))
|
||||
exists = check_cursor.fetchone()
|
||||
|
||||
if exists:
|
||||
logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
|
||||
continue
|
||||
|
||||
|
||||
# 插入新的参数 - 使用新的游标
|
||||
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
|
||||
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
|
||||
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
|
||||
|
||||
|
||||
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
|
||||
with conn.cursor() as insert_cursor:
|
||||
insert_cursor.execute("""
|
||||
INSERT INTO custom_params
|
||||
(equipment_id, param_name, param_value, param_unit, description)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""", (
|
||||
equipment_id,
|
||||
param_name,
|
||||
param_value,
|
||||
param_unit,
|
||||
insert_cursor = conn.cursor()
|
||||
insert_cursor.execute("""
|
||||
INSERT INTO custom_params
|
||||
(equipment_id, param_name, param_value, param_unit, description)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
param_name,
|
||||
param_value,
|
||||
param_unit,
|
||||
param_desc
|
||||
))
|
||||
logger.debug(f"成功插入参数记录")
|
||||
|
||||
237
src/import_sql_data.py
Normal file
237
src/import_sql_data.py
Normal file
@ -0,0 +1,237 @@
|
||||
"""
|
||||
导入 SQL 数据文件到 SQLite 数据库
|
||||
处理 src/ 目录下的 MySQL 格式 SQL 数据文件:
|
||||
- rocket_artillery_data.sql (96 条火箭炮数据)
|
||||
- loitering_munition_data.sql (100 条巡飞弹数据)
|
||||
- manufacturer_data.sql (生产商数据)
|
||||
"""
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
from src.database.db_connection import get_db_connection, DB_PATH
|
||||
from src.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def parse_mysql_insert(sql_text):
|
||||
"""
|
||||
解析 MySQL INSERT 语句,返回 (table_name, columns, values_list)
|
||||
columns 是列名列表, values_list 是每行的值列表
|
||||
"""
|
||||
# 去掉块注释 /* ... */
|
||||
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
|
||||
# 去掉行注释 --
|
||||
sql_text = re.sub(r'--.*', '', sql_text)
|
||||
|
||||
results = []
|
||||
|
||||
# 匹配 INSERT INTO table (columns) VALUES (values), (values), ...;
|
||||
pattern = re.compile(
|
||||
r"INSERT\s+INTO\s+(\w+)\s*\(([^)]+)\)\s*VALUES\s*(.+?);",
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
for match in pattern.finditer(sql_text):
|
||||
table_name = match.group(1).strip()
|
||||
columns_str = match.group(2).strip()
|
||||
values_str = match.group(3).strip()
|
||||
|
||||
# 提取列名(去掉注释部分)
|
||||
columns = []
|
||||
for col in columns_str.split(','):
|
||||
col = col.strip()
|
||||
# 去掉行内注释
|
||||
col = re.sub(r'\s*--.*', '', col).strip()
|
||||
columns.append(col)
|
||||
|
||||
# 解析 values - 处理多个用逗号分隔的 ( ... ) 元组
|
||||
values_list = []
|
||||
current_tuple = []
|
||||
depth = 0
|
||||
current_val = ''
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
for ch in values_str:
|
||||
if in_string:
|
||||
if ch == string_char:
|
||||
in_string = False
|
||||
current_val += ch
|
||||
continue
|
||||
|
||||
if ch in ("'", '"'):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
current_val += ch
|
||||
continue
|
||||
|
||||
if ch == '(':
|
||||
if depth > 0:
|
||||
current_val += ch
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
current_val = ''
|
||||
continue
|
||||
|
||||
if ch == ')':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
current_tuple.append(current_val.strip())
|
||||
values_list.append(current_tuple)
|
||||
current_tuple = []
|
||||
else:
|
||||
current_val += ch
|
||||
continue
|
||||
|
||||
if ch == ',' and depth == 1:
|
||||
current_tuple.append(current_val.strip())
|
||||
current_val = ''
|
||||
continue
|
||||
|
||||
if depth >= 1:
|
||||
current_val += ch
|
||||
|
||||
results.append((table_name, columns, values_list))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def convert_value(val):
|
||||
"""将 MySQL 值字符串转为 Python 对象"""
|
||||
val = val.strip()
|
||||
if val.upper() == 'NULL' or val == '':
|
||||
return None
|
||||
if val.upper() == 'TRUE':
|
||||
return 1
|
||||
if val.upper() == 'FALSE':
|
||||
return 0
|
||||
# 字符串
|
||||
if (val.startswith("'") and val.endswith("'")) or \
|
||||
(val.startswith('"') and val.endswith('"')):
|
||||
return val[1:-1]
|
||||
# 数字
|
||||
try:
|
||||
if '.' in val:
|
||||
return float(val)
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
|
||||
def import_sql_file(filepath):
|
||||
"""导入单个 SQL 文件"""
|
||||
logger.info(f"Reading {filepath}...")
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
sql_text = f.read()
|
||||
|
||||
parsed = parse_mysql_insert(sql_text)
|
||||
total_rows = 0
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
for table_name, columns, values_list in parsed:
|
||||
if not values_list:
|
||||
continue
|
||||
|
||||
# 构建参数化 INSERT
|
||||
placeholders = ','.join(['?'] * len(columns))
|
||||
col_names = ','.join(columns)
|
||||
sql = f"INSERT OR IGNORE INTO {table_name} ({col_names}) VALUES ({placeholders})"
|
||||
|
||||
row_count = 0
|
||||
for values in values_list:
|
||||
if len(values) != len(columns):
|
||||
logger.warning(f"Column mismatch in {table_name}: "
|
||||
f"expected {len(columns)} values, got {len(values)}: {values}")
|
||||
continue
|
||||
converted = [convert_value(v) for v in values]
|
||||
try:
|
||||
cursor.execute(sql, converted)
|
||||
if cursor.rowcount > 0:
|
||||
row_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error inserting into {table_name}: {e}")
|
||||
logger.warning(f" Values: {converted}")
|
||||
|
||||
if row_count > 0:
|
||||
logger.info(f" {table_name}: inserted {row_count} rows")
|
||||
total_rows += row_count
|
||||
|
||||
conn.commit()
|
||||
|
||||
return total_rows
|
||||
|
||||
|
||||
def run_manufacturer_update():
|
||||
"""执行 manufacturer_id 更新(把 equipments.manufacturer 名称映射为 manufacturers.id)"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE equipments
|
||||
SET manufacturer_id = (
|
||||
SELECT id FROM manufacturers WHERE name = equipments.manufacturer
|
||||
)
|
||||
WHERE manufacturer_id IS NULL
|
||||
AND manufacturer IS NOT NULL
|
||||
AND EXISTS (SELECT 1 FROM manufacturers WHERE name = equipments.manufacturer)
|
||||
""")
|
||||
conn.commit()
|
||||
updated = cursor.rowcount
|
||||
if updated > 0:
|
||||
logger.info(f"Updated manufacturer_id for {updated} equipment(s)")
|
||||
else:
|
||||
logger.info("No manufacturer_id updates needed")
|
||||
|
||||
|
||||
def import_all():
|
||||
"""导入所有 SQL 数据文件"""
|
||||
base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
# 清除现有数据库,重新开始
|
||||
if os.path.exists(DB_PATH):
|
||||
os.remove(DB_PATH)
|
||||
logger.info(f"Removed existing database: {DB_PATH}")
|
||||
|
||||
files = [
|
||||
(os.path.join(base_dir, 'src', 'manufacturer_data.sql'), '生产商数据'),
|
||||
(os.path.join(base_dir, 'src', 'rocket_artillery_data.sql'), '火箭炮数据'),
|
||||
(os.path.join(base_dir, 'src', 'loitering_munition_data.sql'), '巡飞弹数据'),
|
||||
]
|
||||
|
||||
total = 0
|
||||
for filepath, label in files:
|
||||
if not os.path.exists(filepath):
|
||||
logger.warning(f"File not found: {filepath}")
|
||||
continue
|
||||
logger.info(f"正在导入 {label}...")
|
||||
rows = import_sql_file(filepath)
|
||||
logger.info(f" {label}: 共导入 {rows} 行")
|
||||
total += rows
|
||||
|
||||
# 更新 manufacturer_id
|
||||
run_manufacturer_update()
|
||||
|
||||
# 统计结果
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as cnt FROM equipments")
|
||||
eq_count = cursor.fetchone()['cnt']
|
||||
cursor.execute("SELECT type, COUNT(*) as cnt FROM equipments GROUP BY type")
|
||||
by_type = {r['type']: r['cnt'] for r in cursor.fetchall()}
|
||||
cursor.execute("SELECT COUNT(*) as cnt FROM manufacturers")
|
||||
mf_count = cursor.fetchone()['cnt']
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"导入完成!")
|
||||
logger.info(f" 装备总计: {eq_count}")
|
||||
for t, c in by_type.items():
|
||||
logger.info(f" {t}: {c}")
|
||||
logger.info(f" 生产商: {mf_count}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
return eq_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import_all()
|
||||
@ -1,319 +0,0 @@
|
||||
/*
|
||||
这是用于开发和测试环境的示例数据。
|
||||
生产环境请使用系统的数据导入功能添加实际数据。
|
||||
|
||||
主要用途:
|
||||
1. 提供开发测试数据
|
||||
2. 作为数据格式参考
|
||||
3. 用于系统功能验证
|
||||
*/
|
||||
|
||||
-- 插入装备基本信息
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
|
||||
|
||||
-- 插入巡飞弹技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES (
|
||||
1, -- 终结者巡飞弹
|
||||
0.56,
|
||||
0.15,
|
||||
0.20,
|
||||
2.72,
|
||||
160.93,
|
||||
96.56,
|
||||
24,
|
||||
15,
|
||||
'破片杀伤战斗部',
|
||||
'凭自身动力起飞',
|
||||
560,
|
||||
150,
|
||||
200
|
||||
);
|
||||
|
||||
-- 插入火箭炮技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES (
|
||||
2, -- 胜利-2火箭炮
|
||||
10,
|
||||
2.5,
|
||||
3.34,
|
||||
15000,
|
||||
23
|
||||
);
|
||||
|
||||
-- 插入成本数据(示例数据)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(1, 1000000), -- 终结者巡飞弹成本
|
||||
(2, 5000000); -- 胜利-2火箭炮成本
|
||||
|
||||
-- 插入更多巡飞弹变体数据用于训练
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
|
||||
|
||||
-- 插入变体技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- 终结者-A(稍大型号)
|
||||
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
|
||||
-- 终结者-B(稍小型号)
|
||||
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
|
||||
-- 终结者-C(标准型号的改进版)
|
||||
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
|
||||
|
||||
-- 插入变体成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(3, 1100000), -- 终结者-A成本(较高)
|
||||
(4, 900000), -- 终结者-B成本(较低)
|
||||
(5, 1050000); -- 终结者-C成本(中等)
|
||||
|
||||
-- 添加更多巡飞弹数据
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
|
||||
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
|
||||
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
|
||||
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
|
||||
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
|
||||
|
||||
-- 添加它们的技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- 哈比
|
||||
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
|
||||
-- 游荡者
|
||||
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
|
||||
-- 凤凰
|
||||
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
|
||||
-- 弹簧刀
|
||||
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
|
||||
-- 彩虹-4
|
||||
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
|
||||
|
||||
-- 添加成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(6, 800000), -- 哈比
|
||||
(7, 500000), -- 游荡者
|
||||
(8, 450000), -- 凤凰
|
||||
(9, 480000), -- 弹簧刀
|
||||
(10, 1500000); -- 彩虹-4
|
||||
|
||||
-- 火箭炮数据
|
||||
INSERT INTO equipment (name, type, manufacturer) VALUES
|
||||
('BM-21', '火箭炮', '俄罗斯'),
|
||||
('SR5', '火箭炮', '中国'),
|
||||
('HIMARS', '火箭炮', '美国'),
|
||||
('LAR-160', '火箭炮', '以色列'),
|
||||
('T-122', '火箭炮', '土耳其'),
|
||||
('RM-70', '火箭炮', '捷克'),
|
||||
('ASTROS II', '火箭炮', '巴西');
|
||||
|
||||
-- 火箭炮通用参数
|
||||
INSERT INTO common_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES
|
||||
-- BM-21
|
||||
(1, 7.35, 2.4, 3.1, 13700, 20.4),
|
||||
-- SR5
|
||||
(2, 10.2, 2.8, 3.2, 28500, 70),
|
||||
-- HIMARS
|
||||
(3, 7.0, 2.4, 3.2, 16250, 70),
|
||||
-- LAR-160
|
||||
(4, 6.7, 2.5, 2.8, 15000, 45),
|
||||
-- T-122
|
||||
(5, 7.2, 2.5, 2.9, 18000, 40),
|
||||
-- RM-70
|
||||
(6, 7.5, 2.5, 3.0, 17200, 20.3),
|
||||
-- ASTROS II
|
||||
(7, 8.0, 2.7, 3.1, 24500, 90);
|
||||
|
||||
-- 火箭炮特有参数
|
||||
INSERT INTO rocket_artillery_params (
|
||||
equipment_id,
|
||||
firing_angle_horizontal,
|
||||
firing_angle_vertical,
|
||||
rocket_length_m,
|
||||
rocket_diameter_mm,
|
||||
rocket_weight_kg,
|
||||
rate_of_fire,
|
||||
combat_weight_kg,
|
||||
speed_kmh,
|
||||
min_range_km,
|
||||
mobility_type,
|
||||
structure_layout,
|
||||
engine_model,
|
||||
engine_params,
|
||||
power_hp,
|
||||
travel_range_km
|
||||
) VALUES
|
||||
-- BM-21
|
||||
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
|
||||
-- SR5
|
||||
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
|
||||
-- HIMARS
|
||||
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
|
||||
-- LAR-160
|
||||
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
|
||||
-- T-122
|
||||
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
|
||||
-- RM-70
|
||||
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
|
||||
-- ASTROS II
|
||||
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
|
||||
|
||||
-- 巡飞弹数据
|
||||
INSERT INTO equipment (name, type, manufacturer) VALUES
|
||||
('Hero-120', '巡飞弹', '以色列'),
|
||||
('Switchblade 600', '巡飞弹', '美国'),
|
||||
('Warmate', '巡飞弹', '波兰'),
|
||||
('CH-901', '巡飞弹', '中国'),
|
||||
('HAROP', '巡飞弹', '以色列'),
|
||||
('Coyote', '巡飞弹', '美国'),
|
||||
('WS-43', '巡飞弹', '中国');
|
||||
|
||||
-- 巡飞弹通用参数
|
||||
INSERT INTO common_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES
|
||||
-- Hero-120
|
||||
(8, 1.3, 0.23, 0.23, 12.5, 40),
|
||||
-- Switchblade 600
|
||||
(9, 1.3, 0.22, 0.22, 15.0, 40),
|
||||
-- Warmate
|
||||
(10, 1.1, 0.15, 0.15, 5.7, 15),
|
||||
-- CH-901
|
||||
(11, 1.2, 0.18, 0.18, 9.0, 20),
|
||||
-- HAROP
|
||||
(12, 2.5, 0.43, 0.43, 135, 1000),
|
||||
-- Coyote
|
||||
(13, 0.9, 0.12, 0.12, 5.9, 20),
|
||||
-- WS-43
|
||||
(14, 1.8, 0.35, 0.35, 20, 60);
|
||||
|
||||
-- 巡飞弹特有参数
|
||||
INSERT INTO loitering_munition_params (
|
||||
equipment_id,
|
||||
wingspan_m,
|
||||
warhead_weight_kg,
|
||||
max_speed_ms,
|
||||
cruise_speed_kmh,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm,
|
||||
power_system,
|
||||
guidance_system
|
||||
) VALUES
|
||||
-- Hero-120
|
||||
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
|
||||
-- Switchblade 600
|
||||
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
|
||||
-- Warmate
|
||||
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
|
||||
-- CH-901
|
||||
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
|
||||
-- HAROP
|
||||
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
-- Coyote
|
||||
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
|
||||
-- WS-43
|
||||
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
|
||||
|
||||
-- 插入成本数据(示例成本)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
-- 火箭炮
|
||||
(1, 800000), -- BM-21
|
||||
(2, 4500000), -- SR5
|
||||
(3, 5500000), -- HIMARS
|
||||
(4, 3500000), -- LAR-160
|
||||
(5, 2800000), -- T-122
|
||||
(6, 1500000), -- RM-70
|
||||
(7, 4800000), -- ASTROS II
|
||||
-- 巡飞弹
|
||||
(8, 150000), -- Hero-120
|
||||
(9, 180000), -- Switchblade 600
|
||||
(10, 80000), -- Warmate
|
||||
(11, 100000), -- CH-901
|
||||
(12, 850000), -- HAROP
|
||||
(13, 75000), -- Coyote
|
||||
(14, 120000); -- WS-43
|
||||
|
||||
-- 创建初始数据集
|
||||
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
|
||||
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
|
||||
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
|
||||
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
|
||||
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
|
||||
|
||||
-- 关联装备到数据集
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
-- 火箭炮训练集
|
||||
(1, 1), (1, 2), (1, 3), (1, 4),
|
||||
-- 巡飞弹训练集
|
||||
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
|
||||
-- 火箭炮验证集
|
||||
(3, 5), (3, 6), (3, 7),
|
||||
-- 巡飞弹验证集
|
||||
(4, 13), (4, 14);
|
||||
@ -26,15 +26,15 @@
|
||||
*/
|
||||
|
||||
-- 插入装备基本信息
|
||||
INSERT INTO equipment (
|
||||
INSERT INTO equipments (
|
||||
id, -- 装备ID
|
||||
name, -- 装备名称
|
||||
type, -- 装备类型
|
||||
manufacturer -- 制造商
|
||||
) VALUES
|
||||
(1, 'IAI Harop', '巡飞弹', '以色列'),
|
||||
(2, 'IAI Harpy', '巡飞弹', '以色列'),
|
||||
(3, 'IAI Mini Harpy', '巡飞弹', '以色列'),
|
||||
(1, 'IAI Harop', '巡飞弹', '以色列 IAI'),
|
||||
(2, 'IAI Harpy', '巡飞弹', '以色列 IAI'),
|
||||
(3, 'IAI Mini Harpy', '巡飞弹', '以色列 IAI'),
|
||||
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
|
||||
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
|
||||
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
|
||||
@ -65,11 +65,11 @@ INSERT INTO equipment (
|
||||
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
|
||||
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
|
||||
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'),
|
||||
(34, 'Shahed-131', '巡飞弹', '伊朗'),
|
||||
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
|
||||
(36, 'Shahed-136', '巡飞弹', '伊朗'),
|
||||
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
|
||||
(38, 'Shahed-136C', '巡飞弹', '伊朗'),
|
||||
(34, 'Shahed-131', '巡飞弹', '伊朗国防工业'),
|
||||
(35, 'Shahed-131B', '巡飞弹', '伊朗国防工业'),
|
||||
(36, 'Shahed-136', '巡飞弹', '伊朗国防工业'),
|
||||
(37, 'Shahed-136B', '巡飞弹', '伊朗国防工业'),
|
||||
(38, 'Shahed-136C', '巡飞弹', '伊朗国防工业'),
|
||||
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
|
||||
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
|
||||
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
|
||||
@ -285,7 +285,7 @@ INSERT INTO loitering_munition_params (
|
||||
(24, 2.8, 8.0, 70, 180, 240, 50, 10.0, 4000, 25, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(25, 3.0, 9.0, 75, 190, 270, 60, 11.0, 4500, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(26, 3.2, 10.0, 80, 200, 300, 70, 12.0, 5000, 35, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
|
||||
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
|
||||
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
|
||||
(28, 3.6, 16.0, 90, 230, 400, 120, 20.0, 6500, 60, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
|
||||
(29, 1.2, 1.0, 40, 90, 30, 5, 1.5, 1500, 3, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
|
||||
(30, 1.3, 1.2, 45, 100, 40, 8, 2.0, 2000, 4, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
|
||||
@ -469,7 +469,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
|
||||
(2, '巡飞弹验证集 2024', '包含20个巡飞弹型号,用于验证模型性能', '巡飞弹', '验证');
|
||||
|
||||
-- 训练集(80个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 以色列系列(8/10)
|
||||
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
|
||||
(1, 4), (1, 5), (1, 6), (1, 7), (1, 8), -- Hero系列
|
||||
@ -520,7 +520,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
(1, 96), (1, 97), (1, 98), (1, 99); -- Shadow/Argus系列
|
||||
|
||||
-- 验证集(20个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 以色列系列(2/10)
|
||||
(2, 9), -- Hero-900
|
||||
(2, 48), -- Rotem L
|
||||
@ -574,8 +574,8 @@ SET description = '包含20个巡飞弹型号,覆盖所有主要制造国,
|
||||
WHERE id = 2;
|
||||
|
||||
-- 更新巡飞弹特征工程字段
|
||||
UPDATE loitering_munition_params l
|
||||
JOIN common_params c ON l.equipment_id = c.equipment_id
|
||||
-- 第一步:更新基于common_params的特征
|
||||
UPDATE loitering_munition_params l, common_params c, equipments e
|
||||
SET
|
||||
-- 长宽比(反映气动布局特点)
|
||||
l.length_width_ratio = c.length_m / NULLIF(c.width_m, 0),
|
||||
@ -586,61 +586,6 @@ SET
|
||||
-- 速度重量比(反映动力性能,m/s/kg)
|
||||
l.speed_weight_ratio = l.max_speed_ms / NULLIF(c.weight_kg, 0),
|
||||
|
||||
-- 制导系统评分(1-10)
|
||||
l.guidance_system_score =
|
||||
CASE
|
||||
WHEN l.guidance_system LIKE '%卫通%' THEN 10
|
||||
WHEN l.guidance_system LIKE '%AI辅助%' AND l.guidance_system LIKE '%红外%' THEN 9
|
||||
WHEN l.guidance_system LIKE '%AI辅助%' THEN 8
|
||||
WHEN l.guidance_system LIKE '%数据链%' AND l.guidance_system LIKE '%光电%' THEN 7
|
||||
WHEN l.guidance_system LIKE '%数据链%' THEN 6
|
||||
WHEN l.guidance_system LIKE '%光电%' THEN 5
|
||||
WHEN l.guidance_system LIKE '%GPS/INS%' THEN 4
|
||||
ELSE 3
|
||||
END,
|
||||
|
||||
-- 战斗部威力评分(1-10)
|
||||
l.warhead_power_score =
|
||||
CASE
|
||||
-- 大型战斗部(>30kg)
|
||||
WHEN l.warhead_weight_kg > 30 AND l.warhead_type LIKE '%模块化%' THEN 10
|
||||
WHEN l.warhead_weight_kg > 30 AND l.warhead_type LIKE '%破甲%' THEN 9
|
||||
WHEN l.warhead_weight_kg > 30 AND l.warhead_type LIKE '%破片%' THEN 8
|
||||
|
||||
-- 中型战斗部(10-30kg)
|
||||
WHEN l.warhead_weight_kg > 10 AND l.warhead_type LIKE '%模块化%' THEN 8
|
||||
WHEN l.warhead_weight_kg > 10 AND l.warhead_type LIKE '%破甲%' THEN 7
|
||||
WHEN l.warhead_weight_kg > 10 AND l.warhead_type LIKE '%破片%' THEN 6
|
||||
|
||||
-- 小型战斗部(3-10kg)
|
||||
WHEN l.warhead_weight_kg > 3 AND l.warhead_type LIKE '%模块化%' THEN 6
|
||||
WHEN l.warhead_weight_kg > 3 AND l.warhead_type LIKE '%破甲%' THEN 5
|
||||
WHEN l.warhead_weight_kg > 3 AND l.warhead_type LIKE '%破片%' THEN 4
|
||||
|
||||
-- 微型战斗部(<3kg)
|
||||
WHEN l.warhead_type LIKE '%破甲%' THEN 3
|
||||
WHEN l.warhead_type LIKE '%破片%' THEN 2
|
||||
ELSE 1
|
||||
END,
|
||||
|
||||
-- 发动机功率(kW,根据重量估算)
|
||||
l.engine_power_kw =
|
||||
CASE
|
||||
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
|
||||
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
|
||||
WHEN l.power_system = '涡轮喷气' THEN c.weight_kg * 0.35
|
||||
ELSE c.weight_kg * 0.2
|
||||
END,
|
||||
|
||||
-- 发动机推力(N,根据重量估算)
|
||||
l.engine_thrust_n =
|
||||
CASE
|
||||
WHEN l.power_system = '电动机' THEN c.weight_kg * 9.8 * 0.3
|
||||
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 9.8 * 0.4
|
||||
WHEN l.power_system = '涡轮喷气' THEN c.weight_kg * 9.8 * 0.5
|
||||
ELSE c.weight_kg * 9.8 * 0.35
|
||||
END,
|
||||
|
||||
-- 最小作战高度(m,根据体型和任务类型估算)
|
||||
l.min_altitude_m =
|
||||
CASE
|
||||
@ -660,10 +605,133 @@ SET
|
||||
ELSE 30
|
||||
END,
|
||||
|
||||
-- 最大作战高度(m,根据航程估算)
|
||||
l.max_altitude_m =
|
||||
-- 发动机功率(kW,根据重量估算)
|
||||
l.engine_power_kw =
|
||||
CASE
|
||||
WHEN l.max_range_km > 500 THEN 5000
|
||||
WHEN l.max_range_km > 100 THEN 3000
|
||||
WHEN power_system = '电动机' THEN c.weight_kg * 0.15
|
||||
WHEN power_system = '活塞发动机' THEN c.weight_kg * 0.25
|
||||
WHEN power_system = '涡轮喷气' THEN c.weight_kg * 0.35
|
||||
ELSE c.weight_kg * 0.2
|
||||
END,
|
||||
|
||||
-- 发动机推力(N,根据重量估算)
|
||||
l.engine_thrust_n =
|
||||
CASE
|
||||
WHEN power_system = '电动机' THEN c.weight_kg * 9.8 * 0.3
|
||||
WHEN power_system = '活塞发动机' THEN c.weight_kg * 9.8 * 0.4
|
||||
WHEN power_system = '涡轮喷气' THEN c.weight_kg * 9.8 * 0.5
|
||||
ELSE c.weight_kg * 9.8 * 0.35
|
||||
END
|
||||
|
||||
WHERE
|
||||
l.equipment_id = c.equipment_id
|
||||
AND l.equipment_id = e.id
|
||||
AND e.type = '巡飞弹';
|
||||
|
||||
-- 第二步:更新基于自身参数的特征
|
||||
UPDATE loitering_munition_params
|
||||
SET
|
||||
-- 制导系统评分(1-10)
|
||||
guidance_system_score =
|
||||
CASE
|
||||
WHEN guidance_system LIKE '%卫通%' THEN 10
|
||||
WHEN guidance_system LIKE '%AI辅助%' AND guidance_system LIKE '%红外%' THEN 9
|
||||
WHEN guidance_system LIKE '%AI辅助%' THEN 8
|
||||
WHEN guidance_system LIKE '%数据链%' AND guidance_system LIKE '%光电%' THEN 7
|
||||
WHEN guidance_system LIKE '%数据链%' THEN 6
|
||||
WHEN guidance_system LIKE '%光电%' THEN 5
|
||||
WHEN guidance_system LIKE '%GPS/INS%' THEN 4
|
||||
ELSE 3
|
||||
END,
|
||||
|
||||
-- 战斗部威力评分(1-10)
|
||||
warhead_power_score =
|
||||
CASE
|
||||
-- 大型战斗部(>30kg)
|
||||
WHEN warhead_weight_kg > 30 AND warhead_type LIKE '%模块化%' THEN 10
|
||||
WHEN warhead_weight_kg > 30 AND warhead_type LIKE '%破甲%' THEN 9
|
||||
WHEN warhead_weight_kg > 30 AND warhead_type LIKE '%破片%' THEN 8
|
||||
|
||||
-- 中型战斗部(10-30kg)
|
||||
WHEN warhead_weight_kg > 10 AND warhead_type LIKE '%模块化%' THEN 8
|
||||
WHEN warhead_weight_kg > 10 AND warhead_type LIKE '%破甲%' THEN 7
|
||||
WHEN warhead_weight_kg > 10 AND warhead_type LIKE '%破片%' THEN 6
|
||||
|
||||
-- 小型战斗部(3-10kg)
|
||||
WHEN warhead_weight_kg > 3 AND warhead_type LIKE '%模块化%' THEN 6
|
||||
WHEN warhead_weight_kg > 3 AND warhead_type LIKE '%破甲%' THEN 5
|
||||
WHEN warhead_weight_kg > 3 AND warhead_type LIKE '%破片%' THEN 4
|
||||
|
||||
-- 微型战斗部(<3kg)
|
||||
WHEN warhead_type LIKE '%破甲%' THEN 3
|
||||
WHEN warhead_type LIKE '%破片%' THEN 2
|
||||
ELSE 1
|
||||
END,
|
||||
|
||||
-- 数据链范围(km)
|
||||
datalink_range_km =
|
||||
CASE
|
||||
-- 大型巡飞弹(通常具有卫星通信能力)
|
||||
WHEN guidance_system LIKE '%卫通%' THEN max_range_km
|
||||
|
||||
-- 中大型巡飞弹(具有较强数据链能力)
|
||||
WHEN guidance_system LIKE '%数据链%' AND max_range_km > 100 THEN LEAST(max_range_km, 200)
|
||||
|
||||
-- 中型巡飞弹
|
||||
WHEN guidance_system LIKE '%数据链%' AND max_range_km > 50 THEN LEAST(max_range_km, 100)
|
||||
|
||||
-- 小型巡飞弹
|
||||
WHEN guidance_system LIKE '%数据链%' THEN LEAST(max_range_km, 50)
|
||||
|
||||
-- 无数据链的情况(使用光电或其他制导方式)
|
||||
ELSE LEAST(max_range_km * 0.5, 30)
|
||||
END,
|
||||
|
||||
-- 最大作战高度(m,根据航程估算)
|
||||
max_altitude_m =
|
||||
CASE
|
||||
WHEN max_range_km > 500 THEN 5000
|
||||
WHEN max_range_km > 100 THEN 3000
|
||||
ELSE 1500
|
||||
END;
|
||||
END
|
||||
|
||||
WHERE equipment_id IN (
|
||||
SELECT id FROM equipments WHERE type = '巡飞弹'
|
||||
);
|
||||
|
||||
-- 更新巡飞弹的制导精度
|
||||
UPDATE loitering_munition_params
|
||||
SET guidance_accuracy_m =
|
||||
CASE
|
||||
-- 基础精度(根据制导系统类型)
|
||||
WHEN guidance_system LIKE '%GPS/INS%' AND guidance_system LIKE '%AI辅助%' THEN 2.0
|
||||
WHEN guidance_system LIKE '%GPS/INS%' THEN 3.0
|
||||
WHEN guidance_system LIKE '%激光制导%' THEN 1.0
|
||||
WHEN guidance_system LIKE '%红外制导%' THEN 2.0
|
||||
WHEN guidance_system LIKE '%卫星制导%' THEN 2.5
|
||||
ELSE 5.0
|
||||
END *
|
||||
-- 速度影响因子(速度越快,精度略微降低)
|
||||
CASE
|
||||
WHEN max_speed_ms > 200 THEN 1.2
|
||||
WHEN max_speed_ms > 150 THEN 1.1
|
||||
WHEN max_speed_ms > 100 THEN 1.0
|
||||
ELSE 0.9
|
||||
END *
|
||||
-- 重量影响因子(重量越大,精度略微降低)
|
||||
CASE
|
||||
WHEN warhead_weight_kg > 100 THEN 1.2
|
||||
WHEN warhead_weight_kg > 50 THEN 1.1
|
||||
WHEN warhead_weight_kg > 20 THEN 1.0
|
||||
ELSE 0.9
|
||||
END *
|
||||
-- 飞行高度影响因子(高度越高,精度略微降低)
|
||||
CASE
|
||||
WHEN ceiling_altitude_m > 5000 THEN 1.2
|
||||
WHEN ceiling_altitude_m > 3000 THEN 1.1
|
||||
WHEN ceiling_altitude_m > 1000 THEN 1.0
|
||||
ELSE 0.9
|
||||
END
|
||||
WHERE equipment_id IN (
|
||||
SELECT id FROM equipments WHERE type = '巡飞弹'
|
||||
);
|
||||
80
src/manufacturer_data.sql
Normal file
80
src/manufacturer_data.sql
Normal file
@ -0,0 +1,80 @@
|
||||
-- 插入供应商数据
|
||||
INSERT INTO manufacturers (
|
||||
name, -- 供应商名称
|
||||
country, -- 所属国家
|
||||
tech_level, -- 技术水平评分(1-10)
|
||||
scale_level, -- 规模评分(1-10)
|
||||
supply_chain_level -- 供应链成熟度评分(1-10)
|
||||
) VALUES
|
||||
-- 美国供应商
|
||||
('美国洛克希德·马丁', '美国', 10, 10, 10), -- 全球最大军工企业
|
||||
('美国 AeroVironment', '美国', 9, 8, 9), -- 无人机和导弹领域领先
|
||||
('美国 Raytheon', '美国', 9, 9, 9), -- 导弹技术领先
|
||||
('美国 AEVEX', '美国', 8, 7, 8), -- 新兴军工企业
|
||||
('美国 AREA-I', '美国', 8, 7, 8), -- 专注无人机系统
|
||||
('美国 Northrop Grumman', '美国', 9, 9, 9), -- 大型军工企业
|
||||
|
||||
-- 欧洲供应商
|
||||
('英国 BAE Systems', '英国', 8, 9, 9), -- 欧洲最大军工企业
|
||||
('英国 MBDA', '英国', 8, 8, 8), -- 导弹系统专家
|
||||
('德国 KMW', '德国', 9, 8, 9), -- 陆军装备主要供应商
|
||||
('德国 MBDA', '德国', 8, 8, 8), -- 导弹系统制造商
|
||||
('德国 Rheinmetall', '德国', 8, 8, 8), -- 综合军工企业
|
||||
('法国 Nexter', '法国', 8, 8, 8), -- 陆军装备制造商
|
||||
('法国 MBDA', '法国', 8, 8, 8), -- 导弹系统制造商
|
||||
('法国 Safran', '法国', 8, 8, 8), -- 航空航天企业
|
||||
('意大利 Leonardo', '意大利', 7, 7, 7), -- 综合军工企业
|
||||
('意大利 OTO Melara', '意大利', 7, 7, 7), -- 火炮系统制造商
|
||||
|
||||
-- 以色列供应商
|
||||
('以色列军事工业', '以色列', 9, 7, 7), -- 技术先进
|
||||
('以色列 IAI', '以色列', 9, 7, 7), -- 航空航天领先
|
||||
('以色列 UVision', '以色列', 8, 6, 7), -- 无人机专家
|
||||
|
||||
-- 中国供应商
|
||||
('中国兵器工业集团', '中国', 8, 9, 8), -- 陆军装备制造商
|
||||
('中国航天科工', '中国', 8, 9, 8), -- 导弹制造商
|
||||
|
||||
-- 亚洲供应商
|
||||
('韩国韩华防务', '韩国', 7, 7, 7), -- 韩国主要军工企业
|
||||
('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业
|
||||
|
||||
-- 俄罗斯供应商
|
||||
('俄罗斯 Rostec', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
|
||||
('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商
|
||||
('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商
|
||||
|
||||
-- 其他欧洲供应商
|
||||
('波兰 WB Electronics', '波兰', 6, 6, 6), -- 电子系统制造商
|
||||
('波兰 WB Group', '波兰', 6, 6, 6), -- 军工集团
|
||||
('波兰胡塔斯塔洛瓦', '波兰', 6, 6, 6), -- 装备制造商
|
||||
('瑞典 UMS Skeldar', '瑞典', 7, 6, 7), -- 无人机系统
|
||||
('瑞典 Saab', '瑞典', 7, 7, 7), -- 综合军工企业
|
||||
('捷克 RETIA', '捷克', 6, 5, 6), -- 电子系统制造商
|
||||
('斯洛伐克 ZTS', '斯洛伐克', 5, 5, 5), -- 装备制造商
|
||||
('捷克 Excalibur Army', '捷克', 6, 5, 6), -- 陆军装备制造商
|
||||
('克罗地亚 RH ALAN', '克罗地亚', 5, 4, 5), -- 军工企业
|
||||
('塞尔维亚 Yugoimport', '塞尔维亚', 5, 4, 5), -- 军工出口企业
|
||||
('芬兰 Patria', '芬兰', 7, 6, 7), -- 装甲车辆制造商
|
||||
('奥地利 Hirtenberger', '奥地利', 7, 6, 7), -- 火炮系统制造商
|
||||
|
||||
-- 其他供应商
|
||||
('土耳其洛克特桑', '土耳其', 6, 6, 6), -- 新兴军工企业
|
||||
('土耳其 STM', '土耳其', 6, 6, 6), -- 防务技术公司
|
||||
('巴西航空工业', '巴西', 6, 6, 5), -- 南美最大军工企业
|
||||
('印度DRDO', '印度', 5, 5, 5), -- 国防研究机构
|
||||
('伊朗国防工业', '伊朗', 4, 4, 4), -- 受制裁影响
|
||||
('埃及 AOI', '埃及', 4, 4, 4), -- 军工企业
|
||||
('罗马尼亚 ROMARM', '罗马尼亚', 5, 4, 5), -- 国营军工企业
|
||||
('乌克兰尤日马什', '乌克兰', 6, 5, 5), -- 航天企业
|
||||
('白俄罗斯国家军工委员会', '白俄罗斯', 5, 5, 5), -- 国家军工管理机构
|
||||
('阿联酋国际金龙', '阿联酋', 6, 6, 6), -- 新兴军工企业
|
||||
('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业
|
||||
|
||||
-- 更新装备表中的供应商ID
|
||||
UPDATE equipments e
|
||||
SET manufacturer_id = (
|
||||
SELECT id
|
||||
FROM manufacturers m
|
||||
WHERE m.name = e.manufacturer
|
||||
);
|
||||
1242
src/model_trainer.py
1242
src/model_trainer.py
File diff suppressed because it is too large
Load Diff
@ -1,485 +0,0 @@
|
||||
-- 清空现有数据
|
||||
SET FOREIGN_KEY_CHECKS=0;
|
||||
TRUNCATE TABLE dataset_equipment;
|
||||
TRUNCATE TABLE datasets;
|
||||
TRUNCATE TABLE cost_data;
|
||||
TRUNCATE TABLE loitering_munition_params;
|
||||
TRUNCATE TABLE common_params;
|
||||
TRUNCATE TABLE equipment;
|
||||
SET FOREIGN_KEY_CHECKS=1;
|
||||
|
||||
-- 按系列插入装备数据,确保ID连续
|
||||
-- 1. HAROP/Harpy 系列 (ID: 1-3)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(1, 'IAI Harop', '巡飞弹', '以色列'),
|
||||
(2, 'IAI Harpy', '巡飞弹', '以色列'),
|
||||
(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
|
||||
|
||||
-- 2. Hero 系列 (ID: 4-9)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
|
||||
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
|
||||
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
|
||||
(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
|
||||
(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
|
||||
(9, 'Hero-900', '巡飞弹', '以色列 UVision');
|
||||
|
||||
-- 3. Switchblade 系列 (ID: 10-13)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
|
||||
(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
|
||||
(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
|
||||
(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
|
||||
|
||||
-- 4. Warmate 系列 (ID: 14-18)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
|
||||
(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
|
||||
(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
|
||||
(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
|
||||
(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
|
||||
|
||||
-- 5. CH-901/902 系列 (ID: 19-23)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(19, 'CH-901', '巡飞弹', '中国航天科工'),
|
||||
(20, 'CH-901A', '巡飞弹', '中国航天科工'),
|
||||
(21, 'CH-901H', '巡飞弹', '中国航天科工'),
|
||||
(22, 'CH-902', '巡飞弹', '中国航天科工'),
|
||||
(23, 'CH-902A', '巡飞弹', '中国航天科工');
|
||||
|
||||
-- 6. WS-43/61 系列 (ID: 24-28)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(24, 'WS-43', '巡飞弹', '中国航天科工'),
|
||||
(25, 'WS-43A', '巡飞弹', '中国航天科工'),
|
||||
(26, 'WS-43B', '巡飞弹', '中国航天科工'),
|
||||
(27, 'WS-61', '巡飞弹', '中国航天科工'),
|
||||
(28, 'WS-61A', '巡飞弹', '中国航天科工');
|
||||
|
||||
-- 7. Kargu/Alpagu 系列 (ID: 29-33)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(29, 'Kargu', '巡飞弹', '土耳其 STM'),
|
||||
(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
|
||||
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
|
||||
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
|
||||
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
|
||||
|
||||
-- 8. Shahed 系列 (ID: 34-38)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(34, 'Shahed-131', '巡飞弹', '伊朗'),
|
||||
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
|
||||
(36, 'Shahed-136', '巡飞弹', '伊朗'),
|
||||
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
|
||||
(38, 'Shahed-136C', '巡飞弹', '伊朗');
|
||||
|
||||
-- 9. Green Dragon 系列 (ID: 39-43)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
|
||||
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
|
||||
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
|
||||
(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
|
||||
(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
|
||||
|
||||
-- 10. Phoenix Ghost 系列 (ID: 44-48)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
|
||||
|
||||
-- 11. ZALA Lancet 系列 (ID: 49-52)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
|
||||
|
||||
-- 12. Rotem L 系列 (ID: 53-56)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
|
||||
(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
|
||||
(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
|
||||
(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
|
||||
|
||||
-- 13. KUB-BLA 系列 (ID: 57-60)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
|
||||
|
||||
-- 插入通用参数
|
||||
INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
|
||||
(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
|
||||
(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
|
||||
(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
|
||||
(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
|
||||
(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
|
||||
(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
|
||||
(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
|
||||
(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
|
||||
(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
|
||||
(10, 0.58, 0.12, 0.12, 2.5, 10),
|
||||
(11, 1.30, 0.22, 0.22, 15.0, 40),
|
||||
(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
|
||||
(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
|
||||
(14, 0.68, 0.12, 0.12, 2.5, 10),
|
||||
(15, 1.30, 0.22, 0.22, 15.0, 40),
|
||||
(16, 0.68, 0.12, 0.12, 2.5, 10),
|
||||
(17, 1.30, 0.22, 0.22, 15.0, 40),
|
||||
(18, 0.68, 0.12, 0.12, 2.5, 10),
|
||||
(19, 1.2, 0.18, 0.18, 9.0, 20),
|
||||
(20, 1.2, 0.18, 0.18, 9.3, 25),
|
||||
(21, 1.2, 0.18, 0.18, 9.5, 20),
|
||||
(22, 1.4, 0.22, 0.22, 15.0, 30),
|
||||
(23, 1.4, 0.22, 0.22, 15.5, 35),
|
||||
(24, 1.8, 0.35, 0.35, 20, 60),
|
||||
(25, 1.8, 0.35, 0.35, 21, 70),
|
||||
(26, 1.9, 0.35, 0.35, 22, 80),
|
||||
(27, 2.2, 0.40, 0.40, 35, 100),
|
||||
(28, 2.2, 0.40, 0.40, 37, 120),
|
||||
(29, 0.6, 0.35, 0.35, 7.0, 10),
|
||||
(30, 0.6, 0.35, 0.35, 7.2, 15),
|
||||
(31, 1.0, 0.23, 0.23, 3.7, 5),
|
||||
(32, 1.0, 0.23, 0.23, 3.9, 8),
|
||||
(33, 0.6, 0.35, 0.35, 7.5, 15),
|
||||
(34, 2.6, 0.34, 0.34, 135, 900),
|
||||
(35, 2.6, 0.34, 0.34, 140, 1000),
|
||||
(36, 3.5, 0.42, 0.42, 200, 2000),
|
||||
(37, 3.5, 0.42, 0.42, 210, 2200),
|
||||
(38, 3.5, 0.42, 0.42, 215, 2500),
|
||||
(39, 1.5, 0.20, 0.20, 15, 40),
|
||||
(40, 1.6, 0.20, 0.20, 16, 50),
|
||||
(41, 1.5, 0.20, 0.20, 15.5, 45),
|
||||
(42, 1.5, 0.20, 0.20, 15.8, 40),
|
||||
(43, 1.2, 0.18, 0.18, 12, 30),
|
||||
(44, 1.5, 0.25, 0.25, 14.0, 30),
|
||||
(45, 1.5, 0.25, 0.25, 14.5, 35),
|
||||
(46, 1.6, 0.26, 0.26, 15.0, 40),
|
||||
(47, 1.5, 0.25, 0.25, 14.8, 30),
|
||||
(48, 1.7, 0.27, 0.27, 16.0, 50),
|
||||
(49, 1.0, 0.20, 0.20, 5.0, 40),
|
||||
(50, 1.65, 0.35, 0.35, 12.0, 70),
|
||||
(51, 1.65, 0.35, 0.35, 12.5, 80),
|
||||
(52, 1.80, 0.40, 0.40, 15.0, 100),
|
||||
(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
|
||||
(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
|
||||
(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
|
||||
(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
|
||||
(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
|
||||
(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
|
||||
(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
|
||||
(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
|
||||
|
||||
-- 插入特有参数
|
||||
INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
|
||||
endurance_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
power_system,
|
||||
guidance_system
|
||||
) VALUES
|
||||
-- HAROP/Harpy系列
|
||||
(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
|
||||
(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
|
||||
|
||||
-- Hero系列
|
||||
(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
|
||||
(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
|
||||
(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
|
||||
-- Switchblade系列
|
||||
(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
|
||||
(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
|
||||
-- Warmate系列
|
||||
(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
|
||||
(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
|
||||
|
||||
-- CH-901/902系列
|
||||
(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
|
||||
(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
|
||||
(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
|
||||
(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
|
||||
(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
|
||||
(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
|
||||
(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
|
||||
(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
|
||||
(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
|
||||
(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
|
||||
(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
|
||||
|
||||
-- 插入成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(1, 800000), -- IAI Harop
|
||||
(2, 700000), -- IAI Harpy
|
||||
(3, 350000), -- IAI Mini Harpy
|
||||
(4, 70000), -- Hero-30
|
||||
(5, 120000), -- Hero-70
|
||||
(6, 150000), -- Hero-120
|
||||
(7, 300000), -- Hero-250
|
||||
(8, 400000), -- Hero-400EC
|
||||
(9, 650000), -- Hero-900
|
||||
(10, 60000), -- Switchblade 300
|
||||
(11, 180000), -- Switchblade 600
|
||||
(12, 75000), -- Switchblade 300 Block 10
|
||||
(13, 200000), -- Switchblade 600 Extended Range
|
||||
(14, 60000), -- Warmate 1.0
|
||||
(15, 180000), -- Warmate 2.0
|
||||
(16, 60000), -- Warmate-V
|
||||
(17, 180000), -- Warmate-L
|
||||
(18, 60000), -- Warmate 3.0
|
||||
(19, 100000), -- CH-901
|
||||
(20, 120000), -- CH-901A
|
||||
(21, 130000), -- CH-901H
|
||||
(22, 180000), -- CH-902
|
||||
(23, 200000), -- CH-902A
|
||||
(24, 120000), -- WS-43
|
||||
(25, 150000), -- WS-43A
|
||||
(26, 180000), -- WS-43B
|
||||
(27, 300000), -- WS-61
|
||||
(28, 350000), -- WS-61A
|
||||
(29, 70000), -- Kargu
|
||||
(30, 85000), -- Kargu-2
|
||||
(31, 45000), -- Alpagu
|
||||
(32, 55000), -- Alpagu Block-II
|
||||
(33, 95000), -- Kargu Autonomous
|
||||
(34, 20000), -- Shahed-131
|
||||
(35, 25000), -- Shahed-131B
|
||||
(36, 40000), -- Shahed-136
|
||||
(37, 45000), -- Shahed-136B
|
||||
(38, 50000), -- Shahed-136C
|
||||
(39, 160000), -- Green Dragon
|
||||
(40, 200000), -- Green Dragon Extended Range
|
||||
(41, 180000), -- Green Dragon Block 2
|
||||
(42, 190000), -- Green Dragon Maritime
|
||||
(43, 140000), -- Green Dragon-S
|
||||
(44, 150000), -- Phoenix Ghost
|
||||
(45, 180000), -- Phoenix Ghost Block I
|
||||
(46, 220000), -- Phoenix Ghost Block II
|
||||
(47, 190000), -- Phoenix Ghost Maritime
|
||||
(48, 250000), -- Phoenix Ghost-ER
|
||||
(49, 80000), -- Lancet-1
|
||||
(50, 150000), -- Lancet-3
|
||||
(51, 180000), -- Lancet-3M
|
||||
(52, 250000), -- Lancet-4
|
||||
(53, 65000), -- Rotem L
|
||||
(54, 85000), -- Rotem L-X
|
||||
(55, 75000), -- Rotem L-M
|
||||
(56, 95000), -- Rotem L-ER
|
||||
(57, 95000), -- KUB-BLA
|
||||
(58, 120000), -- KUB-BLA-E
|
||||
(59, 110000), -- KUB-BLA-M
|
||||
(60, 150000); -- KUB-BLA-ER
|
||||
|
||||
-- 创建数据集
|
||||
INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
|
||||
(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
|
||||
(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
|
||||
|
||||
-- 关联装备到数据集(按照制造商和型号分配)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
-- 训练集(约80%的数据,48个型号)
|
||||
-- 以色列系列
|
||||
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
|
||||
(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
|
||||
(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
|
||||
(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
|
||||
|
||||
-- 美国系列
|
||||
(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
|
||||
(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
|
||||
|
||||
-- 中国系列
|
||||
(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
|
||||
(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
|
||||
|
||||
-- 波兰和土耳其系列
|
||||
(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
|
||||
(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
|
||||
|
||||
-- 俄罗斯系列
|
||||
(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
|
||||
|
||||
-- 验证集(约20%的数据,12个型号)
|
||||
-- 混合系列
|
||||
(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
|
||||
(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
|
||||
(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
|
||||
|
||||
-- 添加分类特征编码
|
||||
INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
|
||||
-- 战斗部类型编码
|
||||
('warhead_type', '破片杀伤战斗部', 1),
|
||||
('warhead_type', '破甲战斗部', 2),
|
||||
('warhead_type', '高爆战斗部', 3),
|
||||
('warhead_type', '破片杀伤/破甲双用战斗部', 4),
|
||||
('warhead_type', '模块化战斗部', 5),
|
||||
|
||||
-- 发射方式编码
|
||||
('launch_mode', '箱式发射', 1),
|
||||
('launch_mode', '弹射式发射', 2),
|
||||
('launch_mode', '垂直起降', 3),
|
||||
('launch_mode', '单兵发射管', 4),
|
||||
('launch_mode', '箱式发射/弹射式', 5),
|
||||
('launch_mode', '箱式发射/空中发射', 6),
|
||||
|
||||
-- 动力装置编码(按复杂度递增)
|
||||
('power_system', '电动机', 1),
|
||||
('power_system', '活塞发动机', 2),
|
||||
|
||||
-- 制导系统编码(按复杂度递增)
|
||||
('guidance_system', 'GPS/INS', 1),
|
||||
('guidance_system', 'GPS/INS/光电', 2),
|
||||
('guidance_system', 'GPS/INS/光电/数据链', 3),
|
||||
('guidance_system', 'GPS/INS/光电/AI识别', 4),
|
||||
('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
|
||||
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
|
||||
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
|
||||
|
||||
-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
|
||||
UPDATE loitering_munition_params l
|
||||
JOIN common_params c ON l.equipment_id = c.equipment_id
|
||||
SET
|
||||
-- 新增关键参数
|
||||
l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
|
||||
l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
|
||||
l.engine_power_kw =
|
||||
CASE
|
||||
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
|
||||
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
|
||||
END,
|
||||
l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
|
||||
l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
|
||||
l.guidance_accuracy_m =
|
||||
CASE
|
||||
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
|
||||
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
|
||||
ELSE 3.0
|
||||
END,
|
||||
l.min_altitude_m = -- 最小作战高度
|
||||
CASE
|
||||
-- 大型巡飞弹(体型大、重量大)
|
||||
WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
|
||||
|
||||
-- 中型巡飞弹
|
||||
WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
|
||||
|
||||
-- 中小型巡飞弹
|
||||
WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
|
||||
|
||||
-- 小型巡飞弹
|
||||
WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
|
||||
|
||||
-- 超小型巡飞弹
|
||||
WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
|
||||
|
||||
-- 其他型号使用默认值
|
||||
ELSE 50
|
||||
END,
|
||||
l.max_altitude_m =
|
||||
CASE
|
||||
WHEN c.max_range_km > 500 THEN 5000
|
||||
WHEN c.max_range_km > 100 THEN 3000
|
||||
ELSE 1500
|
||||
END,
|
||||
|
||||
-- 特征工程字段
|
||||
l.length_width_ratio = c.length_m / c.width_m,
|
||||
l.weight_range_ratio = c.weight_kg / c.max_range_km,
|
||||
l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
|
||||
l.guidance_system_score =
|
||||
CASE
|
||||
WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
|
||||
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
|
||||
WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
|
||||
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
|
||||
ELSE 2
|
||||
END,
|
||||
l.warhead_power_score =
|
||||
CASE
|
||||
WHEN l.warhead_type = '模块化战斗部' THEN 10
|
||||
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
|
||||
WHEN l.warhead_type = '高爆战斗部' THEN 7
|
||||
WHEN l.warhead_type = '破甲战斗部' THEN 6
|
||||
WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
|
||||
ELSE 4
|
||||
END,
|
||||
|
||||
-- 分类特征编码
|
||||
l.warhead_type_code =
|
||||
CASE
|
||||
WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
|
||||
WHEN l.warhead_type = '破甲战斗部' THEN 2
|
||||
WHEN l.warhead_type = '高爆战斗部' THEN 3
|
||||
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
|
||||
WHEN l.warhead_type = '模块化战斗部' THEN 5
|
||||
ELSE 0
|
||||
END,
|
||||
l.launch_mode_code =
|
||||
CASE
|
||||
WHEN l.launch_mode = '箱式发射' THEN 1
|
||||
WHEN l.launch_mode = '弹射式发射' THEN 2
|
||||
WHEN l.launch_mode = '垂直起降' THEN 3
|
||||
WHEN l.launch_mode = '单兵发射管' THEN 4
|
||||
WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
|
||||
WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
|
||||
ELSE 0
|
||||
END,
|
||||
l.power_system_code =
|
||||
CASE
|
||||
WHEN l.power_system = '电动机' THEN 1
|
||||
WHEN l.power_system = '活塞发动机' THEN 2
|
||||
ELSE 0
|
||||
END,
|
||||
l.guidance_system_code =
|
||||
CASE
|
||||
WHEN l.guidance_system = 'GPS/INS' THEN 1
|
||||
WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
|
||||
ELSE 0
|
||||
END;
|
||||
@ -29,7 +29,7 @@
|
||||
*/
|
||||
|
||||
-- 中国系列火箭炮数据
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
INSERT INTO equipments (id, name, type, manufacturer) VALUES
|
||||
(1001, 'PCL-191', '火箭炮', '中国兵器工业集团'),
|
||||
(1002, 'PHL-03', '火箭炮', '中国兵器工业集团'),
|
||||
(1003, 'AR-3', '火箭炮', '中国航天科工'),
|
||||
@ -39,11 +39,11 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(1007, 'WS-2', '火箭炮', '中国航天科工'),
|
||||
(1008, 'WS-3', '火箭炮', '中国航天科工'),
|
||||
(1009, 'Type 63', '火箭炮', '中国兵器工业集团'),
|
||||
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯'),
|
||||
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯'),
|
||||
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯'),
|
||||
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯'),
|
||||
(1014, 'TOS-1A', '火箭炮', '俄罗斯'),
|
||||
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1014, 'TOS-1A', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1015, 'M142 HIMARS', '火箭炮', '美国洛克希德·马丁'),
|
||||
(1016, 'M270 MLRS', '火箭炮', '美国洛克希德·马丁'),
|
||||
(1017, 'M270A1', '火箭炮', '美国洛克希德·马丁'),
|
||||
@ -62,10 +62,10 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(1030, 'ASTROS 2020', '火箭炮', '巴西航空工业'),
|
||||
(1031, 'ASTROS II Mk3', '火箭炮', '巴西航空工业'),
|
||||
(1032, 'ASTROS II Mk6', '火箭炮', '巴西航空工业'),
|
||||
(1033, 'Pinaka', '火箭炮', '印度DRDO'),
|
||||
(1034, 'Pinaka Mk-II', '火箭炮', '印度DRDO'),
|
||||
(1035, 'Pinaka Mk-III', '火箭炮', '印度DRDO'),
|
||||
(1036, 'Pinaka-ER', '火箭炮', '印度DRDO'),
|
||||
(1033, 'Pinaka', '火箭炮', '印度 DRDO'),
|
||||
(1034, 'Pinaka Mk-II', '火箭炮', '印度 DRDO'),
|
||||
(1035, 'Pinaka Mk-III', '火箭炮', '印度 DRDO'),
|
||||
(1036, 'Pinaka-ER', '火箭炮', '印度 DRDO'),
|
||||
(1037, 'WR-40 Langusta', '火箭炮', '波兰胡塔斯塔洛瓦'),
|
||||
(1038, 'RM-70', '火箭炮', '波兰胡塔斯塔洛瓦'),
|
||||
(1039, 'BM-21M', '火箭炮', '波兰胡塔斯塔洛瓦'),
|
||||
@ -485,7 +485,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
|
||||
(4, '火箭炮验证集 2024', '包含19个火箭炮型号,用于验证模型性能', '火箭炮', '验证');
|
||||
|
||||
-- 训练集(约80%的数据,77个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 中国系列(7/9)
|
||||
(3, 1001), (3, 1002), (3, 1003), (3, 1004), (3, 1005), (3, 1006), (3, 1007),
|
||||
|
||||
@ -565,7 +565,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
(3, 1094), (3, 1095);
|
||||
|
||||
-- 验证集(约20%的数据,19个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 中国系列(2/9)
|
||||
(4, 1008), (4, 1009),
|
||||
|
||||
|
||||
1020
src/routes.py
1020
src/routes.py
File diff suppressed because it is too large
Load Diff
@ -10,11 +10,12 @@ COLLATE utf8mb4_unicode_ci;
|
||||
USE equipment_cost_db;
|
||||
|
||||
-- 装备基本信息表
|
||||
CREATE TABLE equipment (
|
||||
CREATE TABLE equipments (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100), -- 名称
|
||||
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
|
||||
manufacturer VARCHAR(100), -- 制造商
|
||||
manufacturer_id INT, -- 制造商ID
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
@ -26,8 +27,7 @@ CREATE TABLE common_params (
|
||||
width_m FLOAT, -- 宽度(m)
|
||||
height_m FLOAT, -- 高度(m)
|
||||
weight_kg FLOAT, -- 重量(kg)
|
||||
max_range_km FLOAT, -- 最大射程(km)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 火箭炮特有参数表
|
||||
@ -61,7 +61,7 @@ CREATE TABLE rocket_artillery_params (
|
||||
deployment_score INT, -- 部署评分(1-10)
|
||||
terrain_adaptability_score INT, -- 地形适应性评分(1-10)
|
||||
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 巡飞弹特有参数表
|
||||
@ -103,7 +103,7 @@ CREATE TABLE loitering_munition_params (
|
||||
power_system_code INT, -- 动力装置编码
|
||||
guidance_system_code INT, -- 制导系统编码
|
||||
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 分类特征编码表
|
||||
@ -122,7 +122,7 @@ CREATE TABLE cost_data (
|
||||
actual_cost DECIMAL(15,2), -- 实际成本(元)
|
||||
predicted_cost DECIMAL(15,2), -- 预测成本(元)
|
||||
prediction_date TIMESTAMP, -- 预测日期
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 特殊参数表
|
||||
@ -133,12 +133,12 @@ CREATE TABLE custom_params (
|
||||
param_value VARCHAR(255), -- 参数值
|
||||
param_unit VARCHAR(50), -- 参数单位
|
||||
description TEXT, -- 参数说明
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_equipment_type ON equipment(type);
|
||||
CREATE INDEX idx_equipment_name ON equipment(name);
|
||||
CREATE INDEX idx_equipment_type ON equipments(type);
|
||||
CREATE INDEX idx_equipment_name ON equipments(name);
|
||||
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
|
||||
|
||||
-- 数据集表
|
||||
@ -153,12 +153,12 @@ CREATE TABLE datasets (
|
||||
);
|
||||
|
||||
-- 数据集-装备关联表
|
||||
CREATE TABLE dataset_equipment (
|
||||
CREATE TABLE dataset_equipments (
|
||||
dataset_id INT NOT NULL,
|
||||
equipment_id INT NOT NULL,
|
||||
PRIMARY KEY (dataset_id, equipment_id),
|
||||
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
-- 训练模型表
|
||||
@ -175,10 +175,34 @@ CREATE TABLE trained_models (
|
||||
feature_importance JSON, -- 特征重要性
|
||||
training_data_size INT, -- 训练数据量
|
||||
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
|
||||
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
|
||||
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
|
||||
created_by VARCHAR(50) -- 创建者
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
|
||||
CREATE INDEX idx_model_active ON trained_models(is_active);
|
||||
CREATE INDEX idx_model_active ON trained_models(is_active);
|
||||
|
||||
-- 生产商表
|
||||
CREATE TABLE manufacturers (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL, -- 生产商名称
|
||||
country VARCHAR(50) NOT NULL, -- 所属国家
|
||||
tech_level INT NOT NULL, -- 技术水平评分(1-10)
|
||||
scale_level INT NOT NULL, -- 规模评分(1-10)
|
||||
supply_chain_level INT NOT NULL, -- 供应链成熟度评分(1-10)
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY unique_name (name)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加生产商外键
|
||||
ALTER TABLE equipments ADD FOREIGN KEY (manufacturer_id) REFERENCES manufacturers(id);
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_manufacturer_country ON manufacturers(country);
|
||||
CREATE INDEX idx_manufacturer_tech_level ON manufacturers(tech_level);
|
||||
CREATE INDEX idx_manufacturer_scale_level ON manufacturers(scale_level);
|
||||
CREATE INDEX idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
|
||||
CREATE INDEX idx_equipment_manufacturer ON equipments(manufacturer_id);
|
||||
|
||||
|
||||
5
src/start.bat
Normal file
5
src/start.bat
Normal file
@ -0,0 +1,5 @@
|
||||
@echo off
|
||||
set FLASK_DEBUG=false
|
||||
echo Starting Cost Prediction System...
|
||||
start /B run.exe
|
||||
start http://localhost:5001
|
||||
118
src/test_api.py
118
src/test_api.py
@ -147,14 +147,124 @@ def test_api_endpoints():
|
||||
response = requests.get(f'{base_url}/models/巡飞弹/latest')
|
||||
print_response(response, "获取最新模型")
|
||||
|
||||
# 8. 测试多模型预测接口
|
||||
logger.info("\n8. 测试多模型预测接口")
|
||||
# 8. 测试预测接口
|
||||
logger.info("\n8. 测试预测接口")
|
||||
|
||||
# 8.1 测试普通预测接口
|
||||
logger.info("8.1 测试普通预测接口")
|
||||
predict_data = {
|
||||
"type": "巡飞弹",
|
||||
"length_m": 1.3,
|
||||
"width_m": 0.23,
|
||||
"height_m": 0.23,
|
||||
"weight_kg": 12.5,
|
||||
"max_range_km": 40,
|
||||
"max_speed_ms": 50,
|
||||
"cruise_speed_kmh": 100,
|
||||
"flight_time_min": 60,
|
||||
"folded_length_mm": 1300,
|
||||
"folded_width_mm": 230,
|
||||
"folded_height_mm": 230,
|
||||
"warhead_type": "破片杀伤战斗部",
|
||||
"launch_mode": "凭自身动力起飞"
|
||||
}
|
||||
response = requests.post(
|
||||
f'{base_url}/predict/all',
|
||||
f'{base_url}/predict',
|
||||
json=predict_data
|
||||
)
|
||||
print_response(response, "多模型预测")
|
||||
print_response(response, "普通预测")
|
||||
|
||||
# 8.2 测试 PLS 预测接口
|
||||
logger.info("8.2 测试 PLS 预测接口")
|
||||
response = requests.post(
|
||||
f'{base_url}/pls/predict',
|
||||
json=predict_data
|
||||
)
|
||||
print_response(response, "PLS 预测")
|
||||
|
||||
# 9. 测试生产商分析接口
|
||||
logger.info("\n9. 测试生产商分析接口")
|
||||
manufacturer_data = {
|
||||
"dataset_id": 1 # 使用已存在的数据集ID
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f'{base_url}/analyze-manufacturers',
|
||||
json=manufacturer_data
|
||||
)
|
||||
print_response(response, "生产商分析")
|
||||
|
||||
# 10. 测试模型激活接口
|
||||
logger.info("\n10. 测试模型激活接口")
|
||||
# 假设存在ID为1的模型
|
||||
response = requests.post(f'{base_url}/models/1/activate')
|
||||
print_response(response, "模型激活")
|
||||
|
||||
# 11. 测试获取最新模型接口
|
||||
logger.info("\n11. 测试获取最新模型接口")
|
||||
response = requests.get(f'{base_url}/models/巡飞弹/latest')
|
||||
print_response(response, "获取最新模型")
|
||||
|
||||
# 12. 测试数据集详情接口
|
||||
logger.info("\n12. 测试数据集详情接口")
|
||||
response = requests.get(f'{base_url}/datasets/1') # 假设存在ID为1的数据集
|
||||
print_response(response, "数据集详情")
|
||||
|
||||
# 13. 测试更新数据集接口
|
||||
logger.info("\n13. 测试更新数据集接口")
|
||||
if available_equipment_ids:
|
||||
update_dataset_data = {
|
||||
"name": "更新后的测试数据集",
|
||||
"description": "用于测试的更新数据集",
|
||||
"equipment_type": "巡飞弹",
|
||||
"purpose": "测试",
|
||||
"equipment_ids": available_equipment_ids[:2] # 使用前两个可用的装备ID
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
f'{base_url}/datasets/1', # 假设更新ID为1的数据集
|
||||
json=update_dataset_data
|
||||
)
|
||||
print_response(response, "更新数据集")
|
||||
else:
|
||||
logger.warning("没有可用的装备ID,跳过数据集更新测试")
|
||||
|
||||
# 14. 测试装备详情接口
|
||||
logger.info("\n14. 测试装备详情接口")
|
||||
if available_equipment_ids:
|
||||
response = requests.get(f'{base_url}/data/details/{available_equipment_ids[0]}')
|
||||
print_response(response, "装备详情")
|
||||
|
||||
# 15. 测试更新装备接口
|
||||
logger.info("\n15. 测试更新装备接口")
|
||||
if available_equipment_ids:
|
||||
equipment_update_data = {
|
||||
"equipment_id": available_equipment_ids[0],
|
||||
"name": "更新后的装备名称",
|
||||
"type": "巡飞弹",
|
||||
"manufacturer": "测试厂商",
|
||||
"length_m": 1.5,
|
||||
"width_m": 0.3,
|
||||
"height_m": 0.3,
|
||||
"weight_kg": 15.0,
|
||||
"wingspan_m": 0.8,
|
||||
"warhead_weight_kg": 5.0,
|
||||
"max_speed_ms": 60,
|
||||
"cruise_speed_kmh": 120,
|
||||
"endurance_min": 45,
|
||||
"max_range_km": 50,
|
||||
"warhead_type": "高爆战斗部",
|
||||
"launch_mode": "弹射起飞",
|
||||
"power_system": "涡轮发动机",
|
||||
"guidance_system": "GPS/INS组合导航"
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
f'{base_url}/data/{available_equipment_ids[0]}',
|
||||
json=equipment_update_data
|
||||
)
|
||||
print_response(response, "更新装备")
|
||||
|
||||
logger.info("所有测试完成")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
|
||||
40
tests/test_demo_routes.py
Normal file
40
tests/test_demo_routes.py
Normal file
@ -0,0 +1,40 @@
|
||||
from src import create_app
|
||||
|
||||
|
||||
def test_demo_algorithms_route_returns_available_models():
|
||||
app = create_app()
|
||||
client = app.test_client()
|
||||
|
||||
response = client.get("/api/demo/algorithms")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert any(item["key"] == "random_forest" for item in payload["algorithms"])
|
||||
|
||||
|
||||
def test_demo_dataset_route_returns_local_file_summary():
|
||||
app = create_app()
|
||||
client = app.test_client()
|
||||
|
||||
response = client.get("/api/demo/dataset")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["source"] == "local-file"
|
||||
assert payload["row_count"] >= 20
|
||||
|
||||
|
||||
def test_demo_run_route_returns_metrics_without_mysql():
|
||||
app = create_app()
|
||||
client = app.test_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/demo/run",
|
||||
json={"algorithms": ["linear", "random_forest"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["source"] == "local-file"
|
||||
assert set(payload["metrics"]) == {"linear", "random_forest"}
|
||||
assert payload["prediction_points"]
|
||||
49
tests/test_demo_service.py
Normal file
49
tests/test_demo_service.py
Normal file
@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.demo_service import DemoModelService
|
||||
|
||||
|
||||
def test_demo_service_loads_local_dataset():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
summary = service.get_dataset_summary()
|
||||
|
||||
assert summary["row_count"] >= 20
|
||||
assert "actual_cost" in summary["columns"]
|
||||
assert summary["target"] == "actual_cost"
|
||||
assert summary["preview"][0]["name"]
|
||||
assert summary["preview"][0]["type"] in {"巡飞弹", "火箭炮"}
|
||||
|
||||
|
||||
def test_demo_service_returns_chinese_algorithm_names_with_english_notes():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
algorithms = service.get_algorithms()
|
||||
|
||||
linear = next(item for item in algorithms if item["key"] == "linear")
|
||||
assert linear["name"] == "线性回归"
|
||||
assert linear["english_name"] == "Linear Regression"
|
||||
assert linear["family"] == "线性模型"
|
||||
|
||||
|
||||
def test_demo_service_runs_multiple_algorithms():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
result = service.run_demo(["linear", "random_forest", "gradient_boosting"])
|
||||
|
||||
assert result["source"] == "local-file"
|
||||
assert result["best_model"] in result["metrics"]
|
||||
assert len(result["metrics"]) == 3
|
||||
assert len(result["prediction_points"]) > 0
|
||||
assert len(result["sample_prediction"]["predictions"]) == 3
|
||||
for metrics in result["metrics"].values():
|
||||
assert {"r2", "mae", "rmse"}.issubset(metrics)
|
||||
|
||||
|
||||
def test_demo_service_ignores_unavailable_algorithms():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
result = service.run_demo(["linear", "does_not_exist"])
|
||||
|
||||
assert list(result["metrics"].keys()) == ["linear"]
|
||||
assert result["warnings"]
|
||||
Loading…
Reference in New Issue
Block a user