From dba9f2fcc902293476ef608d0f270a29039c73c4 Mon Sep 17 00:00:00 2001
From: Tian jianyong <11429339@qq.com>
Date: Mon, 25 Nov 2024 19:58:39 +0800
Subject: [PATCH] =?UTF-8?q?=E5=B0=86=20tensor=20=E6=94=B9=E4=B8=BA=20torch?=
=?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=9B=B4=E6=96=B0=E4=BE=9D=E8=B5=96=EF=BC=8C?=
=?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E7=94=9F=E4=BA=A7=E5=95=86=E7=9A=84?=
=?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=92=8C=E7=89=B9=E5=BE=81=E5=88=86=E6=9E=90?=
=?UTF-8?q?=E3=80=82?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.env.example | 25 -
.gitignore | 39 ++
.python-version | 1 +
LICENSE | 21 +
README.md | 54 ++
config.py | 129 +++--
frontend/src/views/AnalysisPage.vue | 269 +++++++---
pyproject.toml | 59 +++
requirements.txt | 13 +-
run.py | 40 +-
scripts/setup_env.ps1 | 121 +++++
scripts/setup_env.sh | 66 +++
src/__init__.py | 4 +-
src/app.py | 46 +-
src/cost_prediction.py | 271 ++--------
src/data_preparation.py | 109 ++--
src/feature_analysis.py | 84 +++-
src/import_data.py | 8 +-
src/init_data.sql | 319 ------------
src/loitering_munition_data.sql | 67 ++-
src/manufacturer_data.sql | 12 +-
src/model_trainer.py | 739 +++++-----------------------
src/real_data.sql | 485 ------------------
src/rocket_artillery_data.sql | 24 +-
src/routes.py | 415 +++++++++-------
src/schema.sql | 50 +-
26 files changed, 1378 insertions(+), 2092 deletions(-)
delete mode 100644 .env.example
create mode 100644 .python-version
create mode 100644 LICENSE
create mode 100644 pyproject.toml
create mode 100644 scripts/setup_env.ps1
create mode 100755 scripts/setup_env.sh
delete mode 100644 src/init_data.sql
delete mode 100644 src/real_data.sql
diff --git a/.env.example b/.env.example
deleted file mode 100644
index 49204df..0000000
--- a/.env.example
+++ /dev/null
@@ -1,25 +0,0 @@
-# 数据库配置
-MYSQL_HOST=localhost
-MYSQL_USER=root
-MYSQL_PASSWORD=your_password_here
-MYSQL_DATABASE=equipment_cost_db
-
-# 服务配置
-PORT=5001
-DEBUG=False
-
-# 日志配置
-LOG_LEVEL=INFO
-LOG_DIR=logs
-
-# 模型配置
-MODEL_DIR=models
-DATA_DIR=data
-
-# 安全配置
-SECRET_KEY=your_secret_key_here
-ALLOWED_HOSTS=localhost,127.0.0.1
-
-# 其他配置
-UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
-ALLOWED_FILE_TYPES=.xlsx,.xls
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 6c98c42..468c727 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,42 @@
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual Environment
+.env
+.venv
+env/
+venv/
+ENV/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS
.DS_Store
+Thumbs.db
+
node_modules
/dist
/models
@@ -9,6 +47,7 @@ node_modules
# local env files
.env.local
.env.*.local
+.venv
# Log files
npm-debug.log*
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000..b6d8b76
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.11.8
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..1def3fa
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+# MIT License
+
+Copyright (c) 2024 Your Name or Your Organization
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index b04ef07..1ecc6d2 100644
--- a/README.md
+++ b/README.md
@@ -21,3 +21,57 @@ collation-server=utf8mb4_unicode_ci
[client]
default-character-set=utf8mb4
```
+
+## 环境配置
+
+本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。
+
+### 使用脚本自动配置(推荐)
+
+Unix/macOS:
+
+```bash
+chmod +x scripts/setup_env.sh
+./scripts/setup_env.sh
+```
+
+Windows (PowerShell):
+
+```powershell
+Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
+.\scripts\setup_env.ps1
+```
+
+### 手动配置
+
+1. 安装 pyenv
+2. 安装 Python 3.11.8:
+
+ ```bash
+ pyenv install 3.11.8
+ ```
+
+3. 设置本地 Python 版本:
+
+ ```bash
+ pyenv local 3.11.8
+ ```
+
+4. 创建虚拟环境:
+
+ ```bash
+ python -m venv .venv
+ ```
+
+5. 激活虚拟环境:
+
+ ```bash
+ source .venv/bin/activate # Unix
+ .venv\Scripts\activate # Windows
+ ```
+
+6. 安装依赖:
+
+ ```bash
+ pip install -e ".[dev]"
+ ```
diff --git a/config.py b/config.py
index ec92cd1..fa09fcf 100644
--- a/config.py
+++ b/config.py
@@ -1,32 +1,103 @@
import os
-import secrets
-# 数据库配置
-DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
+class Config:
+ """配置类"""
+ # 数据库配置
+ MYSQL_HOST = 'localhost'
+ MYSQL_USER = 'root'
+ MYSQL_PASSWORD = '123456'
+ MYSQL_DB = 'equipment_cost_db'
+
+ # Flask配置
+ FLASK_HOST = '0.0.0.0'
+ FLASK_PORT = 5001
+ FLASK_DEBUG = True
+
+ # 目录配置
+ MODEL_DIR = 'models'
+ DATA_DIR = 'data'
+ LOG_DIR = 'logs'
+ UPLOAD_DIR = 'uploads'
+ TEMPLATE_DIR = 'templates'
+
+ # 文件上传配置
+ ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
+ MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
+
+ # API配置
+ API_VERSION = 'v1'
+ API_PREFIX = f'/api/{API_VERSION}'
+
+ # 日志配置
+ LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ LOG_LEVEL = 'INFO'
+ LOG_FILE = os.path.join(LOG_DIR, 'app.log')
+ LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
+ LOG_BACKUP_COUNT = 5
+
+ # PyTorch配置
+ DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
+ BATCH_SIZE = 32
+ LEARNING_RATE = 0.001
+ NUM_EPOCHS = 100
+
+ # 模型训练配置
+ TRAIN_TEST_SPLIT = 0.2
+ RANDOM_SEED = 42
+ EARLY_STOPPING_PATIENCE = 10
+ MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
+
+ # 缓存配置
+ CACHE_TYPE = 'simple'
+ CACHE_DEFAULT_TIMEOUT = 300
+
+ # 安全配置
+ SECRET_KEY = 'your-secret-key-here'
+ JWT_SECRET_KEY = 'your-jwt-secret-key-here'
+ JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
+
+ # 跨域配置
+ CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
+
+ # 数据验证配置
+ MAX_EQUIPMENT_NAME_LENGTH = 100
+ MAX_MANUFACTURER_NAME_LENGTH = 100
+
+ @classmethod
+ def init_app(cls, app):
+ """初始化应用配置"""
+ # 创建必要的目录
+ for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
+ cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
+ os.makedirs(directory, exist_ok=True)
+
+ # 配置日志
+ import logging
+ from logging.handlers import RotatingFileHandler
+
+ formatter = logging.Formatter(cls.LOG_FORMAT)
+ file_handler = RotatingFileHandler(
+ cls.LOG_FILE,
+ maxBytes=cls.LOG_MAX_SIZE,
+ backupCount=cls.LOG_BACKUP_COUNT
+ )
+ file_handler.setFormatter(formatter)
+ file_handler.setLevel(cls.LOG_LEVEL)
+
+ app.logger.addHandler(file_handler)
+ app.logger.setLevel(cls.LOG_LEVEL)
+
+ # 配置上传目录
+ app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
+ app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
+
+ # 配置跨域
+ from flask_cors import CORS
+ CORS(app, resources={
+ r"/api/*": {"origins": cls.CORS_ORIGINS}
+ })
+
+ return app
-# 安全密钥配置(自动生成随机密钥)
-SECRET_KEY = secrets.token_hex(16)
-
-# 环境配置
-DEBUG = False
-ENV = 'production'
-
-# 文件上传配置
-UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
-ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
-MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
-
-# API配置
-API_VERSION = 'v1'
-API_PREFIX = f'/api/{API_VERSION}'
-
-# 跨域配置
-CORS_ORIGINS = [
- "http://localhost:8080",
- "http://127.0.0.1:8080",
-]
-
-# 日志配置
-LOG_LEVEL = 'DEBUG'
-LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')
\ No newline at end of file
+# 创建配置实例
+config = Config()
\ No newline at end of file
diff --git a/frontend/src/views/AnalysisPage.vue b/frontend/src/views/AnalysisPage.vue
index 9a77831..ea7c201 100644
--- a/frontend/src/views/AnalysisPage.vue
+++ b/frontend/src/views/AnalysisPage.vue
@@ -103,7 +103,31 @@
+
+
+ 制导性能分析
+
+
+
+ 生产商分析
+
+
+
+ 生产商地区分布
+
+
+
+ 生产商综合评分
+
@@ -131,6 +155,10 @@ const newFeatureChartRef = ref(null)
const engineChartRef = ref(null)
const fireChartRef = ref(null)
const mobilityChartRef = ref(null)
+const manufacturerChartRef = ref(null)
+const regionChartRef = ref(null)
+const scoreChartRef = ref(null)
+const guidanceChartRef = ref(null)
// 图表实例引用
const importanceChart = ref(null)
@@ -139,6 +167,10 @@ const newFeatureChart = ref(null)
const engineChart = ref(null)
const fireChart = ref(null)
const mobilityChart = ref(null)
+const manufacturerChart = ref(null)
+const regionChart = ref(null)
+const scoreChart = ref(null)
+const guidanceChart = ref(null)
// 监听分析结果变化
watch(() => analysisResult.value, async (newResult) => {
@@ -236,69 +268,28 @@ const startAnalysis = async () => {
analyzing.value = true
try {
- // 打印请求参数
- console.log('Analysis request params:', {
- dataset_id: analysisForm.value.dataset_id,
- equipment_type: analysisForm.value.equipment_type
- })
-
- const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
+ // 调用特征分析接口
+ const featureResponse = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
- // 打印原始响应数据
- console.log('Raw API response:', response)
- console.log('Response data type:', typeof response.data)
- console.log('Response data:', response.data)
-
- // 检查响应数据的结构
- if (!response.data) {
- throw new Error('API返回的数据为空')
- }
-
- // 确保数据正确赋值
- analysisResult.value = response.data
-
- // 验证数赋值是否成功
- console.log('Analysis result after assignment:', {
- value: analysisResult.value,
- important_features: analysisResult.value?.important_features,
- correlation_analysis: analysisResult.value?.correlation_analysis,
- equipment_names: analysisResult.value?.equipment_names,
- length_width_ratio: analysisResult.value?.length_width_ratio
+ // 调用生产商分析接口
+ const manufacturerResponse = await axios.post(`${API_BASE_URL}/analyze-manufacturers`, {
+ dataset_id: analysisForm.value.dataset_id
})
-
- // 如果是巡飞弹类型,检查特定数据
- if (analysisForm.value.equipment_type === '巡飞弹') {
- const missileData = {
- equipment_names: analysisResult.value?.equipment_names || [],
- length_width_ratio: analysisResult.value?.length_width_ratio || [],
- engine_power_kw: analysisResult.value?.engine_power_kw || [],
- guidance_system_score: analysisResult.value?.guidance_system_score || [],
- warhead_power_score: analysisResult.value?.warhead_power_score || []
- }
-
- console.log('Missile specific data:', missileData)
-
- // 验证数据完整性
- const missingFields = Object.entries(missileData)
- .filter(([key, value]) => !Array.isArray(value) || value.length === 0)
- .map(([key]) => key)
-
- if (missingFields.length > 0) {
- console.warn('Missing or empty missile data fields:', missingFields)
- ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
- }
+
+ // 合并两个接口的结果
+ analysisResult.value = {
+ ...featureResponse.data,
+ ...manufacturerResponse.data
}
-
+
+ // 验证数据
+ console.log('Combined analysis result:', analysisResult.value)
+
} catch (error) {
console.error('Analysis error:', error)
- console.error('Error details:', {
- message: error.message,
- response: error.response?.data,
- status: error.response?.status
- })
- ElMessage.error(error.message || '特征析失败')
+ ElMessage.error(error.message || '分析失败')
} finally {
analyzing.value = false
}
@@ -329,6 +320,18 @@ const createResizeHandler = () => {
if (mobilityChart.value && !mobilityChart.value.isDisposed()) {
mobilityChart.value.resize()
}
+ if (manufacturerChart.value && !manufacturerChart.value.isDisposed()) {
+ manufacturerChart.value.resize()
+ }
+ if (regionChart.value && !regionChart.value.isDisposed()) {
+ regionChart.value.resize()
+ }
+ if (scoreChart.value && !scoreChart.value.isDisposed()) {
+ scoreChart.value.resize()
+ }
+ if (guidanceChart.value && !guidanceChart.value.isDisposed()) {
+ guidanceChart.value.resize()
+ }
} catch (error) {
console.error('Error in resize handler:', error)
}
@@ -360,7 +363,7 @@ onUnmounted(() => {
// 销毁所有图表实例
[importanceChart, correlationChart, newFeatureChart, engineChart,
- fireChart, mobilityChart].forEach(chart => {
+ fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
try {
chart.value.dispose()
@@ -384,7 +387,7 @@ const renderCharts = () => {
try {
// 先销毁所有现有的图表实例
[importanceChart, correlationChart, newFeatureChart, engineChart,
- fireChart, mobilityChart].forEach(chart => {
+ fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
chart.value.dispose()
chart.value = null
@@ -899,6 +902,156 @@ const renderCharts = () => {
mobilityChart.value.setOption(mobilityOption, { notMerge: true })
}
+ // 渲染生产商分析图表
+ if (manufacturerChartRef.value) {
+ manufacturerChart.value = echarts.init(manufacturerChartRef.value)
+ const manufacturerOption = {
+ title: { text: '生产商特征影响分析' },
+ tooltip: {
+ trigger: 'axis',
+ axisPointer: { type: 'shadow' }
+ },
+ legend: {
+ data: ['技术水平', '规模水平', '供应链水平', '综合得分']
+ },
+ xAxis: {
+ type: 'category',
+ data: analysisResult.value.manufacturer_names || []
+ },
+ yAxis: {
+ type: 'value',
+ name: '评分',
+ min: 0,
+ max: 10
+ },
+ series: [
+ {
+ name: '技术水平',
+ type: 'bar',
+ data: analysisResult.value.manufacturer_tech_levels || []
+ },
+ {
+ name: '规模水平',
+ type: 'bar',
+ data: analysisResult.value.manufacturer_scale_levels || []
+ },
+ {
+ name: '供应链水平',
+ type: 'bar',
+ data: analysisResult.value.manufacturer_supply_chain_levels || []
+ },
+ {
+ name: '综合得分',
+ type: 'line',
+ data: analysisResult.value.manufacturer_composite_scores || []
+ }
+ ]
+ }
+ manufacturerChart.value.setOption(manufacturerOption)
+ }
+
+ // 渲染地区分布图表
+ if (regionChartRef.value) {
+ regionChart.value = echarts.init(regionChartRef.value)
+ const regionOption = {
+ title: { text: '生产商地区分布' },
+ tooltip: {
+ trigger: 'item',
+ formatter: '{b}: {c} ({d}%)'
+ },
+ series: [
+ {
+ type: 'pie',
+ radius: '65%',
+ data: analysisResult.value.region_distribution || [],
+ emphasis: {
+ itemStyle: {
+ shadowBlur: 10,
+ shadowOffsetX: 0,
+ shadowColor: 'rgba(0, 0, 0, 0.5)'
+ }
+ }
+ }
+ ]
+ }
+ regionChart.value.setOption(regionOption)
+ }
+
+ // 渲染综合评分图表
+ if (scoreChartRef.value) {
+ scoreChart.value = echarts.init(scoreChartRef.value)
+ const scoreOption = {
+ title: { text: '生产商综合评分雷达图' },
+ tooltip: {},
+ radar: {
+ indicator: [
+ { name: '技术水平', max: 10 },
+ { name: '规模水平', max: 10 },
+ { name: '供应链水平', max: 10 },
+ { name: '区域系数', max: 1.5 },
+ { name: '综合得分', max: 10 }
+ ]
+ },
+ series: [
+ {
+ type: 'radar',
+ data: analysisResult.value.manufacturer_scores || []
+ }
+ ]
+ }
+ scoreChart.value.setOption(scoreOption)
+ }
+
+ // 渲染制导性能分析图表
+ if (guidanceChartRef.value && analysisForm.value.equipment_type === '巡飞弹') {
+ guidanceChart.value = echarts.init(guidanceChartRef.value)
+ const guidanceOption = {
+ title: { text: '制导性能分析' },
+ tooltip: {
+ trigger: 'axis',
+ axisPointer: { type: 'cross' }
+ },
+ legend: {
+ data: ['制导精度(m)', '数据链距离(km)', '制导系统评分']
+ },
+ xAxis: {
+ type: 'category',
+ data: analysisResult.value.equipment_names || []
+ },
+ yAxis: [
+ {
+ type: 'value',
+ name: '制导精度(m)',
+ position: 'left'
+ },
+ {
+ type: 'value',
+ name: '距离(km)',
+ position: 'right'
+ }
+ ],
+ series: [
+ {
+ name: '制导精度(m)',
+ type: 'bar',
+ data: analysisResult.value.guidance_accuracy_m || []
+ },
+ {
+ name: '数据链距离(km)',
+ type: 'line',
+ yAxisIndex: 1,
+ data: analysisResult.value.datalink_range_km || []
+ },
+ {
+ name: '制导系统评分',
+ type: 'line',
+ data: analysisResult.value.guidance_system_score || []
+ }
+ ]
+ }
+ guidanceChart.value.setOption(guidanceOption)
+ }
+
console.log('Charts rendered successfully')
} catch (error) {
console.error('Error in chart rendering:', error)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..7e6e325
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,59 @@
+[project]
+name = "cost-prediction"
+version = "0.1.0"
+description = "装备成本预测系统"
+requires-python = ">=3.9,<3.12"
+readme = "README.md"
+license = {file = "LICENSE"}
+
+dependencies = [
+ # Web框架
+ "flask>=3.1.0",
+ "flask-cors>=5.0.0",
+
+ # 数据库
+ "sqlalchemy>=2.0.36",
+ "pymysql>=1.1.1",
+ "cryptography>=43.0.0",
+ "mysql-connector-python>=8.0.0",
+
+ # 数据处理
+ "numpy>=1.26.0,<2.0.0",
+ "pandas>=2.2.0",
+
+ # 机器学习
+ "scikit-learn>=1.5.2",
+ "torch==2.5.1",
+ "torchvision==0.20.1",
+ "torchaudio==2.5.1",
+
+ # 工具
+ "openpyxl>=3.1.5", # Excel支持
+ "python-dotenv>=1.0.0", # 环境变量
+]
+
+[project.optional-dependencies]
+dev = [
+ # 测试工具
+ "pytest>=7.0",
+ "black>=22.0", # 代码格式化
+ "mypy>=1.0", # 类型检查
+]
+
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+python_files = ["test_*.py"]
+
+[tool.black]
+line-length = 88
+target-version = ["py39", "py310", "py311"]
+
+[tool.mypy]
+python_version = "3.11"
+warn_return_any = true
+warn_unused_configs = true
+
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index a3b4ab7..f168f07 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,12 +3,11 @@ flask-cors>=5.0.0
sqlalchemy>=2.0.36
pymysql>=1.1.1
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
-numpy>=2.0.2
-pandas>=2.2.3
-
-urllib3>=2.2.3
-openpyxl>=3.1.5 # 用于读取 .xlsx 文件
-xlrd>=2.0.1 # 用于读取 .xls 文件
+mysql-connector-python>=8.0.0 # 添加这行
+numpy>=1.26.0,<2.0.0
+pandas>=2.2.0
scikit-learn>=1.5.2
-tensorflow>=2.18.0
\ No newline at end of file
+
+openpyxl>=3.1.5 # 用于读取 .xlsx 文件
+python-dotenv>=1.0.0 # 环境变量
\ No newline at end of file
diff --git a/run.py b/run.py
index 5bca7f6..4a56ece 100644
--- a/run.py
+++ b/run.py
@@ -1,13 +1,33 @@
-from src.app import create_app
-import logging
+from src import create_app
+from src.logger import setup_logger
+from config import config
+import os
-# 创建应用实例
-app = create_app()
+logger = setup_logger(__name__)
+
+def main():
+ try:
+ # 创建必要的目录
+ os.makedirs(config.MODEL_DIR, exist_ok=True)
+ os.makedirs(config.LOG_DIR, exist_ok=True)
+ os.makedirs(config.DATA_DIR, exist_ok=True)
+
+ # 创建并运行应用
+ app = create_app()
+
+ logger.info(f"Starting server in {'debug' if config.FLASK_DEBUG else 'production'} mode")
+ logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}")
+
+ app.run(
+ host=config.FLASK_HOST,
+ port=config.FLASK_PORT,
+ debug=config.FLASK_DEBUG
+ )
+
+ except Exception as e:
+ logger.error(f"Error starting application: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ raise
if __name__ == '__main__':
- # 设置日志
- logging.basicConfig(level=logging.INFO)
- logging.info('=== Server Starting ===')
- logging.info('Initializing directories...')
-
- app.run(host='0.0.0.0', port=5001, debug=True)
\ No newline at end of file
+ main()
\ No newline at end of file
diff --git a/scripts/setup_env.ps1 b/scripts/setup_env.ps1
new file mode 100644
index 0000000..2807d0f
--- /dev/null
+++ b/scripts/setup_env.ps1
@@ -0,0 +1,121 @@
+# 设置错误操作首选项
+$ErrorActionPreference = "Stop"
+
+# 检查管理员权限
+$isAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
+if (-not $isAdmin) {
+ Write-Warning "建议使用管理员权限运行此脚本"
+ Start-Sleep -Seconds 3
+}
+
+# 检查 pyenv-win 是否安装
+if (!(Get-Command pyenv -ErrorAction SilentlyContinue)) {
+ Write-Host "pyenv not found. Installing..."
+ try {
+ # 下载并安装 pyenv-win
+ Invoke-WebRequest -UseBasicParsing -Uri "https://raw.githubusercontent.com/pyenv-win/pyenv-win/master/pyenv-win/install-pyenv-win.ps1" -OutFile "./install-pyenv-win.ps1"
+ & ./install-pyenv-win.ps1
+
+ # 添加环境变量
+ $env:PYENV = "$env:USERPROFILE\.pyenv\pyenv-win"
+ $env:Path = "$env:PYENV\bin;$env:PYENV\shims;$env:Path"
+
+ # 刷新环境变量
+ $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
+ }
+ catch {
+ Write-Error "Failed to install pyenv: $_"
+ exit 1
+ }
+}
+
+try {
+ # 安装指定版本的 Python
+ Write-Host "Installing Python 3.11.8..."
+ pyenv install 3.11.8
+ if ($LASTEXITCODE -ne 0) {
+ throw "Failed to install Python 3.11.8"
+ }
+
+ # 设置本地 Python 版本
+ Write-Host "Setting local Python version..."
+ pyenv local 3.11.8
+ if ($LASTEXITCODE -ne 0) {
+ throw "Failed to set local Python version"
+ }
+
+ # 验证 Python 版本
+ $pythonVersion = python -V
+ if (-not $pythonVersion.Contains("3.11.8")) {
+ throw "Wrong Python version: $pythonVersion"
+ }
+ Write-Host "Using Python version: $pythonVersion"
+
+ # 创建虚拟环境
+ Write-Host "Creating virtual environment..."
+ python -m venv .venv
+
+ # 激活虚拟环境
+ Write-Host "Activating virtual environment..."
+ .\.venv\Scripts\Activate.ps1
+
+ # 升级 pip 和构建工具
+ Write-Host "Upgrading pip and build tools..."
+ python -m pip install --upgrade pip setuptools wheel
+
+ # 分步安装依赖以确保正确的顺序和版本
+ Write-Host "Installing database dependencies..."
+ pip install mysql-connector-python==8.0.33
+
+ Write-Host "Installing PyTorch and related packages..."
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
+
+ Write-Host "Installing basic dependencies..."
+ pip install numpy==1.26.4 pandas==2.2.1
+
+ Write-Host "Installing machine learning packages..."
+ pip install scikit-learn==1.5.2
+
+ # 安装开发依赖
+ Write-Host "Installing development dependencies..."
+ pip install -e ".[dev]"
+ if ($LASTEXITCODE -ne 0) {
+ Write-Warning "Failed to install development dependencies. Installing core package..."
+ pip install -e .
+ }
+
+ # 验证安装
+ Write-Host "Verifying installations..."
+ python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
+ python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"
+ python -c "import pandas; print(f'Pandas version: {pandas.__version__}')"
+ python -c "import sklearn; print(f'Scikit-learn version: {sklearn.__version__}')"
+
+ Write-Host "Environment setup complete!" -ForegroundColor Green
+}
+catch {
+ Write-Error "An error occurred: $_"
+ exit 1
+}
+finally {
+ # 清理临时文件
+ if (Test-Path "./install-pyenv-win.ps1") {
+ Remove-Item "./install-pyenv-win.ps1"
+ }
+}
+
+# 显示使用说明
+Write-Host @"
+
+环境设置完成!使用说明:
+1. 虚拟环境已激活,命令提示符前应该显示 (.venv)
+2. 要退出虚拟环境,运行: deactivate
+3. 要重新激活虚拟环境,运行: .\.venv\Scripts\Activate.ps1
+4. 项目依赖已安装,可以开始开发了
+
+如果遇到问题,请检查:
+- Python 版本: python -V
+- PyTorch 安装: python -c "import torch; print(torch.__version__)"
+- 虚拟环境状态: 确保看到 (.venv) 前缀
+
+"@ -ForegroundColor Cyan
\ No newline at end of file
diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh
new file mode 100755
index 0000000..159a2db
--- /dev/null
+++ b/scripts/setup_env.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# 检查 pyenv 是否安装
+if ! command -v pyenv &> /dev/null; then
+ echo "pyenv not found. Installing..."
+ if [[ "$OSTYPE" == "darwin"* ]]; then
+ brew install pyenv
+ else
+ curl https://pyenv.run | bash
+ fi
+fi
+
+# 安装指定版本的 Python
+pyenv install 3.11.8 || true
+
+# 设置本地 Python 版本
+pyenv local 3.11.8
+
+# 确保使用正确的 Python 版本
+eval "$(pyenv init -)"
+pyenv shell 3.11.8
+
+# 验证 Python 版本
+python_version=$(python -V 2>&1)
+if [[ $python_version != *"3.11.8"* ]]; then
+ echo "Error: Wrong Python version: $python_version"
+ echo "Please ensure pyenv is properly configured in your shell"
+ exit 1
+fi
+
+# 创建虚拟环境
+python -m venv .venv
+
+# 激活虚拟环境
+source .venv/bin/activate
+
+# 升级 pip 和构建工具
+python -m pip install --upgrade pip setuptools wheel
+
+# 分步安装依赖以确保正确的顺序和版本
+echo "Installing database dependencies..."
+pip install mysql-connector-python==8.0.33
+
+echo "Installing PyTorch and related packages..."
+pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
+
+echo "Installing basic dependencies..."
+pip install numpy==1.26.4 pandas==2.2.1
+
+echo "Installing machine learning packages..."
+pip install scikit-learn==1.5.2
+
+# 安装开发依赖
+if ! pip install -e ".[dev]"; then
+ echo "Warning: Failed to install development dependencies. Installing core package..."
+ pip install -e .
+fi
+
+# 验证安装
+echo "Verifying Python version..."
+python --version
+
+echo "Verifying PyTorch installation..."
+python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
+
+echo "Environment setup complete!"
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
index 497b4a4..4ea24f5 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -1 +1,3 @@
-# 这个文件可以为空,但必须存在
+from .app import create_app
+
+__all__ = ['create_app']
diff --git a/src/app.py b/src/app.py
index 037fb79..e9fcbea 100644
--- a/src/app.py
+++ b/src/app.py
@@ -2,49 +2,35 @@ from flask import Flask
from flask_cors import CORS
from .routes import api_bp
from .logger import setup_logger
+from config import config
import os
-# 获取logger
logger = setup_logger(__name__)
def create_app():
- """
- 创建并配置Flask应用
- """
+ """创建并配置 Flask 应用"""
try:
- # 创建必要的目录
- os.makedirs('logs', exist_ok=True)
- os.makedirs('data', exist_ok=True)
- os.makedirs('models', exist_ok=True)
-
- logger.info("=== Server Starting ===")
- logger.info("Initializing directories...")
-
- # 创建Flask应用
app = Flask(__name__)
-
- # 配置CORS
CORS(app)
- logger.info("CORS enabled")
- # 注册API蓝图
+ # 配置数据库连接
+ app.config['MYSQL_HOST'] = config.MYSQL_HOST
+ app.config['MYSQL_USER'] = config.MYSQL_USER
+ app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
+ app.config['MYSQL_DB'] = config.MYSQL_DB
+
+ # 注册路由
app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered")
- # 配置数据库连接
- app.config['MYSQL_HOST'] = 'localhost'
- app.config['MYSQL_USER'] = 'root'
- app.config['MYSQL_PASSWORD'] = '123456'
- app.config['MYSQL_DB'] = 'equipment_cost_db'
-
- logger.info("Starting server...")
+ # 记录配置信息
+ logger.info(f"Database: {app.config['MYSQL_DB']} on {app.config['MYSQL_HOST']}")
+ logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}")
+ logger.info(f"Debug mode: {config.FLASK_DEBUG}")
return app
except Exception as e:
- logger.error(f"Error creating app: {str(e)}")
- raise
-
-if __name__ == '__main__':
- app = create_app()
- app.run(host='localhost', port=5001)
\ No newline at end of file
+ logger.error(f"Error creating application: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ raise
\ No newline at end of file
diff --git a/src/cost_prediction.py b/src/cost_prediction.py
index 2c68e26..c4f00fc 100644
--- a/src/cost_prediction.py
+++ b/src/cost_prediction.py
@@ -1,15 +1,12 @@
import numpy as np
+import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
-import tensorflow as tf
from scipy import stats
-import joblib
-import os
-import pandas as pd
-from .feature_analysis import FeatureAnalysis
import logging
from src.model_trainer import ModelTrainer
from src.database import get_db_connection
+from src.feature_analysis import FeatureAnalysis
from .logger import setup_logger
logger = setup_logger(__name__)
@@ -21,37 +18,18 @@ class CostPredictor:
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
-
- # 添加 TensorFlow 配置
- tf.config.run_functions_eagerly(False) # 启用图执行模式
-
- # 创建预测函数
- @tf.function(reduce_retracing=True, jit_compile=True)
- def predict_fn(x):
- return self.model(x, training=False)
-
- self._predict_fn = predict_fn
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.load_model()
def load_model(self):
"""
- 加载预训练型和标准化器
+ 加载预训练模型和标准化器
"""
try:
- model_dir = 'models'
- os.makedirs(model_dir, exist_ok=True)
-
# 创建默认模型
self._create_default_model()
- # 创建预测函数
- @tf.function(reduce_retracing=True, jit_compile=True)
- def predict_fn(x):
- return self.model(x, training=False)
-
- self._predict_fn = predict_fn
-
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
@@ -60,28 +38,24 @@ class CostPredictor:
"""
创建默认模型并进行初始化训练
"""
- # 创建输入层
- inputs = tf.keras.Input(shape=(11,))
+ import torch.nn as nn
- # 创建隐藏层
- x = tf.keras.layers.Dense(64, activation='relu')(inputs)
- x = tf.keras.layers.Dense(32, activation='relu')(x)
-
- # 创建输出层
- outputs = tf.keras.layers.Dense(1)(x)
-
- # 创建模型
- self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
-
- # 编译模型
- self.model.compile(
- optimizer='adam',
- loss=tf.keras.losses.mean_squared_error,
- metrics=[tf.keras.metrics.mean_absolute_error]
- )
+ class DefaultModel(nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(input_size, 64),
+ nn.ReLU(),
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, 1)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
# 创建示例数据
- example_data = pd.DataFrame({
+ example_features = {
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
@@ -93,59 +67,23 @@ class CostPredictor:
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
- })
+ }
+
+ # 转换为 tensor
+ X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
+ y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
# 训练标准化器
- self.scaler_X.fit(example_data)
- self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
+ self.scaler_X.fit(X.numpy())
+ self.scaler_y.fit(y.numpy())
- # 设置默认装备类型
- self.equipment_type = '火箭炮'
-
- def _create_example_data(self):
- """
- 创建示例数据来训练标准化器
- """
- # 火箭炮示例数据
- rocket_data = pd.DataFrame({
- 'length_m': [7.35, 10.2],
- 'width_m': [2.4, 2.8],
- 'height_m': [3.1, 3.2],
- 'weight_kg': [13700, 28500],
- 'max_range_km': [20.4, 70],
- 'firing_angle_horizontal': [102, 110],
- 'firing_angle_vertical': [55, 60],
- 'rocket_length_m': [2.87, 4.1],
- 'rocket_diameter_mm': [122, 220],
- 'rocket_weight_kg': [66.6, 150],
- 'rate_of_fire': [40, 60]
- })
-
- # 巡飞弹示例数据
- missile_data = pd.DataFrame({
- 'length_m': [1.3, 2.5],
- 'width_m': [0.23, 0.6],
- 'height_m': [0.23, 0.6],
- 'weight_kg': [12.5, 135],
- 'max_range_km': [40, 250],
- 'max_speed_kmh': [180, 185],
- 'cruise_speed_kmh': [100, 110],
- 'flight_time_min': [60, 120],
- 'folded_length_mm': [1300, 2500],
- 'folded_width_mm': [230, 600],
- 'folded_height_mm': [230, 600]
- })
-
- # 训练标准化器
- self.scaler_X.fit(rocket_data) # 使用火箭炮数据
- self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
-
- # 设置默认装备类型
+ # 创建模型
+ self.model = DefaultModel(X.shape[1]).to(self.device)
self.equipment_type = '火箭炮'
def predict(self, data):
"""
- 使用训练好的最优模型进行预测
+ 使用训练好的模型进行预测
"""
try:
logger.info(f"Starting prediction for {data.get('type')}")
@@ -158,20 +96,31 @@ class CostPredictor:
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- X = np.array([[data.get(feature) for feature in features]])
+ X = []
+ for feature in features:
+ value = data.get(feature, 0.0)
+ X.append(float(value))
+
+ # 转换为 tensor
+ X = torch.tensor([X], dtype=torch.float32).to(self.device)
# 预测
- y_pred = trainer.predict(X)
+ with torch.no_grad():
+ trainer.model.eval() # 设置为评估模式
+ y_pred = trainer.model(X)
+
+ # 转回 numpy
+ y_pred = y_pred.cpu().numpy()
# 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
+ confidence_interval = self._calculate_confidence_interval(y_pred[0])
# 获取模型类型
model_type = trainer.get_model_type()
return {
'predicted_cost': float(y_pred[0]),
- 'model_type': model_type, # 返回使用的模型类型
+ 'model_type': model_type,
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
@@ -187,11 +136,10 @@ class CostPredictor:
计算预测值的置信区间
"""
try:
- # 使用预测值的20%作为标准差(增加不确定性)
+ # 使用预测值的20%作为标准差
std = abs(prediction) * 0.2
# 计算置信区间
- from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
@@ -213,130 +161,15 @@ class CostPredictor:
"""
模型评估
"""
+ # 确保输入是 numpy 数组
+ if torch.is_tensor(y_true):
+ y_true = y_true.cpu().numpy()
+ if torch.is_tensor(y_pred):
+ y_pred = y_pred.cpu().numpy()
+
return {
'mae': float(mean_absolute_error(y_true, y_pred)),
'mse': float(mean_squared_error(y_true, y_pred)),
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
'r2': float(r2_score(y_true, y_pred))
- }
-
- def predict_pls(self, data):
- """
- 使用 PLS 型预测成本
- """
- try:
- logger.info(f"Starting PLS prediction for {data.get('type')}")
- equipment_type = data.get('type')
-
- # 加载 PLS 模型
- trainer = ModelTrainer()
- if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
- raise ValueError(f"No trained PLS model found for {equipment_type}")
-
- # 准备特征数据
- features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- X = np.array([[data.get(feature) for feature in features]])
-
- # 预测
- y_pred = trainer.predict(X)
-
- # 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
-
- return {
- 'predicted_cost': float(y_pred[0]),
- 'confidence_interval': {
- 'lower': float(confidence_interval[0]),
- 'upper': float(confidence_interval[1])
- }
- }
-
- except Exception as e:
- logger.error(f"PLS prediction error: {str(e)}")
- raise
-
- def predict_all(self, data):
- """
- 使用所有可用模型进行预测
- """
- try:
- logger.info(f"Starting multi-model prediction for {data.get('type')}")
- equipment_type = data.get('type')
- results = {}
-
- # 1. 获取所有激活的模型
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT id, model_type, model_name, r2_score, mae, rmse
- FROM trained_models
- WHERE equipment_type = %s AND is_active = TRUE
- """, (equipment_type,))
- active_models = cursor.fetchall()
-
- if not active_models:
- raise ValueError(f"No active models found for {equipment_type}")
-
- # 2. 使用每个模型进行预测
- trainer = ModelTrainer()
- for model_info in active_models:
- try:
- # 加载特定模型
- if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
- logger.warning(f"Failed to load model: {model_info['model_name']}")
- continue
-
- # 准备特征数据
- features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- X = np.array([[data.get(feature) for feature in features]])
-
- # 预测
- y_pred = trainer.predict(X)
-
- # 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
-
- # 保存结果
- results[model_info['model_type']] = {
- 'predicted_cost': float(y_pred[0]),
- 'model_info': {
- 'name': model_info['model_name'],
- 'type': model_info['model_type'],
- 'r2_score': float(model_info['r2_score']),
- 'mae': float(model_info['mae']),
- 'rmse': float(model_info['rmse'])
- },
- 'confidence_interval': {
- 'lower': float(confidence_interval[0]),
- 'upper': float(confidence_interval[1])
- }
- }
-
- except Exception as e:
- logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
- continue
-
- if not results:
- raise ValueError("No successful predictions from any model")
-
- # 3. 计算综合预测结果
- all_predictions = [result['predicted_cost'] for result in results.values()]
- ensemble_prediction = float(np.mean(all_predictions))
- prediction_std = float(np.std(all_predictions))
-
- # 4. 返回所有结果
- return {
- 'individual_predictions': results,
- 'ensemble_prediction': {
- 'predicted_cost': ensemble_prediction,
- 'standard_deviation': prediction_std,
- 'confidence_interval': {
- 'lower': float(ensemble_prediction - 1.96 * prediction_std),
- 'upper': float(ensemble_prediction + 1.96 * prediction_std)
- }
- }
- }
-
- except Exception as e:
- logger.error(f"Error in multi-model prediction: {str(e)}")
- raise
\ No newline at end of file
+ }
\ No newline at end of file
diff --git a/src/data_preparation.py b/src/data_preparation.py
index 6380647..8b83330 100644
--- a/src/data_preparation.py
+++ b/src/data_preparation.py
@@ -1,29 +1,35 @@
from sklearn.preprocessing import StandardScaler
-from datetime import datetime
-import os
-import joblib
-import pandas as pd
import numpy as np
-from src.feature_analysis import FeatureAnalysis
-from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
-from xgboost import XGBRegressor
-from lightgbm import LGBMRegressor
-from sklearn.model_selection import cross_val_score, LeaveOneOut
-import json
+import torch
+from torch.utils.data import Dataset, DataLoader
import logging
-from src.database.db_connection import get_db_connection
-from sklearn.metrics import mean_absolute_error, mean_squared_error
+from src.feature_analysis import FeatureAnalysis
+from src.database import get_db_connection
from .logger import setup_logger
logger = setup_logger(__name__)
+class EquipmentDataset(Dataset):
+ """装备数据集类"""
+ def __init__(self, features, targets=None):
+ self.features = torch.FloatTensor(features)
+ self.targets = torch.FloatTensor(targets) if targets is not None else None
+
+ def __len__(self):
+ return len(self.features)
+
+ def __getitem__(self, idx):
+ if self.targets is not None:
+ return self.features[idx], self.targets[idx]
+ return self.features[idx]
+
class DataPreparation:
def __init__(self):
self.feature_analyzer = FeatureAnalysis()
self.feature_scaler = StandardScaler()
- self.target_scaler = StandardScaler() # 添加目标值标准化器
+ self.target_scaler = StandardScaler()
- def prepare_training_data(self, equipment_data, equipment_type):
+ def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
"""
准备训练数据
"""
@@ -31,19 +37,24 @@ class DataPreparation:
logger.info(f"Preparing training data for {equipment_type}")
logger.info(f"Raw data size: {len(equipment_data)}")
- # 如果输入已经是 numpy 数组,直接返回
+ # 如果输入已经是 numpy 数组,转换为 torch.Tensor
if isinstance(equipment_data, np.ndarray):
X = equipment_data
- logger.info(f"Input is already numpy array with shape: {X.shape}")
+ logger.info(f"Input is numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
+ # 转换为 PyTorch 数据集
+ dataset = EquipmentDataset(X)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
return {
- 'X': X,
+ 'dataloader': dataloader,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
- 'target_scaler': self.target_scaler
+ 'target_scaler': self.target_scaler,
+ 'raw_shape': X.shape
}
# 从原始数据中提取特征和目标值
@@ -51,27 +62,28 @@ class DataPreparation:
features = []
targets = []
- for item in equipment_data:
- # 提取特征值
- feature_values = []
- for name in feature_names:
- value = item.get(name)
- try:
- feature_values.append(float(value) if value is not None else 0.0)
- except (ValueError, TypeError):
- feature_values.append(0.0)
- features.append(feature_values)
+ # 获取数据库连接
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
- # 提取目标值(成本)
- try:
- cost = float(item['actual_cost'])
- if cost > 0: # 只使用正数成本值
- targets.append(cost)
- else:
- logger.warning(f"Skipping non-positive cost value: {cost}")
- except (ValueError, TypeError, KeyError):
- logger.error(f"Invalid cost value: {item.get('actual_cost')}")
- continue
+ for item in equipment_data:
+ # 获取该装备的生产商数据
+ manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor)
+
+ # 计算生产商特征
+ manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data)
+
+ # 合并装备特征和生产商特征
+ feature_values = []
+ for name in feature_names:
+ if name in manufacturer_features:
+ value = manufacturer_features[name]
+ else:
+ value = item.get(name)
+ feature_values.append(float(value) if value is not None else 0.0)
+
+ features.append(feature_values)
+ targets.append(float(item['actual_cost']))
# 转换为numpy数组
X = np.array(features, dtype=float)
@@ -85,25 +97,16 @@ class DataPreparation:
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
- # 记录标准化后的数据范围
- logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
- logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
-
- # 记录标准化器参数
- logger.info("Feature scaler params:")
- logger.info(f"Mean: {self.feature_scaler.mean_}")
- logger.info(f"Scale: {self.feature_scaler.scale_}")
-
- logger.info("Target scaler params:")
- logger.info(f"Mean: {self.target_scaler.mean_}")
- logger.info(f"Scale: {self.target_scaler.scale_}")
+ # 创建 PyTorch 数据集和数据加载器
+ dataset = EquipmentDataset(X_scaled, y_scaled)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return {
- 'X': X_scaled,
- 'y': y_scaled,
+ 'dataloader': dataloader,
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
- 'target_scaler': self.target_scaler
+ 'target_scaler': self.target_scaler,
+ 'raw_shape': X.shape
}
except Exception as e:
diff --git a/src/feature_analysis.py b/src/feature_analysis.py
index 1b7f8e9..4a45ea7 100644
--- a/src/feature_analysis.py
+++ b/src/feature_analysis.py
@@ -15,9 +15,9 @@ class FeatureAnalysis:
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
- 'max_range_km': '最大射程(km)',
# 火箭炮特有参数
+ 'max_range_km': '最大射程(km)',
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
@@ -39,6 +39,7 @@ class FeatureAnalysis:
'terrain_adaptability_score': '地形适应性评分',
# 巡飞弹特有参数
+ 'max_range_km': '最大射程(km)',
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
@@ -57,7 +58,14 @@ class FeatureAnalysis:
'weight_range_ratio': '重量射程比',
'speed_weight_ratio': '速度重量比',
'guidance_system_score': '制导系统评分',
- 'warhead_power_score': '战斗部威力评分'
+ 'warhead_power_score': '战斗部威力评分',
+
+ # 添加生产商特征映射
+ 'manufacturer_tech_level': '生产商技术水平',
+ 'manufacturer_scale_level': '生产商规模水平',
+ 'manufacturer_supply_chain_level': '生产商供应链水平',
+ 'manufacturer_composite_score': '生产商综合得分',
+ 'manufacturer_region_factor': '生产商区域系数'
}
def get_equipment_specific_features(self, equipment_type):
@@ -121,6 +129,17 @@ class FeatureAnalysis:
'guidance_system_score',
'warhead_power_score'
])
+
+ # 添加生产商特征
+ manufacturer_features = [
+ 'manufacturer_tech_level',
+ 'manufacturer_scale_level',
+ 'manufacturer_supply_chain_level',
+ 'manufacturer_composite_score',
+ 'manufacturer_region_factor'
+ ]
+
+ numeric_features.extend(manufacturer_features)
return numeric_features
def analyze_features(self, features, target, feature_names):
@@ -234,4 +253,63 @@ class FeatureAnalysis:
except Exception as e:
logger.error(f"Error in analyze_features: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
- raise
\ No newline at end of file
+ raise
+
+ def calculate_manufacturer_features(self, manufacturer_data):
+ """计算生产商相关的特征"""
+ try:
+ # 确保所有必要的字段都存在,使用默认值处理缺失数据
+ tech_level = float(manufacturer_data.get('tech_level', 0))
+ scale_level = float(manufacturer_data.get('scale_level', 0))
+ supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0))
+ country = manufacturer_data.get('country', '未知')
+
+ # 计算综合得分
+ composite_score = (
+ tech_level * 0.4 + # 技术水平权重最高
+ scale_level * 0.3 + # 规模水平次之
+ supply_chain_level * 0.3 # 供应链水平
+ )
+
+ # 计算区域系数(基于不同地区的成本差异)
+ region_factors = {
+ '美国': 1.2,
+ '英国': 1.15,
+ '德国': 1.15,
+ '法国': 1.15,
+ '以色列': 1.1,
+ '中国': 0.8,
+ '俄罗斯': 0.85,
+ '韩国': 0.9,
+ '日本': 1.1
+ }
+
+ region_factor = region_factors.get(country, 1.0)
+
+ # 记录计算过程
+ logger.info(f"Manufacturer features calculation:")
+ logger.info(f"Tech level: {tech_level}")
+ logger.info(f"Scale level: {scale_level}")
+ logger.info(f"Supply chain level: {supply_chain_level}")
+ logger.info(f"Country: {country}")
+ logger.info(f"Composite score: {composite_score}")
+ logger.info(f"Region factor: {region_factor}")
+
+ return {
+ 'manufacturer_tech_level': tech_level,
+ 'manufacturer_scale_level': scale_level,
+ 'manufacturer_supply_chain_level': supply_chain_level,
+ 'manufacturer_composite_score': composite_score,
+ 'manufacturer_region_factor': region_factor
+ }
+
+ except Exception as e:
+ logger.error(f"Error calculating manufacturer features: {str(e)}")
+ # 返回默认值而不是抛出异常,确保分析过程可以继续
+ return {
+ 'manufacturer_tech_level': 0,
+ 'manufacturer_scale_level': 0,
+ 'manufacturer_supply_chain_level': 0,
+ 'manufacturer_composite_score': 0,
+ 'manufacturer_region_factor': 1.0
+ }
\ No newline at end of file
diff --git a/src/import_data.py b/src/import_data.py
index 03dfa36..57286c1 100644
--- a/src/import_data.py
+++ b/src/import_data.py
@@ -26,7 +26,7 @@ def import_training_data(excel_file):
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
- SELECT id FROM equipment
+ SELECT id FROM equipments
WHERE name = %s AND type = '火箭炮'
""", (row['名称'],))
@@ -37,7 +37,7 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
- INSERT INTO equipment (name, type, manufacturer)
+ INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (row['名称'], '火箭炮', row['制造商']))
@@ -116,7 +116,7 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
- INSERT INTO equipment (name, type, manufacturer)
+ INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (
row['名称'],
@@ -192,7 +192,7 @@ def import_training_data(excel_file):
logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
- SELECT id FROM equipment WHERE name = %s
+ SELECT id FROM equipments WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()
diff --git a/src/init_data.sql b/src/init_data.sql
deleted file mode 100644
index ee99db3..0000000
--- a/src/init_data.sql
+++ /dev/null
@@ -1,319 +0,0 @@
-/*
-这是用于开发和测试环境的示例数据。
-生产环境请使用系统的数据导入功能添加实际数据。
-
-主要用途:
-1. 提供开发测试数据
-2. 作为数据格式参考
-3. 用于系统功能验证
-*/
-
--- 插入装备基本信息
-INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
-('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
-('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-
--- 插入巡飞弹技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_speed_kmh,
- cruise_speed_kmh,
- max_range_km,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm
-) VALUES (
- 1, -- 终结者巡飞弹
- 0.56,
- 0.15,
- 0.20,
- 2.72,
- 160.93,
- 96.56,
- 24,
- 15,
- '破片杀伤战斗部',
- '凭自身动力起飞',
- 560,
- 150,
- 200
-);
-
--- 插入火箭炮技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_range_km
-) VALUES (
- 2, -- 胜利-2火箭炮
- 10,
- 2.5,
- 3.34,
- 15000,
- 23
-);
-
--- 插入成本数据(示例数据)
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(1, 1000000), -- 终结者巡飞弹成本
-(2, 5000000); -- 胜利-2火箭炮成本
-
--- 插入更多巡飞弹变体数据用于训练
-INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
-('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
-('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
-('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-
--- 插入变体技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_speed_kmh,
- cruise_speed_kmh,
- max_range_km,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm
-) VALUES
--- 终结者-A(稍大型号)
-(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
--- 终结者-B(稍小型号)
-(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
--- 终结者-C(标准型号的改进版)
-(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-
--- 插入变体成本数据
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(3, 1100000), -- 终结者-A成本(较高)
-(4, 900000), -- 终结者-B成本(较低)
-(5, 1050000); -- 终结者-C成本(中等)
-
--- 添加更多巡飞弹数据
-INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
-('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
-('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
-('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
-('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
-('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-
--- 添加它们的技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_speed_kmh,
- cruise_speed_kmh,
- max_range_km,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm
-) VALUES
--- 哈比
-(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
--- 游荡者
-(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
--- 凤凰
-(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
--- 弹簧刀
-(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
--- 彩虹-4
-(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-
--- 添加成本数据
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(6, 800000), -- 哈比
-(7, 500000), -- 游荡者
-(8, 450000), -- 凤凰
-(9, 480000), -- 弹簧刀
-(10, 1500000); -- 彩虹-4
-
--- 火箭炮数据
-INSERT INTO equipment (name, type, manufacturer) VALUES
-('BM-21', '火箭炮', '俄罗斯'),
-('SR5', '火箭炮', '中国'),
-('HIMARS', '火箭炮', '美国'),
-('LAR-160', '火箭炮', '以色列'),
-('T-122', '火箭炮', '土耳其'),
-('RM-70', '火箭炮', '捷克'),
-('ASTROS II', '火箭炮', '巴西');
-
--- 火箭炮通用参数
-INSERT INTO common_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_range_km
-) VALUES
--- BM-21
-(1, 7.35, 2.4, 3.1, 13700, 20.4),
--- SR5
-(2, 10.2, 2.8, 3.2, 28500, 70),
--- HIMARS
-(3, 7.0, 2.4, 3.2, 16250, 70),
--- LAR-160
-(4, 6.7, 2.5, 2.8, 15000, 45),
--- T-122
-(5, 7.2, 2.5, 2.9, 18000, 40),
--- RM-70
-(6, 7.5, 2.5, 3.0, 17200, 20.3),
--- ASTROS II
-(7, 8.0, 2.7, 3.1, 24500, 90);
-
--- 火箭炮特有参数
-INSERT INTO rocket_artillery_params (
- equipment_id,
- firing_angle_horizontal,
- firing_angle_vertical,
- rocket_length_m,
- rocket_diameter_mm,
- rocket_weight_kg,
- rate_of_fire,
- combat_weight_kg,
- speed_kmh,
- min_range_km,
- mobility_type,
- structure_layout,
- engine_model,
- engine_params,
- power_hp,
- travel_range_km
-) VALUES
--- BM-21
-(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
--- SR5
-(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
--- HIMARS
-(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
--- LAR-160
-(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
--- T-122
-(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
--- RM-70
-(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
--- ASTROS II
-(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-
--- 巡飞弹数据
-INSERT INTO equipment (name, type, manufacturer) VALUES
-('Hero-120', '巡飞弹', '以色列'),
-('Switchblade 600', '巡飞弹', '美国'),
-('Warmate', '巡飞弹', '波兰'),
-('CH-901', '巡飞弹', '中国'),
-('HAROP', '巡飞弹', '以色列'),
-('Coyote', '巡飞弹', '美国'),
-('WS-43', '巡飞弹', '中国');
-
--- 巡飞弹通用参数
-INSERT INTO common_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_range_km
-) VALUES
--- Hero-120
-(8, 1.3, 0.23, 0.23, 12.5, 40),
--- Switchblade 600
-(9, 1.3, 0.22, 0.22, 15.0, 40),
--- Warmate
-(10, 1.1, 0.15, 0.15, 5.7, 15),
--- CH-901
-(11, 1.2, 0.18, 0.18, 9.0, 20),
--- HAROP
-(12, 2.5, 0.43, 0.43, 135, 1000),
--- Coyote
-(13, 0.9, 0.12, 0.12, 5.9, 20),
--- WS-43
-(14, 1.8, 0.35, 0.35, 20, 60);
-
--- 巡飞弹特有参数
-INSERT INTO loitering_munition_params (
- equipment_id,
- wingspan_m,
- warhead_weight_kg,
- max_speed_ms,
- cruise_speed_kmh,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm,
- power_system,
- guidance_system
-) VALUES
--- Hero-120
-(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
--- Switchblade 600
-(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
--- Warmate
-(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
--- CH-901
-(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
--- HAROP
-(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
--- Coyote
-(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
--- WS-43
-(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-
--- 插入成本数据(示例成本)
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
--- 火箭炮
-(1, 800000), -- BM-21
-(2, 4500000), -- SR5
-(3, 5500000), -- HIMARS
-(4, 3500000), -- LAR-160
-(5, 2800000), -- T-122
-(6, 1500000), -- RM-70
-(7, 4800000), -- ASTROS II
--- 巡飞弹
-(8, 150000), -- Hero-120
-(9, 180000), -- Switchblade 600
-(10, 80000), -- Warmate
-(11, 100000), -- CH-901
-(12, 850000), -- HAROP
-(13, 75000), -- Coyote
-(14, 120000); -- WS-43
-
--- 创建初始数据集
-INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
-('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
-('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
-('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
-('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-
--- 关联装备到数据集
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
--- 火箭炮训练集
-(1, 1), (1, 2), (1, 3), (1, 4),
--- 巡飞弹训练集
-(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
--- 火箭炮验证集
-(3, 5), (3, 6), (3, 7),
--- 巡飞弹验证集
-(4, 13), (4, 14);
\ No newline at end of file
diff --git a/src/loitering_munition_data.sql b/src/loitering_munition_data.sql
index 350dc0a..e14ca4e 100644
--- a/src/loitering_munition_data.sql
+++ b/src/loitering_munition_data.sql
@@ -26,15 +26,15 @@
*/
-- 插入装备基本信息
-INSERT INTO equipment (
+INSERT INTO equipments (
id, -- 装备ID
name, -- 装备名称
type, -- 装备类型
manufacturer -- 制造商
) VALUES
-(1, 'IAI Harop', '巡飞弹', '以色列'),
-(2, 'IAI Harpy', '巡飞弹', '以色列'),
-(3, 'IAI Mini Harpy', '巡飞弹', '以色列'),
+(1, 'IAI Harop', '巡飞弹', '以色列 IAI'),
+(2, 'IAI Harpy', '巡飞弹', '以色列 IAI'),
+(3, 'IAI Mini Harpy', '巡飞弹', '以色列 IAI'),
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
@@ -65,11 +65,11 @@ INSERT INTO equipment (
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'),
-(34, 'Shahed-131', '巡飞弹', '伊朗'),
-(35, 'Shahed-131B', '巡飞弹', '伊朗'),
-(36, 'Shahed-136', '巡飞弹', '伊朗'),
-(37, 'Shahed-136B', '巡飞弹', '伊朗'),
-(38, 'Shahed-136C', '巡飞弹', '伊朗'),
+(34, 'Shahed-131', '巡飞弹', '伊朗国防工业'),
+(35, 'Shahed-131B', '巡飞弹', '伊朗国防工业'),
+(36, 'Shahed-136', '巡飞弹', '伊朗国防工业'),
+(37, 'Shahed-136B', '巡飞弹', '伊朗国防工业'),
+(38, 'Shahed-136C', '巡飞弹', '伊朗国防工业'),
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
@@ -285,7 +285,7 @@ INSERT INTO loitering_munition_params (
(24, 2.8, 8.0, 70, 180, 240, 50, 10.0, 4000, 25, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(25, 3.0, 9.0, 75, 190, 270, 60, 11.0, 4500, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(26, 3.2, 10.0, 80, 200, 300, 70, 12.0, 5000, 35, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
-(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
+(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红���'),
(28, 3.6, 16.0, 90, 230, 400, 120, 20.0, 6500, 60, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
(29, 1.2, 1.0, 40, 90, 30, 5, 1.5, 1500, 3, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
(30, 1.3, 1.2, 45, 100, 40, 8, 2.0, 2000, 4, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
@@ -338,7 +338,7 @@ INSERT INTO loitering_munition_params (
(77, 2.8, 40.0, 250, 200, 120, 180, 50.0, 5500, 90, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Light
(78, 3.2, 80.0, 230, 180, 150, 200, 100.0, 6000, 100, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Heavy
(79, 1.5, 3.5, 160, 140, 60, 50, 5.0, 3500, 25, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助'), -- Taifun
-(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K
+(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '电������', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K
(81, 1.5, 3.0, 120, 100, 60, 40, 4.0, 3000, 20, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ES
(82, 2.0, 5.0, 140, 120, 90, 60, 6.0, 4000, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ER
(83, 2.5, 8.0, 160, 140, 120, 80, 10.0, 5000, 40, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'), -- HERO-XL
@@ -469,7 +469,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(2, '巡飞弹验证集 2024', '包含20个巡飞弹型号,用于验证模型性能', '巡飞弹', '验证');
-- 训练集(80个型号)
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
+INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 以色列系列(8/10)
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
(1, 4), (1, 5), (1, 6), (1, 7), (1, 8), -- Hero系列
@@ -520,7 +520,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
(1, 96), (1, 97), (1, 98), (1, 99); -- Shadow/Argus系列
-- 验证集(20个型号)
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
+INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 以色列系列(2/10)
(2, 9), -- Hero-900
(2, 48), -- Rotem L
@@ -666,4 +666,43 @@ SET
WHEN l.max_range_km > 500 THEN 5000
WHEN l.max_range_km > 100 THEN 3000
ELSE 1500
- END;
\ No newline at end of file
+ END;
+
+
+
+-- 更新巡飞弹的制导精度
+UPDATE loitering_munition_params l
+SET guidance_accuracy_m =
+ CASE
+ -- 基础精度(根据制导系统类型)
+ WHEN guidance_system LIKE '%GPS/INS%' AND guidance_system LIKE '%AI辅助%' THEN 2.0
+ WHEN guidance_system LIKE '%GPS/INS%' THEN 3.0
+ WHEN guidance_system LIKE '%激光制导%' THEN 1.0
+ WHEN guidance_system LIKE '%红外制导%' THEN 2.0
+ WHEN guidance_system LIKE '%卫星制导%' THEN 2.5
+ ELSE 5.0
+ END *
+ -- 速度影响因子(速度越快,精度略微降低)
+ CASE
+ WHEN max_speed_ms > 200 THEN 1.2
+ WHEN max_speed_ms > 150 THEN 1.1
+ WHEN max_speed_ms > 100 THEN 1.0
+ ELSE 0.9
+ END *
+ -- 重量影响因子(重量越大,精度略微降低)
+ CASE
+ WHEN warhead_weight_kg > 100 THEN 1.2
+ WHEN warhead_weight_kg > 50 THEN 1.1
+ WHEN warhead_weight_kg > 20 THEN 1.0
+ ELSE 0.9
+ END *
+ -- 飞行高度影响因子(高度越高,精度略微降低)
+ CASE
+ WHEN ceiling_altitude_m > 5000 THEN 1.2
+ WHEN ceiling_altitude_m > 3000 THEN 1.1
+ WHEN ceiling_altitude_m > 1000 THEN 1.0
+ ELSE 0.9
+ END
+WHERE equipment_id IN (
+ SELECT id FROM equipments WHERE type = '巡飞弹'
+);
\ No newline at end of file
diff --git a/src/manufacturer_data.sql b/src/manufacturer_data.sql
index 01cf09b..f80321a 100644
--- a/src/manufacturer_data.sql
+++ b/src/manufacturer_data.sql
@@ -40,7 +40,7 @@ INSERT INTO manufacturers (
('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业
-- 俄罗斯供应商
-('俄罗斯', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
+('俄罗斯 Rostec', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商
('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商
@@ -72,7 +72,9 @@ INSERT INTO manufacturers (
('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业
-- 更新装备表中的供应商ID
-UPDATE equipment e
-SET manufacturer_id = m.id
-FROM manufacturers m
-WHERE e.manufacturer = m.name;
\ No newline at end of file
+UPDATE equipments e
+SET manufacturer_id = (
+ SELECT id
+ FROM manufacturers m
+ WHERE m.name = e.manufacturer
+);
\ No newline at end of file
diff --git a/src/model_trainer.py b/src/model_trainer.py
index 8792f90..9da7c39 100644
--- a/src/model_trainer.py
+++ b/src/model_trainer.py
@@ -1,282 +1,112 @@
import numpy as np
-import pandas as pd
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
-from sklearn.model_selection import cross_val_score
-from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
-from sklearn.impute import SimpleImputer
-import xgboost as xgb
-import lightgbm as lgb
+from sklearn.model_selection import train_test_split
import logging
-import joblib
import os
-from src.feature_analysis import FeatureAnalysis
from datetime import datetime
import json
+from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
-from src.data_preparation import DataPreparation
-from sklearn.cross_decomposition import PLSRegression
+from src.data_preparation import DataPreparation, EquipmentDataset
from .logger import setup_logger
logger = setup_logger(__name__)
+class CostPredictionModel(nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(input_size, 128),
+ nn.ReLU(),
+ nn.Dropout(0.3),
+ nn.Linear(128, 64),
+ nn.ReLU(),
+ nn.Dropout(0.2),
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, 1)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
class ModelTrainer:
def __init__(self):
- """
- 初始化 ModelTrainer
- """
- self.models = {
- 'xgboost': self._create_xgboost_model(),
- 'lightgbm': self._create_lightgbm_model(),
- 'gbm': self._create_gbm_model(),
- 'rf': self._create_rf_model(),
- 'pls': self._create_pls_model()
- }
- self.best_model = None
- self.imputer = SimpleImputer(strategy='mean')
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.model = None
self.feature_scaler = None
self.target_scaler = None
self.equipment_type = None
self.feature_analyzer = FeatureAnalysis()
- def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
- """
- 训练模型并返回评估结果
- """
+ def train_model(self, dataloader, epochs=100, learning_rate=0.001):
+ """训练模型"""
try:
- self.equipment_type = equipment_type
- logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
- logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
+ # 获取输入特征维度
+ sample_features, _ = next(iter(dataloader))
+ input_size = sample_features.shape[1]
- results = {}
- best_score = -float('inf')
- best_model_info = None
+ # 创建模型
+ self.model = CostPredictionModel(input_size).to(self.device)
+ criterion = nn.MSELoss()
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
- # 首先训练 PLS 模型
- logger.info("Training pls...")
- pls_model = self.models['pls']
- pls_model.fit(X_train, y_train)
- pls_metrics = self._calculate_metrics(
- pls_model,
- X_train, y_train,
- X_val, y_val
- )
- results['pls'] = pls_metrics
-
- # 训练其他机器学习模型
- for model_name in model_names:
- if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
- continue
+ # 训练循环
+ for epoch in range(epochs):
+ self.model.train()
+ total_loss = 0
+ for batch_features, batch_targets in dataloader:
+ # 移动数据到设备
+ batch_features = batch_features.to(self.device)
+ batch_targets = batch_targets.to(self.device)
- if model_name not in self.models:
- logger.warning(f"Unknown model: {model_name}")
- continue
+ # 前向传播
+ outputs = self.model(batch_features)
+ loss = criterion(outputs, batch_targets.view(-1, 1))
- logger.info(f"Training {model_name}...")
- model = self.models[model_name]
+ # 反向传播
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item()
- # 训练模型
- model.fit(X_train, y_train)
-
- # 计算评估指标
- metrics = self._calculate_metrics(
- model,
- X_train, y_train,
- X_val, y_val
- )
-
- results[model_name] = metrics
-
- # 更新最佳模型(只在机器学习模型中比较)
- if metrics['validation']['r2'] > best_score:
- best_score = metrics['validation']['r2']
- best_model_info = {
- 'type': model_name,
- 'r2': metrics['validation']['r2'],
- 'mae': metrics['validation']['mae'],
- 'rmse': metrics['validation']['rmse']
- }
- self.best_model = model
+ # 记录训练进度
+ if (epoch + 1) % 10 == 0:
+ avg_loss = total_loss / len(dataloader)
+ logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
- # 保存最佳模型和 PLS 模型
- if equipment_type and best_model_info:
- self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
-
- return {
- 'metrics': results,
- 'best_model': best_model_info
- }
+ return True
except Exception as e:
logger.error(f"Error in model training: {str(e)}")
raise
-
- def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
- """
- 计算模型评估指标
- """
- from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
-
- # 训练集评估
- train_pred = model.predict(X_train)
-
- # 如果使用了标准化,需要转换回原始范围
- if hasattr(self, 'target_scaler'):
- train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
- y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
- else:
- y_train_orig = y_train
-
- # 记录预测范围
- logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
- logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
-
- train_metrics = {
- 'r2': r2_score(y_train_orig, train_pred),
- 'mae': mean_absolute_error(y_train_orig, train_pred),
- 'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
- }
-
- # 验证集评估
- if X_val is not None and y_val is not None:
- val_pred = model.predict(X_val)
-
- # 如果使用了标准化,需要转换回原始范围
- if hasattr(self, 'target_scaler'):
- val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
- y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
- else:
- y_val_orig = y_val
-
- # 记录预测范围
- logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
- logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
-
- val_metrics = {
- 'r2': r2_score(y_val_orig, val_pred),
- 'mae': mean_absolute_error(y_val_orig, val_pred),
- 'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
- }
- else:
- # 使用交叉验证
- cv_scores = cross_val_score(model, X_train, y_train, cv=5)
- val_metrics = {
- 'r2': cv_scores.mean(),
- 'mae': None,
- 'rmse': None
- }
-
- return {
- 'train': train_metrics,
- 'validation': val_metrics
- }
- def _create_xgboost_model(self):
- """
- 创建 XGBoost 模型,增强正则化
- """
- return xgb.XGBRegressor(
- n_estimators=50, # 减少树的数量
- learning_rate=0.05, # 学习率
- max_depth=3, # 减小树的深
- min_child_weight=3, # 增加节点权重
- subsample=0.7, # 减小样本采样比例
- colsample_bytree=0.7, # 减小特征采样比例
- reg_alpha=0.1, # L1 正则化
- reg_lambda=1, # L2 正则化
- random_state=42
- )
-
- def _create_lightgbm_model(self):
- """
- 创建 LightGBM 模型,增强正则化
- """
- return lgb.LGBMRegressor(
- n_estimators=50,
- learning_rate=0.05,
- max_depth=3,
- num_leaves=7,
- min_data_in_leaf=3,
- min_sum_hessian_in_leaf=1e-3,
- subsample=0.7,
- colsample_bytree=0.7,
- reg_alpha=0.1,
- reg_lambda=1,
- random_state=42,
- verbose=-1
- )
-
- def _create_gbm_model(self):
- """
- 创建 GBM 模型,增强正则化以减轻过拟合
- """
- return GradientBoostingRegressor(
- n_estimators=100,
- learning_rate=0.1,
- max_depth=3,
- random_state=42,
- subsample=0.8,
- min_samples_split=3,
- min_samples_leaf=2
- )
-
- def _create_rf_model(self):
- """
- 创建随机森林模型,针对小样本数据调整参数
- """
- return RandomForestRegressor(
- n_estimators=100,
- max_depth=3,
- random_state=42,
- min_samples_split=3,
- min_samples_leaf=2
- )
-
- def _create_pls_model(self):
- """
- 创建 PLS 模型,优化参数配置
- """
- return PLSRegression(
- n_components=2, # 减少主成分数量,从5减到2
- scale=True, # 保持数据标准化
- max_iter=500, # 减少最大迭代次数,避免过拟合
- tol=1e-6 # 降低收敛精度,避免过拟合
- )
-
- def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
- """
- 保存最佳模型和 PLS 模型
- """
+ def save_model(self, equipment_type):
+ """保存模型"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
- # 1. 保存最佳机器学习模型
- model_path = f'{model_dir}/{equipment_type}_{timestamp}'
- if isinstance(self.best_model, xgb.XGBRegressor):
- self.best_model.save_model(f'{model_path}.json')
- model_format = 'json'
- else:
- joblib.dump(self.best_model, f'{model_path}.joblib')
- model_format = 'joblib'
-
- # 2. 保存 PLS 模型
- pls_model = self.models['pls']
- pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
- joblib.dump(pls_model, pls_path)
+ # 保存模型
+ model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth'
+ torch.save({
+ 'model_state_dict': self.model.state_dict(),
+ 'input_size': self.model.layers[0].in_features
+ }, model_path)
- # 3. 保存标准化器
- scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
- joblib.dump({
+ # 保存标准化器
+ scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth'
+ torch.save({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
- logger.info(f"Saved best model to {model_path}.{model_format}")
- logger.info(f"Saved PLS model to {pls_path}")
- logger.info(f"Saved scalers to {scaler_path}")
-
- # 4. 更新数据库中的模型记录
+ # 更新数据库
with get_db_connection() as conn:
cursor = conn.cursor()
@@ -287,420 +117,73 @@ class ModelTrainer:
WHERE equipment_type = %s
""", (equipment_type,))
- # 获取 PLS 模型的评估指标
- pls_metrics = self._calculate_metrics(
- self.models['pls'],
- X_train,
- y_train,
- X_val,
- y_val
- )
-
- # 保存最佳机器学习模型记录
- self.best_model.equipment_type = equipment_type # 设置装备类型
- ml_feature_importance = self._get_feature_importance(self.best_model)
-
+ # 保存新模型记录
cursor.execute("""
INSERT INTO trained_models (
- model_name, model_type, equipment_type, model_path, scaler_path,
- r2_score, mae, rmse, feature_importance, training_data_size,
- training_date, is_active, created_by
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
+ model_name, model_type, equipment_type, model_path,
+ scaler_path, training_date, is_active, created_by
+ ) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
- f"{equipment_type}_{timestamp}", # model_name
- best_model_info['type'], # model_type
- equipment_type, # equipment_type
- f"{model_path}.{model_format}", # model_path
- scaler_path, # scaler_path
- best_model_info['r2'], # r2_score
- best_model_info['mae'], # mae
- best_model_info['rmse'], # rmse
- json.dumps(ml_feature_importance), # feature_importance
- len(X_train), # training_data_size
- 'system' # created_by
- ))
-
- # 保存 PLS 模型记录
- pls_model.equipment_type = equipment_type # 设置装备类型
- pls_feature_importance = self._get_feature_importance(pls_model)
-
- cursor.execute("""
- INSERT INTO trained_models (
- model_name, model_type, equipment_type, model_path, scaler_path,
- r2_score, mae, rmse, feature_importance, training_data_size,
- training_date, is_active, created_by
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
- """, (
- f"{equipment_type}_{timestamp}_pls", # model_name
- 'pls', # model_type
- equipment_type, # equipment_type
- pls_path, # model_path
- scaler_path, # scaler_path
- float(pls_metrics['validation']['r2']), # r2_score
- float(pls_metrics['validation']['mae']), # mae
- float(pls_metrics['validation']['rmse']), # rmse
- json.dumps(pls_feature_importance), # feature_importance
- len(X_train), # training_data_size
- 'system' # created_by
+ f"{equipment_type}_{timestamp}",
+ 'pytorch',
+ equipment_type,
+ model_path,
+ scaler_path,
+ 'system'
))
conn.commit()
-
- except Exception as e:
- logger.error(f"Error saving models: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- raise
-
- def load_model(self, equipment_type, model_type='ml'):
- """
- 加载已训练的模型
- """
- try:
- logger.info(f"Loading {model_type} model for {equipment_type}")
- # 从数据库获取激活的模型
+ return True
+
+ except Exception as e:
+ logger.error(f"Error saving model: {str(e)}")
+ return False
+
+ def load_model(self, equipment_type):
+ """加载模型"""
+ try:
+ # 从数据库获取最新的激活模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
-
- # 构建查询语句
- if model_type == 'pls':
- query = """
- SELECT * FROM trained_models
- WHERE equipment_type = %s
- AND model_type = 'pls'
- AND is_active = TRUE
- LIMIT 1
- """
- params = (equipment_type,)
- else:
- query = """
- SELECT * FROM trained_models
- WHERE equipment_type = %s
- AND model_type != 'pls'
- AND is_active = TRUE
- LIMIT 1
- """
- params = (equipment_type,)
-
- # 记录查询信息
- logger.info(f"Executing query: {query}")
- logger.info(f"Query params: {params}")
-
- cursor.execute(query, params)
+ cursor.execute("""
+ SELECT * FROM trained_models
+ WHERE equipment_type = %s AND is_active = TRUE
+ ORDER BY training_date DESC LIMIT 1
+ """, (equipment_type,))
model_record = cursor.fetchone()
- # 记录查询结果
- if model_record:
- logger.info(f"Found model record: {model_record}")
- else:
- logger.warning(f"No active model found for type {model_type}")
+ if not model_record:
return False
- # 检查文件是否存在
- logger.info(f"Checking model file: {model_record['model_path']}")
- logger.info(f"Checking scaler file: {model_record['scaler_path']}")
-
- if not os.path.exists(model_record['model_path']):
- logger.error(f"Model file not found: {model_record['model_path']}")
- raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
-
- if not os.path.exists(model_record['scaler_path']):
- logger.error(f"Scaler file not found: {model_record['scaler_path']}")
- raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
-
- # 加载模型文件
- logger.info(f"Loading model from {model_record['model_path']}")
- if model_type == 'pls':
- self.best_model = joblib.load(model_record['model_path'])
- logger.info("Loaded PLS model")
- else:
- if model_record['model_type'] == 'xgboost':
- self.best_model = xgb.XGBRegressor()
- self.best_model.load_model(model_record['model_path'])
- logger.info("Loaded XGBoost model")
- else:
- self.best_model = joblib.load(model_record['model_path'])
- logger.info(f"Loaded {model_record['model_type']} model")
+ # 加载模型
+ checkpoint = torch.load(model_record['model_path'])
+ input_size = checkpoint['input_size']
+ self.model = CostPredictionModel(input_size).to(self.device)
+ self.model.load_state_dict(checkpoint['model_state_dict'])
# 加载标准化器
- logger.info(f"Loading scalers from {model_record['scaler_path']}")
- scalers = joblib.load(model_record['scaler_path'])
+ scalers = torch.load(model_record['scaler_path'])
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
- logger.info("Loaded scalers successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
- logger.error(f"Detailed traceback:", exc_info=True)
return False
-
- def get_missile_features(self):
- """获取巡飞弹的特征列表"""
- return [
- # 基本参数
- 'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km',
-
- # 性能参数 - 新增和修改的参数
- 'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
- 'endurance_min', 'max_payload_kg', 'min_combat_radius_km',
-
- # 动力系统参数 - 新增参数
- 'engine_power_kw', 'engine_thrust_n',
-
- # 制导与控制参数 - 新增参数
- 'datalink_range_km', 'guidance_accuracy_m',
- 'min_altitude_m', 'max_altitude_m',
-
- # 特征工程参数 - 新增评分指标
- 'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio',
- 'guidance_system_score', 'warhead_power_score'
- ]
-
- def get_rocket_features(self):
- """获取火箭炮的特征列表"""
- return [
- # 基本参数
- 'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km',
-
- # 火箭炮特有参数 - 新增和修改的参数
- 'firing_angle_horizontal', 'firing_angle_vertical',
- 'rocket_length_m', 'rocket_diameter_mm', 'rocket_weight_kg',
- 'rate_of_fire', 'combat_weight_kg', 'speed_kmh',
- 'min_range_km', 'max_range_km', 'mobility_type', 'structure_layout',
- 'engine_model', 'power_hp', 'travel_range_km',
-
- # 特征工程参数 - 新增评分指标
- 'fire_density', 'range_ratio', 'mobility_score',
- 'combat_readiness_score', 'rocket_power_ratio',
- 'platform_efficiency', 'deployment_score',
- 'terrain_adaptability_score'
- ]
-
+
def predict(self, features):
- """使用加载的模型进行预测"""
+ """使用模型进行预测"""
try:
- if not self.best_model:
- raise ValueError("No model loaded")
-
- if not self.feature_scaler:
- raise ValueError("Feature scaler not loaded")
-
- if not self.target_scaler:
- raise ValueError("Target scaler not loaded")
-
- logger.info("Starting prediction")
- logger.info(f"Input features shape: {features.shape}")
- logger.info(f"Input features: \n{features}")
-
- # 获取正确的特征列表
- if self.equipment_type == '巡飞弹':
- feature_list = self.get_missile_features()
- logger.info(f"Using missile features: {feature_list}")
-
- # 确保特征顺序一致
- features_ordered = np.zeros((features.shape[0], len(feature_list)))
- for i, feature_name in enumerate(feature_list):
- if feature_name in features:
- features_ordered[:, i] = features[feature_name]
- features = features_ordered
-
- # 处理缺失值
- features_filled = np.array(features, dtype=float)
- features_filled[np.isnan(features_filled)] = 0
- features_filled = np.nan_to_num(features_filled, 0)
-
- logger.info(f"Filled features: \n{features_filled}")
-
- # 标准化特征
- X = self.feature_scaler.transform(features_filled)
- logger.info(f"Transformed features shape: {X.shape}")
- logger.info(f"Transformed features: \n{X}")
-
- # 预测
- y_pred_scaled = self.best_model.predict(X)
- logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
- logger.info(f"Scaled prediction: {y_pred_scaled}")
-
- # 反标准化
- y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
- logger.info(f"Final prediction shape: {y_pred.shape}")
- logger.info(f"Final prediction: {y_pred}")
-
- return y_pred.ravel()
-
+ self.model.eval()
+ with torch.no_grad():
+ # 转换为tensor并移动到正确的设备
+ features_tensor = torch.FloatTensor(features).to(self.device)
+ # 进行预测
+ predictions = self.model(features_tensor)
+ # 移回CPU并转换为numpy数组
+ return predictions.cpu().numpy()
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
- raise
-
- def _get_feature_importance(self, model):
- """
- 获取特征重要性
- """
- try:
- if not model:
- return {}
-
- # 获取征名称
- if self.equipment_type == '巡飞弹':
- feature_names = [
- # 基本参数
- 'length_m', 'width_m', 'height_m', 'weight_kg',
-
- # 性能参数
- 'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
- 'endurance_min', 'max_range_km','max_payload_kg', 'ceiling_altitude_m',
- 'combat_radius_km',
-
- # 动力系统参数
- 'engine_power_kw', 'engine_thrust_n',
-
- # 制导与控制参数
- 'datalink_range_km', 'guidance_accuracy_m',
- 'min_altitude_m', 'max_altitude_m',
-
- # 特征工程参数
- 'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio',
- 'guidance_system_score', 'warhead_power_score'
- ]
- else:
- # 其他装备类型使用原有的特征获取逻辑
- feature_analyzer = FeatureAnalysis()
- feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
-
- # 获取特征重要性
- if hasattr(model, 'feature_importances_'):
- importances = model.feature_importances_
- elif hasattr(model, 'coef_'):
- if len(model.coef_.shape) > 1: # 如果是二维数组
- importances = np.abs(model.coef_[0]) # 取第一行
- else:
- importances = np.abs(model.coef_)
- else:
- return {}
-
- # 创建特征重要性字典
- importance_dict = {}
- for name, importance in zip(feature_names, importances):
- importance_dict[name] = float(importance) # 确保转换为 Python 标量
-
- # 按重要性降序排序
- sorted_dict = dict(sorted(
- importance_dict.items(),
- key=lambda x: x[1],
- reverse=True
- ))
-
- # 过滤掉重要性为0的特征
- return {k: v for k, v in sorted_dict.items() if v > 0}
-
- except Exception as e:
- logger.error(f"Error getting feature importance: {str(e)}")
- return {}
-
- def _calculate_confidence_interval(self, prediction, confidence=0.95):
- """
- 计算预测值的置信区间
- """
- try:
- # 使用预测值的20%作为标准差(增加不确定性)
- std = abs(prediction) * 0.2
-
- # 计算置信区间
- from scipy import stats
- interval = stats.norm.interval(confidence, loc=prediction, scale=std)
-
- # 确保区间值为正数且合理
- lower = max(1000, interval[0]) # 最小值设为1000元
- upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
-
- logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
-
- return [lower, upper]
-
- except Exception as e:
- logger.error(f"Error calculating confidence interval: {str(e)}")
- # 如果计算失败,返回基于20%的简单区间
- lower = max(1000, prediction * 0.8)
- upper = prediction * 1.2
- return [lower, upper]
-
- def get_model_type(self):
- """
- 获取当前模型的类型
- """
- if isinstance(self.best_model, xgb.XGBRegressor):
- return 'xgboost'
- elif isinstance(self.best_model, lgb.LGBMRegressor):
- return 'lightgbm'
- elif isinstance(self.best_model, GradientBoostingRegressor):
- return 'gbm'
- elif isinstance(self.best_model, RandomForestRegressor):
- return 'rf'
- else:
- return 'unknown'
-
- def _get_pls_feature_importance(self):
- """
- 获取 PLS 模型的特征重要性
- """
- try:
- if not self.models['pls']:
- return {}
-
- # 获取特征名称
- feature_analyzer = FeatureAnalysis()
- feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
-
- # 获取 PLS 模型的系数作为特征重要性
- pls_model = self.models['pls']
- if hasattr(pls_model, 'coef_'):
- # 使用绝对值作为重要性指标
- importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
- else:
- return {}
-
- # 创建特征重要性字典
- importance_dict = {}
- for name, importance in zip(feature_names, importances):
- importance_dict[name] = float(importance) # 确保转换为 Python 标量
-
- # 按重要性降序排序
- sorted_dict = dict(sorted(
- importance_dict.items(),
- key=lambda x: x[1],
- reverse=True
- ))
-
- # 过滤掉重要性为0的特征
- return {k: v for k, v in sorted_dict.items() if v > 0}
-
- except Exception as e:
- logger.error(f"Error getting PLS feature importance: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- return {}
-
- def _preprocess_data(self, data):
- """数据预处理"""
- try:
- # 获取正确的特征列表
- if self.equipment_type == '巡飞弹':
- feature_list = self.get_missile_features()
- else:
- feature_list = self.get_rocket_features()
-
- logger.info(f"Using features: {feature_list}")
-
- # 处理缺失值
- features_filled = np.array(data, dtype=float)
- features_filled[np.isnan(features_filled)] = 0
- features_filled = np.nan_to_num(features_filled, 0)
-
- logger.info(f"Filled features: \n{features_filled}")
-
- return features_filled
-
- except Exception as e:
- logger.error(f"Error in data preprocessing: {str(e)}")
raise
\ No newline at end of file
diff --git a/src/real_data.sql b/src/real_data.sql
deleted file mode 100644
index d00ab59..0000000
--- a/src/real_data.sql
+++ /dev/null
@@ -1,485 +0,0 @@
--- 清空现有数据
-SET FOREIGN_KEY_CHECKS=0;
-TRUNCATE TABLE dataset_equipment;
-TRUNCATE TABLE datasets;
-TRUNCATE TABLE cost_data;
-TRUNCATE TABLE loitering_munition_params;
-TRUNCATE TABLE common_params;
-TRUNCATE TABLE equipment;
-SET FOREIGN_KEY_CHECKS=1;
-
--- 按系列插入装备数据,确保ID连续
--- 1. HAROP/Harpy 系列 (ID: 1-3)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(1, 'IAI Harop', '巡飞弹', '以色列'),
-(2, 'IAI Harpy', '巡飞弹', '以色列'),
-(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
-
--- 2. Hero 系列 (ID: 4-9)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
-(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
-(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
-(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
-(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
-(9, 'Hero-900', '巡飞弹', '以色列 UVision');
-
--- 3. Switchblade 系列 (ID: 10-13)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
-(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
-(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
-(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
-
--- 4. Warmate 系列 (ID: 14-18)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
-(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
-(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
-(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
-(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
-
--- 5. CH-901/902 系列 (ID: 19-23)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(19, 'CH-901', '巡飞弹', '中国航天科工'),
-(20, 'CH-901A', '巡飞弹', '中国航天科工'),
-(21, 'CH-901H', '巡飞弹', '中国航天科工'),
-(22, 'CH-902', '巡飞弹', '中国航天科工'),
-(23, 'CH-902A', '巡飞弹', '中国航天科工');
-
--- 6. WS-43/61 系列 (ID: 24-28)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(24, 'WS-43', '巡飞弹', '中国航天科工'),
-(25, 'WS-43A', '巡飞弹', '中国航天科工'),
-(26, 'WS-43B', '巡飞弹', '中国航天科工'),
-(27, 'WS-61', '巡飞弹', '中国航天科工'),
-(28, 'WS-61A', '巡飞弹', '中国航天科工');
-
--- 7. Kargu/Alpagu 系列 (ID: 29-33)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(29, 'Kargu', '巡飞弹', '土耳其 STM'),
-(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
-(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
-(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
-(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
-
--- 8. Shahed 系列 (ID: 34-38)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(34, 'Shahed-131', '巡飞弹', '伊朗'),
-(35, 'Shahed-131B', '巡飞弹', '伊朗'),
-(36, 'Shahed-136', '巡飞弹', '伊朗'),
-(37, 'Shahed-136B', '巡飞弹', '伊朗'),
-(38, 'Shahed-136C', '巡飞弹', '伊朗');
-
--- 9. Green Dragon 系列 (ID: 39-43)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
-(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
-(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
-(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
-(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
-
--- 10. Phoenix Ghost 系列 (ID: 44-48)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
-(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
-(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
-(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
-(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
-
--- 11. ZALA Lancet 系列 (ID: 49-52)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
-(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
-(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
-(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
-
--- 12. Rotem L 系列 (ID: 53-56)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
-(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
-(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
-(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
-
--- 13. KUB-BLA 系列 (ID: 57-60)
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
-(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
-(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
-(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
-(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
-
--- 插入通用参数
-INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
-(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
-(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
-(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
-(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
-(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
-(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
-(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
-(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
-(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
-(10, 0.58, 0.12, 0.12, 2.5, 10),
-(11, 1.30, 0.22, 0.22, 15.0, 40),
-(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
-(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
-(14, 0.68, 0.12, 0.12, 2.5, 10),
-(15, 1.30, 0.22, 0.22, 15.0, 40),
-(16, 0.68, 0.12, 0.12, 2.5, 10),
-(17, 1.30, 0.22, 0.22, 15.0, 40),
-(18, 0.68, 0.12, 0.12, 2.5, 10),
-(19, 1.2, 0.18, 0.18, 9.0, 20),
-(20, 1.2, 0.18, 0.18, 9.3, 25),
-(21, 1.2, 0.18, 0.18, 9.5, 20),
-(22, 1.4, 0.22, 0.22, 15.0, 30),
-(23, 1.4, 0.22, 0.22, 15.5, 35),
-(24, 1.8, 0.35, 0.35, 20, 60),
-(25, 1.8, 0.35, 0.35, 21, 70),
-(26, 1.9, 0.35, 0.35, 22, 80),
-(27, 2.2, 0.40, 0.40, 35, 100),
-(28, 2.2, 0.40, 0.40, 37, 120),
-(29, 0.6, 0.35, 0.35, 7.0, 10),
-(30, 0.6, 0.35, 0.35, 7.2, 15),
-(31, 1.0, 0.23, 0.23, 3.7, 5),
-(32, 1.0, 0.23, 0.23, 3.9, 8),
-(33, 0.6, 0.35, 0.35, 7.5, 15),
-(34, 2.6, 0.34, 0.34, 135, 900),
-(35, 2.6, 0.34, 0.34, 140, 1000),
-(36, 3.5, 0.42, 0.42, 200, 2000),
-(37, 3.5, 0.42, 0.42, 210, 2200),
-(38, 3.5, 0.42, 0.42, 215, 2500),
-(39, 1.5, 0.20, 0.20, 15, 40),
-(40, 1.6, 0.20, 0.20, 16, 50),
-(41, 1.5, 0.20, 0.20, 15.5, 45),
-(42, 1.5, 0.20, 0.20, 15.8, 40),
-(43, 1.2, 0.18, 0.18, 12, 30),
-(44, 1.5, 0.25, 0.25, 14.0, 30),
-(45, 1.5, 0.25, 0.25, 14.5, 35),
-(46, 1.6, 0.26, 0.26, 15.0, 40),
-(47, 1.5, 0.25, 0.25, 14.8, 30),
-(48, 1.7, 0.27, 0.27, 16.0, 50),
-(49, 1.0, 0.20, 0.20, 5.0, 40),
-(50, 1.65, 0.35, 0.35, 12.0, 70),
-(51, 1.65, 0.35, 0.35, 12.5, 80),
-(52, 1.80, 0.40, 0.40, 15.0, 100),
-(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
-(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
-(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
-(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
-(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
-(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
-(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
-(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
-
--- 插入特有参数
-INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
- endurance_min,
- warhead_type,
- launch_mode,
- power_system,
- guidance_system
-) VALUES
--- HAROP/Harpy系列
-(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
-(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
-(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
-
--- Hero系列
-(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
-(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
-(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
-(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
-(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
-(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
-
--- Switchblade系列
-(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
-(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
-(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
-(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-
--- Warmate系列
-(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
-(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
-(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
-(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
-(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
-
--- CH-901/902系列
-(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
-(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
-(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
-(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
-(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
-(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
-(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
-(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
-(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
-(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
-(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
-(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
-(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
-(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
-(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
-(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
-(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
-(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
-(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
-(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
-(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
-(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
-(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
-(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
-(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
-(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
-(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
-(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
-(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
-(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
-(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
-(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
-(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
-(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
-(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
-(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
-
--- 插入成本数据
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(1, 800000), -- IAI Harop
-(2, 700000), -- IAI Harpy
-(3, 350000), -- IAI Mini Harpy
-(4, 70000), -- Hero-30
-(5, 120000), -- Hero-70
-(6, 150000), -- Hero-120
-(7, 300000), -- Hero-250
-(8, 400000), -- Hero-400EC
-(9, 650000), -- Hero-900
-(10, 60000), -- Switchblade 300
-(11, 180000), -- Switchblade 600
-(12, 75000), -- Switchblade 300 Block 10
-(13, 200000), -- Switchblade 600 Extended Range
-(14, 60000), -- Warmate 1.0
-(15, 180000), -- Warmate 2.0
-(16, 60000), -- Warmate-V
-(17, 180000), -- Warmate-L
-(18, 60000), -- Warmate 3.0
-(19, 100000), -- CH-901
-(20, 120000), -- CH-901A
-(21, 130000), -- CH-901H
-(22, 180000), -- CH-902
-(23, 200000), -- CH-902A
-(24, 120000), -- WS-43
-(25, 150000), -- WS-43A
-(26, 180000), -- WS-43B
-(27, 300000), -- WS-61
-(28, 350000), -- WS-61A
-(29, 70000), -- Kargu
-(30, 85000), -- Kargu-2
-(31, 45000), -- Alpagu
-(32, 55000), -- Alpagu Block-II
-(33, 95000), -- Kargu Autonomous
-(34, 20000), -- Shahed-131
-(35, 25000), -- Shahed-131B
-(36, 40000), -- Shahed-136
-(37, 45000), -- Shahed-136B
-(38, 50000), -- Shahed-136C
-(39, 160000), -- Green Dragon
-(40, 200000), -- Green Dragon Extended Range
-(41, 180000), -- Green Dragon Block 2
-(42, 190000), -- Green Dragon Maritime
-(43, 140000), -- Green Dragon-S
-(44, 150000), -- Phoenix Ghost
-(45, 180000), -- Phoenix Ghost Block I
-(46, 220000), -- Phoenix Ghost Block II
-(47, 190000), -- Phoenix Ghost Maritime
-(48, 250000), -- Phoenix Ghost-ER
-(49, 80000), -- Lancet-1
-(50, 150000), -- Lancet-3
-(51, 180000), -- Lancet-3M
-(52, 250000), -- Lancet-4
-(53, 65000), -- Rotem L
-(54, 85000), -- Rotem L-X
-(55, 75000), -- Rotem L-M
-(56, 95000), -- Rotem L-ER
-(57, 95000), -- KUB-BLA
-(58, 120000), -- KUB-BLA-E
-(59, 110000), -- KUB-BLA-M
-(60, 150000); -- KUB-BLA-ER
-
--- 创建数据集
-INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
-(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
-(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
-
--- 关联装备到数据集(按照制造商和型号分配)
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
--- 训练集(约80%的数据,48个型号)
--- 以色列系列
-(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
-(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
-(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
-(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
-
--- 美国系列
-(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
-(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
-
--- 中国系列
-(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
-(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
-
--- 波兰和土耳其系列
-(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
-(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
-
--- 俄罗斯系列
-(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
-
--- 验证集(约20%的数据,12个型号)
--- 混合系列
-(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
-(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
-(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
-
--- 添加分类特征编码
-INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
--- 战斗部类型编码
-('warhead_type', '破片杀伤战斗部', 1),
-('warhead_type', '破甲战斗部', 2),
-('warhead_type', '高爆战斗部', 3),
-('warhead_type', '破片杀伤/破甲双用战斗部', 4),
-('warhead_type', '模块化战斗部', 5),
-
--- 发射方式编码
-('launch_mode', '箱式发射', 1),
-('launch_mode', '弹射式发射', 2),
-('launch_mode', '垂直起降', 3),
-('launch_mode', '单兵发射管', 4),
-('launch_mode', '箱式发射/弹射式', 5),
-('launch_mode', '箱式发射/空中发射', 6),
-
--- 动力装置编码(按复杂度递增)
-('power_system', '电动机', 1),
-('power_system', '活塞发动机', 2),
-
--- 制导系统编码(按复杂度递增)
-('guidance_system', 'GPS/INS', 1),
-('guidance_system', 'GPS/INS/光电', 2),
-('guidance_system', 'GPS/INS/光电/数据链', 3),
-('guidance_system', 'GPS/INS/光电/AI识别', 4),
-('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
-('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
-('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
-
--- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
-UPDATE loitering_munition_params l
-JOIN common_params c ON l.equipment_id = c.equipment_id
-SET
- -- 新增关键参数
- l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
- l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
- l.engine_power_kw =
- CASE
- WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
- WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
- END,
- l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
- l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
- l.guidance_accuracy_m =
- CASE
- WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
- WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
- ELSE 3.0
- END,
- l.min_altitude_m = -- 最小作战高度
- CASE
- -- 大型巡飞弹(体型大、重量大)
- WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
-
- -- 中型巡飞弹
- WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
-
- -- 中小型巡飞弹
- WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
-
- -- 小型巡飞弹
- WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
-
- -- 超小型巡飞弹
- WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
-
- -- 其他型号使用默认值
- ELSE 50
- END,
- l.max_altitude_m =
- CASE
- WHEN c.max_range_km > 500 THEN 5000
- WHEN c.max_range_km > 100 THEN 3000
- ELSE 1500
- END,
-
- -- 特征工程字段
- l.length_width_ratio = c.length_m / c.width_m,
- l.weight_range_ratio = c.weight_kg / c.max_range_km,
- l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
- l.guidance_system_score =
- CASE
- WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
- WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
- WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
- WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
- ELSE 2
- END,
- l.warhead_power_score =
- CASE
- WHEN l.warhead_type = '模块化战斗部' THEN 10
- WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
- WHEN l.warhead_type = '高爆战斗部' THEN 7
- WHEN l.warhead_type = '破甲战斗部' THEN 6
- WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
- ELSE 4
- END,
-
- -- 分类特征编码
- l.warhead_type_code =
- CASE
- WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
- WHEN l.warhead_type = '破甲战斗部' THEN 2
- WHEN l.warhead_type = '高爆战斗部' THEN 3
- WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
- WHEN l.warhead_type = '模块化战斗部' THEN 5
- ELSE 0
- END,
- l.launch_mode_code =
- CASE
- WHEN l.launch_mode = '箱式发射' THEN 1
- WHEN l.launch_mode = '弹射式发射' THEN 2
- WHEN l.launch_mode = '垂直起降' THEN 3
- WHEN l.launch_mode = '单兵发射管' THEN 4
- WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
- WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
- ELSE 0
- END,
- l.power_system_code =
- CASE
- WHEN l.power_system = '电动机' THEN 1
- WHEN l.power_system = '活塞发动机' THEN 2
- ELSE 0
- END,
- l.guidance_system_code =
- CASE
- WHEN l.guidance_system = 'GPS/INS' THEN 1
- WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
- WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
- WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
- WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
- WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
- WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
- ELSE 0
- END;
diff --git a/src/rocket_artillery_data.sql b/src/rocket_artillery_data.sql
index e81f8a1..ba84945 100644
--- a/src/rocket_artillery_data.sql
+++ b/src/rocket_artillery_data.sql
@@ -29,7 +29,7 @@
*/
-- 中国系列火箭炮数据
-INSERT INTO equipment (id, name, type, manufacturer) VALUES
+INSERT INTO equipments (id, name, type, manufacturer) VALUES
(1001, 'PCL-191', '火箭炮', '中国兵器工业集团'),
(1002, 'PHL-03', '火箭炮', '中国兵器工业集团'),
(1003, 'AR-3', '火箭炮', '中国航天科工'),
@@ -39,11 +39,11 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1007, 'WS-2', '火箭炮', '中国航天科工'),
(1008, 'WS-3', '火箭炮', '中国航天科工'),
(1009, 'Type 63', '火箭炮', '中国兵器工业集团'),
-(1010, 'BM-21 Grad', '火箭炮', '俄罗斯'),
-(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯'),
-(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯'),
-(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯'),
-(1014, 'TOS-1A', '火箭炮', '俄罗斯'),
+(1010, 'BM-21 Grad', '火箭炮', '俄罗斯 Rostec'),
+(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯 Rostec'),
+(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯 Rostec'),
+(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯 Rostec'),
+(1014, 'TOS-1A', '火箭炮', '俄罗斯 Rostec'),
(1015, 'M142 HIMARS', '火箭炮', '美国洛克希德·马丁'),
(1016, 'M270 MLRS', '火箭炮', '美国洛克希德·马丁'),
(1017, 'M270A1', '火箭炮', '美国洛克希德·马丁'),
@@ -62,10 +62,10 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1030, 'ASTROS 2020', '火箭炮', '巴西航空工业'),
(1031, 'ASTROS II Mk3', '火箭炮', '巴西航空工业'),
(1032, 'ASTROS II Mk6', '火箭炮', '巴西航空工业'),
-(1033, 'Pinaka', '火箭炮', '印度DRDO'),
-(1034, 'Pinaka Mk-II', '火箭炮', '印度DRDO'),
-(1035, 'Pinaka Mk-III', '火箭炮', '印度DRDO'),
-(1036, 'Pinaka-ER', '火箭炮', '印度DRDO'),
+(1033, 'Pinaka', '火箭炮', '印度 DRDO'),
+(1034, 'Pinaka Mk-II', '火箭炮', '印度 DRDO'),
+(1035, 'Pinaka Mk-III', '火箭炮', '印度 DRDO'),
+(1036, 'Pinaka-ER', '火箭炮', '印度 DRDO'),
(1037, 'WR-40 Langusta', '火箭炮', '波兰胡塔斯塔洛瓦'),
(1038, 'RM-70', '火箭炮', '波兰胡塔斯塔洛瓦'),
(1039, 'BM-21M', '火箭炮', '波兰胡塔斯塔洛瓦'),
@@ -485,7 +485,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(4, '火箭炮验证集 2024', '包含19个火箭炮型号,用于验证模型性能', '火箭炮', '验证');
-- 训练集(约80%的数据,77个型号)
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
+INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 中国系列(7/9)
(3, 1001), (3, 1002), (3, 1003), (3, 1004), (3, 1005), (3, 1006), (3, 1007),
@@ -565,7 +565,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
(3, 1094), (3, 1095);
-- 验证集(约20%的数据,19个型号)
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
+INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 中国系列(2/9)
(4, 1008), (4, 1009),
diff --git a/src/routes.py b/src/routes.py
index 04f60b7..ffcf544 100644
--- a/src/routes.py
+++ b/src/routes.py
@@ -48,6 +48,11 @@ def index():
'url': '/api/evaluate',
'method': 'POST',
'description': '模型评估'
+ },
+ 'analyze-manufacturers': {
+ 'url': '/api/analyze-manufacturers',
+ 'method': 'POST',
+ 'description': '供应商分析'
}
}
})
@@ -114,193 +119,149 @@ def analyze_features():
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
- # 获取数据集信息
+ # 首先获取数据集的装备类型
cursor.execute("""
- SELECT d.*,
- e.type as equipment_type
- FROM datasets d
- JOIN dataset_equipment de ON d.id = de.dataset_id
- JOIN equipment e ON de.equipment_id = e.id
- WHERE d.id = %s
+ SELECT DISTINCT e.type
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
+ WHERE de.dataset_id = %s
LIMIT 1
""", (dataset_id,))
- dataset = cursor.fetchone()
- if not dataset:
- logger.warning(f"Dataset {dataset_id} not found")
- return jsonify({'error': '数据集不存在'}), 404
+ equipment_type = cursor.fetchone()['type']
+ logger.info(f"Equipment type: {equipment_type}")
- logger.info(f"Dataset info: {dataset}")
-
- # 创建特征分析实例
- analyzer = FeatureAnalysis()
-
- # 获取特征列表
- feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
- logger.info(f"Feature names: {feature_names}")
-
- # 获取数据集中的装备数据
- if dataset['equipment_type'] == '火箭炮':
+ # 根据装备类型选择查询
+ if equipment_type == '火箭炮':
cursor.execute("""
- SELECT
- e.name,
- e.id,
- cp.length_m,
- cp.width_m,
- cp.height_m,
- cp.weight_kg,
- rap.max_range_km,
- rap.firing_angle_horizontal,
- rap.firing_angle_vertical,
- rap.rocket_length_m,
- rap.rocket_diameter_mm,
- rap.rocket_weight_kg,
- rap.rate_of_fire,
- rap.combat_weight_kg,
- rap.speed_kmh,
- rap.min_range_km,
- rap.mobility_type,
- rap.structure_layout,
- rap.engine_model,
- rap.power_hp,
- rap.travel_range_km,
- rap.fire_density,
- rap.range_ratio,
- rap.mobility_score,
- rap.combat_readiness_score,
- rap.rocket_power_ratio,
- rap.platform_efficiency,
- rap.deployment_score,
- rap.terrain_adaptability_score,
- cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id,
+ m.tech_level, m.scale_level, m.supply_chain_level, m.country,
+ cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
+ cd.actual_cost,
+ rap.max_range_km, rap.firing_angle_horizontal, rap.firing_angle_vertical,
+ rap.rocket_length_m, rap.rocket_diameter_mm, rap.rocket_weight_kg,
+ rap.rate_of_fire, rap.combat_weight_kg, rap.speed_kmh,
+ rap.min_range_km, rap.power_hp, rap.travel_range_km,
+ rap.fire_density, rap.range_ratio, rap.mobility_score,
+ rap.combat_readiness_score, rap.deployment_score, rap.terrain_adaptability_score,
+ rap.rocket_power_ratio, rap.platform_efficiency
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
+ LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
- ORDER BY e.id
""", (dataset_id,))
else:
cursor.execute("""
- SELECT
- e.name,
- e.id,
- cp.length_m,
- cp.width_m,
- cp.height_m,
- cp.weight_kg,
- lmp.max_range_km,
- lmp.wingspan_m,
- lmp.warhead_weight_kg,
- lmp.max_speed_ms,
- lmp.cruise_speed_kmh,
- lmp.endurance_min,
- lmp.max_payload_kg,
- lmp.ceiling_altitude_m,
- lmp.combat_radius_km,
- lmp.engine_power_kw,
- lmp.engine_thrust_n,
- lmp.datalink_range_km,
- lmp.guidance_accuracy_m,
- lmp.min_altitude_m,
- lmp.max_altitude_m,
- lmp.length_width_ratio,
- lmp.weight_range_ratio,
- lmp.speed_weight_ratio,
- lmp.guidance_system_score,
- lmp.warhead_power_score,
- cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id,
+ m.tech_level, m.scale_level, m.supply_chain_level, m.country,
+ cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
+ cd.actual_cost,
+ lmp.max_range_km, lmp.wingspan_m, lmp.warhead_weight_kg,
+ lmp.max_speed_ms, lmp.cruise_speed_kmh, lmp.endurance_min,
+ lmp.length_width_ratio, lmp.weight_range_ratio,
+ lmp.speed_weight_ratio, lmp.ceiling_altitude_m,
+ lmp.guidance_system_score, lmp.warhead_power_score,
+ lmp.engine_power_kw, lmp.engine_thrust_n,
+ lmp.min_altitude_m, lmp.max_altitude_m,
+ lmp.max_payload_kg, lmp.combat_radius_km,
+ lmp.datalink_range_km, lmp.guidance_accuracy_m
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
+ LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
- ORDER BY e.id
""", (dataset_id,))
equipment_data = cursor.fetchall()
- logger.info(f"Found {len(equipment_data)} equipment records")
- # 提取装备名称列表
- equipment_names = [item['name'] for item in equipment_data]
+ # 添加数据检查日志
+ logger.info(f"Total records found: {len(equipment_data)}")
+ if equipment_data:
+ # 检查第一条记录的所有字段
+ first_record = equipment_data[0]
+ logger.info("First record details:")
+ for key, value in first_record.items():
+ logger.info(f"{key}: {value}")
+
+ # 检查所有记录的 max_range_km 字段
+ logger.info("Checking max_range_km for all records:")
+ for item in equipment_data:
+ logger.info(f"Equipment: {item['name']}")
+ logger.info(f" max_range_km: {item.get('max_range_km')}")
+ logger.info(f" type: {item['type']}")
+ if item['type'] == '火箭炮':
+ logger.info(f" rocket_artillery_params fields:")
+ for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']:
+ logger.info(f" {key}: {item.get(key)}")
- # 提取特征数据和目标值
+ # 提取特征和目标值
+ analyzer = FeatureAnalysis()
+ feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type'])
features = []
targets = []
+
for item in equipment_data:
+ # 计算生产商特征
+ manufacturer_features = analyzer.calculate_manufacturer_features({
+ 'tech_level': item['tech_level'],
+ 'scale_level': item['scale_level'],
+ 'supply_chain_level': item['supply_chain_level'],
+ 'country': item['country']
+ })
+
+ # 获取装备特征
feature_values = []
- for feature in feature_names:
- value = item.get(feature)
- feature_values.append(float(value) if value is not None else 0)
+ for name in feature_names:
+ if name in manufacturer_features:
+ value = manufacturer_features[name]
+ else:
+ value = item.get(name)
+ feature_values.append(float(value) if value is not None else 0.0)
+
features.append(feature_values)
targets.append(float(item['actual_cost']))
- # 进行特征分析
- result = analyzer.analyze_features(features, targets, feature_names)
+ # 执行特征分析
+ analysis_result = analyzer.analyze_features(features, targets, feature_names)
- # 添加装备名称列表到结果中
- result['equipment_names'] = equipment_names
+ # 添加装备名称列表
+ analysis_result['equipment_names'] = [item['name'] for item in equipment_data]
- # 如果是火箭炮,添加额外的分析数据
- if dataset['equipment_type'] == '火箭炮':
+ # 添加装备特有的分析数据
+ if equipment_data[0]['type'] == '火箭炮':
rocket_data = {
- 'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
- 'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
- 'rate_of_fire': [float(item['rate_of_fire']) if item['rate_of_fire'] is not None else 0 for item in equipment_data],
- 'max_range_km': [float(item['max_range_km']) if item['max_range_km'] is not None else 0 for item in equipment_data],
- 'rocket_weight_kg': [float(item['rocket_weight_kg']) if item['rocket_weight_kg'] is not None else 0 for item in equipment_data],
- 'rocket_diameter_mm': [float(item['rocket_diameter_mm']) if item['rocket_diameter_mm'] is not None else 0 for item in equipment_data],
- 'rocket_length_m': [float(item['rocket_length_m']) if item['rocket_length_m'] is not None else 0 for item in equipment_data],
- 'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
- 'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
- 'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data],
- 'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
- 'speed_kmh': [float(item['speed_kmh']) if item['speed_kmh'] is not None else 0 for item in equipment_data],
- 'power_hp': [float(item['power_hp']) if item['power_hp'] is not None else 0 for item in equipment_data],
- 'travel_range_km': [float(item['travel_range_km']) if item['travel_range_km'] is not None else 0 for item in equipment_data]
+ 'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data],
+ 'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data],
+ 'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data],
+ 'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data],
+ 'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data],
+ 'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data]
}
- result.update(rocket_data)
-
- # 如果是巡飞弹,添加额外的分析数据
- if dataset['equipment_type'] == '巡飞弹':
+ analysis_result.update(rocket_data)
+ else:
missile_data = {
- 'equipment_names': equipment_names,
- # 特征工程参数
- 'length_width_ratio': [float(item['length_width_ratio']) if item.get('length_width_ratio') is not None else 0 for item in equipment_data],
- 'weight_range_ratio': [float(item['weight_range_ratio']) if item.get('weight_range_ratio') is not None else 0 for item in equipment_data],
- 'speed_weight_ratio': [float(item['speed_weight_ratio']) if item.get('speed_weight_ratio') is not None else 0 for item in equipment_data],
- 'guidance_system_score': [float(item['guidance_system_score']) if item.get('guidance_system_score') is not None else 0 for item in equipment_data],
- 'warhead_power_score': [float(item['warhead_power_score']) if item.get('warhead_power_score') is not None else 0 for item in equipment_data],
-
- # 动力系统参数
- 'engine_power_kw': [float(item['engine_power_kw']) if item.get('engine_power_kw') is not None else 0 for item in equipment_data],
- 'engine_thrust_n': [float(item['engine_thrust_n']) if item.get('engine_thrust_n') is not None else 0 for item in equipment_data],
-
- # 作战参数
- 'min_altitude_m': [float(item['min_altitude_m']) if item.get('min_altitude_m') is not None else 0 for item in equipment_data],
- 'max_altitude_m': [float(item['max_altitude_m']) if item.get('max_altitude_m') is not None else 0 for item in equipment_data],
- 'max_range_km': [float(item['max_range_km']) if item.get('max_range_km') is not None else 0 for item in equipment_data],
- 'max_payload_kg': [float(item['max_payload_kg']) if item.get('max_payload_kg') is not None else 0 for item in equipment_data],
- 'combat_radius_km': [float(item['combat_radius_km']) if item.get('combat_radius_km') is not None else 0 for item in equipment_data],
- 'datalink_range_km': [float(item['datalink_range_km']) if item.get('datalink_range_km') is not None else 0 for item in equipment_data],
- 'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item.get('guidance_accuracy_m') is not None else 0 for item in equipment_data]
+ 'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data],
+ 'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data],
+ 'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data],
+ 'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data],
+ 'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data],
+ 'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data],
+ 'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data],
+ 'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data],
+ 'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data],
+ 'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data],
+ 'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data]
}
-
- # 验证数据完整性
- for key, value in missile_data.items():
- logger.info(f"{key} data length: {len(value)}")
- logger.info(f"{key} sample data: {value[:3]}")
- if not any(value): # 检查是否所有值都为0
- logger.warning(f"All values are 0 for {key}")
-
- # 更新结果
- result.update(missile_data)
+ analysis_result.update(missile_data)
- return jsonify(result)
+ return jsonify(analysis_result)
except Exception as e:
logger.error(f"Error analyzing features: {str(e)}")
@@ -332,8 +293,8 @@ def train_model():
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@@ -343,8 +304,8 @@ def train_model():
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@@ -360,8 +321,8 @@ def train_model():
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@@ -371,8 +332,8 @@ def train_model():
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@@ -500,7 +461,7 @@ def get_equipment_data():
# 获取所有装备数据(使用equipment_id替代id)
cursor.execute("""
- SELECT e.id as equipment_id, e.name, e.type,
+ SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
cd.actual_cost, cd.predicted_cost,
CASE
@@ -546,7 +507,7 @@ def get_equipment_data():
WHERE equipment_id = e.id
)
END as specific_params
- FROM equipment e
+ FROM equipments e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
ORDER BY e.id
@@ -576,7 +537,7 @@ def delete_equipment(id):
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
- cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
+ cursor.execute("DELETE FROM equipments WHERE id = %s", (id,))
db.commit()
cursor.close()
@@ -737,7 +698,7 @@ def update_equipment(id):
# 更新装备基本信息
cursor.execute("""
- UPDATE equipment
+ UPDATE equipments
SET name = %s, manufacturer = %s
WHERE id = %s
""", (data['name'], data['manufacturer'], equipment_id))
@@ -845,7 +806,7 @@ def get_equipment_details(id):
# 获取装备基本信息类型
cursor.execute("""
SELECT e.*, cp.*, cd.actual_cost, cd.predicted_cost
- FROM equipment e
+ FROM equipments e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.id = %s
@@ -854,7 +815,7 @@ def get_equipment_details(id):
result = cursor.fetchone()
if not result:
logger.warning(f"Equipment with ID {id} not found")
- return jsonify({'error': '装备不存在'}), 404
+ return jsonify({'error': '装备不在'}), 404
logger.info(f"Equipment type: {result['type']}")
logger.info(f"Found equipment details: {result['name']}")
@@ -901,8 +862,8 @@ def get_datasets():
COUNT(de.equipment_id) as equipment_count,
GROUP_CONCAT(e.name) as equipment_names
FROM datasets d
- LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
- LEFT JOIN equipment e ON de.equipment_id = e.id
+ LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
+ LEFT JOIN equipments e ON de.equipment_id = e.id
GROUP BY d.id
""")
datasets = cursor.fetchall()
@@ -932,7 +893,7 @@ def get_dataset(id):
SELECT d.*,
COUNT(de.equipment_id) as equipment_count
FROM datasets d
- LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
+ LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
WHERE d.id = %s
GROUP BY d.id
""", (id,))
@@ -945,14 +906,14 @@ def get_dataset(id):
cursor.execute("""
SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
+ FROM equipments e
+ JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
""", (id,))
equipment = cursor.fetchall()
- # 计算统计信息
+ # 算统计信
if equipment:
total_cost = sum(item['actual_cost'] or 0 for item in equipment)
avg_cost = total_cost / len(equipment)
@@ -989,7 +950,7 @@ def create_dataset():
# 直接从 equipment 表查询,不需要 JOIN
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
cursor.execute(f"""
- SELECT DISTINCT id FROM equipment
+ SELECT DISTINCT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s
""", (data['equipment_type'],))
@@ -1015,7 +976,7 @@ def create_dataset():
if 'equipment_ids' in data and data['equipment_ids']:
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
cursor.executemany("""
- INSERT INTO dataset_equipment (dataset_id, equipment_id)
+ INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s)
""", values)
logger.info(f"Added {len(values)} equipment associations")
@@ -1041,7 +1002,7 @@ def update_dataset(id):
if 'equipment_ids' in data:
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
cursor.execute(f"""
- SELECT id FROM equipment
+ SELECT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s
""", (data['equipment_type'],))
@@ -1064,13 +1025,13 @@ def update_dataset(id):
# 3. 更新装备关联
if 'equipment_ids' in data:
# 先删除旧的关联
- cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
+ cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
# 添加新的关联
if valid_ids: # 确保有有效的ID才执行插入
values = [(id, equipment_id) for equipment_id in valid_ids]
cursor.executemany("""
- INSERT INTO dataset_equipment (dataset_id, equipment_id)
+ INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s)
""", values)
logger.info(f"Updated {len(values)} equipment associations")
@@ -1092,7 +1053,7 @@ def delete_dataset(id):
cursor = conn.cursor()
# 删除装备关联
- cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
+ cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
# 删除数据集
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
@@ -1139,7 +1100,7 @@ def get_models():
models = cursor.fetchall()
- # 确保数值类型字段是 float
+ # 确保数值型字段是 float
for model in models:
if model['r2_score'] is not None:
model['r2_score'] = float(model['r2_score'])
@@ -1161,7 +1122,7 @@ def get_models():
@api_bp.route('/models//activate', methods=['POST'])
def activate_model(id):
"""
- 激活指定的模型
+ 激活定的模型
"""
try:
with get_db_connection() as conn:
@@ -1250,4 +1211,104 @@ def predict_all():
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/analyze-manufacturers', methods=['POST'])
+def analyze_manufacturers():
+ """分析生产商数据"""
+ try:
+ data = request.get_json()
+ dataset_id = data.get('dataset_id')
+
+ logger.info(f"Starting manufacturer analysis for dataset {dataset_id}")
+
+ if not dataset_id:
+ logger.warning("No dataset_id provided")
+ return jsonify({'error': '请选择数据集'}), 400
+
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+
+ # 获取数据集中的装备和生产商数据
+ cursor.execute("""
+ SELECT DISTINCT m.*, e.type as equipment_type
+ FROM manufacturers m
+ JOIN equipments e ON e.manufacturer_id = m.id
+ JOIN dataset_equipments de ON e.id = de.equipment_id
+ WHERE de.dataset_id = %s
+ """, (dataset_id,))
+
+ manufacturers = cursor.fetchall()
+
+ if not manufacturers:
+ return jsonify({'error': '数据集中没有生产商数据'}), 404
+
+ # 准备分析数据
+ manufacturer_names = []
+ tech_levels = []
+ scale_levels = []
+ supply_chain_levels = []
+ composite_scores = []
+ region_count = {}
+ manufacturer_scores = []
+
+ for manu in manufacturers:
+ manufacturer_names.append(manu['name'])
+ tech_levels.append(manu['tech_level'])
+ scale_levels.append(manu['scale_level'])
+ supply_chain_levels.append(manu['supply_chain_level'])
+
+ # 计算综合得分
+ composite_score = (
+ manu['tech_level'] * 0.4 +
+ manu['scale_level'] * 0.3 +
+ manu['supply_chain_level'] * 0.3
+ )
+ composite_scores.append(composite_score)
+
+ # 统计地区分布
+ region_count[manu['country']] = region_count.get(manu['country'], 0) + 1
+
+ # 计算区域系数
+ region_factors = {
+ '美国': 1.2, '英国': 1.15, '德国': 1.15,
+ '法国': 1.15, '以色列': 1.1, '中国': 0.8,
+ '俄罗斯': 0.85, '韩国': 0.9, '日本': 1.1
+ }
+ region_factor = region_factors.get(manu['country'], 1.0)
+
+ # 添加雷达图数据
+ manufacturer_scores.append({
+ 'name': manu['name'],
+ 'value': [
+ manu['tech_level'],
+ manu['scale_level'],
+ manu['supply_chain_level'],
+ region_factor,
+ composite_score
+ ]
+ })
+
+ # 准备地区分布数据
+ region_distribution = [
+ {'name': country, 'value': count}
+ for country, count in region_count.items()
+ ]
+
+ # 返回分析结果
+ result = {
+ 'manufacturer_names': manufacturer_names,
+ 'manufacturer_tech_levels': tech_levels,
+ 'manufacturer_scale_levels': scale_levels,
+ 'manufacturer_supply_chain_levels': supply_chain_levels,
+ 'manufacturer_composite_scores': composite_scores,
+ 'region_distribution': region_distribution,
+ 'manufacturer_scores': manufacturer_scores
+ }
+
+ return jsonify(result)
+
+ except Exception as e:
+ logger.error(f"Error analyzing manufacturers: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
\ No newline at end of file
diff --git a/src/schema.sql b/src/schema.sql
index e5e7cf6..3ff6dde 100644
--- a/src/schema.sql
+++ b/src/schema.sql
@@ -10,11 +10,12 @@ COLLATE utf8mb4_unicode_ci;
USE equipment_cost_db;
-- 装备基本信息表
-CREATE TABLE equipment (
+CREATE TABLE equipments (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100), -- 名称
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
manufacturer VARCHAR(100), -- 制造商
+ manufacturer_id INT, -- 制造商ID
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
@@ -26,8 +27,7 @@ CREATE TABLE common_params (
width_m FLOAT, -- 宽度(m)
height_m FLOAT, -- 高度(m)
weight_kg FLOAT, -- 重量(kg)
- max_range_km FLOAT, -- 最大射程(km)
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+ FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 火箭炮特有参数表
@@ -61,7 +61,7 @@ CREATE TABLE rocket_artillery_params (
deployment_score INT, -- 部署评分(1-10)
terrain_adaptability_score INT, -- 地形适应性评分(1-10)
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+ FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 巡飞弹特有参数表
@@ -103,7 +103,7 @@ CREATE TABLE loitering_munition_params (
power_system_code INT, -- 动力装置编码
guidance_system_code INT, -- 制导系统编码
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+ FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 分类特征编码表
@@ -122,7 +122,7 @@ CREATE TABLE cost_data (
actual_cost DECIMAL(15,2), -- 实际成本(元)
predicted_cost DECIMAL(15,2), -- 预测成本(元)
prediction_date TIMESTAMP, -- 预测日期
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+ FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 特殊参数表
@@ -133,12 +133,12 @@ CREATE TABLE custom_params (
param_value VARCHAR(255), -- 参数值
param_unit VARCHAR(50), -- 参数单位
description TEXT, -- 参数说明
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+ FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
-CREATE INDEX idx_equipment_type ON equipment(type);
-CREATE INDEX idx_equipment_name ON equipment(name);
+CREATE INDEX idx_equipment_type ON equipments(type);
+CREATE INDEX idx_equipment_name ON equipments(name);
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-- 数据集表
@@ -153,12 +153,12 @@ CREATE TABLE datasets (
);
-- 数据集-装备关联表
-CREATE TABLE dataset_equipment (
+CREATE TABLE dataset_equipments (
dataset_id INT NOT NULL,
equipment_id INT NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+ FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
-- 训练模型表
@@ -175,10 +175,34 @@ CREATE TABLE trained_models (
feature_importance JSON, -- 特征重要性
training_data_size INT, -- 训练数据量
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
- is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
+ is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
created_by VARCHAR(50) -- 创建者
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
-CREATE INDEX idx_model_active ON trained_models(is_active);
\ No newline at end of file
+CREATE INDEX idx_model_active ON trained_models(is_active);
+
+-- 生产商表
+CREATE TABLE manufacturers (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ name VARCHAR(100) NOT NULL, -- 生产商名称
+ country VARCHAR(50) NOT NULL, -- 所属国家
+ tech_level INT NOT NULL, -- 技术水平评分(1-10)
+ scale_level INT NOT NULL, -- 规模评分(1-10)
+ supply_chain_level INT NOT NULL, -- 供应链成熟度评分(1-10)
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ UNIQUE KEY unique_name (name)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 添加生产商外键
+ALTER TABLE equipments ADD FOREIGN KEY (manufacturer_id) REFERENCES manufacturers(id);
+
+-- 添加索引
+CREATE INDEX idx_manufacturer_country ON manufacturers(country);
+CREATE INDEX idx_manufacturer_tech_level ON manufacturers(tech_level);
+CREATE INDEX idx_manufacturer_scale_level ON manufacturers(scale_level);
+CREATE INDEX idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
+CREATE INDEX idx_equipment_manufacturer ON equipments(manufacturer_id);
+