将 tensor 改为 torch,并更新依赖,增加了生产商的数据和特征分析。
This commit is contained in:
parent
9421512677
commit
dba9f2fcc9
25
.env.example
25
.env.example
@ -1,25 +0,0 @@
|
||||
# 数据库配置
|
||||
MYSQL_HOST=localhost
|
||||
MYSQL_USER=root
|
||||
MYSQL_PASSWORD=your_password_here
|
||||
MYSQL_DATABASE=equipment_cost_db
|
||||
|
||||
# 服务配置
|
||||
PORT=5001
|
||||
DEBUG=False
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL=INFO
|
||||
LOG_DIR=logs
|
||||
|
||||
# 模型配置
|
||||
MODEL_DIR=models
|
||||
DATA_DIR=data
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY=your_secret_key_here
|
||||
ALLOWED_HOSTS=localhost,127.0.0.1
|
||||
|
||||
# 其他配置
|
||||
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
|
||||
ALLOWED_FILE_TYPES=.xlsx,.xls
|
||||
39
.gitignore
vendored
39
.gitignore
vendored
@ -1,4 +1,42 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
node_modules
|
||||
/dist
|
||||
/models
|
||||
@ -9,6 +47,7 @@ node_modules
|
||||
# local env files
|
||||
.env.local
|
||||
.env.*.local
|
||||
.venv
|
||||
|
||||
# Log files
|
||||
npm-debug.log*
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.11.8
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
# MIT License
|
||||
|
||||
Copyright (c) 2024 Your Name or Your Organization
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
54
README.md
54
README.md
@ -21,3 +21,57 @@ collation-server=utf8mb4_unicode_ci
|
||||
[client]
|
||||
default-character-set=utf8mb4
|
||||
```
|
||||
|
||||
## 环境配置
|
||||
|
||||
本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。
|
||||
|
||||
### 使用脚本自动配置(推荐)
|
||||
|
||||
Unix/macOS:
|
||||
|
||||
```bash
|
||||
chmod +x scripts/setup_env.sh
|
||||
./scripts/setup_env.sh
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
||||
.\scripts\setup_env.ps1
|
||||
```
|
||||
|
||||
### 手动配置
|
||||
|
||||
1. 安装 pyenv
|
||||
2. 安装 Python 3.11.8:
|
||||
|
||||
```bash
|
||||
pyenv install 3.11.8
|
||||
```
|
||||
|
||||
3. 设置本地 Python 版本:
|
||||
|
||||
```bash
|
||||
pyenv local 3.11.8
|
||||
```
|
||||
|
||||
4. 创建虚拟环境:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
```
|
||||
|
||||
5. 激活虚拟环境:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # Unix
|
||||
.venv\Scripts\activate # Windows
|
||||
```
|
||||
|
||||
6. 安装依赖:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
129
config.py
129
config.py
@ -1,32 +1,103 @@
|
||||
import os
|
||||
import secrets
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
|
||||
class Config:
|
||||
"""配置类"""
|
||||
# 数据库配置
|
||||
MYSQL_HOST = 'localhost'
|
||||
MYSQL_USER = 'root'
|
||||
MYSQL_PASSWORD = '123456'
|
||||
MYSQL_DB = 'equipment_cost_db'
|
||||
|
||||
# Flask配置
|
||||
FLASK_HOST = '0.0.0.0'
|
||||
FLASK_PORT = 5001
|
||||
FLASK_DEBUG = True
|
||||
|
||||
# 目录配置
|
||||
MODEL_DIR = 'models'
|
||||
DATA_DIR = 'data'
|
||||
LOG_DIR = 'logs'
|
||||
UPLOAD_DIR = 'uploads'
|
||||
TEMPLATE_DIR = 'templates'
|
||||
|
||||
# 文件上传配置
|
||||
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
||||
|
||||
# API配置
|
||||
API_VERSION = 'v1'
|
||||
API_PREFIX = f'/api/{API_VERSION}'
|
||||
|
||||
# 日志配置
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOG_LEVEL = 'INFO'
|
||||
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
|
||||
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
LOG_BACKUP_COUNT = 5
|
||||
|
||||
# PyTorch配置
|
||||
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 0.001
|
||||
NUM_EPOCHS = 100
|
||||
|
||||
# 模型训练配置
|
||||
TRAIN_TEST_SPLIT = 0.2
|
||||
RANDOM_SEED = 42
|
||||
EARLY_STOPPING_PATIENCE = 10
|
||||
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
|
||||
|
||||
# 缓存配置
|
||||
CACHE_TYPE = 'simple'
|
||||
CACHE_DEFAULT_TIMEOUT = 300
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY = 'your-secret-key-here'
|
||||
JWT_SECRET_KEY = 'your-jwt-secret-key-here'
|
||||
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
|
||||
|
||||
# 跨域配置
|
||||
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
|
||||
|
||||
# 数据验证配置
|
||||
MAX_EQUIPMENT_NAME_LENGTH = 100
|
||||
MAX_MANUFACTURER_NAME_LENGTH = 100
|
||||
|
||||
@classmethod
|
||||
def init_app(cls, app):
|
||||
"""初始化应用配置"""
|
||||
# 创建必要的目录
|
||||
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
|
||||
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
# 配置日志
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
formatter = logging.Formatter(cls.LOG_FORMAT)
|
||||
file_handler = RotatingFileHandler(
|
||||
cls.LOG_FILE,
|
||||
maxBytes=cls.LOG_MAX_SIZE,
|
||||
backupCount=cls.LOG_BACKUP_COUNT
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(cls.LOG_LEVEL)
|
||||
|
||||
app.logger.addHandler(file_handler)
|
||||
app.logger.setLevel(cls.LOG_LEVEL)
|
||||
|
||||
# 配置上传目录
|
||||
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
|
||||
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
|
||||
|
||||
# 配置跨域
|
||||
from flask_cors import CORS
|
||||
CORS(app, resources={
|
||||
r"/api/*": {"origins": cls.CORS_ORIGINS}
|
||||
})
|
||||
|
||||
return app
|
||||
|
||||
# 安全密钥配置(自动生成随机密钥)
|
||||
SECRET_KEY = secrets.token_hex(16)
|
||||
|
||||
# 环境配置
|
||||
DEBUG = False
|
||||
ENV = 'production'
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
||||
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
|
||||
|
||||
# API配置
|
||||
API_VERSION = 'v1'
|
||||
API_PREFIX = f'/api/{API_VERSION}'
|
||||
|
||||
# 跨域配置
|
||||
CORS_ORIGINS = [
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
]
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL = 'DEBUG'
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')
|
||||
# 创建配置实例
|
||||
config = Config()
|
||||
@ -103,7 +103,31 @@
|
||||
<div class="chart-container">
|
||||
<div ref="engineChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
|
||||
<!-- 制导性能分析 -->
|
||||
<h3>制导性能分析</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="guidanceChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- 生产商分析 -->
|
||||
<h3>生产商分析</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="manufacturerChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
|
||||
<!-- 生产商地区分布 -->
|
||||
<h3>生产商地区分布</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="regionChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
|
||||
<!-- 生产商综合评分 -->
|
||||
<h3>生产商综合评分</h3>
|
||||
<div class="chart-container">
|
||||
<div ref="scoreChartRef" style="width: 100%; height: 600px"></div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</div>
|
||||
@ -131,6 +155,10 @@ const newFeatureChartRef = ref(null)
|
||||
const engineChartRef = ref(null)
|
||||
const fireChartRef = ref(null)
|
||||
const mobilityChartRef = ref(null)
|
||||
const manufacturerChartRef = ref(null)
|
||||
const regionChartRef = ref(null)
|
||||
const scoreChartRef = ref(null)
|
||||
const guidanceChartRef = ref(null)
|
||||
|
||||
// 图表实例引用
|
||||
const importanceChart = ref(null)
|
||||
@ -139,6 +167,10 @@ const newFeatureChart = ref(null)
|
||||
const engineChart = ref(null)
|
||||
const fireChart = ref(null)
|
||||
const mobilityChart = ref(null)
|
||||
const manufacturerChart = ref(null)
|
||||
const regionChart = ref(null)
|
||||
const scoreChart = ref(null)
|
||||
const guidanceChart = ref(null)
|
||||
|
||||
// 监听分析结果变化
|
||||
watch(() => analysisResult.value, async (newResult) => {
|
||||
@ -236,69 +268,28 @@ const startAnalysis = async () => {
|
||||
|
||||
analyzing.value = true
|
||||
try {
|
||||
// 打印请求参数
|
||||
console.log('Analysis request params:', {
|
||||
dataset_id: analysisForm.value.dataset_id,
|
||||
equipment_type: analysisForm.value.equipment_type
|
||||
})
|
||||
|
||||
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
|
||||
// 调用特征分析接口
|
||||
const featureResponse = await axios.post(`${API_BASE_URL}/analyze-features`, {
|
||||
dataset_id: analysisForm.value.dataset_id
|
||||
})
|
||||
|
||||
// 打印原始响应数据
|
||||
console.log('Raw API response:', response)
|
||||
console.log('Response data type:', typeof response.data)
|
||||
console.log('Response data:', response.data)
|
||||
|
||||
// 检查响应数据的结构
|
||||
if (!response.data) {
|
||||
throw new Error('API返回的数据为空')
|
||||
}
|
||||
|
||||
// 确保数据正确赋值
|
||||
analysisResult.value = response.data
|
||||
|
||||
// 验证数赋值是否成功
|
||||
console.log('Analysis result after assignment:', {
|
||||
value: analysisResult.value,
|
||||
important_features: analysisResult.value?.important_features,
|
||||
correlation_analysis: analysisResult.value?.correlation_analysis,
|
||||
equipment_names: analysisResult.value?.equipment_names,
|
||||
length_width_ratio: analysisResult.value?.length_width_ratio
|
||||
// 调用生产商分析接口
|
||||
const manufacturerResponse = await axios.post(`${API_BASE_URL}/analyze-manufacturers`, {
|
||||
dataset_id: analysisForm.value.dataset_id
|
||||
})
|
||||
|
||||
// 如果是巡飞弹类型,检查特定数据
|
||||
if (analysisForm.value.equipment_type === '巡飞弹') {
|
||||
const missileData = {
|
||||
equipment_names: analysisResult.value?.equipment_names || [],
|
||||
length_width_ratio: analysisResult.value?.length_width_ratio || [],
|
||||
engine_power_kw: analysisResult.value?.engine_power_kw || [],
|
||||
guidance_system_score: analysisResult.value?.guidance_system_score || [],
|
||||
warhead_power_score: analysisResult.value?.warhead_power_score || []
|
||||
}
|
||||
|
||||
console.log('Missile specific data:', missileData)
|
||||
|
||||
// 验证数据完整性
|
||||
const missingFields = Object.entries(missileData)
|
||||
.filter(([key, value]) => !Array.isArray(value) || value.length === 0)
|
||||
.map(([key]) => key)
|
||||
|
||||
if (missingFields.length > 0) {
|
||||
console.warn('Missing or empty missile data fields:', missingFields)
|
||||
ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
|
||||
}
|
||||
|
||||
// 合并两个接口的结果
|
||||
analysisResult.value = {
|
||||
...featureResponse.data,
|
||||
...manufacturerResponse.data
|
||||
}
|
||||
|
||||
|
||||
// 验证数据
|
||||
console.log('Combined analysis result:', analysisResult.value)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Analysis error:', error)
|
||||
console.error('Error details:', {
|
||||
message: error.message,
|
||||
response: error.response?.data,
|
||||
status: error.response?.status
|
||||
})
|
||||
ElMessage.error(error.message || '特征析失败')
|
||||
ElMessage.error(error.message || '分析失败')
|
||||
} finally {
|
||||
analyzing.value = false
|
||||
}
|
||||
@ -329,6 +320,18 @@ const createResizeHandler = () => {
|
||||
if (mobilityChart.value && !mobilityChart.value.isDisposed()) {
|
||||
mobilityChart.value.resize()
|
||||
}
|
||||
if (manufacturerChart.value && !manufacturerChart.value.isDisposed()) {
|
||||
manufacturerChart.value.resize()
|
||||
}
|
||||
if (regionChart.value && !regionChart.value.isDisposed()) {
|
||||
regionChart.value.resize()
|
||||
}
|
||||
if (scoreChart.value && !scoreChart.value.isDisposed()) {
|
||||
scoreChart.value.resize()
|
||||
}
|
||||
if (guidanceChart.value && !guidanceChart.value.isDisposed()) {
|
||||
guidanceChart.value.resize()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in resize handler:', error)
|
||||
}
|
||||
@ -360,7 +363,7 @@ onUnmounted(() => {
|
||||
|
||||
// 销毁所有图表实例
|
||||
[importanceChart, correlationChart, newFeatureChart, engineChart,
|
||||
fireChart, mobilityChart].forEach(chart => {
|
||||
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
|
||||
if (chart.value && !chart.value.isDisposed()) {
|
||||
try {
|
||||
chart.value.dispose()
|
||||
@ -384,7 +387,7 @@ const renderCharts = () => {
|
||||
try {
|
||||
// 先销毁所有现有的图表实例
|
||||
[importanceChart, correlationChart, newFeatureChart, engineChart,
|
||||
fireChart, mobilityChart].forEach(chart => {
|
||||
fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => {
|
||||
if (chart.value && !chart.value.isDisposed()) {
|
||||
chart.value.dispose()
|
||||
chart.value = null
|
||||
@ -899,6 +902,156 @@ const renderCharts = () => {
|
||||
mobilityChart.value.setOption(mobilityOption, { notMerge: true })
|
||||
}
|
||||
|
||||
// 渲染生产商分析图表
|
||||
if (manufacturerChartRef.value) {
|
||||
manufacturerChart.value = echarts.init(manufacturerChartRef.value)
|
||||
const manufacturerOption = {
|
||||
title: { text: '生产商特征影响分析' },
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: { type: 'shadow' }
|
||||
},
|
||||
legend: {
|
||||
data: ['技术水平', '规模水平', '供应链水平', '综合得分']
|
||||
},
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: analysisResult.value.manufacturer_names || []
|
||||
},
|
||||
yAxis: {
|
||||
type: 'value',
|
||||
name: '评分',
|
||||
min: 0,
|
||||
max: 10
|
||||
},
|
||||
series: [
|
||||
{
|
||||
name: '技术水平',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.manufacturer_tech_levels || []
|
||||
},
|
||||
{
|
||||
name: '规模水平',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.manufacturer_scale_levels || []
|
||||
},
|
||||
{
|
||||
name: '供应链水平',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.manufacturer_supply_chain_levels || []
|
||||
},
|
||||
{
|
||||
name: '综合得分',
|
||||
type: 'line',
|
||||
data: analysisResult.value.manufacturer_composite_scores || []
|
||||
}
|
||||
]
|
||||
}
|
||||
manufacturerChart.value.setOption(manufacturerOption)
|
||||
}
|
||||
|
||||
// 渲染地区分布图表
|
||||
if (regionChartRef.value) {
|
||||
regionChart.value = echarts.init(regionChartRef.value)
|
||||
const regionOption = {
|
||||
title: { text: '生产商地区分布' },
|
||||
tooltip: {
|
||||
trigger: 'item',
|
||||
formatter: '{b}: {c} ({d}%)'
|
||||
},
|
||||
series: [
|
||||
{
|
||||
type: 'pie',
|
||||
radius: '65%',
|
||||
data: analysisResult.value.region_distribution || [],
|
||||
emphasis: {
|
||||
itemStyle: {
|
||||
shadowBlur: 10,
|
||||
shadowOffsetX: 0,
|
||||
shadowColor: 'rgba(0, 0, 0, 0.5)'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
regionChart.value.setOption(regionOption)
|
||||
}
|
||||
|
||||
// 渲染综合评分图表
|
||||
if (scoreChartRef.value) {
|
||||
scoreChart.value = echarts.init(scoreChartRef.value)
|
||||
const scoreOption = {
|
||||
title: { text: '生产商综合评分雷达图' },
|
||||
tooltip: {},
|
||||
radar: {
|
||||
indicator: [
|
||||
{ name: '技术水平', max: 10 },
|
||||
{ name: '规模水平', max: 10 },
|
||||
{ name: '供应链水平', max: 10 },
|
||||
{ name: '区域系数', max: 1.5 },
|
||||
{ name: '综合得分', max: 10 }
|
||||
]
|
||||
},
|
||||
series: [
|
||||
{
|
||||
type: 'radar',
|
||||
data: analysisResult.value.manufacturer_scores || []
|
||||
}
|
||||
]
|
||||
}
|
||||
scoreChart.value.setOption(scoreOption)
|
||||
}
|
||||
|
||||
// 渲染制导性能分析图表
|
||||
if (guidanceChartRef.value && analysisForm.value.equipment_type === '巡飞弹') {
|
||||
guidanceChart.value = echarts.init(guidanceChartRef.value)
|
||||
const guidanceOption = {
|
||||
title: { text: '制导性能分析' },
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: { type: 'cross' }
|
||||
},
|
||||
legend: {
|
||||
data: ['制导精度(m)', '数据链距离(km)', '制导系统评分']
|
||||
},
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: analysisResult.value.equipment_names || []
|
||||
},
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value',
|
||||
name: '制导精度(m)',
|
||||
position: 'left'
|
||||
},
|
||||
{
|
||||
type: 'value',
|
||||
name: '距离(km)',
|
||||
position: 'right'
|
||||
}
|
||||
],
|
||||
series: [
|
||||
{
|
||||
name: '制导精度(m)',
|
||||
type: 'bar',
|
||||
data: analysisResult.value.guidance_accuracy_m || []
|
||||
},
|
||||
{
|
||||
name: '数据链距离(km)',
|
||||
type: 'line',
|
||||
yAxisIndex: 1,
|
||||
data: analysisResult.value.datalink_range_km || []
|
||||
},
|
||||
{
|
||||
name: '制导系统评分',
|
||||
type: 'line',
|
||||
data: analysisResult.value.guidance_system_score || []
|
||||
}
|
||||
]
|
||||
}
|
||||
guidanceChart.value.setOption(guidanceOption)
|
||||
}
|
||||
|
||||
console.log('Charts rendered successfully')
|
||||
} catch (error) {
|
||||
console.error('Error in chart rendering:', error)
|
||||
|
||||
59
pyproject.toml
Normal file
59
pyproject.toml
Normal file
@ -0,0 +1,59 @@
|
||||
[project]
|
||||
name = "cost-prediction"
|
||||
version = "0.1.0"
|
||||
description = "装备成本预测系统"
|
||||
requires-python = ">=3.9,<3.12"
|
||||
readme = "README.md"
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
dependencies = [
|
||||
# Web框架
|
||||
"flask>=3.1.0",
|
||||
"flask-cors>=5.0.0",
|
||||
|
||||
# 数据库
|
||||
"sqlalchemy>=2.0.36",
|
||||
"pymysql>=1.1.1",
|
||||
"cryptography>=43.0.0",
|
||||
"mysql-connector-python>=8.0.0",
|
||||
|
||||
# 数据处理
|
||||
"numpy>=1.26.0,<2.0.0",
|
||||
"pandas>=2.2.0",
|
||||
|
||||
# 机器学习
|
||||
"scikit-learn>=1.5.2",
|
||||
"torch==2.5.1",
|
||||
"torchvision==0.20.1",
|
||||
"torchaudio==2.5.1",
|
||||
|
||||
# 工具
|
||||
"openpyxl>=3.1.5", # Excel支持
|
||||
"python-dotenv>=1.0.0", # 环境变量
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
# 测试工具
|
||||
"pytest>=7.0",
|
||||
"black>=22.0", # 代码格式化
|
||||
"mypy>=1.0", # 类型检查
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py39", "py310", "py311"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
|
||||
@ -3,12 +3,11 @@ flask-cors>=5.0.0
|
||||
sqlalchemy>=2.0.36
|
||||
pymysql>=1.1.1
|
||||
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
|
||||
numpy>=2.0.2
|
||||
pandas>=2.2.3
|
||||
|
||||
urllib3>=2.2.3
|
||||
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
|
||||
xlrd>=2.0.1 # 用于读取 .xls 文件
|
||||
mysql-connector-python>=8.0.0 # 添加这行
|
||||
numpy>=1.26.0,<2.0.0
|
||||
pandas>=2.2.0
|
||||
|
||||
scikit-learn>=1.5.2
|
||||
tensorflow>=2.18.0
|
||||
|
||||
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
|
||||
python-dotenv>=1.0.0 # 环境变量
|
||||
40
run.py
40
run.py
@ -1,13 +1,33 @@
|
||||
from src.app import create_app
|
||||
import logging
|
||||
from src import create_app
|
||||
from src.logger import setup_logger
|
||||
from config import config
|
||||
import os
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
def main():
|
||||
try:
|
||||
# 创建必要的目录
|
||||
os.makedirs(config.MODEL_DIR, exist_ok=True)
|
||||
os.makedirs(config.LOG_DIR, exist_ok=True)
|
||||
os.makedirs(config.DATA_DIR, exist_ok=True)
|
||||
|
||||
# 创建并运行应用
|
||||
app = create_app()
|
||||
|
||||
logger.info(f"Starting server in {'debug' if config.FLASK_DEBUG else 'production'} mode")
|
||||
logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}")
|
||||
|
||||
app.run(
|
||||
host=config.FLASK_HOST,
|
||||
port=config.FLASK_PORT,
|
||||
debug=config.FLASK_DEBUG
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting application: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('=== Server Starting ===')
|
||||
logging.info('Initializing directories...')
|
||||
|
||||
app.run(host='0.0.0.0', port=5001, debug=True)
|
||||
main()
|
||||
121
scripts/setup_env.ps1
Normal file
121
scripts/setup_env.ps1
Normal file
@ -0,0 +1,121 @@
|
||||
# 设置错误操作首选项
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# 检查管理员权限
|
||||
$isAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
|
||||
if (-not $isAdmin) {
|
||||
Write-Warning "建议使用管理员权限运行此脚本"
|
||||
Start-Sleep -Seconds 3
|
||||
}
|
||||
|
||||
# 检查 pyenv-win 是否安装
|
||||
if (!(Get-Command pyenv -ErrorAction SilentlyContinue)) {
|
||||
Write-Host "pyenv not found. Installing..."
|
||||
try {
|
||||
# 下载并安装 pyenv-win
|
||||
Invoke-WebRequest -UseBasicParsing -Uri "https://raw.githubusercontent.com/pyenv-win/pyenv-win/master/pyenv-win/install-pyenv-win.ps1" -OutFile "./install-pyenv-win.ps1"
|
||||
& ./install-pyenv-win.ps1
|
||||
|
||||
# 添加环境变量
|
||||
$env:PYENV = "$env:USERPROFILE\.pyenv\pyenv-win"
|
||||
$env:Path = "$env:PYENV\bin;$env:PYENV\shims;$env:Path"
|
||||
|
||||
# 刷新环境变量
|
||||
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
}
|
||||
catch {
|
||||
Write-Error "Failed to install pyenv: $_"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
# 安装指定版本的 Python
|
||||
Write-Host "Installing Python 3.11.8..."
|
||||
pyenv install 3.11.8
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "Failed to install Python 3.11.8"
|
||||
}
|
||||
|
||||
# 设置本地 Python 版本
|
||||
Write-Host "Setting local Python version..."
|
||||
pyenv local 3.11.8
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "Failed to set local Python version"
|
||||
}
|
||||
|
||||
# 验证 Python 版本
|
||||
$pythonVersion = python -V
|
||||
if (-not $pythonVersion.Contains("3.11.8")) {
|
||||
throw "Wrong Python version: $pythonVersion"
|
||||
}
|
||||
Write-Host "Using Python version: $pythonVersion"
|
||||
|
||||
# 创建虚拟环境
|
||||
Write-Host "Creating virtual environment..."
|
||||
python -m venv .venv
|
||||
|
||||
# 激活虚拟环境
|
||||
Write-Host "Activating virtual environment..."
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
|
||||
# 升级 pip 和构建工具
|
||||
Write-Host "Upgrading pip and build tools..."
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
# 分步安装依赖以确保正确的顺序和版本
|
||||
Write-Host "Installing database dependencies..."
|
||||
pip install mysql-connector-python==8.0.33
|
||||
|
||||
Write-Host "Installing PyTorch and related packages..."
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
Write-Host "Installing basic dependencies..."
|
||||
pip install numpy==1.26.4 pandas==2.2.1
|
||||
|
||||
Write-Host "Installing machine learning packages..."
|
||||
pip install scikit-learn==1.5.2
|
||||
|
||||
# 安装开发依赖
|
||||
Write-Host "Installing development dependencies..."
|
||||
pip install -e ".[dev]"
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Warning "Failed to install development dependencies. Installing core package..."
|
||||
pip install -e .
|
||||
}
|
||||
|
||||
# 验证安装
|
||||
Write-Host "Verifying installations..."
|
||||
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
|
||||
python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"
|
||||
python -c "import pandas; print(f'Pandas version: {pandas.__version__}')"
|
||||
python -c "import sklearn; print(f'Scikit-learn version: {sklearn.__version__}')"
|
||||
|
||||
Write-Host "Environment setup complete!" -ForegroundColor Green
|
||||
}
|
||||
catch {
|
||||
Write-Error "An error occurred: $_"
|
||||
exit 1
|
||||
}
|
||||
finally {
|
||||
# 清理临时文件
|
||||
if (Test-Path "./install-pyenv-win.ps1") {
|
||||
Remove-Item "./install-pyenv-win.ps1"
|
||||
}
|
||||
}
|
||||
|
||||
# 显示使用说明
|
||||
Write-Host @"
|
||||
|
||||
环境设置完成!使用说明:
|
||||
1. 虚拟环境已激活,命令提示符前应该显示 (.venv)
|
||||
2. 要退出虚拟环境,运行: deactivate
|
||||
3. 要重新激活虚拟环境,运行: .\.venv\Scripts\Activate.ps1
|
||||
4. 项目依赖已安装,可以开始开发了
|
||||
|
||||
如果遇到问题,请检查:
|
||||
- Python 版本: python -V
|
||||
- PyTorch 安装: python -c "import torch; print(torch.__version__)"
|
||||
- 虚拟环境状态: 确保看到 (.venv) 前缀
|
||||
|
||||
"@ -ForegroundColor Cyan
|
||||
66
scripts/setup_env.sh
Executable file
66
scripts/setup_env.sh
Executable file
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 检查 pyenv 是否安装
|
||||
if ! command -v pyenv &> /dev/null; then
|
||||
echo "pyenv not found. Installing..."
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
brew install pyenv
|
||||
else
|
||||
curl https://pyenv.run | bash
|
||||
fi
|
||||
fi
|
||||
|
||||
# 安装指定版本的 Python
|
||||
pyenv install 3.11.8 || true
|
||||
|
||||
# 设置本地 Python 版本
|
||||
pyenv local 3.11.8
|
||||
|
||||
# 确保使用正确的 Python 版本
|
||||
eval "$(pyenv init -)"
|
||||
pyenv shell 3.11.8
|
||||
|
||||
# 验证 Python 版本
|
||||
python_version=$(python -V 2>&1)
|
||||
if [[ $python_version != *"3.11.8"* ]]; then
|
||||
echo "Error: Wrong Python version: $python_version"
|
||||
echo "Please ensure pyenv is properly configured in your shell"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 创建虚拟环境
|
||||
python -m venv .venv
|
||||
|
||||
# 激活虚拟环境
|
||||
source .venv/bin/activate
|
||||
|
||||
# 升级 pip 和构建工具
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
# 分步安装依赖以确保正确的顺序和版本
|
||||
echo "Installing database dependencies..."
|
||||
pip install mysql-connector-python==8.0.33
|
||||
|
||||
echo "Installing PyTorch and related packages..."
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
echo "Installing basic dependencies..."
|
||||
pip install numpy==1.26.4 pandas==2.2.1
|
||||
|
||||
echo "Installing machine learning packages..."
|
||||
pip install scikit-learn==1.5.2
|
||||
|
||||
# 安装开发依赖
|
||||
if ! pip install -e ".[dev]"; then
|
||||
echo "Warning: Failed to install development dependencies. Installing core package..."
|
||||
pip install -e .
|
||||
fi
|
||||
|
||||
# 验证安装
|
||||
echo "Verifying Python version..."
|
||||
python --version
|
||||
|
||||
echo "Verifying PyTorch installation..."
|
||||
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
|
||||
|
||||
echo "Environment setup complete!"
|
||||
@ -1 +1,3 @@
|
||||
# 这个文件可以为空,但必须存在
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ['create_app']
|
||||
|
||||
46
src/app.py
46
src/app.py
@ -2,49 +2,35 @@ from flask import Flask
|
||||
from flask_cors import CORS
|
||||
from .routes import api_bp
|
||||
from .logger import setup_logger
|
||||
from config import config
|
||||
import os
|
||||
|
||||
# 获取logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
def create_app():
|
||||
"""
|
||||
创建并配置Flask应用
|
||||
"""
|
||||
"""创建并配置 Flask 应用"""
|
||||
try:
|
||||
# 创建必要的目录
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
os.makedirs('data', exist_ok=True)
|
||||
os.makedirs('models', exist_ok=True)
|
||||
|
||||
logger.info("=== Server Starting ===")
|
||||
logger.info("Initializing directories...")
|
||||
|
||||
# 创建Flask应用
|
||||
app = Flask(__name__)
|
||||
|
||||
# 配置CORS
|
||||
CORS(app)
|
||||
logger.info("CORS enabled")
|
||||
|
||||
# 注册API蓝图
|
||||
# 配置数据库连接
|
||||
app.config['MYSQL_HOST'] = config.MYSQL_HOST
|
||||
app.config['MYSQL_USER'] = config.MYSQL_USER
|
||||
app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
|
||||
app.config['MYSQL_DB'] = config.MYSQL_DB
|
||||
|
||||
# 注册路由
|
||||
app.register_blueprint(api_bp, url_prefix='/api')
|
||||
logger.info("API blueprint registered")
|
||||
|
||||
# 配置数据库连接
|
||||
app.config['MYSQL_HOST'] = 'localhost'
|
||||
app.config['MYSQL_USER'] = 'root'
|
||||
app.config['MYSQL_PASSWORD'] = '123456'
|
||||
app.config['MYSQL_DB'] = 'equipment_cost_db'
|
||||
|
||||
logger.info("Starting server...")
|
||||
# 记录配置信息
|
||||
logger.info(f"Database: {app.config['MYSQL_DB']} on {app.config['MYSQL_HOST']}")
|
||||
logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}")
|
||||
logger.info(f"Debug mode: {config.FLASK_DEBUG}")
|
||||
|
||||
return app
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_app()
|
||||
app.run(host='localhost', port=5001)
|
||||
logger.error(f"Error creating application: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
@ -1,15 +1,12 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import tensorflow as tf
|
||||
from scipy import stats
|
||||
import joblib
|
||||
import os
|
||||
import pandas as pd
|
||||
from .feature_analysis import FeatureAnalysis
|
||||
import logging
|
||||
from src.model_trainer import ModelTrainer
|
||||
from src.database import get_db_connection
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from .logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
@ -21,37 +18,18 @@ class CostPredictor:
|
||||
self.model = None
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.equipment_type = None
|
||||
|
||||
# 添加 TensorFlow 配置
|
||||
tf.config.run_functions_eagerly(False) # 启用图执行模式
|
||||
|
||||
# 创建预测函数
|
||||
@tf.function(reduce_retracing=True, jit_compile=True)
|
||||
def predict_fn(x):
|
||||
return self.model(x, training=False)
|
||||
|
||||
self._predict_fn = predict_fn
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
加载预训练型和标准化器
|
||||
加载预训练模型和标准化器
|
||||
"""
|
||||
try:
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 创建默认模型
|
||||
self._create_default_model()
|
||||
|
||||
# 创建预测函数
|
||||
@tf.function(reduce_retracing=True, jit_compile=True)
|
||||
def predict_fn(x):
|
||||
return self.model(x, training=False)
|
||||
|
||||
self._predict_fn = predict_fn
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
self._create_default_model()
|
||||
@ -60,28 +38,24 @@ class CostPredictor:
|
||||
"""
|
||||
创建默认模型并进行初始化训练
|
||||
"""
|
||||
# 创建输入层
|
||||
inputs = tf.keras.Input(shape=(11,))
|
||||
import torch.nn as nn
|
||||
|
||||
# 创建隐藏层
|
||||
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
|
||||
x = tf.keras.layers.Dense(32, activation='relu')(x)
|
||||
|
||||
# 创建输出层
|
||||
outputs = tf.keras.layers.Dense(1)(x)
|
||||
|
||||
# 创建模型
|
||||
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
# 编译模型
|
||||
self.model.compile(
|
||||
optimizer='adam',
|
||||
loss=tf.keras.losses.mean_squared_error,
|
||||
metrics=[tf.keras.metrics.mean_absolute_error]
|
||||
)
|
||||
class DefaultModel(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(input_size, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
# 创建示例数据
|
||||
example_data = pd.DataFrame({
|
||||
example_features = {
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
@ -93,59 +67,23 @@ class CostPredictor:
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
})
|
||||
}
|
||||
|
||||
# 转换为 tensor
|
||||
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
|
||||
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(example_data)
|
||||
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
|
||||
self.scaler_X.fit(X.numpy())
|
||||
self.scaler_y.fit(y.numpy())
|
||||
|
||||
# 设置默认装备类型
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def _create_example_data(self):
|
||||
"""
|
||||
创建示例数据来训练标准化器
|
||||
"""
|
||||
# 火箭炮示例数据
|
||||
rocket_data = pd.DataFrame({
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
'weight_kg': [13700, 28500],
|
||||
'max_range_km': [20.4, 70],
|
||||
'firing_angle_horizontal': [102, 110],
|
||||
'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1],
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
})
|
||||
|
||||
# 巡飞弹示例数据
|
||||
missile_data = pd.DataFrame({
|
||||
'length_m': [1.3, 2.5],
|
||||
'width_m': [0.23, 0.6],
|
||||
'height_m': [0.23, 0.6],
|
||||
'weight_kg': [12.5, 135],
|
||||
'max_range_km': [40, 250],
|
||||
'max_speed_kmh': [180, 185],
|
||||
'cruise_speed_kmh': [100, 110],
|
||||
'flight_time_min': [60, 120],
|
||||
'folded_length_mm': [1300, 2500],
|
||||
'folded_width_mm': [230, 600],
|
||||
'folded_height_mm': [230, 600]
|
||||
})
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
|
||||
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
|
||||
|
||||
# 设置默认装备类型
|
||||
# 创建模型
|
||||
self.model = DefaultModel(X.shape[1]).to(self.device)
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def predict(self, data):
|
||||
"""
|
||||
使用训练好的最优模型进行预测
|
||||
使用训练好的模型进行预测
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting prediction for {data.get('type')}")
|
||||
@ -158,20 +96,31 @@ class CostPredictor:
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
X = []
|
||||
for feature in features:
|
||||
value = data.get(feature, 0.0)
|
||||
X.append(float(value))
|
||||
|
||||
# 转换为 tensor
|
||||
X = torch.tensor([X], dtype=torch.float32).to(self.device)
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
with torch.no_grad():
|
||||
trainer.model.eval() # 设置为评估模式
|
||||
y_pred = trainer.model(X)
|
||||
|
||||
# 转回 numpy
|
||||
y_pred = y_pred.cpu().numpy()
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
|
||||
confidence_interval = self._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
# 获取模型类型
|
||||
model_type = trainer.get_model_type()
|
||||
|
||||
return {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'model_type': model_type, # 返回使用的模型类型
|
||||
'model_type': model_type,
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
@ -187,11 +136,10 @@ class CostPredictor:
|
||||
计算预测值的置信区间
|
||||
"""
|
||||
try:
|
||||
# 使用预测值的20%作为标准差(增加不确定性)
|
||||
# 使用预测值的20%作为标准差
|
||||
std = abs(prediction) * 0.2
|
||||
|
||||
# 计算置信区间
|
||||
from scipy import stats
|
||||
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
|
||||
|
||||
# 确保区间值为正数且合理
|
||||
@ -213,130 +161,15 @@ class CostPredictor:
|
||||
"""
|
||||
模型评估
|
||||
"""
|
||||
# 确保输入是 numpy 数组
|
||||
if torch.is_tensor(y_true):
|
||||
y_true = y_true.cpu().numpy()
|
||||
if torch.is_tensor(y_pred):
|
||||
y_pred = y_pred.cpu().numpy()
|
||||
|
||||
return {
|
||||
'mae': float(mean_absolute_error(y_true, y_pred)),
|
||||
'mse': float(mean_squared_error(y_true, y_pred)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||||
'r2': float(r2_score(y_true, y_pred))
|
||||
}
|
||||
|
||||
def predict_pls(self, data):
|
||||
"""
|
||||
使用 PLS 型预测成本
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting PLS prediction for {data.get('type')}")
|
||||
equipment_type = data.get('type')
|
||||
|
||||
# 加载 PLS 模型
|
||||
trainer = ModelTrainer()
|
||||
if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
|
||||
raise ValueError(f"No trained PLS model found for {equipment_type}")
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
return {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PLS prediction error: {str(e)}")
|
||||
raise
|
||||
|
||||
def predict_all(self, data):
|
||||
"""
|
||||
使用所有可用模型进行预测
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting multi-model prediction for {data.get('type')}")
|
||||
equipment_type = data.get('type')
|
||||
results = {}
|
||||
|
||||
# 1. 获取所有激活的模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("""
|
||||
SELECT id, model_type, model_name, r2_score, mae, rmse
|
||||
FROM trained_models
|
||||
WHERE equipment_type = %s AND is_active = TRUE
|
||||
""", (equipment_type,))
|
||||
active_models = cursor.fetchall()
|
||||
|
||||
if not active_models:
|
||||
raise ValueError(f"No active models found for {equipment_type}")
|
||||
|
||||
# 2. 使用每个模型进行预测
|
||||
trainer = ModelTrainer()
|
||||
for model_info in active_models:
|
||||
try:
|
||||
# 加载特定模型
|
||||
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
|
||||
logger.warning(f"Failed to load model: {model_info['model_name']}")
|
||||
continue
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
# 保存结果
|
||||
results[model_info['model_type']] = {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'model_info': {
|
||||
'name': model_info['model_name'],
|
||||
'type': model_info['model_type'],
|
||||
'r2_score': float(model_info['r2_score']),
|
||||
'mae': float(model_info['mae']),
|
||||
'rmse': float(model_info['rmse'])
|
||||
},
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not results:
|
||||
raise ValueError("No successful predictions from any model")
|
||||
|
||||
# 3. 计算综合预测结果
|
||||
all_predictions = [result['predicted_cost'] for result in results.values()]
|
||||
ensemble_prediction = float(np.mean(all_predictions))
|
||||
prediction_std = float(np.std(all_predictions))
|
||||
|
||||
# 4. 返回所有结果
|
||||
return {
|
||||
'individual_predictions': results,
|
||||
'ensemble_prediction': {
|
||||
'predicted_cost': ensemble_prediction,
|
||||
'standard_deviation': prediction_std,
|
||||
'confidence_interval': {
|
||||
'lower': float(ensemble_prediction - 1.96 * prediction_std),
|
||||
'upper': float(ensemble_prediction + 1.96 * prediction_std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in multi-model prediction: {str(e)}")
|
||||
raise
|
||||
}
|
||||
@ -1,29 +1,35 @@
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from datetime import datetime
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
|
||||
from xgboost import XGBRegressor
|
||||
from lightgbm import LGBMRegressor
|
||||
from sklearn.model_selection import cross_val_score, LeaveOneOut
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import logging
|
||||
from src.database.db_connection import get_db_connection
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from src.database import get_db_connection
|
||||
from .logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class EquipmentDataset(Dataset):
|
||||
"""装备数据集类"""
|
||||
def __init__(self, features, targets=None):
|
||||
self.features = torch.FloatTensor(features)
|
||||
self.targets = torch.FloatTensor(targets) if targets is not None else None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.targets is not None:
|
||||
return self.features[idx], self.targets[idx]
|
||||
return self.features[idx]
|
||||
|
||||
class DataPreparation:
|
||||
def __init__(self):
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.feature_scaler = StandardScaler()
|
||||
self.target_scaler = StandardScaler() # 添加目标值标准化器
|
||||
self.target_scaler = StandardScaler()
|
||||
|
||||
def prepare_training_data(self, equipment_data, equipment_type):
|
||||
def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
|
||||
"""
|
||||
准备训练数据
|
||||
"""
|
||||
@ -31,19 +37,24 @@ class DataPreparation:
|
||||
logger.info(f"Preparing training data for {equipment_type}")
|
||||
logger.info(f"Raw data size: {len(equipment_data)}")
|
||||
|
||||
# 如果输入已经是 numpy 数组,直接返回
|
||||
# 如果输入已经是 numpy 数组,转换为 torch.Tensor
|
||||
if isinstance(equipment_data, np.ndarray):
|
||||
X = equipment_data
|
||||
logger.info(f"Input is already numpy array with shape: {X.shape}")
|
||||
logger.info(f"Input is numpy array with shape: {X.shape}")
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 转换为 PyTorch 数据集
|
||||
dataset = EquipmentDataset(X)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
return {
|
||||
'X': X,
|
||||
'dataloader': dataloader,
|
||||
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
'target_scaler': self.target_scaler,
|
||||
'raw_shape': X.shape
|
||||
}
|
||||
|
||||
# 从原始数据中提取特征和目标值
|
||||
@ -51,27 +62,28 @@ class DataPreparation:
|
||||
features = []
|
||||
targets = []
|
||||
|
||||
for item in equipment_data:
|
||||
# 提取特征值
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
value = item.get(name)
|
||||
try:
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
except (ValueError, TypeError):
|
||||
feature_values.append(0.0)
|
||||
features.append(feature_values)
|
||||
# 获取数据库连接
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 提取目标值(成本)
|
||||
try:
|
||||
cost = float(item['actual_cost'])
|
||||
if cost > 0: # 只使用正数成本值
|
||||
targets.append(cost)
|
||||
else:
|
||||
logger.warning(f"Skipping non-positive cost value: {cost}")
|
||||
except (ValueError, TypeError, KeyError):
|
||||
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
|
||||
continue
|
||||
for item in equipment_data:
|
||||
# 获取该装备的生产商数据
|
||||
manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor)
|
||||
|
||||
# 计算生产商特征
|
||||
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data)
|
||||
|
||||
# 合并装备特征和生产商特征
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
if name in manufacturer_features:
|
||||
value = manufacturer_features[name]
|
||||
else:
|
||||
value = item.get(name)
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
|
||||
features.append(feature_values)
|
||||
targets.append(float(item['actual_cost']))
|
||||
|
||||
# 转换为numpy数组
|
||||
X = np.array(features, dtype=float)
|
||||
@ -85,25 +97,16 @@ class DataPreparation:
|
||||
X_scaled = self.feature_scaler.fit_transform(X)
|
||||
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
|
||||
|
||||
# 记录标准化后的数据范围
|
||||
logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
|
||||
logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
|
||||
|
||||
# 记录标准化器参数
|
||||
logger.info("Feature scaler params:")
|
||||
logger.info(f"Mean: {self.feature_scaler.mean_}")
|
||||
logger.info(f"Scale: {self.feature_scaler.scale_}")
|
||||
|
||||
logger.info("Target scaler params:")
|
||||
logger.info(f"Mean: {self.target_scaler.mean_}")
|
||||
logger.info(f"Scale: {self.target_scaler.scale_}")
|
||||
# 创建 PyTorch 数据集和数据加载器
|
||||
dataset = EquipmentDataset(X_scaled, y_scaled)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
return {
|
||||
'X': X_scaled,
|
||||
'y': y_scaled,
|
||||
'dataloader': dataloader,
|
||||
'feature_names': feature_names,
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
'target_scaler': self.target_scaler,
|
||||
'raw_shape': X.shape
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -15,9 +15,9 @@ class FeatureAnalysis:
|
||||
'width_m': '宽度(m)',
|
||||
'height_m': '高度(m)',
|
||||
'weight_kg': '重量(kg)',
|
||||
'max_range_km': '最大射程(km)',
|
||||
|
||||
# 火箭炮特有参数
|
||||
'max_range_km': '最大射程(km)',
|
||||
'firing_angle_horizontal': '方向射界(度)',
|
||||
'firing_angle_vertical': '高低射界(度)',
|
||||
'rocket_length_m': '火箭弹长度(m)',
|
||||
@ -39,6 +39,7 @@ class FeatureAnalysis:
|
||||
'terrain_adaptability_score': '地形适应性评分',
|
||||
|
||||
# 巡飞弹特有参数
|
||||
'max_range_km': '最大射程(km)',
|
||||
'wingspan_m': '翼展(m)',
|
||||
'warhead_weight_kg': '战斗部重量(kg)',
|
||||
'max_speed_ms': '最大速度(m/s)',
|
||||
@ -57,7 +58,14 @@ class FeatureAnalysis:
|
||||
'weight_range_ratio': '重量射程比',
|
||||
'speed_weight_ratio': '速度重量比',
|
||||
'guidance_system_score': '制导系统评分',
|
||||
'warhead_power_score': '战斗部威力评分'
|
||||
'warhead_power_score': '战斗部威力评分',
|
||||
|
||||
# 添加生产商特征映射
|
||||
'manufacturer_tech_level': '生产商技术水平',
|
||||
'manufacturer_scale_level': '生产商规模水平',
|
||||
'manufacturer_supply_chain_level': '生产商供应链水平',
|
||||
'manufacturer_composite_score': '生产商综合得分',
|
||||
'manufacturer_region_factor': '生产商区域系数'
|
||||
}
|
||||
|
||||
def get_equipment_specific_features(self, equipment_type):
|
||||
@ -121,6 +129,17 @@ class FeatureAnalysis:
|
||||
'guidance_system_score',
|
||||
'warhead_power_score'
|
||||
])
|
||||
|
||||
# 添加生产商特征
|
||||
manufacturer_features = [
|
||||
'manufacturer_tech_level',
|
||||
'manufacturer_scale_level',
|
||||
'manufacturer_supply_chain_level',
|
||||
'manufacturer_composite_score',
|
||||
'manufacturer_region_factor'
|
||||
]
|
||||
|
||||
numeric_features.extend(manufacturer_features)
|
||||
return numeric_features
|
||||
|
||||
def analyze_features(self, features, target, feature_names):
|
||||
@ -234,4 +253,63 @@ class FeatureAnalysis:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analyze_features: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
raise
|
||||
|
||||
def calculate_manufacturer_features(self, manufacturer_data):
|
||||
"""计算生产商相关的特征"""
|
||||
try:
|
||||
# 确保所有必要的字段都存在,使用默认值处理缺失数据
|
||||
tech_level = float(manufacturer_data.get('tech_level', 0))
|
||||
scale_level = float(manufacturer_data.get('scale_level', 0))
|
||||
supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0))
|
||||
country = manufacturer_data.get('country', '未知')
|
||||
|
||||
# 计算综合得分
|
||||
composite_score = (
|
||||
tech_level * 0.4 + # 技术水平权重最高
|
||||
scale_level * 0.3 + # 规模水平次之
|
||||
supply_chain_level * 0.3 # 供应链水平
|
||||
)
|
||||
|
||||
# 计算区域系数(基于不同地区的成本差异)
|
||||
region_factors = {
|
||||
'美国': 1.2,
|
||||
'英国': 1.15,
|
||||
'德国': 1.15,
|
||||
'法国': 1.15,
|
||||
'以色列': 1.1,
|
||||
'中国': 0.8,
|
||||
'俄罗斯': 0.85,
|
||||
'韩国': 0.9,
|
||||
'日本': 1.1
|
||||
}
|
||||
|
||||
region_factor = region_factors.get(country, 1.0)
|
||||
|
||||
# 记录计算过程
|
||||
logger.info(f"Manufacturer features calculation:")
|
||||
logger.info(f"Tech level: {tech_level}")
|
||||
logger.info(f"Scale level: {scale_level}")
|
||||
logger.info(f"Supply chain level: {supply_chain_level}")
|
||||
logger.info(f"Country: {country}")
|
||||
logger.info(f"Composite score: {composite_score}")
|
||||
logger.info(f"Region factor: {region_factor}")
|
||||
|
||||
return {
|
||||
'manufacturer_tech_level': tech_level,
|
||||
'manufacturer_scale_level': scale_level,
|
||||
'manufacturer_supply_chain_level': supply_chain_level,
|
||||
'manufacturer_composite_score': composite_score,
|
||||
'manufacturer_region_factor': region_factor
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating manufacturer features: {str(e)}")
|
||||
# 返回默认值而不是抛出异常,确保分析过程可以继续
|
||||
return {
|
||||
'manufacturer_tech_level': 0,
|
||||
'manufacturer_scale_level': 0,
|
||||
'manufacturer_supply_chain_level': 0,
|
||||
'manufacturer_composite_score': 0,
|
||||
'manufacturer_region_factor': 1.0
|
||||
}
|
||||
@ -26,7 +26,7 @@ def import_training_data(excel_file):
|
||||
equipment_names.add(row['名称'])
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipment
|
||||
SELECT id FROM equipments
|
||||
WHERE name = %s AND type = '火箭炮'
|
||||
""", (row['名称'],))
|
||||
|
||||
@ -37,7 +37,7 @@ def import_training_data(excel_file):
|
||||
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipment (name, type, manufacturer)
|
||||
INSERT INTO equipments (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
""", (row['名称'], '火箭炮', row['制造商']))
|
||||
|
||||
@ -116,7 +116,7 @@ def import_training_data(excel_file):
|
||||
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipment (name, type, manufacturer)
|
||||
INSERT INTO equipments (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
""", (
|
||||
row['名称'],
|
||||
@ -192,7 +192,7 @@ def import_training_data(excel_file):
|
||||
logger.debug(f"查询装备ID: {equipment_name}")
|
||||
with conn.cursor() as id_cursor:
|
||||
id_cursor.execute("""
|
||||
SELECT id FROM equipment WHERE name = %s
|
||||
SELECT id FROM equipments WHERE name = %s
|
||||
""", (equipment_name,))
|
||||
result = id_cursor.fetchone()
|
||||
|
||||
|
||||
@ -1,319 +0,0 @@
|
||||
/*
|
||||
这是用于开发和测试环境的示例数据。
|
||||
生产环境请使用系统的数据导入功能添加实际数据。
|
||||
|
||||
主要用途:
|
||||
1. 提供开发测试数据
|
||||
2. 作为数据格式参考
|
||||
3. 用于系统功能验证
|
||||
*/
|
||||
|
||||
-- 插入装备基本信息
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
|
||||
|
||||
-- 插入巡飞弹技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES (
|
||||
1, -- 终结者巡飞弹
|
||||
0.56,
|
||||
0.15,
|
||||
0.20,
|
||||
2.72,
|
||||
160.93,
|
||||
96.56,
|
||||
24,
|
||||
15,
|
||||
'破片杀伤战斗部',
|
||||
'凭自身动力起飞',
|
||||
560,
|
||||
150,
|
||||
200
|
||||
);
|
||||
|
||||
-- 插入火箭炮技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES (
|
||||
2, -- 胜利-2火箭炮
|
||||
10,
|
||||
2.5,
|
||||
3.34,
|
||||
15000,
|
||||
23
|
||||
);
|
||||
|
||||
-- 插入成本数据(示例数据)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(1, 1000000), -- 终结者巡飞弹成本
|
||||
(2, 5000000); -- 胜利-2火箭炮成本
|
||||
|
||||
-- 插入更多巡飞弹变体数据用于训练
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
|
||||
|
||||
-- 插入变体技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- 终结者-A(稍大型号)
|
||||
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
|
||||
-- 终结者-B(稍小型号)
|
||||
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
|
||||
-- 终结者-C(标准型号的改进版)
|
||||
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
|
||||
|
||||
-- 插入变体成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(3, 1100000), -- 终结者-A成本(较高)
|
||||
(4, 900000), -- 终结者-B成本(较低)
|
||||
(5, 1050000); -- 终结者-C成本(中等)
|
||||
|
||||
-- 添加更多巡飞弹数据
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
|
||||
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
|
||||
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
|
||||
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
|
||||
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
|
||||
|
||||
-- 添加它们的技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- 哈比
|
||||
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
|
||||
-- 游荡者
|
||||
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
|
||||
-- 凤凰
|
||||
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
|
||||
-- 弹簧刀
|
||||
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
|
||||
-- 彩虹-4
|
||||
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
|
||||
|
||||
-- 添加成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(6, 800000), -- 哈比
|
||||
(7, 500000), -- 游荡者
|
||||
(8, 450000), -- 凤凰
|
||||
(9, 480000), -- 弹簧刀
|
||||
(10, 1500000); -- 彩虹-4
|
||||
|
||||
-- 火箭炮数据
|
||||
INSERT INTO equipment (name, type, manufacturer) VALUES
|
||||
('BM-21', '火箭炮', '俄罗斯'),
|
||||
('SR5', '火箭炮', '中国'),
|
||||
('HIMARS', '火箭炮', '美国'),
|
||||
('LAR-160', '火箭炮', '以色列'),
|
||||
('T-122', '火箭炮', '土耳其'),
|
||||
('RM-70', '火箭炮', '捷克'),
|
||||
('ASTROS II', '火箭炮', '巴西');
|
||||
|
||||
-- 火箭炮通用参数
|
||||
INSERT INTO common_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES
|
||||
-- BM-21
|
||||
(1, 7.35, 2.4, 3.1, 13700, 20.4),
|
||||
-- SR5
|
||||
(2, 10.2, 2.8, 3.2, 28500, 70),
|
||||
-- HIMARS
|
||||
(3, 7.0, 2.4, 3.2, 16250, 70),
|
||||
-- LAR-160
|
||||
(4, 6.7, 2.5, 2.8, 15000, 45),
|
||||
-- T-122
|
||||
(5, 7.2, 2.5, 2.9, 18000, 40),
|
||||
-- RM-70
|
||||
(6, 7.5, 2.5, 3.0, 17200, 20.3),
|
||||
-- ASTROS II
|
||||
(7, 8.0, 2.7, 3.1, 24500, 90);
|
||||
|
||||
-- 火箭炮特有参数
|
||||
INSERT INTO rocket_artillery_params (
|
||||
equipment_id,
|
||||
firing_angle_horizontal,
|
||||
firing_angle_vertical,
|
||||
rocket_length_m,
|
||||
rocket_diameter_mm,
|
||||
rocket_weight_kg,
|
||||
rate_of_fire,
|
||||
combat_weight_kg,
|
||||
speed_kmh,
|
||||
min_range_km,
|
||||
mobility_type,
|
||||
structure_layout,
|
||||
engine_model,
|
||||
engine_params,
|
||||
power_hp,
|
||||
travel_range_km
|
||||
) VALUES
|
||||
-- BM-21
|
||||
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
|
||||
-- SR5
|
||||
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
|
||||
-- HIMARS
|
||||
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
|
||||
-- LAR-160
|
||||
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
|
||||
-- T-122
|
||||
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
|
||||
-- RM-70
|
||||
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
|
||||
-- ASTROS II
|
||||
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
|
||||
|
||||
-- 巡飞弹数据
|
||||
INSERT INTO equipment (name, type, manufacturer) VALUES
|
||||
('Hero-120', '巡飞弹', '以色列'),
|
||||
('Switchblade 600', '巡飞弹', '美国'),
|
||||
('Warmate', '巡飞弹', '波兰'),
|
||||
('CH-901', '巡飞弹', '中国'),
|
||||
('HAROP', '巡飞弹', '以色列'),
|
||||
('Coyote', '巡飞弹', '美国'),
|
||||
('WS-43', '巡飞弹', '中国');
|
||||
|
||||
-- 巡飞弹通用参数
|
||||
INSERT INTO common_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES
|
||||
-- Hero-120
|
||||
(8, 1.3, 0.23, 0.23, 12.5, 40),
|
||||
-- Switchblade 600
|
||||
(9, 1.3, 0.22, 0.22, 15.0, 40),
|
||||
-- Warmate
|
||||
(10, 1.1, 0.15, 0.15, 5.7, 15),
|
||||
-- CH-901
|
||||
(11, 1.2, 0.18, 0.18, 9.0, 20),
|
||||
-- HAROP
|
||||
(12, 2.5, 0.43, 0.43, 135, 1000),
|
||||
-- Coyote
|
||||
(13, 0.9, 0.12, 0.12, 5.9, 20),
|
||||
-- WS-43
|
||||
(14, 1.8, 0.35, 0.35, 20, 60);
|
||||
|
||||
-- 巡飞弹特有参数
|
||||
INSERT INTO loitering_munition_params (
|
||||
equipment_id,
|
||||
wingspan_m,
|
||||
warhead_weight_kg,
|
||||
max_speed_ms,
|
||||
cruise_speed_kmh,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm,
|
||||
power_system,
|
||||
guidance_system
|
||||
) VALUES
|
||||
-- Hero-120
|
||||
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
|
||||
-- Switchblade 600
|
||||
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
|
||||
-- Warmate
|
||||
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
|
||||
-- CH-901
|
||||
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
|
||||
-- HAROP
|
||||
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
-- Coyote
|
||||
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
|
||||
-- WS-43
|
||||
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
|
||||
|
||||
-- 插入成本数据(示例成本)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
-- 火箭炮
|
||||
(1, 800000), -- BM-21
|
||||
(2, 4500000), -- SR5
|
||||
(3, 5500000), -- HIMARS
|
||||
(4, 3500000), -- LAR-160
|
||||
(5, 2800000), -- T-122
|
||||
(6, 1500000), -- RM-70
|
||||
(7, 4800000), -- ASTROS II
|
||||
-- 巡飞弹
|
||||
(8, 150000), -- Hero-120
|
||||
(9, 180000), -- Switchblade 600
|
||||
(10, 80000), -- Warmate
|
||||
(11, 100000), -- CH-901
|
||||
(12, 850000), -- HAROP
|
||||
(13, 75000), -- Coyote
|
||||
(14, 120000); -- WS-43
|
||||
|
||||
-- 创建初始数据集
|
||||
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
|
||||
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
|
||||
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
|
||||
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
|
||||
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
|
||||
|
||||
-- 关联装备到数据集
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
-- 火箭炮训练集
|
||||
(1, 1), (1, 2), (1, 3), (1, 4),
|
||||
-- 巡飞弹训练集
|
||||
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
|
||||
-- 火箭炮验证集
|
||||
(3, 5), (3, 6), (3, 7),
|
||||
-- 巡飞弹验证集
|
||||
(4, 13), (4, 14);
|
||||
@ -26,15 +26,15 @@
|
||||
*/
|
||||
|
||||
-- 插入装备基本信息
|
||||
INSERT INTO equipment (
|
||||
INSERT INTO equipments (
|
||||
id, -- 装备ID
|
||||
name, -- 装备名称
|
||||
type, -- 装备类型
|
||||
manufacturer -- 制造商
|
||||
) VALUES
|
||||
(1, 'IAI Harop', '巡飞弹', '以色列'),
|
||||
(2, 'IAI Harpy', '巡飞弹', '以色列'),
|
||||
(3, 'IAI Mini Harpy', '巡飞弹', '以色列'),
|
||||
(1, 'IAI Harop', '巡飞弹', '以色列 IAI'),
|
||||
(2, 'IAI Harpy', '巡飞弹', '以色列 IAI'),
|
||||
(3, 'IAI Mini Harpy', '巡飞弹', '以色列 IAI'),
|
||||
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
|
||||
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
|
||||
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
|
||||
@ -65,11 +65,11 @@ INSERT INTO equipment (
|
||||
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
|
||||
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
|
||||
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'),
|
||||
(34, 'Shahed-131', '巡飞弹', '伊朗'),
|
||||
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
|
||||
(36, 'Shahed-136', '巡飞弹', '伊朗'),
|
||||
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
|
||||
(38, 'Shahed-136C', '巡飞弹', '伊朗'),
|
||||
(34, 'Shahed-131', '巡飞弹', '伊朗国防工业'),
|
||||
(35, 'Shahed-131B', '巡飞弹', '伊朗国防工业'),
|
||||
(36, 'Shahed-136', '巡飞弹', '伊朗国防工业'),
|
||||
(37, 'Shahed-136B', '巡飞弹', '伊朗国防工业'),
|
||||
(38, 'Shahed-136C', '巡飞弹', '伊朗国防工业'),
|
||||
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
|
||||
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
|
||||
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
|
||||
@ -285,7 +285,7 @@ INSERT INTO loitering_munition_params (
|
||||
(24, 2.8, 8.0, 70, 180, 240, 50, 10.0, 4000, 25, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(25, 3.0, 9.0, 75, 190, 270, 60, 11.0, 4500, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(26, 3.2, 10.0, 80, 200, 300, 70, 12.0, 5000, 35, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
|
||||
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
|
||||
(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红<EFBFBD><EFBFBD><EFBFBD>'),
|
||||
(28, 3.6, 16.0, 90, 230, 400, 120, 20.0, 6500, 60, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'),
|
||||
(29, 1.2, 1.0, 40, 90, 30, 5, 1.5, 1500, 3, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
|
||||
(30, 1.3, 1.2, 45, 100, 40, 8, 2.0, 2000, 4, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'),
|
||||
@ -338,7 +338,7 @@ INSERT INTO loitering_munition_params (
|
||||
(77, 2.8, 40.0, 250, 200, 120, 180, 50.0, 5500, 90, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Light
|
||||
(78, 3.2, 80.0, 230, 180, 150, 200, 100.0, 6000, 100, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Heavy
|
||||
(79, 1.5, 3.5, 160, 140, 60, 50, 5.0, 3500, 25, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助'), -- Taifun
|
||||
(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K
|
||||
(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '电<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K
|
||||
(81, 1.5, 3.0, 120, 100, 60, 40, 4.0, 3000, 20, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ES
|
||||
(82, 2.0, 5.0, 140, 120, 90, 60, 6.0, 4000, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ER
|
||||
(83, 2.5, 8.0, 160, 140, 120, 80, 10.0, 5000, 40, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'), -- HERO-XL
|
||||
@ -469,7 +469,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
|
||||
(2, '巡飞弹验证集 2024', '包含20个巡飞弹型号,用于验证模型性能', '巡飞弹', '验证');
|
||||
|
||||
-- 训练集(80个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 以色列系列(8/10)
|
||||
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
|
||||
(1, 4), (1, 5), (1, 6), (1, 7), (1, 8), -- Hero系列
|
||||
@ -520,7 +520,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
(1, 96), (1, 97), (1, 98), (1, 99); -- Shadow/Argus系列
|
||||
|
||||
-- 验证集(20个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 以色列系列(2/10)
|
||||
(2, 9), -- Hero-900
|
||||
(2, 48), -- Rotem L
|
||||
@ -666,4 +666,43 @@ SET
|
||||
WHEN l.max_range_km > 500 THEN 5000
|
||||
WHEN l.max_range_km > 100 THEN 3000
|
||||
ELSE 1500
|
||||
END;
|
||||
END;
|
||||
|
||||
|
||||
|
||||
-- 更新巡飞弹的制导精度
|
||||
UPDATE loitering_munition_params l
|
||||
SET guidance_accuracy_m =
|
||||
CASE
|
||||
-- 基础精度(根据制导系统类型)
|
||||
WHEN guidance_system LIKE '%GPS/INS%' AND guidance_system LIKE '%AI辅助%' THEN 2.0
|
||||
WHEN guidance_system LIKE '%GPS/INS%' THEN 3.0
|
||||
WHEN guidance_system LIKE '%激光制导%' THEN 1.0
|
||||
WHEN guidance_system LIKE '%红外制导%' THEN 2.0
|
||||
WHEN guidance_system LIKE '%卫星制导%' THEN 2.5
|
||||
ELSE 5.0
|
||||
END *
|
||||
-- 速度影响因子(速度越快,精度略微降低)
|
||||
CASE
|
||||
WHEN max_speed_ms > 200 THEN 1.2
|
||||
WHEN max_speed_ms > 150 THEN 1.1
|
||||
WHEN max_speed_ms > 100 THEN 1.0
|
||||
ELSE 0.9
|
||||
END *
|
||||
-- 重量影响因子(重量越大,精度略微降低)
|
||||
CASE
|
||||
WHEN warhead_weight_kg > 100 THEN 1.2
|
||||
WHEN warhead_weight_kg > 50 THEN 1.1
|
||||
WHEN warhead_weight_kg > 20 THEN 1.0
|
||||
ELSE 0.9
|
||||
END *
|
||||
-- 飞行高度影响因子(高度越高,精度略微降低)
|
||||
CASE
|
||||
WHEN ceiling_altitude_m > 5000 THEN 1.2
|
||||
WHEN ceiling_altitude_m > 3000 THEN 1.1
|
||||
WHEN ceiling_altitude_m > 1000 THEN 1.0
|
||||
ELSE 0.9
|
||||
END
|
||||
WHERE equipment_id IN (
|
||||
SELECT id FROM equipments WHERE type = '巡飞弹'
|
||||
);
|
||||
@ -40,7 +40,7 @@ INSERT INTO manufacturers (
|
||||
('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业
|
||||
|
||||
-- 俄罗斯供应商
|
||||
('俄罗斯', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
|
||||
('俄罗斯 Rostec', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限
|
||||
('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商
|
||||
('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商
|
||||
|
||||
@ -72,7 +72,9 @@ INSERT INTO manufacturers (
|
||||
('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业
|
||||
|
||||
-- 更新装备表中的供应商ID
|
||||
UPDATE equipment e
|
||||
SET manufacturer_id = m.id
|
||||
FROM manufacturers m
|
||||
WHERE e.manufacturer = m.name;
|
||||
UPDATE equipments e
|
||||
SET manufacturer_id = (
|
||||
SELECT id
|
||||
FROM manufacturers m
|
||||
WHERE m.name = e.manufacturer
|
||||
);
|
||||
@ -1,282 +1,112 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.model_selection import cross_val_score
|
||||
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
||||
from sklearn.impute import SimpleImputer
|
||||
import xgboost as xgb
|
||||
import lightgbm as lgb
|
||||
from sklearn.model_selection import train_test_split
|
||||
import logging
|
||||
import joblib
|
||||
import os
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from datetime import datetime
|
||||
import json
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from src.database import get_db_connection
|
||||
from src.data_preparation import DataPreparation
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from src.data_preparation import DataPreparation, EquipmentDataset
|
||||
from .logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class CostPredictionModel(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(input_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class ModelTrainer:
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化 ModelTrainer
|
||||
"""
|
||||
self.models = {
|
||||
'xgboost': self._create_xgboost_model(),
|
||||
'lightgbm': self._create_lightgbm_model(),
|
||||
'gbm': self._create_gbm_model(),
|
||||
'rf': self._create_rf_model(),
|
||||
'pls': self._create_pls_model()
|
||||
}
|
||||
self.best_model = None
|
||||
self.imputer = SimpleImputer(strategy='mean')
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model = None
|
||||
self.feature_scaler = None
|
||||
self.target_scaler = None
|
||||
self.equipment_type = None
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
|
||||
def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
|
||||
"""
|
||||
训练模型并返回评估结果
|
||||
"""
|
||||
def train_model(self, dataloader, epochs=100, learning_rate=0.001):
|
||||
"""训练模型"""
|
||||
try:
|
||||
self.equipment_type = equipment_type
|
||||
logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
|
||||
logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
|
||||
# 获取输入特征维度
|
||||
sample_features, _ = next(iter(dataloader))
|
||||
input_size = sample_features.shape[1]
|
||||
|
||||
results = {}
|
||||
best_score = -float('inf')
|
||||
best_model_info = None
|
||||
# 创建模型
|
||||
self.model = CostPredictionModel(input_size).to(self.device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||
|
||||
# 首先训练 PLS 模型
|
||||
logger.info("Training pls...")
|
||||
pls_model = self.models['pls']
|
||||
pls_model.fit(X_train, y_train)
|
||||
pls_metrics = self._calculate_metrics(
|
||||
pls_model,
|
||||
X_train, y_train,
|
||||
X_val, y_val
|
||||
)
|
||||
results['pls'] = pls_metrics
|
||||
|
||||
# 训练其他机器学习模型
|
||||
for model_name in model_names:
|
||||
if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
|
||||
continue
|
||||
# 训练循环
|
||||
for epoch in range(epochs):
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
for batch_features, batch_targets in dataloader:
|
||||
# 移动数据到设备
|
||||
batch_features = batch_features.to(self.device)
|
||||
batch_targets = batch_targets.to(self.device)
|
||||
|
||||
if model_name not in self.models:
|
||||
logger.warning(f"Unknown model: {model_name}")
|
||||
continue
|
||||
# 前向传播
|
||||
outputs = self.model(batch_features)
|
||||
loss = criterion(outputs, batch_targets.view(-1, 1))
|
||||
|
||||
logger.info(f"Training {model_name}...")
|
||||
model = self.models[model_name]
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# 计算评估指标
|
||||
metrics = self._calculate_metrics(
|
||||
model,
|
||||
X_train, y_train,
|
||||
X_val, y_val
|
||||
)
|
||||
|
||||
results[model_name] = metrics
|
||||
|
||||
# 更新最佳模型(只在机器学习模型中比较)
|
||||
if metrics['validation']['r2'] > best_score:
|
||||
best_score = metrics['validation']['r2']
|
||||
best_model_info = {
|
||||
'type': model_name,
|
||||
'r2': metrics['validation']['r2'],
|
||||
'mae': metrics['validation']['mae'],
|
||||
'rmse': metrics['validation']['rmse']
|
||||
}
|
||||
self.best_model = model
|
||||
# 记录训练进度
|
||||
if (epoch + 1) % 10 == 0:
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
|
||||
|
||||
# 保存最佳模型和 PLS 模型
|
||||
if equipment_type and best_model_info:
|
||||
self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
|
||||
|
||||
return {
|
||||
'metrics': results,
|
||||
'best_model': best_model_info
|
||||
}
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in model training: {str(e)}")
|
||||
raise
|
||||
|
||||
def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
|
||||
"""
|
||||
计算模型评估指标
|
||||
"""
|
||||
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
||||
|
||||
# 训练集评估
|
||||
train_pred = model.predict(X_train)
|
||||
|
||||
# 如果使用了标准化,需要转换回原始范围
|
||||
if hasattr(self, 'target_scaler'):
|
||||
train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
|
||||
y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
|
||||
else:
|
||||
y_train_orig = y_train
|
||||
|
||||
# 记录预测范围
|
||||
logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
|
||||
logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
|
||||
|
||||
train_metrics = {
|
||||
'r2': r2_score(y_train_orig, train_pred),
|
||||
'mae': mean_absolute_error(y_train_orig, train_pred),
|
||||
'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
|
||||
}
|
||||
|
||||
# 验证集评估
|
||||
if X_val is not None and y_val is not None:
|
||||
val_pred = model.predict(X_val)
|
||||
|
||||
# 如果使用了标准化,需要转换回原始范围
|
||||
if hasattr(self, 'target_scaler'):
|
||||
val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
|
||||
y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
|
||||
else:
|
||||
y_val_orig = y_val
|
||||
|
||||
# 记录预测范围
|
||||
logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
|
||||
logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
|
||||
|
||||
val_metrics = {
|
||||
'r2': r2_score(y_val_orig, val_pred),
|
||||
'mae': mean_absolute_error(y_val_orig, val_pred),
|
||||
'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
|
||||
}
|
||||
else:
|
||||
# 使用交叉验证
|
||||
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
|
||||
val_metrics = {
|
||||
'r2': cv_scores.mean(),
|
||||
'mae': None,
|
||||
'rmse': None
|
||||
}
|
||||
|
||||
return {
|
||||
'train': train_metrics,
|
||||
'validation': val_metrics
|
||||
}
|
||||
|
||||
def _create_xgboost_model(self):
|
||||
"""
|
||||
创建 XGBoost 模型,增强正则化
|
||||
"""
|
||||
return xgb.XGBRegressor(
|
||||
n_estimators=50, # 减少树的数量
|
||||
learning_rate=0.05, # 学习率
|
||||
max_depth=3, # 减小树的深
|
||||
min_child_weight=3, # 增加节点权重
|
||||
subsample=0.7, # 减小样本采样比例
|
||||
colsample_bytree=0.7, # 减小特征采样比例
|
||||
reg_alpha=0.1, # L1 正则化
|
||||
reg_lambda=1, # L2 正则化
|
||||
random_state=42
|
||||
)
|
||||
|
||||
def _create_lightgbm_model(self):
|
||||
"""
|
||||
创建 LightGBM 模型,增强正则化
|
||||
"""
|
||||
return lgb.LGBMRegressor(
|
||||
n_estimators=50,
|
||||
learning_rate=0.05,
|
||||
max_depth=3,
|
||||
num_leaves=7,
|
||||
min_data_in_leaf=3,
|
||||
min_sum_hessian_in_leaf=1e-3,
|
||||
subsample=0.7,
|
||||
colsample_bytree=0.7,
|
||||
reg_alpha=0.1,
|
||||
reg_lambda=1,
|
||||
random_state=42,
|
||||
verbose=-1
|
||||
)
|
||||
|
||||
def _create_gbm_model(self):
|
||||
"""
|
||||
创建 GBM 模型,增强正则化以减轻过拟合
|
||||
"""
|
||||
return GradientBoostingRegressor(
|
||||
n_estimators=100,
|
||||
learning_rate=0.1,
|
||||
max_depth=3,
|
||||
random_state=42,
|
||||
subsample=0.8,
|
||||
min_samples_split=3,
|
||||
min_samples_leaf=2
|
||||
)
|
||||
|
||||
def _create_rf_model(self):
|
||||
"""
|
||||
创建随机森林模型,针对小样本数据调整参数
|
||||
"""
|
||||
return RandomForestRegressor(
|
||||
n_estimators=100,
|
||||
max_depth=3,
|
||||
random_state=42,
|
||||
min_samples_split=3,
|
||||
min_samples_leaf=2
|
||||
)
|
||||
|
||||
def _create_pls_model(self):
|
||||
"""
|
||||
创建 PLS 模型,优化参数配置
|
||||
"""
|
||||
return PLSRegression(
|
||||
n_components=2, # 减少主成分数量,从5减到2
|
||||
scale=True, # 保持数据标准化
|
||||
max_iter=500, # 减少最大迭代次数,避免过拟合
|
||||
tol=1e-6 # 降低收敛精度,避免过拟合
|
||||
)
|
||||
|
||||
def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
|
||||
"""
|
||||
保存最佳模型和 PLS 模型
|
||||
"""
|
||||
def save_model(self, equipment_type):
|
||||
"""保存模型"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 1. 保存最佳机器学习模型
|
||||
model_path = f'{model_dir}/{equipment_type}_{timestamp}'
|
||||
if isinstance(self.best_model, xgb.XGBRegressor):
|
||||
self.best_model.save_model(f'{model_path}.json')
|
||||
model_format = 'json'
|
||||
else:
|
||||
joblib.dump(self.best_model, f'{model_path}.joblib')
|
||||
model_format = 'joblib'
|
||||
|
||||
# 2. 保存 PLS 模型
|
||||
pls_model = self.models['pls']
|
||||
pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
|
||||
joblib.dump(pls_model, pls_path)
|
||||
# 保存模型
|
||||
model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth'
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'input_size': self.model.layers[0].in_features
|
||||
}, model_path)
|
||||
|
||||
# 3. 保存标准化器
|
||||
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
|
||||
joblib.dump({
|
||||
# 保存标准化器
|
||||
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth'
|
||||
torch.save({
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}, scaler_path)
|
||||
|
||||
logger.info(f"Saved best model to {model_path}.{model_format}")
|
||||
logger.info(f"Saved PLS model to {pls_path}")
|
||||
logger.info(f"Saved scalers to {scaler_path}")
|
||||
|
||||
# 4. 更新数据库中的模型记录
|
||||
# 更新数据库
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
@ -287,420 +117,73 @@ class ModelTrainer:
|
||||
WHERE equipment_type = %s
|
||||
""", (equipment_type,))
|
||||
|
||||
# 获取 PLS 模型的评估指标
|
||||
pls_metrics = self._calculate_metrics(
|
||||
self.models['pls'],
|
||||
X_train,
|
||||
y_train,
|
||||
X_val,
|
||||
y_val
|
||||
)
|
||||
|
||||
# 保存最佳机器学习模型记录
|
||||
self.best_model.equipment_type = equipment_type # 设置装备类型
|
||||
ml_feature_importance = self._get_feature_importance(self.best_model)
|
||||
|
||||
# 保存新模型记录
|
||||
cursor.execute("""
|
||||
INSERT INTO trained_models (
|
||||
model_name, model_type, equipment_type, model_path, scaler_path,
|
||||
r2_score, mae, rmse, feature_importance, training_data_size,
|
||||
training_date, is_active, created_by
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
scaler_path, training_date, is_active, created_by
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s)
|
||||
""", (
|
||||
f"{equipment_type}_{timestamp}", # model_name
|
||||
best_model_info['type'], # model_type
|
||||
equipment_type, # equipment_type
|
||||
f"{model_path}.{model_format}", # model_path
|
||||
scaler_path, # scaler_path
|
||||
best_model_info['r2'], # r2_score
|
||||
best_model_info['mae'], # mae
|
||||
best_model_info['rmse'], # rmse
|
||||
json.dumps(ml_feature_importance), # feature_importance
|
||||
len(X_train), # training_data_size
|
||||
'system' # created_by
|
||||
))
|
||||
|
||||
# 保存 PLS 模型记录
|
||||
pls_model.equipment_type = equipment_type # 设置装备类型
|
||||
pls_feature_importance = self._get_feature_importance(pls_model)
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO trained_models (
|
||||
model_name, model_type, equipment_type, model_path, scaler_path,
|
||||
r2_score, mae, rmse, feature_importance, training_data_size,
|
||||
training_date, is_active, created_by
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
|
||||
""", (
|
||||
f"{equipment_type}_{timestamp}_pls", # model_name
|
||||
'pls', # model_type
|
||||
equipment_type, # equipment_type
|
||||
pls_path, # model_path
|
||||
scaler_path, # scaler_path
|
||||
float(pls_metrics['validation']['r2']), # r2_score
|
||||
float(pls_metrics['validation']['mae']), # mae
|
||||
float(pls_metrics['validation']['rmse']), # rmse
|
||||
json.dumps(pls_feature_importance), # feature_importance
|
||||
len(X_train), # training_data_size
|
||||
'system' # created_by
|
||||
f"{equipment_type}_{timestamp}",
|
||||
'pytorch',
|
||||
equipment_type,
|
||||
model_path,
|
||||
scaler_path,
|
||||
'system'
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving models: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
|
||||
def load_model(self, equipment_type, model_type='ml'):
|
||||
"""
|
||||
加载已训练的模型
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading {model_type} model for {equipment_type}")
|
||||
|
||||
# 从数据库获取激活的模型
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {str(e)}")
|
||||
return False
|
||||
|
||||
def load_model(self, equipment_type):
|
||||
"""加载模型"""
|
||||
try:
|
||||
# 从数据库获取最新的激活模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建查询语句
|
||||
if model_type == 'pls':
|
||||
query = """
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s
|
||||
AND model_type = 'pls'
|
||||
AND is_active = TRUE
|
||||
LIMIT 1
|
||||
"""
|
||||
params = (equipment_type,)
|
||||
else:
|
||||
query = """
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s
|
||||
AND model_type != 'pls'
|
||||
AND is_active = TRUE
|
||||
LIMIT 1
|
||||
"""
|
||||
params = (equipment_type,)
|
||||
|
||||
# 记录查询信息
|
||||
logger.info(f"Executing query: {query}")
|
||||
logger.info(f"Query params: {params}")
|
||||
|
||||
cursor.execute(query, params)
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
model_record = cursor.fetchone()
|
||||
|
||||
# 记录查询结果
|
||||
if model_record:
|
||||
logger.info(f"Found model record: {model_record}")
|
||||
else:
|
||||
logger.warning(f"No active model found for type {model_type}")
|
||||
if not model_record:
|
||||
return False
|
||||
|
||||
# 检查文件是否存在
|
||||
logger.info(f"Checking model file: {model_record['model_path']}")
|
||||
logger.info(f"Checking scaler file: {model_record['scaler_path']}")
|
||||
|
||||
if not os.path.exists(model_record['model_path']):
|
||||
logger.error(f"Model file not found: {model_record['model_path']}")
|
||||
raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
|
||||
|
||||
if not os.path.exists(model_record['scaler_path']):
|
||||
logger.error(f"Scaler file not found: {model_record['scaler_path']}")
|
||||
raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
|
||||
|
||||
# 加载模型文件
|
||||
logger.info(f"Loading model from {model_record['model_path']}")
|
||||
if model_type == 'pls':
|
||||
self.best_model = joblib.load(model_record['model_path'])
|
||||
logger.info("Loaded PLS model")
|
||||
else:
|
||||
if model_record['model_type'] == 'xgboost':
|
||||
self.best_model = xgb.XGBRegressor()
|
||||
self.best_model.load_model(model_record['model_path'])
|
||||
logger.info("Loaded XGBoost model")
|
||||
else:
|
||||
self.best_model = joblib.load(model_record['model_path'])
|
||||
logger.info(f"Loaded {model_record['model_type']} model")
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_record['model_path'])
|
||||
input_size = checkpoint['input_size']
|
||||
self.model = CostPredictionModel(input_size).to(self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# 加载标准化器
|
||||
logger.info(f"Loading scalers from {model_record['scaler_path']}")
|
||||
scalers = joblib.load(model_record['scaler_path'])
|
||||
scalers = torch.load(model_record['scaler_path'])
|
||||
self.feature_scaler = scalers['feature_scaler']
|
||||
self.target_scaler = scalers['target_scaler']
|
||||
logger.info("Loaded scalers successfully")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
logger.error(f"Detailed traceback:", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_missile_features(self):
|
||||
"""获取巡飞弹的特征列表"""
|
||||
return [
|
||||
# 基本参数
|
||||
'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km',
|
||||
|
||||
# 性能参数 - 新增和修改的参数
|
||||
'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
|
||||
'endurance_min', 'max_payload_kg', 'min_combat_radius_km',
|
||||
|
||||
# 动力系统参数 - 新增参数
|
||||
'engine_power_kw', 'engine_thrust_n',
|
||||
|
||||
# 制导与控制参数 - 新增参数
|
||||
'datalink_range_km', 'guidance_accuracy_m',
|
||||
'min_altitude_m', 'max_altitude_m',
|
||||
|
||||
# 特征工程参数 - 新增评分指标
|
||||
'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio',
|
||||
'guidance_system_score', 'warhead_power_score'
|
||||
]
|
||||
|
||||
def get_rocket_features(self):
|
||||
"""获取火箭炮的特征列表"""
|
||||
return [
|
||||
# 基本参数
|
||||
'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km',
|
||||
|
||||
# 火箭炮特有参数 - 新增和修改的参数
|
||||
'firing_angle_horizontal', 'firing_angle_vertical',
|
||||
'rocket_length_m', 'rocket_diameter_mm', 'rocket_weight_kg',
|
||||
'rate_of_fire', 'combat_weight_kg', 'speed_kmh',
|
||||
'min_range_km', 'max_range_km', 'mobility_type', 'structure_layout',
|
||||
'engine_model', 'power_hp', 'travel_range_km',
|
||||
|
||||
# 特征工程参数 - 新增评分指标
|
||||
'fire_density', 'range_ratio', 'mobility_score',
|
||||
'combat_readiness_score', 'rocket_power_ratio',
|
||||
'platform_efficiency', 'deployment_score',
|
||||
'terrain_adaptability_score'
|
||||
]
|
||||
|
||||
|
||||
def predict(self, features):
|
||||
"""使用加载的模型进行预测"""
|
||||
"""使用模型进行预测"""
|
||||
try:
|
||||
if not self.best_model:
|
||||
raise ValueError("No model loaded")
|
||||
|
||||
if not self.feature_scaler:
|
||||
raise ValueError("Feature scaler not loaded")
|
||||
|
||||
if not self.target_scaler:
|
||||
raise ValueError("Target scaler not loaded")
|
||||
|
||||
logger.info("Starting prediction")
|
||||
logger.info(f"Input features shape: {features.shape}")
|
||||
logger.info(f"Input features: \n{features}")
|
||||
|
||||
# 获取正确的特征列表
|
||||
if self.equipment_type == '巡飞弹':
|
||||
feature_list = self.get_missile_features()
|
||||
logger.info(f"Using missile features: {feature_list}")
|
||||
|
||||
# 确保特征顺序一致
|
||||
features_ordered = np.zeros((features.shape[0], len(feature_list)))
|
||||
for i, feature_name in enumerate(feature_list):
|
||||
if feature_name in features:
|
||||
features_ordered[:, i] = features[feature_name]
|
||||
features = features_ordered
|
||||
|
||||
# 处理缺失值
|
||||
features_filled = np.array(features, dtype=float)
|
||||
features_filled[np.isnan(features_filled)] = 0
|
||||
features_filled = np.nan_to_num(features_filled, 0)
|
||||
|
||||
logger.info(f"Filled features: \n{features_filled}")
|
||||
|
||||
# 标准化特征
|
||||
X = self.feature_scaler.transform(features_filled)
|
||||
logger.info(f"Transformed features shape: {X.shape}")
|
||||
logger.info(f"Transformed features: \n{X}")
|
||||
|
||||
# 预测
|
||||
y_pred_scaled = self.best_model.predict(X)
|
||||
logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
|
||||
logger.info(f"Scaled prediction: {y_pred_scaled}")
|
||||
|
||||
# 反标准化
|
||||
y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
|
||||
logger.info(f"Final prediction shape: {y_pred.shape}")
|
||||
logger.info(f"Final prediction: {y_pred}")
|
||||
|
||||
return y_pred.ravel()
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# 转换为tensor并移动到正确的设备
|
||||
features_tensor = torch.FloatTensor(features).to(self.device)
|
||||
# 进行预测
|
||||
predictions = self.model(features_tensor)
|
||||
# 移回CPU并转换为numpy数组
|
||||
return predictions.cpu().numpy()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction: {str(e)}")
|
||||
raise
|
||||
|
||||
def _get_feature_importance(self, model):
|
||||
"""
|
||||
获取特征重要性
|
||||
"""
|
||||
try:
|
||||
if not model:
|
||||
return {}
|
||||
|
||||
# 获取征名称
|
||||
if self.equipment_type == '巡飞弹':
|
||||
feature_names = [
|
||||
# 基本参数
|
||||
'length_m', 'width_m', 'height_m', 'weight_kg',
|
||||
|
||||
# 性能参数
|
||||
'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
|
||||
'endurance_min', 'max_range_km','max_payload_kg', 'ceiling_altitude_m',
|
||||
'combat_radius_km',
|
||||
|
||||
# 动力系统参数
|
||||
'engine_power_kw', 'engine_thrust_n',
|
||||
|
||||
# 制导与控制参数
|
||||
'datalink_range_km', 'guidance_accuracy_m',
|
||||
'min_altitude_m', 'max_altitude_m',
|
||||
|
||||
# 特征工程参数
|
||||
'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio',
|
||||
'guidance_system_score', 'warhead_power_score'
|
||||
]
|
||||
else:
|
||||
# 其他装备类型使用原有的特征获取逻辑
|
||||
feature_analyzer = FeatureAnalysis()
|
||||
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
|
||||
|
||||
# 获取特征重要性
|
||||
if hasattr(model, 'feature_importances_'):
|
||||
importances = model.feature_importances_
|
||||
elif hasattr(model, 'coef_'):
|
||||
if len(model.coef_.shape) > 1: # 如果是二维数组
|
||||
importances = np.abs(model.coef_[0]) # 取第一行
|
||||
else:
|
||||
importances = np.abs(model.coef_)
|
||||
else:
|
||||
return {}
|
||||
|
||||
# 创建特征重要性字典
|
||||
importance_dict = {}
|
||||
for name, importance in zip(feature_names, importances):
|
||||
importance_dict[name] = float(importance) # 确保转换为 Python 标量
|
||||
|
||||
# 按重要性降序排序
|
||||
sorted_dict = dict(sorted(
|
||||
importance_dict.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
))
|
||||
|
||||
# 过滤掉重要性为0的特征
|
||||
return {k: v for k, v in sorted_dict.items() if v > 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting feature importance: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _calculate_confidence_interval(self, prediction, confidence=0.95):
|
||||
"""
|
||||
计算预测值的置信区间
|
||||
"""
|
||||
try:
|
||||
# 使用预测值的20%作为标准差(增加不确定性)
|
||||
std = abs(prediction) * 0.2
|
||||
|
||||
# 计算置信区间
|
||||
from scipy import stats
|
||||
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
|
||||
|
||||
# 确保区间值为正数且合理
|
||||
lower = max(1000, interval[0]) # 最小值设为1000元
|
||||
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
|
||||
|
||||
logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
|
||||
|
||||
return [lower, upper]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating confidence interval: {str(e)}")
|
||||
# 如果计算失败,返回基于20%的简单区间
|
||||
lower = max(1000, prediction * 0.8)
|
||||
upper = prediction * 1.2
|
||||
return [lower, upper]
|
||||
|
||||
def get_model_type(self):
|
||||
"""
|
||||
获取当前模型的类型
|
||||
"""
|
||||
if isinstance(self.best_model, xgb.XGBRegressor):
|
||||
return 'xgboost'
|
||||
elif isinstance(self.best_model, lgb.LGBMRegressor):
|
||||
return 'lightgbm'
|
||||
elif isinstance(self.best_model, GradientBoostingRegressor):
|
||||
return 'gbm'
|
||||
elif isinstance(self.best_model, RandomForestRegressor):
|
||||
return 'rf'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
def _get_pls_feature_importance(self):
|
||||
"""
|
||||
获取 PLS 模型的特征重要性
|
||||
"""
|
||||
try:
|
||||
if not self.models['pls']:
|
||||
return {}
|
||||
|
||||
# 获取特征名称
|
||||
feature_analyzer = FeatureAnalysis()
|
||||
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
|
||||
|
||||
# 获取 PLS 模型的系数作为特征重要性
|
||||
pls_model = self.models['pls']
|
||||
if hasattr(pls_model, 'coef_'):
|
||||
# 使用绝对值作为重要性指标
|
||||
importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
|
||||
else:
|
||||
return {}
|
||||
|
||||
# 创建特征重要性字典
|
||||
importance_dict = {}
|
||||
for name, importance in zip(feature_names, importances):
|
||||
importance_dict[name] = float(importance) # 确保转换为 Python 标量
|
||||
|
||||
# 按重要性降序排序
|
||||
sorted_dict = dict(sorted(
|
||||
importance_dict.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
))
|
||||
|
||||
# 过滤掉重要性为0的特征
|
||||
return {k: v for k, v in sorted_dict.items() if v > 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting PLS feature importance: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
return {}
|
||||
|
||||
def _preprocess_data(self, data):
|
||||
"""数据预处理"""
|
||||
try:
|
||||
# 获取正确的特征列表
|
||||
if self.equipment_type == '巡飞弹':
|
||||
feature_list = self.get_missile_features()
|
||||
else:
|
||||
feature_list = self.get_rocket_features()
|
||||
|
||||
logger.info(f"Using features: {feature_list}")
|
||||
|
||||
# 处理缺失值
|
||||
features_filled = np.array(data, dtype=float)
|
||||
features_filled[np.isnan(features_filled)] = 0
|
||||
features_filled = np.nan_to_num(features_filled, 0)
|
||||
|
||||
logger.info(f"Filled features: \n{features_filled}")
|
||||
|
||||
return features_filled
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data preprocessing: {str(e)}")
|
||||
raise
|
||||
@ -1,485 +0,0 @@
|
||||
-- 清空现有数据
|
||||
SET FOREIGN_KEY_CHECKS=0;
|
||||
TRUNCATE TABLE dataset_equipment;
|
||||
TRUNCATE TABLE datasets;
|
||||
TRUNCATE TABLE cost_data;
|
||||
TRUNCATE TABLE loitering_munition_params;
|
||||
TRUNCATE TABLE common_params;
|
||||
TRUNCATE TABLE equipment;
|
||||
SET FOREIGN_KEY_CHECKS=1;
|
||||
|
||||
-- 按系列插入装备数据,确保ID连续
|
||||
-- 1. HAROP/Harpy 系列 (ID: 1-3)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(1, 'IAI Harop', '巡飞弹', '以色列'),
|
||||
(2, 'IAI Harpy', '巡飞弹', '以色列'),
|
||||
(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
|
||||
|
||||
-- 2. Hero 系列 (ID: 4-9)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
|
||||
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
|
||||
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
|
||||
(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
|
||||
(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
|
||||
(9, 'Hero-900', '巡飞弹', '以色列 UVision');
|
||||
|
||||
-- 3. Switchblade 系列 (ID: 10-13)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
|
||||
(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
|
||||
(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
|
||||
(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
|
||||
|
||||
-- 4. Warmate 系列 (ID: 14-18)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
|
||||
(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
|
||||
(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
|
||||
(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
|
||||
(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
|
||||
|
||||
-- 5. CH-901/902 系列 (ID: 19-23)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(19, 'CH-901', '巡飞弹', '中国航天科工'),
|
||||
(20, 'CH-901A', '巡飞弹', '中国航天科工'),
|
||||
(21, 'CH-901H', '巡飞弹', '中国航天科工'),
|
||||
(22, 'CH-902', '巡飞弹', '中国航天科工'),
|
||||
(23, 'CH-902A', '巡飞弹', '中国航天科工');
|
||||
|
||||
-- 6. WS-43/61 系列 (ID: 24-28)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(24, 'WS-43', '巡飞弹', '中国航天科工'),
|
||||
(25, 'WS-43A', '巡飞弹', '中国航天科工'),
|
||||
(26, 'WS-43B', '巡飞弹', '中国航天科工'),
|
||||
(27, 'WS-61', '巡飞弹', '中国航天科工'),
|
||||
(28, 'WS-61A', '巡飞弹', '中国航天科工');
|
||||
|
||||
-- 7. Kargu/Alpagu 系列 (ID: 29-33)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(29, 'Kargu', '巡飞弹', '土耳其 STM'),
|
||||
(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
|
||||
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
|
||||
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
|
||||
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
|
||||
|
||||
-- 8. Shahed 系列 (ID: 34-38)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(34, 'Shahed-131', '巡飞弹', '伊朗'),
|
||||
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
|
||||
(36, 'Shahed-136', '巡飞弹', '伊朗'),
|
||||
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
|
||||
(38, 'Shahed-136C', '巡飞弹', '伊朗');
|
||||
|
||||
-- 9. Green Dragon 系列 (ID: 39-43)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
|
||||
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
|
||||
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
|
||||
(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
|
||||
(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
|
||||
|
||||
-- 10. Phoenix Ghost 系列 (ID: 44-48)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
|
||||
(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
|
||||
|
||||
-- 11. ZALA Lancet 系列 (ID: 49-52)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
|
||||
|
||||
-- 12. Rotem L 系列 (ID: 53-56)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
|
||||
(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
|
||||
(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
|
||||
(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
|
||||
|
||||
-- 13. KUB-BLA 系列 (ID: 57-60)
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
|
||||
(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
|
||||
|
||||
-- 插入通用参数
|
||||
INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
|
||||
(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
|
||||
(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
|
||||
(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
|
||||
(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
|
||||
(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
|
||||
(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
|
||||
(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
|
||||
(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
|
||||
(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
|
||||
(10, 0.58, 0.12, 0.12, 2.5, 10),
|
||||
(11, 1.30, 0.22, 0.22, 15.0, 40),
|
||||
(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
|
||||
(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
|
||||
(14, 0.68, 0.12, 0.12, 2.5, 10),
|
||||
(15, 1.30, 0.22, 0.22, 15.0, 40),
|
||||
(16, 0.68, 0.12, 0.12, 2.5, 10),
|
||||
(17, 1.30, 0.22, 0.22, 15.0, 40),
|
||||
(18, 0.68, 0.12, 0.12, 2.5, 10),
|
||||
(19, 1.2, 0.18, 0.18, 9.0, 20),
|
||||
(20, 1.2, 0.18, 0.18, 9.3, 25),
|
||||
(21, 1.2, 0.18, 0.18, 9.5, 20),
|
||||
(22, 1.4, 0.22, 0.22, 15.0, 30),
|
||||
(23, 1.4, 0.22, 0.22, 15.5, 35),
|
||||
(24, 1.8, 0.35, 0.35, 20, 60),
|
||||
(25, 1.8, 0.35, 0.35, 21, 70),
|
||||
(26, 1.9, 0.35, 0.35, 22, 80),
|
||||
(27, 2.2, 0.40, 0.40, 35, 100),
|
||||
(28, 2.2, 0.40, 0.40, 37, 120),
|
||||
(29, 0.6, 0.35, 0.35, 7.0, 10),
|
||||
(30, 0.6, 0.35, 0.35, 7.2, 15),
|
||||
(31, 1.0, 0.23, 0.23, 3.7, 5),
|
||||
(32, 1.0, 0.23, 0.23, 3.9, 8),
|
||||
(33, 0.6, 0.35, 0.35, 7.5, 15),
|
||||
(34, 2.6, 0.34, 0.34, 135, 900),
|
||||
(35, 2.6, 0.34, 0.34, 140, 1000),
|
||||
(36, 3.5, 0.42, 0.42, 200, 2000),
|
||||
(37, 3.5, 0.42, 0.42, 210, 2200),
|
||||
(38, 3.5, 0.42, 0.42, 215, 2500),
|
||||
(39, 1.5, 0.20, 0.20, 15, 40),
|
||||
(40, 1.6, 0.20, 0.20, 16, 50),
|
||||
(41, 1.5, 0.20, 0.20, 15.5, 45),
|
||||
(42, 1.5, 0.20, 0.20, 15.8, 40),
|
||||
(43, 1.2, 0.18, 0.18, 12, 30),
|
||||
(44, 1.5, 0.25, 0.25, 14.0, 30),
|
||||
(45, 1.5, 0.25, 0.25, 14.5, 35),
|
||||
(46, 1.6, 0.26, 0.26, 15.0, 40),
|
||||
(47, 1.5, 0.25, 0.25, 14.8, 30),
|
||||
(48, 1.7, 0.27, 0.27, 16.0, 50),
|
||||
(49, 1.0, 0.20, 0.20, 5.0, 40),
|
||||
(50, 1.65, 0.35, 0.35, 12.0, 70),
|
||||
(51, 1.65, 0.35, 0.35, 12.5, 80),
|
||||
(52, 1.80, 0.40, 0.40, 15.0, 100),
|
||||
(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
|
||||
(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
|
||||
(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
|
||||
(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
|
||||
(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
|
||||
(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
|
||||
(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
|
||||
(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
|
||||
|
||||
-- 插入特有参数
|
||||
INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
|
||||
endurance_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
power_system,
|
||||
guidance_system
|
||||
) VALUES
|
||||
-- HAROP/Harpy系列
|
||||
(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
|
||||
(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
|
||||
|
||||
-- Hero系列
|
||||
(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
|
||||
(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
|
||||
(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
|
||||
-- Switchblade系列
|
||||
(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
|
||||
(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
|
||||
-- Warmate系列
|
||||
(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
|
||||
(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
|
||||
|
||||
-- CH-901/902系列
|
||||
(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
|
||||
(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
|
||||
(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
|
||||
(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
|
||||
(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
|
||||
(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
|
||||
(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
|
||||
(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
|
||||
(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
|
||||
(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
|
||||
(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
|
||||
(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
|
||||
(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
|
||||
(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
|
||||
(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
|
||||
(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
|
||||
(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
|
||||
|
||||
-- 插入成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(1, 800000), -- IAI Harop
|
||||
(2, 700000), -- IAI Harpy
|
||||
(3, 350000), -- IAI Mini Harpy
|
||||
(4, 70000), -- Hero-30
|
||||
(5, 120000), -- Hero-70
|
||||
(6, 150000), -- Hero-120
|
||||
(7, 300000), -- Hero-250
|
||||
(8, 400000), -- Hero-400EC
|
||||
(9, 650000), -- Hero-900
|
||||
(10, 60000), -- Switchblade 300
|
||||
(11, 180000), -- Switchblade 600
|
||||
(12, 75000), -- Switchblade 300 Block 10
|
||||
(13, 200000), -- Switchblade 600 Extended Range
|
||||
(14, 60000), -- Warmate 1.0
|
||||
(15, 180000), -- Warmate 2.0
|
||||
(16, 60000), -- Warmate-V
|
||||
(17, 180000), -- Warmate-L
|
||||
(18, 60000), -- Warmate 3.0
|
||||
(19, 100000), -- CH-901
|
||||
(20, 120000), -- CH-901A
|
||||
(21, 130000), -- CH-901H
|
||||
(22, 180000), -- CH-902
|
||||
(23, 200000), -- CH-902A
|
||||
(24, 120000), -- WS-43
|
||||
(25, 150000), -- WS-43A
|
||||
(26, 180000), -- WS-43B
|
||||
(27, 300000), -- WS-61
|
||||
(28, 350000), -- WS-61A
|
||||
(29, 70000), -- Kargu
|
||||
(30, 85000), -- Kargu-2
|
||||
(31, 45000), -- Alpagu
|
||||
(32, 55000), -- Alpagu Block-II
|
||||
(33, 95000), -- Kargu Autonomous
|
||||
(34, 20000), -- Shahed-131
|
||||
(35, 25000), -- Shahed-131B
|
||||
(36, 40000), -- Shahed-136
|
||||
(37, 45000), -- Shahed-136B
|
||||
(38, 50000), -- Shahed-136C
|
||||
(39, 160000), -- Green Dragon
|
||||
(40, 200000), -- Green Dragon Extended Range
|
||||
(41, 180000), -- Green Dragon Block 2
|
||||
(42, 190000), -- Green Dragon Maritime
|
||||
(43, 140000), -- Green Dragon-S
|
||||
(44, 150000), -- Phoenix Ghost
|
||||
(45, 180000), -- Phoenix Ghost Block I
|
||||
(46, 220000), -- Phoenix Ghost Block II
|
||||
(47, 190000), -- Phoenix Ghost Maritime
|
||||
(48, 250000), -- Phoenix Ghost-ER
|
||||
(49, 80000), -- Lancet-1
|
||||
(50, 150000), -- Lancet-3
|
||||
(51, 180000), -- Lancet-3M
|
||||
(52, 250000), -- Lancet-4
|
||||
(53, 65000), -- Rotem L
|
||||
(54, 85000), -- Rotem L-X
|
||||
(55, 75000), -- Rotem L-M
|
||||
(56, 95000), -- Rotem L-ER
|
||||
(57, 95000), -- KUB-BLA
|
||||
(58, 120000), -- KUB-BLA-E
|
||||
(59, 110000), -- KUB-BLA-M
|
||||
(60, 150000); -- KUB-BLA-ER
|
||||
|
||||
-- 创建数据集
|
||||
INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
|
||||
(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
|
||||
(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
|
||||
|
||||
-- 关联装备到数据集(按照制造商和型号分配)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
-- 训练集(约80%的数据,48个型号)
|
||||
-- 以色列系列
|
||||
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
|
||||
(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
|
||||
(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
|
||||
(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
|
||||
|
||||
-- 美国系列
|
||||
(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
|
||||
(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
|
||||
|
||||
-- 中国系列
|
||||
(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
|
||||
(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
|
||||
|
||||
-- 波兰和土耳其系列
|
||||
(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
|
||||
(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
|
||||
|
||||
-- 俄罗斯系列
|
||||
(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
|
||||
|
||||
-- 验证集(约20%的数据,12个型号)
|
||||
-- 混合系列
|
||||
(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
|
||||
(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
|
||||
(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
|
||||
|
||||
-- 添加分类特征编码
|
||||
INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
|
||||
-- 战斗部类型编码
|
||||
('warhead_type', '破片杀伤战斗部', 1),
|
||||
('warhead_type', '破甲战斗部', 2),
|
||||
('warhead_type', '高爆战斗部', 3),
|
||||
('warhead_type', '破片杀伤/破甲双用战斗部', 4),
|
||||
('warhead_type', '模块化战斗部', 5),
|
||||
|
||||
-- 发射方式编码
|
||||
('launch_mode', '箱式发射', 1),
|
||||
('launch_mode', '弹射式发射', 2),
|
||||
('launch_mode', '垂直起降', 3),
|
||||
('launch_mode', '单兵发射管', 4),
|
||||
('launch_mode', '箱式发射/弹射式', 5),
|
||||
('launch_mode', '箱式发射/空中发射', 6),
|
||||
|
||||
-- 动力装置编码(按复杂度递增)
|
||||
('power_system', '电动机', 1),
|
||||
('power_system', '活塞发动机', 2),
|
||||
|
||||
-- 制导系统编码(按复杂度递增)
|
||||
('guidance_system', 'GPS/INS', 1),
|
||||
('guidance_system', 'GPS/INS/光电', 2),
|
||||
('guidance_system', 'GPS/INS/光电/数据链', 3),
|
||||
('guidance_system', 'GPS/INS/光电/AI识别', 4),
|
||||
('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
|
||||
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
|
||||
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
|
||||
|
||||
-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
|
||||
UPDATE loitering_munition_params l
|
||||
JOIN common_params c ON l.equipment_id = c.equipment_id
|
||||
SET
|
||||
-- 新增关键参数
|
||||
l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
|
||||
l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
|
||||
l.engine_power_kw =
|
||||
CASE
|
||||
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
|
||||
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
|
||||
END,
|
||||
l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
|
||||
l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
|
||||
l.guidance_accuracy_m =
|
||||
CASE
|
||||
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
|
||||
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
|
||||
ELSE 3.0
|
||||
END,
|
||||
l.min_altitude_m = -- 最小作战高度
|
||||
CASE
|
||||
-- 大型巡飞弹(体型大、重量大)
|
||||
WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
|
||||
|
||||
-- 中型巡飞弹
|
||||
WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
|
||||
|
||||
-- 中小型巡飞弹
|
||||
WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
|
||||
|
||||
-- 小型巡飞弹
|
||||
WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
|
||||
|
||||
-- 超小型巡飞弹
|
||||
WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
|
||||
|
||||
-- 其他型号使用默认值
|
||||
ELSE 50
|
||||
END,
|
||||
l.max_altitude_m =
|
||||
CASE
|
||||
WHEN c.max_range_km > 500 THEN 5000
|
||||
WHEN c.max_range_km > 100 THEN 3000
|
||||
ELSE 1500
|
||||
END,
|
||||
|
||||
-- 特征工程字段
|
||||
l.length_width_ratio = c.length_m / c.width_m,
|
||||
l.weight_range_ratio = c.weight_kg / c.max_range_km,
|
||||
l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
|
||||
l.guidance_system_score =
|
||||
CASE
|
||||
WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
|
||||
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
|
||||
WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
|
||||
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
|
||||
ELSE 2
|
||||
END,
|
||||
l.warhead_power_score =
|
||||
CASE
|
||||
WHEN l.warhead_type = '模块化战斗部' THEN 10
|
||||
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
|
||||
WHEN l.warhead_type = '高爆战斗部' THEN 7
|
||||
WHEN l.warhead_type = '破甲战斗部' THEN 6
|
||||
WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
|
||||
ELSE 4
|
||||
END,
|
||||
|
||||
-- 分类特征编码
|
||||
l.warhead_type_code =
|
||||
CASE
|
||||
WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
|
||||
WHEN l.warhead_type = '破甲战斗部' THEN 2
|
||||
WHEN l.warhead_type = '高爆战斗部' THEN 3
|
||||
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
|
||||
WHEN l.warhead_type = '模块化战斗部' THEN 5
|
||||
ELSE 0
|
||||
END,
|
||||
l.launch_mode_code =
|
||||
CASE
|
||||
WHEN l.launch_mode = '箱式发射' THEN 1
|
||||
WHEN l.launch_mode = '弹射式发射' THEN 2
|
||||
WHEN l.launch_mode = '垂直起降' THEN 3
|
||||
WHEN l.launch_mode = '单兵发射管' THEN 4
|
||||
WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
|
||||
WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
|
||||
ELSE 0
|
||||
END,
|
||||
l.power_system_code =
|
||||
CASE
|
||||
WHEN l.power_system = '电动机' THEN 1
|
||||
WHEN l.power_system = '活塞发动机' THEN 2
|
||||
ELSE 0
|
||||
END,
|
||||
l.guidance_system_code =
|
||||
CASE
|
||||
WHEN l.guidance_system = 'GPS/INS' THEN 1
|
||||
WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
|
||||
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
|
||||
ELSE 0
|
||||
END;
|
||||
@ -29,7 +29,7 @@
|
||||
*/
|
||||
|
||||
-- 中国系列火箭炮数据
|
||||
INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
INSERT INTO equipments (id, name, type, manufacturer) VALUES
|
||||
(1001, 'PCL-191', '火箭炮', '中国兵器工业集团'),
|
||||
(1002, 'PHL-03', '火箭炮', '中国兵器工业集团'),
|
||||
(1003, 'AR-3', '火箭炮', '中国航天科工'),
|
||||
@ -39,11 +39,11 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(1007, 'WS-2', '火箭炮', '中国航天科工'),
|
||||
(1008, 'WS-3', '火箭炮', '中国航天科工'),
|
||||
(1009, 'Type 63', '火箭炮', '中国兵器工业集团'),
|
||||
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯'),
|
||||
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯'),
|
||||
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯'),
|
||||
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯'),
|
||||
(1014, 'TOS-1A', '火箭炮', '俄罗斯'),
|
||||
(1010, 'BM-21 Grad', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1014, 'TOS-1A', '火箭炮', '俄罗斯 Rostec'),
|
||||
(1015, 'M142 HIMARS', '火箭炮', '美国洛克希德·马丁'),
|
||||
(1016, 'M270 MLRS', '火箭炮', '美国洛克希德·马丁'),
|
||||
(1017, 'M270A1', '火箭炮', '美国洛克希德·马丁'),
|
||||
@ -62,10 +62,10 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES
|
||||
(1030, 'ASTROS 2020', '火箭炮', '巴西航空工业'),
|
||||
(1031, 'ASTROS II Mk3', '火箭炮', '巴西航空工业'),
|
||||
(1032, 'ASTROS II Mk6', '火箭炮', '巴西航空工业'),
|
||||
(1033, 'Pinaka', '火箭炮', '印度DRDO'),
|
||||
(1034, 'Pinaka Mk-II', '火箭炮', '印度DRDO'),
|
||||
(1035, 'Pinaka Mk-III', '火箭炮', '印度DRDO'),
|
||||
(1036, 'Pinaka-ER', '火箭炮', '印度DRDO'),
|
||||
(1033, 'Pinaka', '火箭炮', '印度 DRDO'),
|
||||
(1034, 'Pinaka Mk-II', '火箭炮', '印度 DRDO'),
|
||||
(1035, 'Pinaka Mk-III', '火箭炮', '印度 DRDO'),
|
||||
(1036, 'Pinaka-ER', '火箭炮', '印度 DRDO'),
|
||||
(1037, 'WR-40 Langusta', '火箭炮', '波兰胡塔斯塔洛瓦'),
|
||||
(1038, 'RM-70', '火箭炮', '波兰胡塔斯塔洛瓦'),
|
||||
(1039, 'BM-21M', '火箭炮', '波兰胡塔斯塔洛瓦'),
|
||||
@ -485,7 +485,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
|
||||
(4, '火箭炮验证集 2024', '包含19个火箭炮型号,用于验证模型性能', '火箭炮', '验证');
|
||||
|
||||
-- 训练集(约80%的数据,77个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 中国系列(7/9)
|
||||
(3, 1001), (3, 1002), (3, 1003), (3, 1004), (3, 1005), (3, 1006), (3, 1007),
|
||||
|
||||
@ -565,7 +565,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
(3, 1094), (3, 1095);
|
||||
|
||||
-- 验证集(约20%的数据,19个型号)
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES
|
||||
-- 中国系列(2/9)
|
||||
(4, 1008), (4, 1009),
|
||||
|
||||
|
||||
415
src/routes.py
415
src/routes.py
@ -48,6 +48,11 @@ def index():
|
||||
'url': '/api/evaluate',
|
||||
'method': 'POST',
|
||||
'description': '模型评估'
|
||||
},
|
||||
'analyze-manufacturers': {
|
||||
'url': '/api/analyze-manufacturers',
|
||||
'method': 'POST',
|
||||
'description': '供应商分析'
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -114,193 +119,149 @@ def analyze_features():
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 获取数据集信息
|
||||
# 首先获取数据集的装备类型
|
||||
cursor.execute("""
|
||||
SELECT d.*,
|
||||
e.type as equipment_type
|
||||
FROM datasets d
|
||||
JOIN dataset_equipment de ON d.id = de.dataset_id
|
||||
JOIN equipment e ON de.equipment_id = e.id
|
||||
WHERE d.id = %s
|
||||
SELECT DISTINCT e.type
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
LIMIT 1
|
||||
""", (dataset_id,))
|
||||
dataset = cursor.fetchone()
|
||||
|
||||
if not dataset:
|
||||
logger.warning(f"Dataset {dataset_id} not found")
|
||||
return jsonify({'error': '数据集不存在'}), 404
|
||||
equipment_type = cursor.fetchone()['type']
|
||||
logger.info(f"Equipment type: {equipment_type}")
|
||||
|
||||
logger.info(f"Dataset info: {dataset}")
|
||||
|
||||
# 创建特征分析实例
|
||||
analyzer = FeatureAnalysis()
|
||||
|
||||
# 获取特征列表
|
||||
feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
|
||||
logger.info(f"Feature names: {feature_names}")
|
||||
|
||||
# 获取数据集中的装备数据
|
||||
if dataset['equipment_type'] == '火箭炮':
|
||||
# 根据装备类型选择查询
|
||||
if equipment_type == '火箭炮':
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
e.name,
|
||||
e.id,
|
||||
cp.length_m,
|
||||
cp.width_m,
|
||||
cp.height_m,
|
||||
cp.weight_kg,
|
||||
rap.max_range_km,
|
||||
rap.firing_angle_horizontal,
|
||||
rap.firing_angle_vertical,
|
||||
rap.rocket_length_m,
|
||||
rap.rocket_diameter_mm,
|
||||
rap.rocket_weight_kg,
|
||||
rap.rate_of_fire,
|
||||
rap.combat_weight_kg,
|
||||
rap.speed_kmh,
|
||||
rap.min_range_km,
|
||||
rap.mobility_type,
|
||||
rap.structure_layout,
|
||||
rap.engine_model,
|
||||
rap.power_hp,
|
||||
rap.travel_range_km,
|
||||
rap.fire_density,
|
||||
rap.range_ratio,
|
||||
rap.mobility_score,
|
||||
rap.combat_readiness_score,
|
||||
rap.rocket_power_ratio,
|
||||
rap.platform_efficiency,
|
||||
rap.deployment_score,
|
||||
rap.terrain_adaptability_score,
|
||||
cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id,
|
||||
m.tech_level, m.scale_level, m.supply_chain_level, m.country,
|
||||
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
|
||||
cd.actual_cost,
|
||||
rap.max_range_km, rap.firing_angle_horizontal, rap.firing_angle_vertical,
|
||||
rap.rocket_length_m, rap.rocket_diameter_mm, rap.rocket_weight_kg,
|
||||
rap.rate_of_fire, rap.combat_weight_kg, rap.speed_kmh,
|
||||
rap.min_range_km, rap.power_hp, rap.travel_range_km,
|
||||
rap.fire_density, rap.range_ratio, rap.mobility_score,
|
||||
rap.combat_readiness_score, rap.deployment_score, rap.terrain_adaptability_score,
|
||||
rap.rocket_power_ratio, rap.platform_efficiency
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
ORDER BY e.id
|
||||
""", (dataset_id,))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
e.name,
|
||||
e.id,
|
||||
cp.length_m,
|
||||
cp.width_m,
|
||||
cp.height_m,
|
||||
cp.weight_kg,
|
||||
lmp.max_range_km,
|
||||
lmp.wingspan_m,
|
||||
lmp.warhead_weight_kg,
|
||||
lmp.max_speed_ms,
|
||||
lmp.cruise_speed_kmh,
|
||||
lmp.endurance_min,
|
||||
lmp.max_payload_kg,
|
||||
lmp.ceiling_altitude_m,
|
||||
lmp.combat_radius_km,
|
||||
lmp.engine_power_kw,
|
||||
lmp.engine_thrust_n,
|
||||
lmp.datalink_range_km,
|
||||
lmp.guidance_accuracy_m,
|
||||
lmp.min_altitude_m,
|
||||
lmp.max_altitude_m,
|
||||
lmp.length_width_ratio,
|
||||
lmp.weight_range_ratio,
|
||||
lmp.speed_weight_ratio,
|
||||
lmp.guidance_system_score,
|
||||
lmp.warhead_power_score,
|
||||
cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id,
|
||||
m.tech_level, m.scale_level, m.supply_chain_level, m.country,
|
||||
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
|
||||
cd.actual_cost,
|
||||
lmp.max_range_km, lmp.wingspan_m, lmp.warhead_weight_kg,
|
||||
lmp.max_speed_ms, lmp.cruise_speed_kmh, lmp.endurance_min,
|
||||
lmp.length_width_ratio, lmp.weight_range_ratio,
|
||||
lmp.speed_weight_ratio, lmp.ceiling_altitude_m,
|
||||
lmp.guidance_system_score, lmp.warhead_power_score,
|
||||
lmp.engine_power_kw, lmp.engine_thrust_n,
|
||||
lmp.min_altitude_m, lmp.max_altitude_m,
|
||||
lmp.max_payload_kg, lmp.combat_radius_km,
|
||||
lmp.datalink_range_km, lmp.guidance_accuracy_m
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
ORDER BY e.id
|
||||
""", (dataset_id,))
|
||||
|
||||
equipment_data = cursor.fetchall()
|
||||
logger.info(f"Found {len(equipment_data)} equipment records")
|
||||
|
||||
# 提取装备名称列表
|
||||
equipment_names = [item['name'] for item in equipment_data]
|
||||
# 添加数据检查日志
|
||||
logger.info(f"Total records found: {len(equipment_data)}")
|
||||
if equipment_data:
|
||||
# 检查第一条记录的所有字段
|
||||
first_record = equipment_data[0]
|
||||
logger.info("First record details:")
|
||||
for key, value in first_record.items():
|
||||
logger.info(f"{key}: {value}")
|
||||
|
||||
# 检查所有记录的 max_range_km 字段
|
||||
logger.info("Checking max_range_km for all records:")
|
||||
for item in equipment_data:
|
||||
logger.info(f"Equipment: {item['name']}")
|
||||
logger.info(f" max_range_km: {item.get('max_range_km')}")
|
||||
logger.info(f" type: {item['type']}")
|
||||
if item['type'] == '火箭炮':
|
||||
logger.info(f" rocket_artillery_params fields:")
|
||||
for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']:
|
||||
logger.info(f" {key}: {item.get(key)}")
|
||||
|
||||
# 提取特征数据和目标值
|
||||
# 提取特征和目标值
|
||||
analyzer = FeatureAnalysis()
|
||||
feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type'])
|
||||
features = []
|
||||
targets = []
|
||||
|
||||
for item in equipment_data:
|
||||
# 计算生产商特征
|
||||
manufacturer_features = analyzer.calculate_manufacturer_features({
|
||||
'tech_level': item['tech_level'],
|
||||
'scale_level': item['scale_level'],
|
||||
'supply_chain_level': item['supply_chain_level'],
|
||||
'country': item['country']
|
||||
})
|
||||
|
||||
# 获取装备特征
|
||||
feature_values = []
|
||||
for feature in feature_names:
|
||||
value = item.get(feature)
|
||||
feature_values.append(float(value) if value is not None else 0)
|
||||
for name in feature_names:
|
||||
if name in manufacturer_features:
|
||||
value = manufacturer_features[name]
|
||||
else:
|
||||
value = item.get(name)
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
|
||||
features.append(feature_values)
|
||||
targets.append(float(item['actual_cost']))
|
||||
|
||||
# 进行特征分析
|
||||
result = analyzer.analyze_features(features, targets, feature_names)
|
||||
# 执行特征分析
|
||||
analysis_result = analyzer.analyze_features(features, targets, feature_names)
|
||||
|
||||
# 添加装备名称列表到结果中
|
||||
result['equipment_names'] = equipment_names
|
||||
# 添加装备名称列表
|
||||
analysis_result['equipment_names'] = [item['name'] for item in equipment_data]
|
||||
|
||||
# 如果是火箭炮,添加额外的分析数据
|
||||
if dataset['equipment_type'] == '火箭炮':
|
||||
# 添加装备特有的分析数据
|
||||
if equipment_data[0]['type'] == '火箭炮':
|
||||
rocket_data = {
|
||||
'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
|
||||
'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
|
||||
'rate_of_fire': [float(item['rate_of_fire']) if item['rate_of_fire'] is not None else 0 for item in equipment_data],
|
||||
'max_range_km': [float(item['max_range_km']) if item['max_range_km'] is not None else 0 for item in equipment_data],
|
||||
'rocket_weight_kg': [float(item['rocket_weight_kg']) if item['rocket_weight_kg'] is not None else 0 for item in equipment_data],
|
||||
'rocket_diameter_mm': [float(item['rocket_diameter_mm']) if item['rocket_diameter_mm'] is not None else 0 for item in equipment_data],
|
||||
'rocket_length_m': [float(item['rocket_length_m']) if item['rocket_length_m'] is not None else 0 for item in equipment_data],
|
||||
'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
|
||||
'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
|
||||
'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data],
|
||||
'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
|
||||
'speed_kmh': [float(item['speed_kmh']) if item['speed_kmh'] is not None else 0 for item in equipment_data],
|
||||
'power_hp': [float(item['power_hp']) if item['power_hp'] is not None else 0 for item in equipment_data],
|
||||
'travel_range_km': [float(item['travel_range_km']) if item['travel_range_km'] is not None else 0 for item in equipment_data]
|
||||
'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data],
|
||||
'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data],
|
||||
'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data],
|
||||
'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data],
|
||||
'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data],
|
||||
'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data]
|
||||
}
|
||||
result.update(rocket_data)
|
||||
|
||||
# 如果是巡飞弹,添加额外的分析数据
|
||||
if dataset['equipment_type'] == '巡飞弹':
|
||||
analysis_result.update(rocket_data)
|
||||
else:
|
||||
missile_data = {
|
||||
'equipment_names': equipment_names,
|
||||
# 特征工程参数
|
||||
'length_width_ratio': [float(item['length_width_ratio']) if item.get('length_width_ratio') is not None else 0 for item in equipment_data],
|
||||
'weight_range_ratio': [float(item['weight_range_ratio']) if item.get('weight_range_ratio') is not None else 0 for item in equipment_data],
|
||||
'speed_weight_ratio': [float(item['speed_weight_ratio']) if item.get('speed_weight_ratio') is not None else 0 for item in equipment_data],
|
||||
'guidance_system_score': [float(item['guidance_system_score']) if item.get('guidance_system_score') is not None else 0 for item in equipment_data],
|
||||
'warhead_power_score': [float(item['warhead_power_score']) if item.get('warhead_power_score') is not None else 0 for item in equipment_data],
|
||||
|
||||
# 动力系统参数
|
||||
'engine_power_kw': [float(item['engine_power_kw']) if item.get('engine_power_kw') is not None else 0 for item in equipment_data],
|
||||
'engine_thrust_n': [float(item['engine_thrust_n']) if item.get('engine_thrust_n') is not None else 0 for item in equipment_data],
|
||||
|
||||
# 作战参数
|
||||
'min_altitude_m': [float(item['min_altitude_m']) if item.get('min_altitude_m') is not None else 0 for item in equipment_data],
|
||||
'max_altitude_m': [float(item['max_altitude_m']) if item.get('max_altitude_m') is not None else 0 for item in equipment_data],
|
||||
'max_range_km': [float(item['max_range_km']) if item.get('max_range_km') is not None else 0 for item in equipment_data],
|
||||
'max_payload_kg': [float(item['max_payload_kg']) if item.get('max_payload_kg') is not None else 0 for item in equipment_data],
|
||||
'combat_radius_km': [float(item['combat_radius_km']) if item.get('combat_radius_km') is not None else 0 for item in equipment_data],
|
||||
'datalink_range_km': [float(item['datalink_range_km']) if item.get('datalink_range_km') is not None else 0 for item in equipment_data],
|
||||
'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item.get('guidance_accuracy_m') is not None else 0 for item in equipment_data]
|
||||
'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data],
|
||||
'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data],
|
||||
'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data],
|
||||
'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data],
|
||||
'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data],
|
||||
'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data],
|
||||
'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data],
|
||||
'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data],
|
||||
'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data],
|
||||
'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data],
|
||||
'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data]
|
||||
}
|
||||
|
||||
# 验证数据完整性
|
||||
for key, value in missile_data.items():
|
||||
logger.info(f"{key} data length: {len(value)}")
|
||||
logger.info(f"{key} sample data: {value[:3]}")
|
||||
if not any(value): # 检查是否所有值都为0
|
||||
logger.warning(f"All values are 0 for {key}")
|
||||
|
||||
# 更新结果
|
||||
result.update(missile_data)
|
||||
analysis_result.update(missile_data)
|
||||
|
||||
return jsonify(result)
|
||||
return jsonify(analysis_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing features: {str(e)}")
|
||||
@ -332,8 +293,8 @@ def train_model():
|
||||
if equipment_type == '火箭炮':
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
@ -343,8 +304,8 @@ def train_model():
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
@ -360,8 +321,8 @@ def train_model():
|
||||
if equipment_type == '火箭炮':
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
@ -371,8 +332,8 @@ def train_model():
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
@ -500,7 +461,7 @@ def get_equipment_data():
|
||||
|
||||
# 获取所有装备数据(使用equipment_id替代id)
|
||||
cursor.execute("""
|
||||
SELECT e.id as equipment_id, e.name, e.type,
|
||||
SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
|
||||
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
|
||||
cd.actual_cost, cd.predicted_cost,
|
||||
CASE
|
||||
@ -546,7 +507,7 @@ def get_equipment_data():
|
||||
WHERE equipment_id = e.id
|
||||
)
|
||||
END as specific_params
|
||||
FROM equipment e
|
||||
FROM equipments e
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
ORDER BY e.id
|
||||
@ -576,7 +537,7 @@ def delete_equipment(id):
|
||||
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
|
||||
cursor.execute("DELETE FROM equipments WHERE id = %s", (id,))
|
||||
|
||||
db.commit()
|
||||
cursor.close()
|
||||
@ -737,7 +698,7 @@ def update_equipment(id):
|
||||
|
||||
# 更新装备基本信息
|
||||
cursor.execute("""
|
||||
UPDATE equipment
|
||||
UPDATE equipments
|
||||
SET name = %s, manufacturer = %s
|
||||
WHERE id = %s
|
||||
""", (data['name'], data['manufacturer'], equipment_id))
|
||||
@ -845,7 +806,7 @@ def get_equipment_details(id):
|
||||
# 获取装备基本信息类型
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, cd.actual_cost, cd.predicted_cost
|
||||
FROM equipment e
|
||||
FROM equipments e
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE e.id = %s
|
||||
@ -854,7 +815,7 @@ def get_equipment_details(id):
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Equipment with ID {id} not found")
|
||||
return jsonify({'error': '装备不存在'}), 404
|
||||
return jsonify({'error': '装备不在'}), 404
|
||||
|
||||
logger.info(f"Equipment type: {result['type']}")
|
||||
logger.info(f"Found equipment details: {result['name']}")
|
||||
@ -901,8 +862,8 @@ def get_datasets():
|
||||
COUNT(de.equipment_id) as equipment_count,
|
||||
GROUP_CONCAT(e.name) as equipment_names
|
||||
FROM datasets d
|
||||
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
|
||||
LEFT JOIN equipment e ON de.equipment_id = e.id
|
||||
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
|
||||
LEFT JOIN equipments e ON de.equipment_id = e.id
|
||||
GROUP BY d.id
|
||||
""")
|
||||
datasets = cursor.fetchall()
|
||||
@ -932,7 +893,7 @@ def get_dataset(id):
|
||||
SELECT d.*,
|
||||
COUNT(de.equipment_id) as equipment_count
|
||||
FROM datasets d
|
||||
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
|
||||
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
|
||||
WHERE d.id = %s
|
||||
GROUP BY d.id
|
||||
""", (id,))
|
||||
@ -945,14 +906,14 @@ def get_dataset(id):
|
||||
cursor.execute("""
|
||||
SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
|
||||
cd.actual_cost
|
||||
FROM equipment e
|
||||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
""", (id,))
|
||||
equipment = cursor.fetchall()
|
||||
|
||||
# 计算统计信息
|
||||
# 算统计信
|
||||
if equipment:
|
||||
total_cost = sum(item['actual_cost'] or 0 for item in equipment)
|
||||
avg_cost = total_cost / len(equipment)
|
||||
@ -989,7 +950,7 @@ def create_dataset():
|
||||
# 直接从 equipment 表查询,不需要 JOIN
|
||||
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
|
||||
cursor.execute(f"""
|
||||
SELECT DISTINCT id FROM equipment
|
||||
SELECT DISTINCT id FROM equipments
|
||||
WHERE id IN ({equipment_ids_str}) AND type = %s
|
||||
""", (data['equipment_type'],))
|
||||
|
||||
@ -1015,7 +976,7 @@ def create_dataset():
|
||||
if 'equipment_ids' in data and data['equipment_ids']:
|
||||
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
|
||||
cursor.executemany("""
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id)
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
||||
VALUES (%s, %s)
|
||||
""", values)
|
||||
logger.info(f"Added {len(values)} equipment associations")
|
||||
@ -1041,7 +1002,7 @@ def update_dataset(id):
|
||||
if 'equipment_ids' in data:
|
||||
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
|
||||
cursor.execute(f"""
|
||||
SELECT id FROM equipment
|
||||
SELECT id FROM equipments
|
||||
WHERE id IN ({equipment_ids_str}) AND type = %s
|
||||
""", (data['equipment_type'],))
|
||||
|
||||
@ -1064,13 +1025,13 @@ def update_dataset(id):
|
||||
# 3. 更新装备关联
|
||||
if 'equipment_ids' in data:
|
||||
# 先删除旧的关联
|
||||
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
|
||||
|
||||
# 添加新的关联
|
||||
if valid_ids: # 确保有有效的ID才执行插入
|
||||
values = [(id, equipment_id) for equipment_id in valid_ids]
|
||||
cursor.executemany("""
|
||||
INSERT INTO dataset_equipment (dataset_id, equipment_id)
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
||||
VALUES (%s, %s)
|
||||
""", values)
|
||||
logger.info(f"Updated {len(values)} equipment associations")
|
||||
@ -1092,7 +1053,7 @@ def delete_dataset(id):
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 删除装备关联
|
||||
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
|
||||
|
||||
# 删除数据集
|
||||
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
|
||||
@ -1139,7 +1100,7 @@ def get_models():
|
||||
|
||||
models = cursor.fetchall()
|
||||
|
||||
# 确保数值类型字段是 float
|
||||
# 确保数值型字段是 float
|
||||
for model in models:
|
||||
if model['r2_score'] is not None:
|
||||
model['r2_score'] = float(model['r2_score'])
|
||||
@ -1161,7 +1122,7 @@ def get_models():
|
||||
@api_bp.route('/models/<int:id>/activate', methods=['POST'])
|
||||
def activate_model(id):
|
||||
"""
|
||||
激活指定的模型
|
||||
激活定的模型
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
@ -1250,4 +1211,104 @@ def predict_all():
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction: {str(e)}")
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@api_bp.route('/analyze-manufacturers', methods=['POST'])
|
||||
def analyze_manufacturers():
|
||||
"""分析生产商数据"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
dataset_id = data.get('dataset_id')
|
||||
|
||||
logger.info(f"Starting manufacturer analysis for dataset {dataset_id}")
|
||||
|
||||
if not dataset_id:
|
||||
logger.warning("No dataset_id provided")
|
||||
return jsonify({'error': '请选择数据集'}), 400
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 获取数据集中的装备和生产商数据
|
||||
cursor.execute("""
|
||||
SELECT DISTINCT m.*, e.type as equipment_type
|
||||
FROM manufacturers m
|
||||
JOIN equipments e ON e.manufacturer_id = m.id
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
""", (dataset_id,))
|
||||
|
||||
manufacturers = cursor.fetchall()
|
||||
|
||||
if not manufacturers:
|
||||
return jsonify({'error': '数据集中没有生产商数据'}), 404
|
||||
|
||||
# 准备分析数据
|
||||
manufacturer_names = []
|
||||
tech_levels = []
|
||||
scale_levels = []
|
||||
supply_chain_levels = []
|
||||
composite_scores = []
|
||||
region_count = {}
|
||||
manufacturer_scores = []
|
||||
|
||||
for manu in manufacturers:
|
||||
manufacturer_names.append(manu['name'])
|
||||
tech_levels.append(manu['tech_level'])
|
||||
scale_levels.append(manu['scale_level'])
|
||||
supply_chain_levels.append(manu['supply_chain_level'])
|
||||
|
||||
# 计算综合得分
|
||||
composite_score = (
|
||||
manu['tech_level'] * 0.4 +
|
||||
manu['scale_level'] * 0.3 +
|
||||
manu['supply_chain_level'] * 0.3
|
||||
)
|
||||
composite_scores.append(composite_score)
|
||||
|
||||
# 统计地区分布
|
||||
region_count[manu['country']] = region_count.get(manu['country'], 0) + 1
|
||||
|
||||
# 计算区域系数
|
||||
region_factors = {
|
||||
'美国': 1.2, '英国': 1.15, '德国': 1.15,
|
||||
'法国': 1.15, '以色列': 1.1, '中国': 0.8,
|
||||
'俄罗斯': 0.85, '韩国': 0.9, '日本': 1.1
|
||||
}
|
||||
region_factor = region_factors.get(manu['country'], 1.0)
|
||||
|
||||
# 添加雷达图数据
|
||||
manufacturer_scores.append({
|
||||
'name': manu['name'],
|
||||
'value': [
|
||||
manu['tech_level'],
|
||||
manu['scale_level'],
|
||||
manu['supply_chain_level'],
|
||||
region_factor,
|
||||
composite_score
|
||||
]
|
||||
})
|
||||
|
||||
# 准备地区分布数据
|
||||
region_distribution = [
|
||||
{'name': country, 'value': count}
|
||||
for country, count in region_count.items()
|
||||
]
|
||||
|
||||
# 返回分析结果
|
||||
result = {
|
||||
'manufacturer_names': manufacturer_names,
|
||||
'manufacturer_tech_levels': tech_levels,
|
||||
'manufacturer_scale_levels': scale_levels,
|
||||
'manufacturer_supply_chain_levels': supply_chain_levels,
|
||||
'manufacturer_composite_scores': composite_scores,
|
||||
'region_distribution': region_distribution,
|
||||
'manufacturer_scores': manufacturer_scores
|
||||
}
|
||||
|
||||
return jsonify(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing manufacturers: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
@ -10,11 +10,12 @@ COLLATE utf8mb4_unicode_ci;
|
||||
USE equipment_cost_db;
|
||||
|
||||
-- 装备基本信息表
|
||||
CREATE TABLE equipment (
|
||||
CREATE TABLE equipments (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100), -- 名称
|
||||
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
|
||||
manufacturer VARCHAR(100), -- 制造商
|
||||
manufacturer_id INT, -- 制造商ID
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
@ -26,8 +27,7 @@ CREATE TABLE common_params (
|
||||
width_m FLOAT, -- 宽度(m)
|
||||
height_m FLOAT, -- 高度(m)
|
||||
weight_kg FLOAT, -- 重量(kg)
|
||||
max_range_km FLOAT, -- 最大射程(km)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 火箭炮特有参数表
|
||||
@ -61,7 +61,7 @@ CREATE TABLE rocket_artillery_params (
|
||||
deployment_score INT, -- 部署评分(1-10)
|
||||
terrain_adaptability_score INT, -- 地形适应性评分(1-10)
|
||||
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 巡飞弹特有参数表
|
||||
@ -103,7 +103,7 @@ CREATE TABLE loitering_munition_params (
|
||||
power_system_code INT, -- 动力装置编码
|
||||
guidance_system_code INT, -- 制导系统编码
|
||||
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 分类特征编码表
|
||||
@ -122,7 +122,7 @@ CREATE TABLE cost_data (
|
||||
actual_cost DECIMAL(15,2), -- 实际成本(元)
|
||||
predicted_cost DECIMAL(15,2), -- 预测成本(元)
|
||||
prediction_date TIMESTAMP, -- 预测日期
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 特殊参数表
|
||||
@ -133,12 +133,12 @@ CREATE TABLE custom_params (
|
||||
param_value VARCHAR(255), -- 参数值
|
||||
param_unit VARCHAR(50), -- 参数单位
|
||||
description TEXT, -- 参数说明
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_equipment_type ON equipment(type);
|
||||
CREATE INDEX idx_equipment_name ON equipment(name);
|
||||
CREATE INDEX idx_equipment_type ON equipments(type);
|
||||
CREATE INDEX idx_equipment_name ON equipments(name);
|
||||
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
|
||||
|
||||
-- 数据集表
|
||||
@ -153,12 +153,12 @@ CREATE TABLE datasets (
|
||||
);
|
||||
|
||||
-- 数据集-装备关联表
|
||||
CREATE TABLE dataset_equipment (
|
||||
CREATE TABLE dataset_equipments (
|
||||
dataset_id INT NOT NULL,
|
||||
equipment_id INT NOT NULL,
|
||||
PRIMARY KEY (dataset_id, equipment_id),
|
||||
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
-- 训练模型表
|
||||
@ -175,10 +175,34 @@ CREATE TABLE trained_models (
|
||||
feature_importance JSON, -- 特征重要性
|
||||
training_data_size INT, -- 训练数据量
|
||||
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
|
||||
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
|
||||
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型
|
||||
created_by VARCHAR(50) -- 创建者
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
|
||||
CREATE INDEX idx_model_active ON trained_models(is_active);
|
||||
CREATE INDEX idx_model_active ON trained_models(is_active);
|
||||
|
||||
-- 生产商表
|
||||
CREATE TABLE manufacturers (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL, -- 生产商名称
|
||||
country VARCHAR(50) NOT NULL, -- 所属国家
|
||||
tech_level INT NOT NULL, -- 技术水平评分(1-10)
|
||||
scale_level INT NOT NULL, -- 规模评分(1-10)
|
||||
supply_chain_level INT NOT NULL, -- 供应链成熟度评分(1-10)
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY unique_name (name)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加生产商外键
|
||||
ALTER TABLE equipments ADD FOREIGN KEY (manufacturer_id) REFERENCES manufacturers(id);
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_manufacturer_country ON manufacturers(country);
|
||||
CREATE INDEX idx_manufacturer_tech_level ON manufacturers(tech_level);
|
||||
CREATE INDEX idx_manufacturer_scale_level ON manufacturers(scale_level);
|
||||
CREATE INDEX idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
|
||||
CREATE INDEX idx_equipment_manufacturer ON equipments(manufacturer_id);
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user