将 tensor 改为 torch,并更新依赖,增加了生产商的数据和特征分析。

This commit is contained in:
Tian jianyong 2024-11-25 19:58:39 +08:00
parent 9421512677
commit dba9f2fcc9
26 changed files with 1378 additions and 2092 deletions

View File

@ -1,25 +0,0 @@
# 数据库配置
MYSQL_HOST=localhost
MYSQL_USER=root
MYSQL_PASSWORD=your_password_here
MYSQL_DATABASE=equipment_cost_db
# 服务配置
PORT=5001
DEBUG=False
# 日志配置
LOG_LEVEL=INFO
LOG_DIR=logs
# 模型配置
MODEL_DIR=models
DATA_DIR=data
# 安全配置
SECRET_KEY=your_secret_key_here
ALLOWED_HOSTS=localhost,127.0.0.1
# 其他配置
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
ALLOWED_FILE_TYPES=.xlsx,.xls

39
.gitignore vendored
View File

@ -1,4 +1,42 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
.env
.venv
env/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
node_modules
/dist
/models
@ -9,6 +47,7 @@ node_modules
# local env files
.env.local
.env.*.local
.venv
# Log files
npm-debug.log*

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.11.8

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
# MIT License
Copyright (c) 2024 Your Name or Your Organization
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -21,3 +21,57 @@ collation-server=utf8mb4_unicode_ci
[client]
default-character-set=utf8mb4
```
## 环境配置
本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。
### 使用脚本自动配置(推荐)
Unix/macOS:
```bash
chmod +x scripts/setup_env.sh
./scripts/setup_env.sh
```
Windows (PowerShell):
```powershell
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
.\scripts\setup_env.ps1
```
### 手动配置
1. 安装 pyenv
2. 安装 Python 3.11.8:
```bash
pyenv install 3.11.8
```
3. 设置本地 Python 版本:
```bash
pyenv local 3.11.8
```
4. 创建虚拟环境:
```bash
python -m venv .venv
```
5. 激活虚拟环境:
```bash
source .venv/bin/activate # Unix
.venv\Scripts\activate # Windows
```
6. 安装依赖:
```bash
pip install -e ".[dev]"
```

129
config.py
View File

@ -1,32 +1,103 @@
import os
import secrets
# 数据库配置
DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
class Config:
"""配置类"""
# 数据库配置
MYSQL_HOST = 'localhost'
MYSQL_USER = 'root'
MYSQL_PASSWORD = '123456'
MYSQL_DB = 'equipment_cost_db'
# Flask配置
FLASK_HOST = '0.0.0.0'
FLASK_PORT = 5001
FLASK_DEBUG = True
# 目录配置
MODEL_DIR = 'models'
DATA_DIR = 'data'
LOG_DIR = 'logs'
UPLOAD_DIR = 'uploads'
TEMPLATE_DIR = 'templates'
# 文件上传配置
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 日志配置
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_LEVEL = 'INFO'
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
LOG_BACKUP_COUNT = 5
# PyTorch配置
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 100
# 模型训练配置
TRAIN_TEST_SPLIT = 0.2
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 10
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
# 缓存配置
CACHE_TYPE = 'simple'
CACHE_DEFAULT_TIMEOUT = 300
# 安全配置
SECRET_KEY = 'your-secret-key-here'
JWT_SECRET_KEY = 'your-jwt-secret-key-here'
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
# 跨域配置
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
# 数据验证配置
MAX_EQUIPMENT_NAME_LENGTH = 100
MAX_MANUFACTURER_NAME_LENGTH = 100
@classmethod
def init_app(cls, app):
"""初始化应用配置"""
# 创建必要的目录
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
os.makedirs(directory, exist_ok=True)
# 配置日志
import logging
from logging.handlers import RotatingFileHandler
formatter = logging.Formatter(cls.LOG_FORMAT)
file_handler = RotatingFileHandler(
cls.LOG_FILE,
maxBytes=cls.LOG_MAX_SIZE,
backupCount=cls.LOG_BACKUP_COUNT
)
file_handler.setFormatter(formatter)
file_handler.setLevel(cls.LOG_LEVEL)
app.logger.addHandler(file_handler)
app.logger.setLevel(cls.LOG_LEVEL)
# 配置上传目录
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
# 配置跨域
from flask_cors import CORS
CORS(app, resources={
r"/api/*": {"origins": cls.CORS_ORIGINS}
})
return app
# 安全密钥配置(自动生成随机密钥)
SECRET_KEY = secrets.token_hex(16)
# 环境配置
DEBUG = False
ENV = 'production'
# 文件上传配置
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 跨域配置
CORS_ORIGINS = [
"http://localhost:8080",
"http://127.0.0.1:8080",
]
# 日志配置
LOG_LEVEL = 'DEBUG'
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')
# 创建配置实例
config = Config()

View File

@ -103,7 +103,31 @@
<div class="chart-container">
<div ref="engineChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 制导性能分析 -->
<h3>制导性能分析</h3>
<div class="chart-container">
<div ref="guidanceChartRef" style="width: 100%; height: 600px"></div>
</div>
</template>
<!-- 生产商分析 -->
<h3>生产商分析</h3>
<div class="chart-container">
<div ref="manufacturerChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 生产商地区分布 -->
<h3>生产商地区分布</h3>
<div class="chart-container">
<div ref="regionChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 生产商综合评分 -->
<h3>生产商综合评分</h3>
<div class="chart-container">
<div ref="scoreChartRef" style="width: 100%; height: 600px"></div>
</div>
</div>
</el-card>
</div>
@ -131,6 +155,10 @@ const newFeatureChartRef = ref(null)
const engineChartRef = ref(null)
const fireChartRef = ref(null)
const mobilityChartRef = ref(null)
const manufacturerChartRef = ref(null)
const regionChartRef = ref(null)
const scoreChartRef = ref(null)
const guidanceChartRef = ref(null)
//
const importanceChart = ref(null)
@ -139,6 +167,10 @@ const newFeatureChart = ref(null)
const engineChart = ref(null)
const fireChart = ref(null)
const mobilityChart = ref(null)
const manufacturerChart = ref(null)
const regionChart = ref(null)
const scoreChart = ref(null)
const guidanceChart = ref(null)
//
watch(() => analysisResult.value, async (newResult) => {
@ -236,69 +268,28 @@ const startAnalysis = async () => {
analyzing.value = true
try {
//
console.log('Analysis request params:', {
dataset_id: analysisForm.value.dataset_id,
equipment_type: analysisForm.value.equipment_type
})
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
//
const featureResponse = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
//
console.log('Raw API response:', response)
console.log('Response data type:', typeof response.data)
console.log('Response data:', response.data)
//
if (!response.data) {
throw new Error('API返回的数据为空')
}
//
analysisResult.value = response.data
//
console.log('Analysis result after assignment:', {
value: analysisResult.value,
important_features: analysisResult.value?.important_features,
correlation_analysis: analysisResult.value?.correlation_analysis,
equipment_names: analysisResult.value?.equipment_names,
length_width_ratio: analysisResult.value?.length_width_ratio
//
const manufacturerResponse = await axios.post(`${API_BASE_URL}/analyze-manufacturers`, {
dataset_id: analysisForm.value.dataset_id
})
//
if (analysisForm.value.equipment_type === '巡飞弹') {
const missileData = {
equipment_names: analysisResult.value?.equipment_names || [],
length_width_ratio: analysisResult.value?.length_width_ratio || [],
engine_power_kw: analysisResult.value?.engine_power_kw || [],
guidance_system_score: analysisResult.value?.guidance_system_score || [],
warhead_power_score: analysisResult.value?.warhead_power_score || []
}
console.log('Missile specific data:', missileData)
//
const missingFields = Object.entries(missileData)
.filter(([key, value]) => !Array.isArray(value) || value.length === 0)
.map(([key]) => key)
if (missingFields.length > 0) {
console.warn('Missing or empty missile data fields:', missingFields)
ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
}
//
analysisResult.value = {
...featureResponse.data,
...manufacturerResponse.data
}
//
console.log('Combined analysis result:', analysisResult.value)
} catch (error) {
console.error('Analysis error:', error)
console.error('Error details:', {
message: error.message,
response: error.response?.data,
status: error.response?.status
})
ElMessage.error(error.message || '特征析失败')
ElMessage.error(error.message || '分析失败')
} finally {
analyzing.value = false
}
@ -329,6 +320,18 @@ const createResizeHandler = () => {
if (mobilityChart.value && !mobilityChart.value.isDisposed()) {
mobilityChart.value.resize()
}
if (manufacturerChart.value && !manufacturerChart.value.isDisposed()) {
manufacturerChart.value.resize()
}
if (regionChart.value && !regionChart.value.isDisposed()) {
regionChart.value.resize()
}
if (scoreChart.value && !scoreChart.value.isDisposed()) {
scoreChart.value.resize()
}
if (guidanceChart.value && !guidanceChart.value.isDisposed()) {
guidanceChart.value.resize()
}
} catch (error) {
console.error('Error in resize handler:', error)
}
@ -360,7 +363,7 @@ onUnmounted(() => {
//
[importanceChart, correlationChart, newFeatureChart, engineChart,
fireChart, mobilityChart].forEach(chart => {
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
try {
chart.value.dispose()
@ -384,7 +387,7 @@ const renderCharts = () => {
try {
//
[importanceChart, correlationChart, newFeatureChart, engineChart,
fireChart, mobilityChart].forEach(chart => {
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
chart.value.dispose()
chart.value = null
@ -899,6 +902,156 @@ const renderCharts = () => {
mobilityChart.value.setOption(mobilityOption, { notMerge: true })
}
//
if (manufacturerChartRef.value) {
manufacturerChart.value = echarts.init(manufacturerChartRef.value)
const manufacturerOption = {
title: { text: '生产商特征影响分析' },
tooltip: {
trigger: 'axis',
axisPointer: { type: 'shadow' }
},
legend: {
data: ['技术水平', '规模水平', '供应链水平', '综合得分']
},
xAxis: {
type: 'category',
data: analysisResult.value.manufacturer_names || []
},
yAxis: {
type: 'value',
name: '评分',
min: 0,
max: 10
},
series: [
{
name: '技术水平',
type: 'bar',
data: analysisResult.value.manufacturer_tech_levels || []
},
{
name: '规模水平',
type: 'bar',
data: analysisResult.value.manufacturer_scale_levels || []
},
{
name: '供应链水平',
type: 'bar',
data: analysisResult.value.manufacturer_supply_chain_levels || []
},
{
name: '综合得分',
type: 'line',
data: analysisResult.value.manufacturer_composite_scores || []
}
]
}
manufacturerChart.value.setOption(manufacturerOption)
}
//
if (regionChartRef.value) {
regionChart.value = echarts.init(regionChartRef.value)
const regionOption = {
title: { text: '生产商地区分布' },
tooltip: {
trigger: 'item',
formatter: '{b}: {c} ({d}%)'
},
series: [
{
type: 'pie',
radius: '65%',
data: analysisResult.value.region_distribution || [],
emphasis: {
itemStyle: {
shadowBlur: 10,
shadowOffsetX: 0,
shadowColor: 'rgba(0, 0, 0, 0.5)'
}
}
}
]
}
regionChart.value.setOption(regionOption)
}
//
if (scoreChartRef.value) {
scoreChart.value = echarts.init(scoreChartRef.value)
const scoreOption = {
title: { text: '生产商综合评分雷达图' },
tooltip: {},
radar: {
indicator: [
{ name: '技术水平', max: 10 },
{ name: '规模水平', max: 10 },
{ name: '供应链水平', max: 10 },
{ name: '区域系数', max: 1.5 },
{ name: '综合得分', max: 10 }
]
},
series: [
{
type: 'radar',
data: analysisResult.value.manufacturer_scores || []
}
]
}
scoreChart.value.setOption(scoreOption)
}
//
if (guidanceChartRef.value && analysisForm.value.equipment_type === '巡飞弹') {
guidanceChart.value = echarts.init(guidanceChartRef.value)
const guidanceOption = {
title: { text: '制导性能分析' },
tooltip: {
trigger: 'axis',
axisPointer: { type: 'cross' }
},
legend: {
data: ['制导精度(m)', '数据链距离(km)', '制导系统评分']
},
xAxis: {
type: 'category',
data: analysisResult.value.equipment_names || []
},
yAxis: [
{
type: 'value',
name: '制导精度(m)',
position: 'left'
},
{
type: 'value',
name: '距离(km)',
position: 'right'
}
],
series: [
{
name: '制导精度(m)',
type: 'bar',
data: analysisResult.value.guidance_accuracy_m || []
},
{
name: '数据链距离(km)',
type: 'line',
yAxisIndex: 1,
data: analysisResult.value.datalink_range_km || []
},
{
name: '制导系统评分',
type: 'line',
data: analysisResult.value.guidance_system_score || []
}
]
}
guidanceChart.value.setOption(guidanceOption)
}
console.log('Charts rendered successfully')
} catch (error) {
console.error('Error in chart rendering:', error)

59
pyproject.toml Normal file
View File

@ -0,0 +1,59 @@
[project]
name = "cost-prediction"
version = "0.1.0"
description = "装备成本预测系统"
requires-python = ">=3.9,<3.12"
readme = "README.md"
license = {file = "LICENSE"}
dependencies = [
# Web框架
"flask>=3.1.0",
"flask-cors>=5.0.0",
# 数据库
"sqlalchemy>=2.0.36",
"pymysql>=1.1.1",
"cryptography>=43.0.0",
"mysql-connector-python>=8.0.0",
# 数据处理
"numpy>=1.26.0,<2.0.0",
"pandas>=2.2.0",
# 机器学习
"scikit-learn>=1.5.2",
"torch==2.5.1",
"torchvision==0.20.1",
"torchaudio==2.5.1",
# 工具
"openpyxl>=3.1.5", # Excel支持
"python-dotenv>=1.0.0", # 环境变量
]
[project.optional-dependencies]
dev = [
# 测试工具
"pytest>=7.0",
"black>=22.0", # 代码格式化
"mypy>=1.0", # 类型检查
]
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
[tool.black]
line-length = 88
target-version = ["py39", "py310", "py311"]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true

View File

@ -3,12 +3,11 @@ flask-cors>=5.0.0
sqlalchemy>=2.0.36
pymysql>=1.1.1
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
numpy>=2.0.2
pandas>=2.2.3
urllib3>=2.2.3
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
xlrd>=2.0.1 # 用于读取 .xls 文件
mysql-connector-python>=8.0.0 # 添加这行
numpy>=1.26.0,<2.0.0
pandas>=2.2.0
scikit-learn>=1.5.2
tensorflow>=2.18.0
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
python-dotenv>=1.0.0 # 环境变量

40
run.py
View File

@ -1,13 +1,33 @@
from src.app import create_app
import logging
from src import create_app
from src.logger import setup_logger
from config import config
import os
# 创建应用实例
app = create_app()
logger = setup_logger(__name__)
def main():
try:
# 创建必要的目录
os.makedirs(config.MODEL_DIR, exist_ok=True)
os.makedirs(config.LOG_DIR, exist_ok=True)
os.makedirs(config.DATA_DIR, exist_ok=True)
# 创建并运行应用
app = create_app()
logger.info(f"Starting server in {'debug' if config.FLASK_DEBUG else 'production'} mode")
logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}")
app.run(
host=config.FLASK_HOST,
port=config.FLASK_PORT,
debug=config.FLASK_DEBUG
)
except Exception as e:
logger.error(f"Error starting application: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
if __name__ == '__main__':
# 设置日志
logging.basicConfig(level=logging.INFO)
logging.info('=== Server Starting ===')
logging.info('Initializing directories...')
app.run(host='0.0.0.0', port=5001, debug=True)
main()

121
scripts/setup_env.ps1 Normal file
View File

@ -0,0 +1,121 @@
# 设置错误操作首选项
$ErrorActionPreference = "Stop"
# 检查管理员权限
$isAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
if (-not $isAdmin) {
Write-Warning "建议使用管理员权限运行此脚本"
Start-Sleep -Seconds 3
}
# 检查 pyenv-win 是否安装
if (!(Get-Command pyenv -ErrorAction SilentlyContinue)) {
Write-Host "pyenv not found. Installing..."
try {
# 下载并安装 pyenv-win
Invoke-WebRequest -UseBasicParsing -Uri "https://raw.githubusercontent.com/pyenv-win/pyenv-win/master/pyenv-win/install-pyenv-win.ps1" -OutFile "./install-pyenv-win.ps1"
& ./install-pyenv-win.ps1
# 添加环境变量
$env:PYENV = "$env:USERPROFILE\.pyenv\pyenv-win"
$env:Path = "$env:PYENV\bin;$env:PYENV\shims;$env:Path"
# 刷新环境变量
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
}
catch {
Write-Error "Failed to install pyenv: $_"
exit 1
}
}
try {
# 安装指定版本的 Python
Write-Host "Installing Python 3.11.8..."
pyenv install 3.11.8
if ($LASTEXITCODE -ne 0) {
throw "Failed to install Python 3.11.8"
}
# 设置本地 Python 版本
Write-Host "Setting local Python version..."
pyenv local 3.11.8
if ($LASTEXITCODE -ne 0) {
throw "Failed to set local Python version"
}
# 验证 Python 版本
$pythonVersion = python -V
if (-not $pythonVersion.Contains("3.11.8")) {
throw "Wrong Python version: $pythonVersion"
}
Write-Host "Using Python version: $pythonVersion"
# 创建虚拟环境
Write-Host "Creating virtual environment..."
python -m venv .venv
# 激活虚拟环境
Write-Host "Activating virtual environment..."
.\.venv\Scripts\Activate.ps1
# 升级 pip 和构建工具
Write-Host "Upgrading pip and build tools..."
python -m pip install --upgrade pip setuptools wheel
# 分步安装依赖以确保正确的顺序和版本
Write-Host "Installing database dependencies..."
pip install mysql-connector-python==8.0.33
Write-Host "Installing PyTorch and related packages..."
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
Write-Host "Installing basic dependencies..."
pip install numpy==1.26.4 pandas==2.2.1
Write-Host "Installing machine learning packages..."
pip install scikit-learn==1.5.2
# 安装开发依赖
Write-Host "Installing development dependencies..."
pip install -e ".[dev]"
if ($LASTEXITCODE -ne 0) {
Write-Warning "Failed to install development dependencies. Installing core package..."
pip install -e .
}
# 验证安装
Write-Host "Verifying installations..."
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"
python -c "import pandas; print(f'Pandas version: {pandas.__version__}')"
python -c "import sklearn; print(f'Scikit-learn version: {sklearn.__version__}')"
Write-Host "Environment setup complete!" -ForegroundColor Green
}
catch {
Write-Error "An error occurred: $_"
exit 1
}
finally {
# 清理临时文件
if (Test-Path "./install-pyenv-win.ps1") {
Remove-Item "./install-pyenv-win.ps1"
}
}
# 显示使用说明
Write-Host @"
环境设置完成使用说明
1. 虚拟环境已激活命令提示符前应该显示 (.venv)
2. 要退出虚拟环境运行: deactivate
3. 要重新激活虚拟环境运行: .\.venv\Scripts\Activate.ps1
4. 项目依赖已安装可以开始开发了
如果遇到问题请检查
- Python 版本: python -V
- PyTorch 安装: python -c "import torch; print(torch.__version__)"
- 虚拟环境状态: 确保看到 (.venv) 前缀
"@ -ForegroundColor Cyan

66
scripts/setup_env.sh Executable file
View File

@ -0,0 +1,66 @@
#!/bin/bash
# 检查 pyenv 是否安装
if ! command -v pyenv &> /dev/null; then
echo "pyenv not found. Installing..."
if [[ "$OSTYPE" == "darwin"* ]]; then
brew install pyenv
else
curl https://pyenv.run | bash
fi
fi
# 安装指定版本的 Python
pyenv install 3.11.8 || true
# 设置本地 Python 版本
pyenv local 3.11.8
# 确保使用正确的 Python 版本
eval "$(pyenv init -)"
pyenv shell 3.11.8
# 验证 Python 版本
python_version=$(python -V 2>&1)
if [[ $python_version != *"3.11.8"* ]]; then
echo "Error: Wrong Python version: $python_version"
echo "Please ensure pyenv is properly configured in your shell"
exit 1
fi
# 创建虚拟环境
python -m venv .venv
# 激活虚拟环境
source .venv/bin/activate
# 升级 pip 和构建工具
python -m pip install --upgrade pip setuptools wheel
# 分步安装依赖以确保正确的顺序和版本
echo "Installing database dependencies..."
pip install mysql-connector-python==8.0.33
echo "Installing PyTorch and related packages..."
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
echo "Installing basic dependencies..."
pip install numpy==1.26.4 pandas==2.2.1
echo "Installing machine learning packages..."
pip install scikit-learn==1.5.2
# 安装开发依赖
if ! pip install -e ".[dev]"; then
echo "Warning: Failed to install development dependencies. Installing core package..."
pip install -e .
fi
# 验证安装
echo "Verifying Python version..."
python --version
echo "Verifying PyTorch installation..."
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
echo "Environment setup complete!"

View File

@ -1 +1,3 @@
# 这个文件可以为空,但必须存在
from .app import create_app
__all__ = ['create_app']

View File

@ -2,49 +2,35 @@ from flask import Flask
from flask_cors import CORS
from .routes import api_bp
from .logger import setup_logger
from config import config
import os
# 获取logger
logger = setup_logger(__name__)
def create_app():
"""
创建并配置Flask应用
"""
"""创建并配置 Flask 应用"""
try:
# 创建必要的目录
os.makedirs('logs', exist_ok=True)
os.makedirs('data', exist_ok=True)
os.makedirs('models', exist_ok=True)
logger.info("=== Server Starting ===")
logger.info("Initializing directories...")
# 创建Flask应用
app = Flask(__name__)
# 配置CORS
CORS(app)
logger.info("CORS enabled")
# 注册API蓝图
# 配置数据库连接
app.config['MYSQL_HOST'] = config.MYSQL_HOST
app.config['MYSQL_USER'] = config.MYSQL_USER
app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
app.config['MYSQL_DB'] = config.MYSQL_DB
# 注册路由
app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered")
# 配置数据库连接
app.config['MYSQL_HOST'] = 'localhost'
app.config['MYSQL_USER'] = 'root'
app.config['MYSQL_PASSWORD'] = '123456'
app.config['MYSQL_DB'] = 'equipment_cost_db'
logger.info("Starting server...")
# 记录配置信息
logger.info(f"Database: {app.config['MYSQL_DB']} on {app.config['MYSQL_HOST']}")
logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}")
logger.info(f"Debug mode: {config.FLASK_DEBUG}")
return app
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
raise
if __name__ == '__main__':
app = create_app()
app.run(host='localhost', port=5001)
logger.error(f"Error creating application: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise

View File

@ -1,15 +1,12 @@
import numpy as np
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from scipy import stats
import joblib
import os
import pandas as pd
from .feature_analysis import FeatureAnalysis
import logging
from src.model_trainer import ModelTrainer
from src.database import get_db_connection
from src.feature_analysis import FeatureAnalysis
from .logger import setup_logger
logger = setup_logger(__name__)
@ -21,37 +18,18 @@ class CostPredictor:
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
# 添加 TensorFlow 配置
tf.config.run_functions_eagerly(False) # 启用图执行模式
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.load_model()
def load_model(self):
"""
加载预训练型和标准化器
加载预训练模型和标准化器
"""
try:
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 创建默认模型
self._create_default_model()
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
@ -60,28 +38,24 @@ class CostPredictor:
"""
创建默认模型并进行初始化训练
"""
# 创建输入层
inputs = tf.keras.Input(shape=(11,))
import torch.nn as nn
# 创建隐藏层
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
# 创建输出层
outputs = tf.keras.layers.Dense(1)(x)
# 创建模型
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 编译模型
self.model.compile(
optimizer='adam',
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_absolute_error]
)
class DefaultModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
return self.layers(x)
# 创建示例数据
example_data = pd.DataFrame({
example_features = {
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
@ -93,59 +67,23 @@ class CostPredictor:
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
}
# 转换为 tensor
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
# 训练标准化器
self.scaler_X.fit(example_data)
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
self.scaler_X.fit(X.numpy())
self.scaler_y.fit(y.numpy())
# 设置默认装备类型
self.equipment_type = '火箭炮'
def _create_example_data(self):
"""
创建示例数据来训练标准化器
"""
# 火箭炮示例数据
rocket_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 巡飞弹示例数据
missile_data = pd.DataFrame({
'length_m': [1.3, 2.5],
'width_m': [0.23, 0.6],
'height_m': [0.23, 0.6],
'weight_kg': [12.5, 135],
'max_range_km': [40, 250],
'max_speed_kmh': [180, 185],
'cruise_speed_kmh': [100, 110],
'flight_time_min': [60, 120],
'folded_length_mm': [1300, 2500],
'folded_width_mm': [230, 600],
'folded_height_mm': [230, 600]
})
# 训练标准化器
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
# 设置默认装备类型
# 创建模型
self.model = DefaultModel(X.shape[1]).to(self.device)
self.equipment_type = '火箭炮'
def predict(self, data):
"""
使用训练好的最优模型进行预测
使用训练好的模型进行预测
"""
try:
logger.info(f"Starting prediction for {data.get('type')}")
@ -158,20 +96,31 @@ class CostPredictor:
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
X = []
for feature in features:
value = data.get(feature, 0.0)
X.append(float(value))
# 转换为 tensor
X = torch.tensor([X], dtype=torch.float32).to(self.device)
# 预测
y_pred = trainer.predict(X)
with torch.no_grad():
trainer.model.eval() # 设置为评估模式
y_pred = trainer.model(X)
# 转回 numpy
y_pred = y_pred.cpu().numpy()
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
confidence_interval = self._calculate_confidence_interval(y_pred[0])
# 获取模型类型
model_type = trainer.get_model_type()
return {
'predicted_cost': float(y_pred[0]),
'model_type': model_type, # 返回使用的模型类型
'model_type': model_type,
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
@ -187,11 +136,10 @@ class CostPredictor:
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
# 使用预测值的20%作为标准差
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
@ -213,130 +161,15 @@ class CostPredictor:
"""
模型评估
"""
# 确保输入是 numpy 数组
if torch.is_tensor(y_true):
y_true = y_true.cpu().numpy()
if torch.is_tensor(y_pred):
y_pred = y_pred.cpu().numpy()
return {
'mae': float(mean_absolute_error(y_true, y_pred)),
'mse': float(mean_squared_error(y_true, y_pred)),
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
'r2': float(r2_score(y_true, y_pred))
}
def predict_pls(self, data):
"""
使用 PLS 型预测成本
"""
try:
logger.info(f"Starting PLS prediction for {data.get('type')}")
equipment_type = data.get('type')
# 加载 PLS 模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
raise ValueError(f"No trained PLS model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
return {
'predicted_cost': float(y_pred[0]),
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"PLS prediction error: {str(e)}")
raise
def predict_all(self, data):
"""
使用所有可用模型进行预测
"""
try:
logger.info(f"Starting multi-model prediction for {data.get('type')}")
equipment_type = data.get('type')
results = {}
# 1. 获取所有激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT id, model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
""", (equipment_type,))
active_models = cursor.fetchall()
if not active_models:
raise ValueError(f"No active models found for {equipment_type}")
# 2. 使用每个模型进行预测
trainer = ModelTrainer()
for model_info in active_models:
try:
# 加载特定模型
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
logger.warning(f"Failed to load model: {model_info['model_name']}")
continue
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 保存结果
results[model_info['model_type']] = {
'predicted_cost': float(y_pred[0]),
'model_info': {
'name': model_info['model_name'],
'type': model_info['model_type'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
continue
if not results:
raise ValueError("No successful predictions from any model")
# 3. 计算综合预测结果
all_predictions = [result['predicted_cost'] for result in results.values()]
ensemble_prediction = float(np.mean(all_predictions))
prediction_std = float(np.std(all_predictions))
# 4. 返回所有结果
return {
'individual_predictions': results,
'ensemble_prediction': {
'predicted_cost': ensemble_prediction,
'standard_deviation': prediction_std,
'confidence_interval': {
'lower': float(ensemble_prediction - 1.96 * prediction_std),
'upper': float(ensemble_prediction + 1.96 * prediction_std)
}
}
}
except Exception as e:
logger.error(f"Error in multi-model prediction: {str(e)}")
raise
}

View File

@ -1,29 +1,35 @@
from sklearn.preprocessing import StandardScaler
from datetime import datetime
import os
import joblib
import pandas as pd
import numpy as np
from src.feature_analysis import FeatureAnalysis
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.model_selection import cross_val_score, LeaveOneOut
import json
import torch
from torch.utils.data import Dataset, DataLoader
import logging
from src.database.db_connection import get_db_connection
from sklearn.metrics import mean_absolute_error, mean_squared_error
from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
from .logger import setup_logger
logger = setup_logger(__name__)
class EquipmentDataset(Dataset):
"""装备数据集类"""
def __init__(self, features, targets=None):
self.features = torch.FloatTensor(features)
self.targets = torch.FloatTensor(targets) if targets is not None else None
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
if self.targets is not None:
return self.features[idx], self.targets[idx]
return self.features[idx]
class DataPreparation:
def __init__(self):
self.feature_analyzer = FeatureAnalysis()
self.feature_scaler = StandardScaler()
self.target_scaler = StandardScaler() # 添加目标值标准化器
self.target_scaler = StandardScaler()
def prepare_training_data(self, equipment_data, equipment_type):
def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
"""
准备训练数据
"""
@ -31,19 +37,24 @@ class DataPreparation:
logger.info(f"Preparing training data for {equipment_type}")
logger.info(f"Raw data size: {len(equipment_data)}")
# 如果输入已经是 numpy 数组,直接返回
# 如果输入已经是 numpy 数组,转换为 torch.Tensor
if isinstance(equipment_data, np.ndarray):
X = equipment_data
logger.info(f"Input is already numpy array with shape: {X.shape}")
logger.info(f"Input is numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 转换为 PyTorch 数据集
dataset = EquipmentDataset(X)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return {
'X': X,
'dataloader': dataloader,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
'target_scaler': self.target_scaler,
'raw_shape': X.shape
}
# 从原始数据中提取特征和目标值
@ -51,27 +62,28 @@ class DataPreparation:
features = []
targets = []
for item in equipment_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 获取数据库连接
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 提取目标值(成本)
try:
cost = float(item['actual_cost'])
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logger.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError, KeyError):
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
for item in equipment_data:
# 获取该装备的生产商数据
manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor)
# 计算生产商特征
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data)
# 合并装备特征和生产商特征
feature_values = []
for name in feature_names:
if name in manufacturer_features:
value = manufacturer_features[name]
else:
value = item.get(name)
feature_values.append(float(value) if value is not None else 0.0)
features.append(feature_values)
targets.append(float(item['actual_cost']))
# 转换为numpy数组
X = np.array(features, dtype=float)
@ -85,25 +97,16 @@ class DataPreparation:
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
# 记录标准化后的数据范围
logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
# 记录标准化器参数
logger.info("Feature scaler params:")
logger.info(f"Mean: {self.feature_scaler.mean_}")
logger.info(f"Scale: {self.feature_scaler.scale_}")
logger.info("Target scaler params:")
logger.info(f"Mean: {self.target_scaler.mean_}")
logger.info(f"Scale: {self.target_scaler.scale_}")
# 创建 PyTorch 数据集和数据加载器
dataset = EquipmentDataset(X_scaled, y_scaled)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return {
'X': X_scaled,
'y': y_scaled,
'dataloader': dataloader,
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
'target_scaler': self.target_scaler,
'raw_shape': X.shape
}
except Exception as e:

View File

@ -15,9 +15,9 @@ class FeatureAnalysis:
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
# 火箭炮特有参数
'max_range_km': '最大射程(km)',
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
@ -39,6 +39,7 @@ class FeatureAnalysis:
'terrain_adaptability_score': '地形适应性评分',
# 巡飞弹特有参数
'max_range_km': '最大射程(km)',
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
@ -57,7 +58,14 @@ class FeatureAnalysis:
'weight_range_ratio': '重量射程比',
'speed_weight_ratio': '速度重量比',
'guidance_system_score': '制导系统评分',
'warhead_power_score': '战斗部威力评分'
'warhead_power_score': '战斗部威力评分',
# 添加生产商特征映射
'manufacturer_tech_level': '生产商技术水平',
'manufacturer_scale_level': '生产商规模水平',
'manufacturer_supply_chain_level': '生产商供应链水平',
'manufacturer_composite_score': '生产商综合得分',
'manufacturer_region_factor': '生产商区域系数'
}
def get_equipment_specific_features(self, equipment_type):
@ -121,6 +129,17 @@ class FeatureAnalysis:
'guidance_system_score',
'warhead_power_score'
])
# 添加生产商特征
manufacturer_features = [
'manufacturer_tech_level',
'manufacturer_scale_level',
'manufacturer_supply_chain_level',
'manufacturer_composite_score',
'manufacturer_region_factor'
]
numeric_features.extend(manufacturer_features)
return numeric_features
def analyze_features(self, features, target, feature_names):
@ -234,4 +253,63 @@ class FeatureAnalysis:
except Exception as e:
logger.error(f"Error in analyze_features: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
raise
def calculate_manufacturer_features(self, manufacturer_data):
"""计算生产商相关的特征"""
try:
# 确保所有必要的字段都存在,使用默认值处理缺失数据
tech_level = float(manufacturer_data.get('tech_level', 0))
scale_level = float(manufacturer_data.get('scale_level', 0))
supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0))
country = manufacturer_data.get('country', '未知')
# 计算综合得分
composite_score = (
tech_level * 0.4 + # 技术水平权重最高
scale_level * 0.3 + # 规模水平次之
supply_chain_level * 0.3 # 供应链水平
)
# 计算区域系数(基于不同地区的成本差异)
region_factors = {
'美国': 1.2,
'英国': 1.15,
'德国': 1.15,
'法国': 1.15,
'以色列': 1.1,
'中国': 0.8,
'俄罗斯': 0.85,
'韩国': 0.9,
'日本': 1.1
}
region_factor = region_factors.get(country, 1.0)
# 记录计算过程
logger.info(f"Manufacturer features calculation:")
logger.info(f"Tech level: {tech_level}")
logger.info(f"Scale level: {scale_level}")
logger.info(f"Supply chain level: {supply_chain_level}")
logger.info(f"Country: {country}")
logger.info(f"Composite score: {composite_score}")
logger.info(f"Region factor: {region_factor}")
return {
'manufacturer_tech_level': tech_level,
'manufacturer_scale_level': scale_level,
'manufacturer_supply_chain_level': supply_chain_level,
'manufacturer_composite_score': composite_score,
'manufacturer_region_factor': region_factor
}
except Exception as e:
logger.error(f"Error calculating manufacturer features: {str(e)}")
# 返回默认值而不是抛出异常,确保分析过程可以继续
return {
'manufacturer_tech_level': 0,
'manufacturer_scale_level': 0,
'manufacturer_supply_chain_level': 0,
'manufacturer_composite_score': 0,
'manufacturer_region_factor': 1.0
}

View File

@ -26,7 +26,7 @@ def import_training_data(excel_file):
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
SELECT id FROM equipments
WHERE name = %s AND type = '火箭炮'
""", (row['名称'],))
@ -37,7 +37,7 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (row['名称'], '火箭炮', row['制造商']))
@ -116,7 +116,7 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (
row['名称'],
@ -192,7 +192,7 @@ def import_training_data(excel_file):
logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
SELECT id FROM equipment WHERE name = %s
SELECT id FROM equipments WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()

View File

@ -1,319 +0,0 @@
/*
使
1.
2.
3.
*/
-- 插入装备基本信息
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-- 插入巡飞弹技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES (
1, -- 终结者巡飞弹
0.56,
0.15,
0.20,
2.72,
160.93,
96.56,
24,
15,
'破片杀伤战斗部',
'凭自身动力起飞',
560,
150,
200
);
-- 插入火箭炮技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES (
2, -- 胜利-2火箭炮
10,
2.5,
3.34,
15000,
23
);
-- 插入成本数据(示例数据)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 1000000), -- 终结者巡飞弹成本
(2, 5000000); -- 胜利-2火箭炮成本
-- 插入更多巡飞弹变体数据用于训练
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-- 插入变体技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 终结者-A稍大型号
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
-- 终结者-B稍小型号
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
-- 终结者-C标准型号的改进版
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-- 插入变体成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(3, 1100000), -- 终结者-A成本较高
(4, 900000), -- 终结者-B成本较低
(5, 1050000); -- 终结者-C成本中等
-- 添加更多巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-- 添加它们的技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 哈比
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
-- 游荡者
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
-- 凤凰
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
-- 弹簧刀
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
-- 彩虹-4
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-- 添加成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(6, 800000), -- 哈比
(7, 500000), -- 游荡者
(8, 450000), -- 凤凰
(9, 480000), -- 弹簧刀
(10, 1500000); -- 彩虹-4
-- 火箭炮数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('BM-21', '火箭炮', '俄罗斯'),
('SR5', '火箭炮', '中国'),
('HIMARS', '火箭炮', '美国'),
('LAR-160', '火箭炮', '以色列'),
('T-122', '火箭炮', '土耳其'),
('RM-70', '火箭炮', '捷克'),
('ASTROS II', '火箭炮', '巴西');
-- 火箭炮通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- BM-21
(1, 7.35, 2.4, 3.1, 13700, 20.4),
-- SR5
(2, 10.2, 2.8, 3.2, 28500, 70),
-- HIMARS
(3, 7.0, 2.4, 3.2, 16250, 70),
-- LAR-160
(4, 6.7, 2.5, 2.8, 15000, 45),
-- T-122
(5, 7.2, 2.5, 2.9, 18000, 40),
-- RM-70
(6, 7.5, 2.5, 3.0, 17200, 20.3),
-- ASTROS II
(7, 8.0, 2.7, 3.1, 24500, 90);
-- 火箭炮特有参数
INSERT INTO rocket_artillery_params (
equipment_id,
firing_angle_horizontal,
firing_angle_vertical,
rocket_length_m,
rocket_diameter_mm,
rocket_weight_kg,
rate_of_fire,
combat_weight_kg,
speed_kmh,
min_range_km,
mobility_type,
structure_layout,
engine_model,
engine_params,
power_hp,
travel_range_km
) VALUES
-- BM-21
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
-- SR5
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
-- HIMARS
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
-- LAR-160
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
-- T-122
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
-- RM-70
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
-- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-- 巡飞弹数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('Hero-120', '巡飞弹', '以色列'),
('Switchblade 600', '巡飞弹', '美国'),
('Warmate', '巡飞弹', '波兰'),
('CH-901', '巡飞弹', '中国'),
('HAROP', '巡飞弹', '以色列'),
('Coyote', '巡飞弹', '美国'),
('WS-43', '巡飞弹', '中国');
-- 巡飞弹通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- Hero-120
(8, 1.3, 0.23, 0.23, 12.5, 40),
-- Switchblade 600
(9, 1.3, 0.22, 0.22, 15.0, 40),
-- Warmate
(10, 1.1, 0.15, 0.15, 5.7, 15),
-- CH-901
(11, 1.2, 0.18, 0.18, 9.0, 20),
-- HAROP
(12, 2.5, 0.43, 0.43, 135, 1000),
-- Coyote
(13, 0.9, 0.12, 0.12, 5.9, 20),
-- WS-43
(14, 1.8, 0.35, 0.35, 20, 60);
-- 巡飞弹特有参数
INSERT INTO loitering_munition_params (
equipment_id,
wingspan_m,
warhead_weight_kg,
max_speed_ms,
cruise_speed_kmh,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm,
power_system,
guidance_system
) VALUES
-- Hero-120
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
-- Switchblade 600
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
-- Warmate
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
-- CH-901
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
-- HAROP
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
-- Coyote
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
-- WS-43
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-- 火箭炮
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000), -- ASTROS II
-- 巡飞弹
(8, 150000), -- Hero-120
(9, 180000), -- Switchblade 600
(10, 80000), -- Warmate
(11, 100000), -- CH-901
(12, 850000), -- HAROP
(13, 75000), -- Coyote
(14, 120000); -- WS-43
-- 创建初始数据集
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-- 关联装备到数据集
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 火箭炮训练集
(1, 1), (1, 2), (1, 3), (1, 4),
-- 巡飞弹训练集
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
-- 火箭炮验证集
(3, 5), (3, 6), (3, 7),
-- 巡飞弹验证集
(4, 13), (4, 14);

View File

@ -26,15 +26,15 @@
*/
-- 插入装备基本信息
INSERT INTO equipment (
INSERT INTO equipments (
id, -- 装备ID
name, -- 装备名称
type, -- 装备类型
manufacturer -- 制造商
) VALUES
(1, 'IAI Harop', '巡飞弹', '以色列'),
(2, 'IAI Harpy', '巡飞弹', '以色列'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列'),
(1, 'IAI Harop', '巡飞弹', '以色列 IAI'),
(2, 'IAI Harpy', '巡飞弹', '以色列 IAI'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列 IAI'),
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
@ -65,11 +65,11 @@ INSERT INTO equipment (
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'),
(34, 'Shahed-131', '巡飞弹', '伊朗'),
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
(36, 'Shahed-136', '巡飞弹', '伊朗'),
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
(38, 'Shahed-136C', '巡飞弹', '伊朗'),
(34, 'Shahed-131', '巡飞弹', '伊朗国防工业'),
(35, 'Shahed-131B', '巡飞弹', '伊朗国防工业'),
(36, 'Shahed-136', '巡飞弹', '伊朗国防工业'),
(37, 'Shahed-136B', '巡飞弹', '伊朗国防工业'),
(38, 'Shahed-136C', '巡飞弹', '伊朗国防工业'),
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
@ -285,7 +285,7 @@ INSERT INTO loitering_munition_params (
(24, 2.8, 8.0, 70, 180, 240, 50, 10.0, 4000, 25, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(25, 3.0, 9.0, 75, 190, 270, 60, 11.0, 4500, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(26, 3.2, 10.0, 80, 200, 300, 70, 12.0, 5000, 35, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红'),
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红<EFBFBD><EFBFBD><EFBFBD>'),
(28, 3.6, 16.0, 90, 230, 400, 120, 20.0, 6500, 60, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
(29, 1.2, 1.0, 40, 90, 30, 5, 1.5, 1500, 3, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
(30, 1.3, 1.2, 45, 100, 40, 8, 2.0, 2000, 4, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
@ -338,7 +338,7 @@ INSERT INTO loitering_munition_params (
(77, 2.8, 40.0, 250, 200, 120, 180, 50.0, 5500, 90, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Light
(78, 3.2, 80.0, 230, 180, 150, 200, 100.0, 6000, 100, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Heavy
(79, 1.5, 3.5, 160, 140, 60, 50, 5.0, 3500, 25, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助'), -- Taifun
(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '动机', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K
(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K
(81, 1.5, 3.0, 120, 100, 60, 40, 4.0, 3000, 20, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ES
(82, 2.0, 5.0, 140, 120, 90, 60, 6.0, 4000, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ER
(83, 2.5, 8.0, 160, 140, 120, 80, 10.0, 5000, 40, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'), -- HERO-XL
@ -469,7 +469,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(2, '巡飞弹验证集 2024', '包含20个巡飞弹型号用于验证模型性能', '巡飞弹', '验证');
-- 训练集80个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 以色列系列8/10
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
(1, 4), (1, 5), (1, 6), (1, 7), (1, 8), -- Hero系列
@ -520,7 +520,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
(1, 96), (1, 97), (1, 98), (1, 99); -- Shadow/Argus系列
-- 验证集20个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 以色列系列2/10
(2, 9), -- Hero-900
(2, 48), -- Rotem L
@ -666,4 +666,43 @@ SET
WHEN l.max_range_km > 500 THEN 5000
WHEN l.max_range_km > 100 THEN 3000
ELSE 1500
END;
END;
-- 更新巡飞弹的制导精度
UPDATE loitering_munition_params l
SET guidance_accuracy_m =
CASE
-- 基础精度(根据制导系统类型)
WHEN guidance_system LIKE '%GPS/INS%' AND guidance_system LIKE '%AI辅助%' THEN 2.0
WHEN guidance_system LIKE '%GPS/INS%' THEN 3.0
WHEN guidance_system LIKE '%激光制导%' THEN 1.0
WHEN guidance_system LIKE '%红外制导%' THEN 2.0
WHEN guidance_system LIKE '%卫星制导%' THEN 2.5
ELSE 5.0
END *
-- 速度影响因子(速度越快,精度略微降低)
CASE
WHEN max_speed_ms > 200 THEN 1.2
WHEN max_speed_ms > 150 THEN 1.1
WHEN max_speed_ms > 100 THEN 1.0
ELSE 0.9
END *
-- 重量影响因子(重量越大,精度略微降低)
CASE
WHEN warhead_weight_kg > 100 THEN 1.2
WHEN warhead_weight_kg > 50 THEN 1.1
WHEN warhead_weight_kg > 20 THEN 1.0
ELSE 0.9
END *
-- 飞行高度影响因子(高度越高,精度略微降低)
CASE
WHEN ceiling_altitude_m > 5000 THEN 1.2
WHEN ceiling_altitude_m > 3000 THEN 1.1
WHEN ceiling_altitude_m > 1000 THEN 1.0
ELSE 0.9
END
WHERE equipment_id IN (
SELECT id FROM equipments WHERE type = '巡飞弹'
);

View File

@ -40,7 +40,7 @@ INSERT INTO manufacturers (
('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业
-- 俄罗斯供应商
('俄罗斯', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
('俄罗斯 Rostec', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商
('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商
@ -72,7 +72,9 @@ INSERT INTO manufacturers (
('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业
-- 更新装备表中的供应商ID
UPDATE equipment e
SET manufacturer_id = m.id
FROM manufacturers m
WHERE e.manufacturer = m.name;
UPDATE equipments e
SET manufacturer_id = (
SELECT id
FROM manufacturers m
WHERE m.name = e.manufacturer
);

View File

@ -1,282 +1,112 @@
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.impute import SimpleImputer
import xgboost as xgb
import lightgbm as lgb
from sklearn.model_selection import train_test_split
import logging
import joblib
import os
from src.feature_analysis import FeatureAnalysis
from datetime import datetime
import json
from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
from src.data_preparation import DataPreparation
from sklearn.cross_decomposition import PLSRegression
from src.data_preparation import DataPreparation, EquipmentDataset
from .logger import setup_logger
logger = setup_logger(__name__)
class CostPredictionModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
return self.layers(x)
class ModelTrainer:
def __init__(self):
"""
初始化 ModelTrainer
"""
self.models = {
'xgboost': self._create_xgboost_model(),
'lightgbm': self._create_lightgbm_model(),
'gbm': self._create_gbm_model(),
'rf': self._create_rf_model(),
'pls': self._create_pls_model()
}
self.best_model = None
self.imputer = SimpleImputer(strategy='mean')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.feature_scaler = None
self.target_scaler = None
self.equipment_type = None
self.feature_analyzer = FeatureAnalysis()
def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
"""
训练模型并返回评估结果
"""
def train_model(self, dataloader, epochs=100, learning_rate=0.001):
"""训练模型"""
try:
self.equipment_type = equipment_type
logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
# 获取输入特征维度
sample_features, _ = next(iter(dataloader))
input_size = sample_features.shape[1]
results = {}
best_score = -float('inf')
best_model_info = None
# 创建模型
self.model = CostPredictionModel(input_size).to(self.device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
# 首先训练 PLS 模型
logger.info("Training pls...")
pls_model = self.models['pls']
pls_model.fit(X_train, y_train)
pls_metrics = self._calculate_metrics(
pls_model,
X_train, y_train,
X_val, y_val
)
results['pls'] = pls_metrics
# 训练其他机器学习模型
for model_name in model_names:
if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
continue
# 训练循环
for epoch in range(epochs):
self.model.train()
total_loss = 0
for batch_features, batch_targets in dataloader:
# 移动数据到设备
batch_features = batch_features.to(self.device)
batch_targets = batch_targets.to(self.device)
if model_name not in self.models:
logger.warning(f"Unknown model: {model_name}")
continue
# 前向传播
outputs = self.model(batch_features)
loss = criterion(outputs, batch_targets.view(-1, 1))
logger.info(f"Training {model_name}...")
model = self.models[model_name]
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# 训练模型
model.fit(X_train, y_train)
# 计算评估指标
metrics = self._calculate_metrics(
model,
X_train, y_train,
X_val, y_val
)
results[model_name] = metrics
# 更新最佳模型(只在机器学习模型中比较)
if metrics['validation']['r2'] > best_score:
best_score = metrics['validation']['r2']
best_model_info = {
'type': model_name,
'r2': metrics['validation']['r2'],
'mae': metrics['validation']['mae'],
'rmse': metrics['validation']['rmse']
}
self.best_model = model
# 记录训练进度
if (epoch + 1) % 10 == 0:
avg_loss = total_loss / len(dataloader)
logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
# 保存最佳模型和 PLS 模型
if equipment_type and best_model_info:
self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
return {
'metrics': results,
'best_model': best_model_info
}
return True
except Exception as e:
logger.error(f"Error in model training: {str(e)}")
raise
def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
"""
计算模型评估指标
"""
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
# 训练集评估
train_pred = model.predict(X_train)
# 如果使用了标准化,需要转换回原始范围
if hasattr(self, 'target_scaler'):
train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
else:
y_train_orig = y_train
# 记录预测范围
logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
train_metrics = {
'r2': r2_score(y_train_orig, train_pred),
'mae': mean_absolute_error(y_train_orig, train_pred),
'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
}
# 验证集评估
if X_val is not None and y_val is not None:
val_pred = model.predict(X_val)
# 如果使用了标准化,需要转换回原始范围
if hasattr(self, 'target_scaler'):
val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
else:
y_val_orig = y_val
# 记录预测范围
logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
val_metrics = {
'r2': r2_score(y_val_orig, val_pred),
'mae': mean_absolute_error(y_val_orig, val_pred),
'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
}
else:
# 使用交叉验证
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
val_metrics = {
'r2': cv_scores.mean(),
'mae': None,
'rmse': None
}
return {
'train': train_metrics,
'validation': val_metrics
}
def _create_xgboost_model(self):
"""
创建 XGBoost 模型增强正则化
"""
return xgb.XGBRegressor(
n_estimators=50, # 减少树的数量
learning_rate=0.05, # 学习率
max_depth=3, # 减小树的深
min_child_weight=3, # 增加节点权重
subsample=0.7, # 减小样本采样比例
colsample_bytree=0.7, # 减小特征采样比例
reg_alpha=0.1, # L1 正则化
reg_lambda=1, # L2 正则化
random_state=42
)
def _create_lightgbm_model(self):
"""
创建 LightGBM 模型增强正则化
"""
return lgb.LGBMRegressor(
n_estimators=50,
learning_rate=0.05,
max_depth=3,
num_leaves=7,
min_data_in_leaf=3,
min_sum_hessian_in_leaf=1e-3,
subsample=0.7,
colsample_bytree=0.7,
reg_alpha=0.1,
reg_lambda=1,
random_state=42,
verbose=-1
)
def _create_gbm_model(self):
"""
创建 GBM 模型增强正则化以减轻过拟合
"""
return GradientBoostingRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42,
subsample=0.8,
min_samples_split=3,
min_samples_leaf=2
)
def _create_rf_model(self):
"""
创建随机森林模型针对小样本数据调整参数
"""
return RandomForestRegressor(
n_estimators=100,
max_depth=3,
random_state=42,
min_samples_split=3,
min_samples_leaf=2
)
def _create_pls_model(self):
"""
创建 PLS 模型优化参数配置
"""
return PLSRegression(
n_components=2, # 减少主成分数量从5减到2
scale=True, # 保持数据标准化
max_iter=500, # 减少最大迭代次数,避免过拟合
tol=1e-6 # 降低收敛精度,避免过拟合
)
def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
"""
保存最佳模型和 PLS 模型
"""
def save_model(self, equipment_type):
"""保存模型"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 1. 保存最佳机器学习模型
model_path = f'{model_dir}/{equipment_type}_{timestamp}'
if isinstance(self.best_model, xgb.XGBRegressor):
self.best_model.save_model(f'{model_path}.json')
model_format = 'json'
else:
joblib.dump(self.best_model, f'{model_path}.joblib')
model_format = 'joblib'
# 2. 保存 PLS 模型
pls_model = self.models['pls']
pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
joblib.dump(pls_model, pls_path)
# 保存模型
model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth'
torch.save({
'model_state_dict': self.model.state_dict(),
'input_size': self.model.layers[0].in_features
}, model_path)
# 3. 保存标准化器
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
joblib.dump({
# 保存标准化器
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth'
torch.save({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
logger.info(f"Saved best model to {model_path}.{model_format}")
logger.info(f"Saved PLS model to {pls_path}")
logger.info(f"Saved scalers to {scaler_path}")
# 4. 更新数据库中的模型记录
# 更新数据库
with get_db_connection() as conn:
cursor = conn.cursor()
@ -287,420 +117,73 @@ class ModelTrainer:
WHERE equipment_type = %s
""", (equipment_type,))
# 获取 PLS 模型的评估指标
pls_metrics = self._calculate_metrics(
self.models['pls'],
X_train,
y_train,
X_val,
y_val
)
# 保存最佳机器学习模型记录
self.best_model.equipment_type = equipment_type # 设置装备类型
ml_feature_importance = self._get_feature_importance(self.best_model)
# 保存新模型记录
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path, scaler_path,
r2_score, mae, rmse, feature_importance, training_data_size,
training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
f"{equipment_type}_{timestamp}", # model_name
best_model_info['type'], # model_type
equipment_type, # equipment_type
f"{model_path}.{model_format}", # model_path
scaler_path, # scaler_path
best_model_info['r2'], # r2_score
best_model_info['mae'], # mae
best_model_info['rmse'], # rmse
json.dumps(ml_feature_importance), # feature_importance
len(X_train), # training_data_size
'system' # created_by
))
# 保存 PLS 模型记录
pls_model.equipment_type = equipment_type # 设置装备类型
pls_feature_importance = self._get_feature_importance(pls_model)
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path, scaler_path,
r2_score, mae, rmse, feature_importance, training_data_size,
training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
f"{equipment_type}_{timestamp}_pls", # model_name
'pls', # model_type
equipment_type, # equipment_type
pls_path, # model_path
scaler_path, # scaler_path
float(pls_metrics['validation']['r2']), # r2_score
float(pls_metrics['validation']['mae']), # mae
float(pls_metrics['validation']['rmse']), # rmse
json.dumps(pls_feature_importance), # feature_importance
len(X_train), # training_data_size
'system' # created_by
f"{equipment_type}_{timestamp}",
'pytorch',
equipment_type,
model_path,
scaler_path,
'system'
))
conn.commit()
except Exception as e:
logger.error(f"Error saving models: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
def load_model(self, equipment_type, model_type='ml'):
"""
加载已训练的模型
"""
try:
logger.info(f"Loading {model_type} model for {equipment_type}")
# 从数据库获取激活的模型
return True
except Exception as e:
logger.error(f"Error saving model: {str(e)}")
return False
def load_model(self, equipment_type):
"""加载模型"""
try:
# 从数据库获取最新的激活模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 构建查询语句
if model_type == 'pls':
query = """
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type = 'pls'
AND is_active = TRUE
LIMIT 1
"""
params = (equipment_type,)
else:
query = """
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type != 'pls'
AND is_active = TRUE
LIMIT 1
"""
params = (equipment_type,)
# 记录查询信息
logger.info(f"Executing query: {query}")
logger.info(f"Query params: {params}")
cursor.execute(query, params)
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
model_record = cursor.fetchone()
# 记录查询结果
if model_record:
logger.info(f"Found model record: {model_record}")
else:
logger.warning(f"No active model found for type {model_type}")
if not model_record:
return False
# 检查文件是否存在
logger.info(f"Checking model file: {model_record['model_path']}")
logger.info(f"Checking scaler file: {model_record['scaler_path']}")
if not os.path.exists(model_record['model_path']):
logger.error(f"Model file not found: {model_record['model_path']}")
raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
if not os.path.exists(model_record['scaler_path']):
logger.error(f"Scaler file not found: {model_record['scaler_path']}")
raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
# 加载模型文件
logger.info(f"Loading model from {model_record['model_path']}")
if model_type == 'pls':
self.best_model = joblib.load(model_record['model_path'])
logger.info("Loaded PLS model")
else:
if model_record['model_type'] == 'xgboost':
self.best_model = xgb.XGBRegressor()
self.best_model.load_model(model_record['model_path'])
logger.info("Loaded XGBoost model")
else:
self.best_model = joblib.load(model_record['model_path'])
logger.info(f"Loaded {model_record['model_type']} model")
# 加载模型
checkpoint = torch.load(model_record['model_path'])
input_size = checkpoint['input_size']
self.model = CostPredictionModel(input_size).to(self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
# 加载标准化器
logger.info(f"Loading scalers from {model_record['scaler_path']}")
scalers = joblib.load(model_record['scaler_path'])
scalers = torch.load(model_record['scaler_path'])
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
logger.info("Loaded scalers successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.error(f"Detailed traceback:", exc_info=True)
return False
def get_missile_features(self):
"""获取巡飞弹的特征列表"""
return [
# 基本参数
'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km',
# 性能参数 - 新增和修改的参数
'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
'endurance_min', 'max_payload_kg', 'min_combat_radius_km',
# 动力系统参数 - 新增参数
'engine_power_kw', 'engine_thrust_n',
# 制导与控制参数 - 新增参数
'datalink_range_km', 'guidance_accuracy_m',
'min_altitude_m', 'max_altitude_m',
# 特征工程参数 - 新增评分指标
'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio',
'guidance_system_score', 'warhead_power_score'
]
def get_rocket_features(self):
"""获取火箭炮的特征列表"""
return [
# 基本参数
'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km',
# 火箭炮特有参数 - 新增和修改的参数
'firing_angle_horizontal', 'firing_angle_vertical',
'rocket_length_m', 'rocket_diameter_mm', 'rocket_weight_kg',
'rate_of_fire', 'combat_weight_kg', 'speed_kmh',
'min_range_km', 'max_range_km', 'mobility_type', 'structure_layout',
'engine_model', 'power_hp', 'travel_range_km',
# 特征工程参数 - 新增评分指标
'fire_density', 'range_ratio', 'mobility_score',
'combat_readiness_score', 'rocket_power_ratio',
'platform_efficiency', 'deployment_score',
'terrain_adaptability_score'
]
def predict(self, features):
"""使用加载的模型进行预测"""
"""使用模型进行预测"""
try:
if not self.best_model:
raise ValueError("No model loaded")
if not self.feature_scaler:
raise ValueError("Feature scaler not loaded")
if not self.target_scaler:
raise ValueError("Target scaler not loaded")
logger.info("Starting prediction")
logger.info(f"Input features shape: {features.shape}")
logger.info(f"Input features: \n{features}")
# 获取正确的特征列表
if self.equipment_type == '巡飞弹':
feature_list = self.get_missile_features()
logger.info(f"Using missile features: {feature_list}")
# 确保特征顺序一致
features_ordered = np.zeros((features.shape[0], len(feature_list)))
for i, feature_name in enumerate(feature_list):
if feature_name in features:
features_ordered[:, i] = features[feature_name]
features = features_ordered
# 处理缺失值
features_filled = np.array(features, dtype=float)
features_filled[np.isnan(features_filled)] = 0
features_filled = np.nan_to_num(features_filled, 0)
logger.info(f"Filled features: \n{features_filled}")
# 标准化特征
X = self.feature_scaler.transform(features_filled)
logger.info(f"Transformed features shape: {X.shape}")
logger.info(f"Transformed features: \n{X}")
# 预测
y_pred_scaled = self.best_model.predict(X)
logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
logger.info(f"Scaled prediction: {y_pred_scaled}")
# 反标准化
y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
logger.info(f"Final prediction shape: {y_pred.shape}")
logger.info(f"Final prediction: {y_pred}")
return y_pred.ravel()
self.model.eval()
with torch.no_grad():
# 转换为tensor并移动到正确的设备
features_tensor = torch.FloatTensor(features).to(self.device)
# 进行预测
predictions = self.model(features_tensor)
# 移回CPU并转换为numpy数组
return predictions.cpu().numpy()
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise
def _get_feature_importance(self, model):
"""
获取特征重要性
"""
try:
if not model:
return {}
# 获取征名称
if self.equipment_type == '巡飞弹':
feature_names = [
# 基本参数
'length_m', 'width_m', 'height_m', 'weight_kg',
# 性能参数
'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
'endurance_min', 'max_range_km','max_payload_kg', 'ceiling_altitude_m',
'combat_radius_km',
# 动力系统参数
'engine_power_kw', 'engine_thrust_n',
# 制导与控制参数
'datalink_range_km', 'guidance_accuracy_m',
'min_altitude_m', 'max_altitude_m',
# 特征工程参数
'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio',
'guidance_system_score', 'warhead_power_score'
]
else:
# 其他装备类型使用原有的特征获取逻辑
feature_analyzer = FeatureAnalysis()
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
# 获取特征重要性
if hasattr(model, 'feature_importances_'):
importances = model.feature_importances_
elif hasattr(model, 'coef_'):
if len(model.coef_.shape) > 1: # 如果是二维数组
importances = np.abs(model.coef_[0]) # 取第一行
else:
importances = np.abs(model.coef_)
else:
return {}
# 创建特征重要性字典
importance_dict = {}
for name, importance in zip(feature_names, importances):
importance_dict[name] = float(importance) # 确保转换为 Python 标量
# 按重要性降序排序
sorted_dict = dict(sorted(
importance_dict.items(),
key=lambda x: x[1],
reverse=True
))
# 过滤掉重要性为0的特征
return {k: v for k, v in sorted_dict.items() if v > 0}
except Exception as e:
logger.error(f"Error getting feature importance: {str(e)}")
return {}
def _calculate_confidence_interval(self, prediction, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
lower = max(1000, interval[0]) # 最小值设为1000元
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
return [lower, upper]
except Exception as e:
logger.error(f"Error calculating confidence interval: {str(e)}")
# 如果计算失败返回基于20%的简单区间
lower = max(1000, prediction * 0.8)
upper = prediction * 1.2
return [lower, upper]
def get_model_type(self):
"""
获取当前模型的类型
"""
if isinstance(self.best_model, xgb.XGBRegressor):
return 'xgboost'
elif isinstance(self.best_model, lgb.LGBMRegressor):
return 'lightgbm'
elif isinstance(self.best_model, GradientBoostingRegressor):
return 'gbm'
elif isinstance(self.best_model, RandomForestRegressor):
return 'rf'
else:
return 'unknown'
def _get_pls_feature_importance(self):
"""
获取 PLS 模型的特征重要性
"""
try:
if not self.models['pls']:
return {}
# 获取特征名称
feature_analyzer = FeatureAnalysis()
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
# 获取 PLS 模型的系数作为特征重要性
pls_model = self.models['pls']
if hasattr(pls_model, 'coef_'):
# 使用绝对值作为重要性指标
importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
else:
return {}
# 创建特征重要性字典
importance_dict = {}
for name, importance in zip(feature_names, importances):
importance_dict[name] = float(importance) # 确保转换为 Python 标量
# 按重要性降序排序
sorted_dict = dict(sorted(
importance_dict.items(),
key=lambda x: x[1],
reverse=True
))
# 过滤掉重要性为0的特征
return {k: v for k, v in sorted_dict.items() if v > 0}
except Exception as e:
logger.error(f"Error getting PLS feature importance: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return {}
def _preprocess_data(self, data):
"""数据预处理"""
try:
# 获取正确的特征列表
if self.equipment_type == '巡飞弹':
feature_list = self.get_missile_features()
else:
feature_list = self.get_rocket_features()
logger.info(f"Using features: {feature_list}")
# 处理缺失值
features_filled = np.array(data, dtype=float)
features_filled[np.isnan(features_filled)] = 0
features_filled = np.nan_to_num(features_filled, 0)
logger.info(f"Filled features: \n{features_filled}")
return features_filled
except Exception as e:
logger.error(f"Error in data preprocessing: {str(e)}")
raise

View File

@ -1,485 +0,0 @@
-- 清空现有数据
SET FOREIGN_KEY_CHECKS=0;
TRUNCATE TABLE dataset_equipment;
TRUNCATE TABLE datasets;
TRUNCATE TABLE cost_data;
TRUNCATE TABLE loitering_munition_params;
TRUNCATE TABLE common_params;
TRUNCATE TABLE equipment;
SET FOREIGN_KEY_CHECKS=1;
-- 按系列插入装备数据确保ID连续
-- 1. HAROP/Harpy 系列 (ID: 1-3)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1, 'IAI Harop', '巡飞弹', '以色列'),
(2, 'IAI Harpy', '巡飞弹', '以色列'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
-- 2. Hero 系列 (ID: 4-9)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
(9, 'Hero-900', '巡飞弹', '以色列 UVision');
-- 3. Switchblade 系列 (ID: 10-13)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
-- 4. Warmate 系列 (ID: 14-18)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
-- 5. CH-901/902 系列 (ID: 19-23)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(19, 'CH-901', '巡飞弹', '中国航天科工'),
(20, 'CH-901A', '巡飞弹', '中国航天科工'),
(21, 'CH-901H', '巡飞弹', '中国航天科工'),
(22, 'CH-902', '巡飞弹', '中国航天科工'),
(23, 'CH-902A', '巡飞弹', '中国航天科工');
-- 6. WS-43/61 系列 (ID: 24-28)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(24, 'WS-43', '巡飞弹', '中国航天科工'),
(25, 'WS-43A', '巡飞弹', '中国航天科工'),
(26, 'WS-43B', '巡飞弹', '中国航天科工'),
(27, 'WS-61', '巡飞弹', '中国航天科工'),
(28, 'WS-61A', '巡飞弹', '中国航天科工');
-- 7. Kargu/Alpagu 系列 (ID: 29-33)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(29, 'Kargu', '巡飞弹', '土耳其 STM'),
(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
-- 8. Shahed 系列 (ID: 34-38)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(34, 'Shahed-131', '巡飞弹', '伊朗'),
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
(36, 'Shahed-136', '巡飞弹', '伊朗'),
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
(38, 'Shahed-136C', '巡飞弹', '伊朗');
-- 9. Green Dragon 系列 (ID: 39-43)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
-- 10. Phoenix Ghost 系列 (ID: 44-48)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
-- 11. ZALA Lancet 系列 (ID: 49-52)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
-- 12. Rotem L 系列 (ID: 53-56)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
-- 13. KUB-BLA 系列 (ID: 57-60)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
-- 插入通用参数
INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
(10, 0.58, 0.12, 0.12, 2.5, 10),
(11, 1.30, 0.22, 0.22, 15.0, 40),
(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
(14, 0.68, 0.12, 0.12, 2.5, 10),
(15, 1.30, 0.22, 0.22, 15.0, 40),
(16, 0.68, 0.12, 0.12, 2.5, 10),
(17, 1.30, 0.22, 0.22, 15.0, 40),
(18, 0.68, 0.12, 0.12, 2.5, 10),
(19, 1.2, 0.18, 0.18, 9.0, 20),
(20, 1.2, 0.18, 0.18, 9.3, 25),
(21, 1.2, 0.18, 0.18, 9.5, 20),
(22, 1.4, 0.22, 0.22, 15.0, 30),
(23, 1.4, 0.22, 0.22, 15.5, 35),
(24, 1.8, 0.35, 0.35, 20, 60),
(25, 1.8, 0.35, 0.35, 21, 70),
(26, 1.9, 0.35, 0.35, 22, 80),
(27, 2.2, 0.40, 0.40, 35, 100),
(28, 2.2, 0.40, 0.40, 37, 120),
(29, 0.6, 0.35, 0.35, 7.0, 10),
(30, 0.6, 0.35, 0.35, 7.2, 15),
(31, 1.0, 0.23, 0.23, 3.7, 5),
(32, 1.0, 0.23, 0.23, 3.9, 8),
(33, 0.6, 0.35, 0.35, 7.5, 15),
(34, 2.6, 0.34, 0.34, 135, 900),
(35, 2.6, 0.34, 0.34, 140, 1000),
(36, 3.5, 0.42, 0.42, 200, 2000),
(37, 3.5, 0.42, 0.42, 210, 2200),
(38, 3.5, 0.42, 0.42, 215, 2500),
(39, 1.5, 0.20, 0.20, 15, 40),
(40, 1.6, 0.20, 0.20, 16, 50),
(41, 1.5, 0.20, 0.20, 15.5, 45),
(42, 1.5, 0.20, 0.20, 15.8, 40),
(43, 1.2, 0.18, 0.18, 12, 30),
(44, 1.5, 0.25, 0.25, 14.0, 30),
(45, 1.5, 0.25, 0.25, 14.5, 35),
(46, 1.6, 0.26, 0.26, 15.0, 40),
(47, 1.5, 0.25, 0.25, 14.8, 30),
(48, 1.7, 0.27, 0.27, 16.0, 50),
(49, 1.0, 0.20, 0.20, 5.0, 40),
(50, 1.65, 0.35, 0.35, 12.0, 70),
(51, 1.65, 0.35, 0.35, 12.5, 80),
(52, 1.80, 0.40, 0.40, 15.0, 100),
(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
-- 插入特有参数
INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
endurance_min,
warhead_type,
launch_mode,
power_system,
guidance_system
) VALUES
-- HAROP/Harpy系列
(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
-- Hero系列
(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
-- Switchblade系列
(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-- Warmate系列
(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
-- CH-901/902系列
(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
-- 插入成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 800000), -- IAI Harop
(2, 700000), -- IAI Harpy
(3, 350000), -- IAI Mini Harpy
(4, 70000), -- Hero-30
(5, 120000), -- Hero-70
(6, 150000), -- Hero-120
(7, 300000), -- Hero-250
(8, 400000), -- Hero-400EC
(9, 650000), -- Hero-900
(10, 60000), -- Switchblade 300
(11, 180000), -- Switchblade 600
(12, 75000), -- Switchblade 300 Block 10
(13, 200000), -- Switchblade 600 Extended Range
(14, 60000), -- Warmate 1.0
(15, 180000), -- Warmate 2.0
(16, 60000), -- Warmate-V
(17, 180000), -- Warmate-L
(18, 60000), -- Warmate 3.0
(19, 100000), -- CH-901
(20, 120000), -- CH-901A
(21, 130000), -- CH-901H
(22, 180000), -- CH-902
(23, 200000), -- CH-902A
(24, 120000), -- WS-43
(25, 150000), -- WS-43A
(26, 180000), -- WS-43B
(27, 300000), -- WS-61
(28, 350000), -- WS-61A
(29, 70000), -- Kargu
(30, 85000), -- Kargu-2
(31, 45000), -- Alpagu
(32, 55000), -- Alpagu Block-II
(33, 95000), -- Kargu Autonomous
(34, 20000), -- Shahed-131
(35, 25000), -- Shahed-131B
(36, 40000), -- Shahed-136
(37, 45000), -- Shahed-136B
(38, 50000), -- Shahed-136C
(39, 160000), -- Green Dragon
(40, 200000), -- Green Dragon Extended Range
(41, 180000), -- Green Dragon Block 2
(42, 190000), -- Green Dragon Maritime
(43, 140000), -- Green Dragon-S
(44, 150000), -- Phoenix Ghost
(45, 180000), -- Phoenix Ghost Block I
(46, 220000), -- Phoenix Ghost Block II
(47, 190000), -- Phoenix Ghost Maritime
(48, 250000), -- Phoenix Ghost-ER
(49, 80000), -- Lancet-1
(50, 150000), -- Lancet-3
(51, 180000), -- Lancet-3M
(52, 250000), -- Lancet-4
(53, 65000), -- Rotem L
(54, 85000), -- Rotem L-X
(55, 75000), -- Rotem L-M
(56, 95000), -- Rotem L-ER
(57, 95000), -- KUB-BLA
(58, 120000), -- KUB-BLA-E
(59, 110000), -- KUB-BLA-M
(60, 150000); -- KUB-BLA-ER
-- 创建数据集
INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
-- 关联装备到数据集(按照制造商和型号分配)
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 训练集约80%的数据48个型号
-- 以色列系列
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
-- 美国系列
(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
-- 中国系列
(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
-- 波兰和土耳其系列
(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
-- 俄罗斯系列
(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
-- 验证集约20%的数据12个型号
-- 混合系列
(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
-- 添加分类特征编码
INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
-- 战斗部类型编码
('warhead_type', '破片杀伤战斗部', 1),
('warhead_type', '破甲战斗部', 2),
('warhead_type', '高爆战斗部', 3),
('warhead_type', '破片杀伤/破甲双用战斗部', 4),
('warhead_type', '模块化战斗部', 5),
-- 发射方式编码
('launch_mode', '箱式发射', 1),
('launch_mode', '弹射式发射', 2),
('launch_mode', '垂直起降', 3),
('launch_mode', '单兵发射管', 4),
('launch_mode', '箱式发射/弹射式', 5),
('launch_mode', '箱式发射/空中发射', 6),
-- 动力装置编码(按复杂度递增)
('power_system', '电动机', 1),
('power_system', '活塞发动机', 2),
-- 制导系统编码(按复杂度递增)
('guidance_system', 'GPS/INS', 1),
('guidance_system', 'GPS/INS/光电', 2),
('guidance_system', 'GPS/INS/光电/数据链', 3),
('guidance_system', 'GPS/INS/光电/AI识别', 4),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
UPDATE loitering_munition_params l
JOIN common_params c ON l.equipment_id = c.equipment_id
SET
-- 新增关键参数
l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
l.engine_power_kw =
CASE
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
END,
l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
l.guidance_accuracy_m =
CASE
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
ELSE 3.0
END,
l.min_altitude_m = -- 最小作战高度
CASE
-- 大型巡飞弹(体型大、重量大)
WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
-- 中型巡飞弹
WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
-- 中小型巡飞弹
WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
-- 小型巡飞弹
WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
-- 超小型巡飞弹
WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
-- 其他型号使用默认值
ELSE 50
END,
l.max_altitude_m =
CASE
WHEN c.max_range_km > 500 THEN 5000
WHEN c.max_range_km > 100 THEN 3000
ELSE 1500
END,
-- 特征工程字段
l.length_width_ratio = c.length_m / c.width_m,
l.weight_range_ratio = c.weight_kg / c.max_range_km,
l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
l.guidance_system_score =
CASE
WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
ELSE 2
END,
l.warhead_power_score =
CASE
WHEN l.warhead_type = '模块化战斗部' THEN 10
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
WHEN l.warhead_type = '高爆战斗部' THEN 7
WHEN l.warhead_type = '破甲战斗部' THEN 6
WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
ELSE 4
END,
-- 分类特征编码
l.warhead_type_code =
CASE
WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
WHEN l.warhead_type = '破甲战斗部' THEN 2
WHEN l.warhead_type = '高爆战斗部' THEN 3
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
WHEN l.warhead_type = '模块化战斗部' THEN 5
ELSE 0
END,
l.launch_mode_code =
CASE
WHEN l.launch_mode = '箱式发射' THEN 1
WHEN l.launch_mode = '弹射式发射' THEN 2
WHEN l.launch_mode = '垂直起降' THEN 3
WHEN l.launch_mode = '单兵发射管' THEN 4
WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
ELSE 0
END,
l.power_system_code =
CASE
WHEN l.power_system = '电动机' THEN 1
WHEN l.power_system = '活塞发动机' THEN 2
ELSE 0
END,
l.guidance_system_code =
CASE
WHEN l.guidance_system = 'GPS/INS' THEN 1
WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
ELSE 0
END;

View File

@ -29,7 +29,7 @@
*/
-- 中国系列火箭炮数据
INSERT INTO equipment (id, name, type, manufacturer) VALUES
INSERT INTO equipments (id, name, type, manufacturer) VALUES
(1001, 'PCL-191', '火箭炮', '中国兵器工业集团'),
(1002, 'PHL-03', '火箭炮', '中国兵器工业集团'),
(1003, 'AR-3', '火箭炮', '中国航天科工'),
@ -39,11 +39,11 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1007, 'WS-2', '火箭炮', '中国航天科工'),
(1008, 'WS-3', '火箭炮', '中国航天科工'),
(1009, 'Type 63', '火箭炮', '中国兵器工业集团'),
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯'),
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯'),
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯'),
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯'),
(1014, 'TOS-1A', '火箭炮', '俄罗斯'),
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯 Rostec'),
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯 Rostec'),
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯 Rostec'),
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯 Rostec'),
(1014, 'TOS-1A', '火箭炮', '俄罗斯 Rostec'),
(1015, 'M142 HIMARS', '火箭炮', '美国洛克希德·马丁'),
(1016, 'M270 MLRS', '火箭炮', '美国洛克希德·马丁'),
(1017, 'M270A1', '火箭炮', '美国洛克希德·马丁'),
@ -62,10 +62,10 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1030, 'ASTROS 2020', '火箭炮', '巴西航空工业'),
(1031, 'ASTROS II Mk3', '火箭炮', '巴西航空工业'),
(1032, 'ASTROS II Mk6', '火箭炮', '巴西航空工业'),
(1033, 'Pinaka', '火箭炮', '印度DRDO'),
(1034, 'Pinaka Mk-II', '火箭炮', '印度DRDO'),
(1035, 'Pinaka Mk-III', '火箭炮', '印度DRDO'),
(1036, 'Pinaka-ER', '火箭炮', '印度DRDO'),
(1033, 'Pinaka', '火箭炮', '印度 DRDO'),
(1034, 'Pinaka Mk-II', '火箭炮', '印度 DRDO'),
(1035, 'Pinaka Mk-III', '火箭炮', '印度 DRDO'),
(1036, 'Pinaka-ER', '火箭炮', '印度 DRDO'),
(1037, 'WR-40 Langusta', '火箭炮', '波兰胡塔斯塔洛瓦'),
(1038, 'RM-70', '火箭炮', '波兰胡塔斯塔洛瓦'),
(1039, 'BM-21M', '火箭炮', '波兰胡塔斯塔洛瓦'),
@ -485,7 +485,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(4, '火箭炮验证集 2024', '包含19个火箭炮型号用于验证模型性能', '火箭炮', '验证');
-- 训练集约80%的数据77个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 中国系列7/9
(3, 1001), (3, 1002), (3, 1003), (3, 1004), (3, 1005), (3, 1006), (3, 1007),
@ -565,7 +565,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
(3, 1094), (3, 1095);
-- 验证集约20%的数据19个型号
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
-- 中国系列2/9
(4, 1008), (4, 1009),

View File

@ -48,6 +48,11 @@ def index():
'url': '/api/evaluate',
'method': 'POST',
'description': '模型评估'
},
'analyze-manufacturers': {
'url': '/api/analyze-manufacturers',
'method': 'POST',
'description': '供应商分析'
}
}
})
@ -114,193 +119,149 @@ def analyze_features():
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取数据集信息
# 首先获取数据集的装备类型
cursor.execute("""
SELECT d.*,
e.type as equipment_type
FROM datasets d
JOIN dataset_equipment de ON d.id = de.dataset_id
JOIN equipment e ON de.equipment_id = e.id
WHERE d.id = %s
SELECT DISTINCT e.type
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
WHERE de.dataset_id = %s
LIMIT 1
""", (dataset_id,))
dataset = cursor.fetchone()
if not dataset:
logger.warning(f"Dataset {dataset_id} not found")
return jsonify({'error': '数据集不存在'}), 404
equipment_type = cursor.fetchone()['type']
logger.info(f"Equipment type: {equipment_type}")
logger.info(f"Dataset info: {dataset}")
# 创建特征分析实例
analyzer = FeatureAnalysis()
# 获取特征列表
feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
logger.info(f"Feature names: {feature_names}")
# 获取数据集中的装备数据
if dataset['equipment_type'] == '火箭炮':
# 根据装备类型选择查询
if equipment_type == '火箭炮':
cursor.execute("""
SELECT
e.name,
e.id,
cp.length_m,
cp.width_m,
cp.height_m,
cp.weight_kg,
rap.max_range_km,
rap.firing_angle_horizontal,
rap.firing_angle_vertical,
rap.rocket_length_m,
rap.rocket_diameter_mm,
rap.rocket_weight_kg,
rap.rate_of_fire,
rap.combat_weight_kg,
rap.speed_kmh,
rap.min_range_km,
rap.mobility_type,
rap.structure_layout,
rap.engine_model,
rap.power_hp,
rap.travel_range_km,
rap.fire_density,
rap.range_ratio,
rap.mobility_score,
rap.combat_readiness_score,
rap.rocket_power_ratio,
rap.platform_efficiency,
rap.deployment_score,
rap.terrain_adaptability_score,
cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id,
m.tech_level, m.scale_level, m.supply_chain_level, m.country,
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
cd.actual_cost,
rap.max_range_km, rap.firing_angle_horizontal, rap.firing_angle_vertical,
rap.rocket_length_m, rap.rocket_diameter_mm, rap.rocket_weight_kg,
rap.rate_of_fire, rap.combat_weight_kg, rap.speed_kmh,
rap.min_range_km, rap.power_hp, rap.travel_range_km,
rap.fire_density, rap.range_ratio, rap.mobility_score,
rap.combat_readiness_score, rap.deployment_score, rap.terrain_adaptability_score,
rap.rocket_power_ratio, rap.platform_efficiency
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
ORDER BY e.id
""", (dataset_id,))
else:
cursor.execute("""
SELECT
e.name,
e.id,
cp.length_m,
cp.width_m,
cp.height_m,
cp.weight_kg,
lmp.max_range_km,
lmp.wingspan_m,
lmp.warhead_weight_kg,
lmp.max_speed_ms,
lmp.cruise_speed_kmh,
lmp.endurance_min,
lmp.max_payload_kg,
lmp.ceiling_altitude_m,
lmp.combat_radius_km,
lmp.engine_power_kw,
lmp.engine_thrust_n,
lmp.datalink_range_km,
lmp.guidance_accuracy_m,
lmp.min_altitude_m,
lmp.max_altitude_m,
lmp.length_width_ratio,
lmp.weight_range_ratio,
lmp.speed_weight_ratio,
lmp.guidance_system_score,
lmp.warhead_power_score,
cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id,
m.tech_level, m.scale_level, m.supply_chain_level, m.country,
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
cd.actual_cost,
lmp.max_range_km, lmp.wingspan_m, lmp.warhead_weight_kg,
lmp.max_speed_ms, lmp.cruise_speed_kmh, lmp.endurance_min,
lmp.length_width_ratio, lmp.weight_range_ratio,
lmp.speed_weight_ratio, lmp.ceiling_altitude_m,
lmp.guidance_system_score, lmp.warhead_power_score,
lmp.engine_power_kw, lmp.engine_thrust_n,
lmp.min_altitude_m, lmp.max_altitude_m,
lmp.max_payload_kg, lmp.combat_radius_km,
lmp.datalink_range_km, lmp.guidance_accuracy_m
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
ORDER BY e.id
""", (dataset_id,))
equipment_data = cursor.fetchall()
logger.info(f"Found {len(equipment_data)} equipment records")
# 提取装备名称列表
equipment_names = [item['name'] for item in equipment_data]
# 添加数据检查日志
logger.info(f"Total records found: {len(equipment_data)}")
if equipment_data:
# 检查第一条记录的所有字段
first_record = equipment_data[0]
logger.info("First record details:")
for key, value in first_record.items():
logger.info(f"{key}: {value}")
# 检查所有记录的 max_range_km 字段
logger.info("Checking max_range_km for all records:")
for item in equipment_data:
logger.info(f"Equipment: {item['name']}")
logger.info(f" max_range_km: {item.get('max_range_km')}")
logger.info(f" type: {item['type']}")
if item['type'] == '火箭炮':
logger.info(f" rocket_artillery_params fields:")
for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']:
logger.info(f" {key}: {item.get(key)}")
# 提取特征数据和目标值
# 提取特征和目标值
analyzer = FeatureAnalysis()
feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type'])
features = []
targets = []
for item in equipment_data:
# 计算生产商特征
manufacturer_features = analyzer.calculate_manufacturer_features({
'tech_level': item['tech_level'],
'scale_level': item['scale_level'],
'supply_chain_level': item['supply_chain_level'],
'country': item['country']
})
# 获取装备特征
feature_values = []
for feature in feature_names:
value = item.get(feature)
feature_values.append(float(value) if value is not None else 0)
for name in feature_names:
if name in manufacturer_features:
value = manufacturer_features[name]
else:
value = item.get(name)
feature_values.append(float(value) if value is not None else 0.0)
features.append(feature_values)
targets.append(float(item['actual_cost']))
# 进行特征分析
result = analyzer.analyze_features(features, targets, feature_names)
# 行特征分析
analysis_result = analyzer.analyze_features(features, targets, feature_names)
# 添加装备名称列表到结果中
result['equipment_names'] = equipment_names
# 添加装备名称列表
analysis_result['equipment_names'] = [item['name'] for item in equipment_data]
# 如果是火箭炮,添加额外的分析数据
if dataset['equipment_type'] == '火箭炮':
# 添加装备特有的分析数据
if equipment_data[0]['type'] == '火箭炮':
rocket_data = {
'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
'rate_of_fire': [float(item['rate_of_fire']) if item['rate_of_fire'] is not None else 0 for item in equipment_data],
'max_range_km': [float(item['max_range_km']) if item['max_range_km'] is not None else 0 for item in equipment_data],
'rocket_weight_kg': [float(item['rocket_weight_kg']) if item['rocket_weight_kg'] is not None else 0 for item in equipment_data],
'rocket_diameter_mm': [float(item['rocket_diameter_mm']) if item['rocket_diameter_mm'] is not None else 0 for item in equipment_data],
'rocket_length_m': [float(item['rocket_length_m']) if item['rocket_length_m'] is not None else 0 for item in equipment_data],
'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data],
'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
'speed_kmh': [float(item['speed_kmh']) if item['speed_kmh'] is not None else 0 for item in equipment_data],
'power_hp': [float(item['power_hp']) if item['power_hp'] is not None else 0 for item in equipment_data],
'travel_range_km': [float(item['travel_range_km']) if item['travel_range_km'] is not None else 0 for item in equipment_data]
'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data],
'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data],
'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data],
'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data],
'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data],
'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data]
}
result.update(rocket_data)
# 如果是巡飞弹,添加额外的分析数据
if dataset['equipment_type'] == '巡飞弹':
analysis_result.update(rocket_data)
else:
missile_data = {
'equipment_names': equipment_names,
# 特征工程参数
'length_width_ratio': [float(item['length_width_ratio']) if item.get('length_width_ratio') is not None else 0 for item in equipment_data],
'weight_range_ratio': [float(item['weight_range_ratio']) if item.get('weight_range_ratio') is not None else 0 for item in equipment_data],
'speed_weight_ratio': [float(item['speed_weight_ratio']) if item.get('speed_weight_ratio') is not None else 0 for item in equipment_data],
'guidance_system_score': [float(item['guidance_system_score']) if item.get('guidance_system_score') is not None else 0 for item in equipment_data],
'warhead_power_score': [float(item['warhead_power_score']) if item.get('warhead_power_score') is not None else 0 for item in equipment_data],
# 动力系统参数
'engine_power_kw': [float(item['engine_power_kw']) if item.get('engine_power_kw') is not None else 0 for item in equipment_data],
'engine_thrust_n': [float(item['engine_thrust_n']) if item.get('engine_thrust_n') is not None else 0 for item in equipment_data],
# 作战参数
'min_altitude_m': [float(item['min_altitude_m']) if item.get('min_altitude_m') is not None else 0 for item in equipment_data],
'max_altitude_m': [float(item['max_altitude_m']) if item.get('max_altitude_m') is not None else 0 for item in equipment_data],
'max_range_km': [float(item['max_range_km']) if item.get('max_range_km') is not None else 0 for item in equipment_data],
'max_payload_kg': [float(item['max_payload_kg']) if item.get('max_payload_kg') is not None else 0 for item in equipment_data],
'combat_radius_km': [float(item['combat_radius_km']) if item.get('combat_radius_km') is not None else 0 for item in equipment_data],
'datalink_range_km': [float(item['datalink_range_km']) if item.get('datalink_range_km') is not None else 0 for item in equipment_data],
'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item.get('guidance_accuracy_m') is not None else 0 for item in equipment_data]
'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data],
'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data],
'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data],
'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data],
'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data],
'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data],
'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data],
'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data],
'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data],
'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data],
'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data]
}
# 验证数据完整性
for key, value in missile_data.items():
logger.info(f"{key} data length: {len(value)}")
logger.info(f"{key} sample data: {value[:3]}")
if not any(value): # 检查是否所有值都为0
logger.warning(f"All values are 0 for {key}")
# 更新结果
result.update(missile_data)
analysis_result.update(missile_data)
return jsonify(result)
return jsonify(analysis_result)
except Exception as e:
logger.error(f"Error analyzing features: {str(e)}")
@ -332,8 +293,8 @@ def train_model():
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@ -343,8 +304,8 @@ def train_model():
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@ -360,8 +321,8 @@ def train_model():
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@ -371,8 +332,8 @@ def train_model():
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
@ -500,7 +461,7 @@ def get_equipment_data():
# 获取所有装备数据(使用equipment_id替代id)
cursor.execute("""
SELECT e.id as equipment_id, e.name, e.type,
SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
cd.actual_cost, cd.predicted_cost,
CASE
@ -546,7 +507,7 @@ def get_equipment_data():
WHERE equipment_id = e.id
)
END as specific_params
FROM equipment e
FROM equipments e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
ORDER BY e.id
@ -576,7 +537,7 @@ def delete_equipment(id):
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
cursor.execute("DELETE FROM equipments WHERE id = %s", (id,))
db.commit()
cursor.close()
@ -737,7 +698,7 @@ def update_equipment(id):
# 更新装备基本信息
cursor.execute("""
UPDATE equipment
UPDATE equipments
SET name = %s, manufacturer = %s
WHERE id = %s
""", (data['name'], data['manufacturer'], equipment_id))
@ -845,7 +806,7 @@ def get_equipment_details(id):
# 获取装备基本信息类型
cursor.execute("""
SELECT e.*, cp.*, cd.actual_cost, cd.predicted_cost
FROM equipment e
FROM equipments e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.id = %s
@ -854,7 +815,7 @@ def get_equipment_details(id):
result = cursor.fetchone()
if not result:
logger.warning(f"Equipment with ID {id} not found")
return jsonify({'error': '装备不'}), 404
return jsonify({'error': '装备不'}), 404
logger.info(f"Equipment type: {result['type']}")
logger.info(f"Found equipment details: {result['name']}")
@ -901,8 +862,8 @@ def get_datasets():
COUNT(de.equipment_id) as equipment_count,
GROUP_CONCAT(e.name) as equipment_names
FROM datasets d
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
LEFT JOIN equipment e ON de.equipment_id = e.id
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
LEFT JOIN equipments e ON de.equipment_id = e.id
GROUP BY d.id
""")
datasets = cursor.fetchall()
@ -932,7 +893,7 @@ def get_dataset(id):
SELECT d.*,
COUNT(de.equipment_id) as equipment_count
FROM datasets d
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
WHERE d.id = %s
GROUP BY d.id
""", (id,))
@ -945,14 +906,14 @@ def get_dataset(id):
cursor.execute("""
SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
""", (id,))
equipment = cursor.fetchall()
# 算统计信
# 算统计信
if equipment:
total_cost = sum(item['actual_cost'] or 0 for item in equipment)
avg_cost = total_cost / len(equipment)
@ -989,7 +950,7 @@ def create_dataset():
# 直接从 equipment 表查询,不需要 JOIN
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
cursor.execute(f"""
SELECT DISTINCT id FROM equipment
SELECT DISTINCT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s
""", (data['equipment_type'],))
@ -1015,7 +976,7 @@ def create_dataset():
if 'equipment_ids' in data and data['equipment_ids']:
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
cursor.executemany("""
INSERT INTO dataset_equipment (dataset_id, equipment_id)
INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s)
""", values)
logger.info(f"Added {len(values)} equipment associations")
@ -1041,7 +1002,7 @@ def update_dataset(id):
if 'equipment_ids' in data:
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
cursor.execute(f"""
SELECT id FROM equipment
SELECT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s
""", (data['equipment_type'],))
@ -1064,13 +1025,13 @@ def update_dataset(id):
# 3. 更新装备关联
if 'equipment_ids' in data:
# 先删除旧的关联
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
# 添加新的关联
if valid_ids: # 确保有有效的ID才执行插入
values = [(id, equipment_id) for equipment_id in valid_ids]
cursor.executemany("""
INSERT INTO dataset_equipment (dataset_id, equipment_id)
INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s)
""", values)
logger.info(f"Updated {len(values)} equipment associations")
@ -1092,7 +1053,7 @@ def delete_dataset(id):
cursor = conn.cursor()
# 删除装备关联
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
# 删除数据集
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
@ -1139,7 +1100,7 @@ def get_models():
models = cursor.fetchall()
# 确保数值型字段是 float
# 确保数值型字段是 float
for model in models:
if model['r2_score'] is not None:
model['r2_score'] = float(model['r2_score'])
@ -1161,7 +1122,7 @@ def get_models():
@api_bp.route('/models/<int:id>/activate', methods=['POST'])
def activate_model(id):
"""
激活定的模型
激活定的模型
"""
try:
with get_db_connection() as conn:
@ -1250,4 +1211,104 @@ def predict_all():
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/analyze-manufacturers', methods=['POST'])
def analyze_manufacturers():
"""分析生产商数据"""
try:
data = request.get_json()
dataset_id = data.get('dataset_id')
logger.info(f"Starting manufacturer analysis for dataset {dataset_id}")
if not dataset_id:
logger.warning("No dataset_id provided")
return jsonify({'error': '请选择数据集'}), 400
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取数据集中的装备和生产商数据
cursor.execute("""
SELECT DISTINCT m.*, e.type as equipment_type
FROM manufacturers m
JOIN equipments e ON e.manufacturer_id = m.id
JOIN dataset_equipments de ON e.id = de.equipment_id
WHERE de.dataset_id = %s
""", (dataset_id,))
manufacturers = cursor.fetchall()
if not manufacturers:
return jsonify({'error': '数据集中没有生产商数据'}), 404
# 准备分析数据
manufacturer_names = []
tech_levels = []
scale_levels = []
supply_chain_levels = []
composite_scores = []
region_count = {}
manufacturer_scores = []
for manu in manufacturers:
manufacturer_names.append(manu['name'])
tech_levels.append(manu['tech_level'])
scale_levels.append(manu['scale_level'])
supply_chain_levels.append(manu['supply_chain_level'])
# 计算综合得分
composite_score = (
manu['tech_level'] * 0.4 +
manu['scale_level'] * 0.3 +
manu['supply_chain_level'] * 0.3
)
composite_scores.append(composite_score)
# 统计地区分布
region_count[manu['country']] = region_count.get(manu['country'], 0) + 1
# 计算区域系数
region_factors = {
'美国': 1.2, '英国': 1.15, '德国': 1.15,
'法国': 1.15, '以色列': 1.1, '中国': 0.8,
'俄罗斯': 0.85, '韩国': 0.9, '日本': 1.1
}
region_factor = region_factors.get(manu['country'], 1.0)
# 添加雷达图数据
manufacturer_scores.append({
'name': manu['name'],
'value': [
manu['tech_level'],
manu['scale_level'],
manu['supply_chain_level'],
region_factor,
composite_score
]
})
# 准备地区分布数据
region_distribution = [
{'name': country, 'value': count}
for country, count in region_count.items()
]
# 返回分析结果
result = {
'manufacturer_names': manufacturer_names,
'manufacturer_tech_levels': tech_levels,
'manufacturer_scale_levels': scale_levels,
'manufacturer_supply_chain_levels': supply_chain_levels,
'manufacturer_composite_scores': composite_scores,
'region_distribution': region_distribution,
'manufacturer_scores': manufacturer_scores
}
return jsonify(result)
except Exception as e:
logger.error(f"Error analyzing manufacturers: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500

View File

@ -10,11 +10,12 @@ COLLATE utf8mb4_unicode_ci;
USE equipment_cost_db;
-- 装备基本信息表
CREATE TABLE equipment (
CREATE TABLE equipments (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100), -- 名称
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
manufacturer VARCHAR(100), -- 制造商
manufacturer_id INT, -- 制造商ID
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
@ -26,8 +27,7 @@ CREATE TABLE common_params (
width_m FLOAT, -- 宽度(m)
height_m FLOAT, -- 高度(m)
weight_kg FLOAT, -- 重量(kg)
max_range_km FLOAT, -- 最大射程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 火箭炮特有参数表
@ -61,7 +61,7 @@ CREATE TABLE rocket_artillery_params (
deployment_score INT, -- 部署评分(1-10)
terrain_adaptability_score INT, -- 地形适应性评分(1-10)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 巡飞弹特有参数表
@ -103,7 +103,7 @@ CREATE TABLE loitering_munition_params (
power_system_code INT, -- 动力装置编码
guidance_system_code INT, -- 制导系统编码
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 分类特征编码表
@ -122,7 +122,7 @@ CREATE TABLE cost_data (
actual_cost DECIMAL(15,2), -- 实际成本(元)
predicted_cost DECIMAL(15,2), -- 预测成本(元)
prediction_date TIMESTAMP, -- 预测日期
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 特殊参数表
@ -133,12 +133,12 @@ CREATE TABLE custom_params (
param_value VARCHAR(255), -- 参数值
param_unit VARCHAR(50), -- 参数单位
description TEXT, -- 参数说明
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_equipment_type ON equipment(type);
CREATE INDEX idx_equipment_name ON equipment(name);
CREATE INDEX idx_equipment_type ON equipments(type);
CREATE INDEX idx_equipment_name ON equipments(name);
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-- 数据集表
@ -153,12 +153,12 @@ CREATE TABLE datasets (
);
-- 数据集-装备关联表
CREATE TABLE dataset_equipment (
CREATE TABLE dataset_equipments (
dataset_id INT NOT NULL,
equipment_id INT NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
-- 训练模型表
@ -175,10 +175,34 @@ CREATE TABLE trained_models (
feature_importance JSON, -- 特征重要性
training_data_size INT, -- 训练数据量
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
created_by VARCHAR(50) -- 创建者
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX idx_model_active ON trained_models(is_active);
CREATE INDEX idx_model_active ON trained_models(is_active);
-- 生产商表
CREATE TABLE manufacturers (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL, -- 生产商名称
country VARCHAR(50) NOT NULL, -- 所属国家
tech_level INT NOT NULL, -- 技术水平评分(1-10)
scale_level INT NOT NULL, -- 规模评分(1-10)
supply_chain_level INT NOT NULL, -- 供应链成熟度评分(1-10)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_name (name)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加生产商外键
ALTER TABLE equipments ADD FOREIGN KEY (manufacturer_id) REFERENCES manufacturers(id);
-- 添加索引
CREATE INDEX idx_manufacturer_country ON manufacturers(country);
CREATE INDEX idx_manufacturer_tech_level ON manufacturers(tech_level);
CREATE INDEX idx_manufacturer_scale_level ON manufacturers(scale_level);
CREATE INDEX idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
CREATE INDEX idx_equipment_manufacturer ON equipments(manufacturer_id);