重新搜集整理巡飞弹的数据和装备信息

This commit is contained in:
Tian jianyong 2024-11-12 10:59:45 +08:00
parent 30d4b58cdf
commit f8c1ed7560
54 changed files with 1288 additions and 8282 deletions

Binary file not shown.

View File

@ -1,25 +0,0 @@
# 数据库配置
MYSQL_HOST=localhost
MYSQL_USER=root
MYSQL_PASSWORD=your_password_here
MYSQL_DATABASE=equipment_cost_db
# 服务配置
PORT=5001
DEBUG=False
# 日志配置
LOG_LEVEL=INFO
LOG_DIR=logs
# 模型配置
MODEL_DIR=models
DATA_DIR=data
# 安全配置
SECRET_KEY=your_secret_key_here
ALLOWED_HOSTS=localhost,127.0.0.1
# 其他配置
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
ALLOWED_FILE_TYPES=.xlsx,.xls

View File

@ -1,663 +0,0 @@
# 装备成本估算系统 API 文档
这个 API 文档提供了完整的接口说明,包括:
- 每个端点的详细描述
- 请求和响应的具体示例
- 清晰的参数格式要求
- 统一的错误处理说明
- 重要的注意事项
文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。
## 基本信息
- 基础URL: `http://localhost:5001/api`
- 版本: 1.0.0
- 响应格式: JSON
## API 端点列表
### 1. 获取 API 信息
获取 API 版本信息和可用端点列表。
- **URL**: `/`
- **方法**: `GET`
- **响应示例**:
json
{
"name": "装备成本估算系统 API",
"version": "1.0.0",
"endpoints": {
"predict": {
"url": "/api/predict",
"method": "POST",
"description": "成本预测"
},
"analyze-features": {
"url": "/api/analyze-features",
"method": "POST",
"description": "特征分析"
},
"train": {
"url": "/api/train",
"method": "POST",
"description": "模型训练"
},
"evaluate": {
"url": "/api/evaluate",
"method": "POST",
"description": "模型评估"
}
}
}
### 2. 单模型预测
使用当前激活的最优模型进行成本预测。
- **URL**: `/predict`
- **方法**: `POST`
- **请求体示例** (巡飞弹):
```json
{
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
```
- **响应示例**:
```json
{
"predicted_cost": 150000.0,
"model_info": {
"type": "xgboost",
"name": "巡飞弹_20241111_model",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"confidence_interval": {
"lower": 135000.0,
"upper": 165000.0
}
}
```
### 3. PLS 模型预测
使用 PLS 回归模型进行预测。
- **URL**: `/pls/predict`
- **方法**: `POST`
- **请求体**: 与单模型预测相同
- **响应示例**:
```json
{
"predicted_cost": 148000.0,
"confidence_interval": {
"lower": 133000.0,
"upper": 163000.0
}
}
```
### 4. 多模型预测
使用所有激活的模型进行预测并返回综合结果。
- **URL**: `/predict/all`
- **方法**: `POST`
- **请求体**: 与单模型预测相同
- **响应示例**:
```json
{
"individual_predictions": {
"xgboost": {
"predicted_cost": 150000.0,
"model_info": {
"name": "巡飞弹_xgboost_model",
"type": "xgboost",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"confidence_interval": {
"lower": 135000.0,
"upper": 165000.0
}
},
"pls": {
"predicted_cost": 148000.0,
"model_info": {
"name": "巡飞弹_pls_model",
"type": "pls",
"r2_score": 0.92,
"mae": 5500.0,
"rmse": 8000.0
},
"confidence_interval": {
"lower": 133000.0,
"upper": 163000.0
}
}
},
"ensemble_prediction": {
"predicted_cost": 149000.0,
"standard_deviation": 1414.21,
"confidence_interval": {
"lower": 146228.15,
"upper": 151771.85
}
}
}
```
### 5. 特征分析
分析数据集中特征的重要性和相关性。
- **URL**: `/analyze-features`
- **方法**: `POST`
- **请求体示例**:
```json
{
"dataset_id": 1,
"equipment_type": "巡飞弹"
}
```
- **响应示例**:
```json
{
"important_features": [
{
"name": "最大射程(km)",
"importance": 0.35
},
{
"name": "重量(kg)",
"importance": 0.25
}
],
"correlation_analysis": {
"features": ["最大射程(km)", "重量(kg)"],
"matrix": [[1.0, 0.8], [0.8, 1.0]]
}
}
```
### 6. 模型训练
训练新的模型。
- **URL**: `/train`
- **方法**: `POST`
- **请求体示例**:
```json
{
"type": "巡飞弹",
"train_dataset_id": 1,
"validation_dataset_id": 2,
"models": ["xgboost", "lightgbm", "rf"]
}
```
- **响应示例**:
```json
{
"metrics": {
"xgboost": {
"train": {
"r2": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"validation": {
"r2": 0.92,
"mae": 5500.0,
"rmse": 8000.0
}
}
},
"best_model": {
"type": "xgboost",
"r2": 0.92,
"mae": 5500.0,
"rmse": 8000.0
}
}
```
### 7. 数据集管理
#### 7.1 获取数据集列表
- **URL**: `/datasets`
- **方法**: `GET`
- **响应示例**:
```json
[
{
"id": 1,
"name": "训练数据集",
"description": "用于训练的数据集",
"equipment_type": "巡飞弹",
"equipment_count": 10,
"equipment_names": ["设备1", "设备2"],
"purpose": "训练",
"created_at": "2024-11-11T10:00:00"
}
]
```
#### 7.2 获取数据集详情
- **URL**: `/datasets/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 1,
"name": "训练数据集",
"description": "用于训练的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"created_at": "2024-11-11T10:00:00",
"equipment": [
{
"id": 1,
"name": "设备1",
"type": "巡飞弹",
"manufacturer": "制造商1",
"actual_cost": 150000
}
],
"statistics": {
"equipment_count": 10,
"total_cost": 1500000,
"average_cost": 150000
}
}
```
#### 7.3 创建数据集
- **URL**: `/datasets`
- **方法**: `POST`
- **请求体示例**:
```json
{
"name": "测试数据集",
"description": "用于测试的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"equipment_ids": [1, 2, 3]
}
```
- **响应示例**:
```json
{
"id": 2,
"message": "数据集创建成功"
}
```
#### 7.4 更新数据集
- **URL**: `/datasets/{id}`
- **方法**: `PUT`
- **请求体示例**:
```json
{
"name": "更新后的数据集名称",
"description": "更新后的描述",
"equipment_type": "巡飞弹",
"purpose": "验证",
"equipment_ids": [1, 2, 3, 4]
}
```
- **响应示例**:
```json
{
"success": true,
"message": "数据集更新成功"
}
```
#### 7.5 删除数据集
- **URL**: `/datasets/{id}`
- **方法**: `DELETE`
- **描述**: 删除指定的数据集及其关联关系
- **响应示例**:
```json
{
"success": true,
"message": "数据集删除成功"
}
```
注意事项:
1. 数据集删除后不会删除关联的装备数据
2. 不能删除正在被模型使用的数据集
3. 更新数据集时会重新计算统计信息
4. 数据集的装备类型一旦创建后不能更改
### 8. 模型管理
#### 8.1 获取模型列表
- **URL**: `/models`
- **方法**: `GET`
- **响应示例**:
```json
[
{
"id": 1,
"model_name": "巡飞弹_xgboost_model",
"model_type": "xgboost",
"equipment_type": "巡飞弹",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0,
"is_active": true,
"training_date": "2024-11-11T10:00:00"
}
]
```
#### 8.2 获取最新模型
- **URL**: `/models/{equipment_type}/latest`
- **方法**: `GET`
- **响应示例**: 与模型列表的单个模型格式相同
#### 8.3 获取模型详情
- **URL**: `/models/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 1,
"model_name": "巡飞弹_xgboost_model",
"model_type": "xgboost",
"equipment_type": "巡飞弹",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0,
"is_active": true,
"training_date": "2024-11-11T10:00:00",
"feature_importance": {
"max_range_km": 0.35,
"weight_kg": 0.25,
"length_m": 0.20
},
"training_data_size": 100,
"created_by": "system"
}
```
#### 8.4 激活模型
- **URL**: `/models/{id}/activate`
- **方法**: `POST`
- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态
- **响应示例**:
```json
{
"success": true,
"message": "模型已激活"
}
```
#### 8.5 删除模型
- **URL**: `/models/{id}`
- **方法**: `DELETE`
- **描述**: 删除指定模型,包括模型文件和数据库记录
- **响应示例**:
```json
{
"success": true,
"message": "模型已删除"
}
```
注意事项:
1. 删除模型时会同时删除相关的文件和数据库记录
2. 不能删除当前正在使用(已激活)的模型
3. 激活模型时会自动取消同类型其他模型的激活状态
4. 模型详情包含了更多的训练相关信息,如特征重要性等
### 9. 数据管理
#### 9.1 获取装备数据列表
- **URL**: `/data`
- **方法**: `GET`
- **响应示例**:
```json
{
"rocket_artillery": [
{
"id": 1,
"name": "BM-21",
"type": "火箭炮",
"manufacturer": "俄罗斯",
"length_m": 7.35,
"width_m": 2.4,
"height_m": 3.1,
"weight_kg": 13700,
"max_range_km": 20.4,
"actual_cost": 800000
}
],
"loitering_munition": [
{
"id": 8,
"name": "Hero-120",
"type": "巡飞弹",
"manufacturer": "以色列",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"actual_cost": 150000
}
]
}
```
#### 9.2 获取装备详情
- **URL**: `/data/details/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 8,
"name": "Hero-120",
"type": "巡飞弹",
"manufacturer": "以色列",
"common_params": {
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40
},
"specific_params": {
"wingspan_m": 2.1,
"warhead_weight_kg": 3.5,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "箱式发射",
"power_system": "电动机",
"guidance_system": "GPS/INS"
},
"cost_data": {
"actual_cost": 150000,
"prediction_date": "2024-11-11T10:00:00",
"predicted_cost": 148000
},
"custom_params": [
{
"id": 1,
"param_name": "续航时间",
"param_value": "2小时",
"param_unit": "小时",
"description": "最大续航时间"
}
]
}
```
#### 9.3 更新装备数据
- **URL**: `/data/{id}`
- **方法**: `PUT`
- **请求体示例**:
```json
{
"name": "Hero-120",
"manufacturer": "以色列",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"wingspan_m": 2.1,
"warhead_weight_kg": 3.5,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"actual_cost": 150000,
"custom_params": [
{
"id": 1,
"param_value": "2.5小时"
}
]
}
```
- **响应示例**:
```json
{
"success": true,
"message": "装备数据更新成功"
}
```
#### 9.4 删除装备数据
- **URL**: `/data/{id}`
- **方法**: `DELETE`
- **响应示例**:
```json
{
"success": true,
"message": "装备数据删除成功"
}
```
#### 9.5 下载数据模板
- **URL**: `/data/template`
- **方法**: `GET`
- **描述**: 下载Excel格式的数据导入模板
- **响应**: Excel文件下载
#### 9.6 导入数据
- **URL**: `/data/import`
- **方法**: `POST`
- **请求体**:
- Content-Type: multipart/form-data
- 参数名: file
- 文件类型: .xlsx 或 .xls
- **响应示例**:
```json
{
"success": true,
"message": "数据导入成功",
"imported_count": {
"rocket_artillery": 3,
"loitering_munition": 5
}
}
```
注意事项:
1. 导入数据时必须使用系统提供的模板
2. 更新装备数据时会同时更新关联的参数表
3. 删除装备数据会同时删除相关的参数和成本数据
4. 导入的Excel文件大小不应超过10MB
5. 所有数值字段必须符合指定的单位和范围要求
6. 特殊参数的值必须包含单位信息
## 错误响应
所有接口在发生错误时都会返回以下格式的响应:
```json
{
"error": "错误描述信息"
}
```
## 注意事项
1. 所有数值参数必须大于0
2. 所有单位必须按照参数名称中指定的单位提供
3. 预测结果中的成本单位为元
4. 置信区间表示预测结果的95%置信水平范围
5. 所有请求和响应的编码均为 UTF-8

View File

@ -1,120 +0,0 @@
# 装备成本估算系统部署指南
## 一、系统要求
### 1. 基础软件
- Linux 操作系统 (推荐 Ubuntu 20.04+)
- Python 3.8+ 及相关组件
```bash
sudo apt update
sudo apt install python3 python3-pip python3-venv
sudo apt install python3-dev build-essential
```
- Node.js 14+ 及 npm
```bash
# 使用 nvm 安装 Node.js
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
source ~/.bashrc
nvm install 14
nvm use 14
```
### 2. 数据库
- MySQL 8.0+
```bash
sudo apt install mysql-server mysql-client
sudo apt install libmysqlclient-dev
```
### 3. Python包依赖
```bash
# 科学计算相关
sudo apt install libatlas-base-dev # numpy依赖
sudo apt install libopenblas-dev # 线性代数库
sudo apt install liblapack-dev # 线性代数包
sudo apt install gfortran # Fortran编译器(scipy依赖)
# XML处理相关(用于Excel文件处理)
sudo apt install libxml2-dev
sudo apt install libxslt1-dev
```
## 二、部署运行
### 1. 安装服务
```bash
sh scripts/install.sh
```
### 2. 启动服务
```bash
sh scripts/start.sh
```
### 3. 停止服务
```bash
sh scripts/stop.sh
```
## 三、维护说明
### 1. 日志管理
```bash
# 后端日志
tail -f logs/api.log
# 数据库日志
tail -f /var/log/mysql/error.log
```
## 四、安全建议
1. 系统安全
- 使用防火墙限制端口访问
- 定期更新系统和依赖包
2. 数据安全
- 定期备份数据库
- 加密敏感信息
- 限制数据库远程访问
3. 访问控制
- 使用强密码
- 配置适当的文件权限
- 使用非root用户运行服务
## 五、监控方案
### 1. 系统监控
```bash
# 资源使用
top -b -n 1
df -h
free -m
# 服务状态
ps aux | grep gunicorn
ps aux | grep node
```
### 2. 应用监控
```bash
# API 响应时间
curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
# 错误日志
grep "ERROR" logs/api.log
```

View File

@ -1,5 +0,0 @@
module.exports = {
presets: [
'@vue/cli-plugin-babel/preset'
]
}

View File

@ -1,61 +0,0 @@
{
"name": "frontend",
"version": "0.1.0",
"private": true,
"engines": {
"node": ">=16",
"npm": ">=8"
},
"scripts": {
"serve": "vue-cli-service serve",
"build": "vue-cli-service build",
"lint": "vue-cli-service lint"
},
"dependencies": {
"axios": "^1.6.0",
"core-js": "^3.8.3",
"echarts": "^5.4.3",
"element-plus": "^2.4.2",
"vue": "^3.2.13",
"vue-router": "^4.0.3",
"vuex": "^4.0.0"
},
"devDependencies": {
"@babel/core": "^7.12.16",
"@babel/eslint-parser": "^7.12.16",
"@element-plus/icons-vue": "^2.3.1",
"@vue/cli-plugin-babel": "~5.0.0",
"@vue/cli-plugin-eslint": "~5.0.0",
"@vue/cli-plugin-router": "~5.0.0",
"@vue/cli-plugin-vuex": "~5.0.0",
"@vue/cli-service": "~5.0.0",
"@vue/compiler-sfc": "^3.2.13",
"eslint": "^7.32.0",
"eslint-plugin-vue": "^8.0.3",
"sass": "^1.32.7",
"sass-loader": "^12.0.0"
},
"eslintConfig": {
"root": true,
"env": {
"node": true
},
"extends": [
"plugin:vue/vue3-essential",
"eslint:recommended"
],
"parserOptions": {
"parser": "@babel/eslint-parser"
},
"rules": {
"vue/multi-word-component-names": "off",
"no-unused-vars": "warn"
}
},
"browserslist": [
"> 1%",
"last 2 versions",
"not dead",
"not ie 11"
]
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 KiB

View File

@ -1,17 +0,0 @@
<!DOCTYPE html>
<html lang="">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
<title><%= htmlWebpackPlugin.options.title %></title>
</head>
<body>
<noscript>
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
</noscript>
<div id="app"></div>
<!-- built files will be auto injected -->
</body>
</html>

View File

@ -1,42 +0,0 @@
<template>
<el-container>
<el-header>
<el-menu
:router="true"
mode="horizontal"
:default-active="$route.path"
>
<el-menu-item index="/">首页</el-menu-item>
<el-menu-item index="/predict">成本预测</el-menu-item>
<el-menu-item index="/analysis">特征分析</el-menu-item>
<el-menu-item index="/training">模型训练</el-menu-item>
<el-menu-item index="/models">模型管理</el-menu-item>
<el-menu-item index="/datasets">数据集管理</el-menu-item>
<el-menu-item index="/data">数据管理</el-menu-item>
</el-menu>
</el-header>
<el-main>
<router-view v-slot="{ Component }">
<keep-alive>
<component :is="Component" :key="$route.fullPath" />
</keep-alive>
</router-view>
</el-main>
</el-container>
</template>
<style lang="scss" scoped>
.el-header {
padding: 0;
.el-menu {
border-bottom: none;
}
}
.el-main {
background-color: #f5f7fa;
min-height: calc(100vh - 60px);
}
</style>

View File

@ -1,43 +0,0 @@
import axios from 'axios'
import { API_BASE_URL } from '@/config'
const api = axios.create({
baseURL: API_BASE_URL,
timeout: 10000
})
export const predict = (data) => {
return api.post('/predict', data)
}
export const analyzeFeatures = (data) => {
return api.post('/analyze-features', data)
}
export const trainModel = (data) => {
return api.post('/train', data)
}
export const evaluateModel = (data) => {
return api.post('/evaluate', data)
}
export const importData = (formData) => {
return api.post('/data/import', formData, {
headers: {
'Content-Type': 'multipart/form-data'
}
})
}
export const getEquipmentData = () => {
return api.get('/data')
}
export const updateEquipment = (id, data) => {
return api.put(`/data/${id}`, data)
}
export const deleteEquipment = (id) => {
return api.delete(`/data/${id}`)
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.7 KiB

View File

@ -1,39 +0,0 @@
/* 禁用 ResizeObserver 警告 */
iframe[style*="position: fixed; top: 0px; left: 0px; width: 100%; height: 100%; border: none; z-index: 2147483647;"] {
display: none !important;
}
.el-overlay {
overflow: hidden !important;
}
/* 添加全局样式 */
body {
margin: 0;
padding: 0;
overflow-x: hidden;
}
/* 修复 Element Plus 的一些已知问题 */
.el-dialog__wrapper {
overflow: hidden !important;
}
.el-select-dropdown {
overflow: hidden !important;
}
/* 禁用 ResizeObserver 相关的警告样式 */
.resize-observer-warning {
display: none !important;
}
/* 优化滚动行为 */
* {
scroll-behavior: smooth;
}
/* 防止页面抖动 */
.el-main {
overflow-x: hidden;
}

View File

@ -1,8 +0,0 @@
export const API_BASE_URL = 'http://localhost:5001/api';
export const DB_CONFIG = {
host: 'localhost',
user: 'root',
password: '123456',
database: 'equipment_cost_db'
};

View File

@ -1,55 +0,0 @@
import { createApp } from 'vue'
import App from './App.vue'
import router from './router'
import store from './store'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import './assets/styles/global.css'
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
// 创建应用实例
const app = createApp(App)
// 注册插件
app.use(ElementPlus, {
size: 'default'
})
app.use(router)
app.use(store)
// 注册图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
app.component(key, component)
}
// 全局错误处理
app.config.errorHandler = (err) => {
if (err.message && err.message.includes('ResizeObserver')) {
return
}
console.error(err)
}
// 全局警告处理
app.config.warnHandler = (msg, trace) => {
if (msg.includes('ResizeObserver')) {
return
}
console.warn(msg, trace)
}
// 挂载应用
app.mount('#app')
// 处理 ResizeObserver 错误
const _ResizeObserver = window.ResizeObserver
window.ResizeObserver = class ResizeObserver extends _ResizeObserver {
constructor(callback) {
super((entries, observer) => {
requestAnimationFrame(() => {
if (!Array.isArray(entries)) return
callback(entries, observer)
})
})
}
}

View File

@ -1,52 +0,0 @@
import { createRouter, createWebHistory } from 'vue-router'
import HomePage from '@/views/HomePage.vue'
import DataPage from '@/views/DataPage.vue'
import DatasetPage from '@/views/DatasetPage.vue'
import PredictPage from '@/views/PredictPage.vue'
import AnalysisPage from '@/views/AnalysisPage.vue'
import TrainingPage from '@/views/TrainingPage.vue'
const routes = [
{
path: '/',
name: 'Home',
component: HomePage
},
{
path: '/data',
name: 'Data',
component: DataPage
},
{
path: '/datasets',
name: 'Datasets',
component: DatasetPage
},
{
path: '/predict',
name: 'Predict',
component: PredictPage
},
{
path: '/analysis',
name: 'Analysis',
component: AnalysisPage
},
{
path: '/training',
name: 'Training',
component: TrainingPage
},
{
path: '/models',
name: 'Models',
component: () => import('../views/ModelPage.vue')
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
export default router

View File

@ -1,14 +0,0 @@
import { createStore } from 'vuex'
export default createStore({
state: {
},
getters: {
},
mutations: {
},
actions: {
},
modules: {
}
})

View File

@ -1,73 +0,0 @@
// 处理 ResizeObserver 错误
const resizeHandler = () => {
const resizeObserverErrors = [
"ResizeObserver loop completed with undelivered notifications.",
"ResizeObserver loop limit exceeded",
"ResizeObserver loop completed"
]
// 添加全局错误处理
const handler = (event) => {
if (event && event.message && resizeObserverErrors.includes(event.message)) {
event.stopPropagation()
event.preventDefault()
event.stopImmediatePropagation()
return false
}
}
// 添加多个事件监听器
window.addEventListener('error', handler, true)
window.addEventListener('unhandledrejection', handler, true)
// 添加 ResizeObserver 错误处理
if (window.ResizeObserver) {
const resizeObserverPrototype = ResizeObserver.prototype
const originalObserve = resizeObserverPrototype.observe
resizeObserverPrototype.observe = function (...args) {
try {
return originalObserve.apply(this, args)
} catch (e) {
if (resizeObserverErrors.includes(e.message)) {
return null
}
throw e
}
}
}
}
export default {
install: (app) => {
resizeHandler()
// 添加全局错误处理器
app.config.errorHandler = (err, vm, info) => {
const resizeObserverErrors = [
"ResizeObserver loop completed with undelivered notifications.",
"ResizeObserver loop limit exceeded",
"ResizeObserver loop completed"
]
if (err && err.message && resizeObserverErrors.includes(err.message)) {
return
}
console.error('Vue Error:', err, info)
}
// 添加全局警告处理器
app.config.warnHandler = (msg, vm, trace) => {
const resizeObserverErrors = [
"ResizeObserver loop completed with undelivered notifications.",
"ResizeObserver loop limit exceeded",
"ResizeObserver loop completed"
]
if (resizeObserverErrors.includes(msg)) {
return
}
console.warn('Vue Warning:', msg, trace)
}
}
}

View File

@ -1,304 +0,0 @@
<template>
<div class="analysis-page">
<el-card class="analysis-card">
<template #header>
<div class="header-content">
<h2>特征分析</h2>
</div>
</template>
<!-- 数据集选择 -->
<div class="dataset-section">
<el-form :model="analysisForm" label-width="120px">
<el-form-item label="装备类型" required>
<el-select v-model="analysisForm.equipment_type" @change="handleEquipmentTypeChange">
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<el-form-item label="选择数据集" required>
<el-select v-model="analysisForm.dataset_id" @change="handleDatasetChange">
<el-option
v-for="dataset in availableDatasets"
:key="dataset.id"
:label="dataset.name"
:value="dataset.id"
></el-option>
</el-select>
</el-form-item>
</el-form>
<!-- 数据集信息 -->
<el-descriptions v-if="selectedDataset" :column="2" border>
<el-descriptions-item label="数据集名称">{{ selectedDataset.name }}</el-descriptions-item>
<el-descriptions-item label="装备数量">{{ selectedDataset.equipment_count }}</el-descriptions-item>
<el-descriptions-item label="描述" :span="2">{{ selectedDataset.description }}</el-descriptions-item>
</el-descriptions>
</div>
<!-- 分析按钮 -->
<div class="action-section">
<el-button type="primary" @click="startAnalysis" :loading="analyzing" :disabled="!analysisForm.dataset_id">
{{ analyzing ? '分析中...' : '开始分析' }}
</el-button>
</div>
<!-- 分析结果 -->
<div v-if="analysisResult" class="result-section">
<el-divider content-position="left">分析结果</el-divider>
<!-- 特征重要性 -->
<h3>特征重要性</h3>
<div class="chart-container">
<div ref="importanceChartRef" style="width: 100%; height: 400px"></div>
</div>
<!-- 相关性分析 -->
<h3>相关性分析</h3>
<div class="chart-container">
<div ref="correlationChartRef" style="width: 100%; height: 500px"></div>
</div>
</div>
</el-card>
</div>
</template>
<script setup>
import { ref, onMounted, watch, nextTick, onUnmounted } from 'vue'
import { ElMessage } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
import * as echarts from 'echarts'
//
const analysisForm = ref({
equipment_type: '',
dataset_id: null
})
const availableDatasets = ref([])
const selectedDataset = ref(null)
const analyzing = ref(false)
const analysisResult = ref(null)
const importanceChartRef = ref(null)
const correlationChartRef = ref(null)
//
const importanceChart = ref(null)
const correlationChart = ref(null)
//
watch(() => analysisResult.value, async (newResult) => {
if (newResult) {
console.log('Analysis result updated:', newResult)
// tick DOM
await nextTick()
//
setTimeout(() => {
renderCharts()
}, 100)
}
}, { deep: true })
//
const loadDatasets = async (type) => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets`, {
params: { equipment_type: type, purpose: '训练' }
})
availableDatasets.value = response.data
} catch (error) {
ElMessage.error('获取数据集列表失败')
}
}
//
const handleEquipmentTypeChange = () => {
analysisForm.value.dataset_id = null
selectedDataset.value = null
analysisResult.value = null
loadDatasets(analysisForm.value.equipment_type)
}
//
const handleDatasetChange = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets/${analysisForm.value.dataset_id}`)
selectedDataset.value = response.data
analysisResult.value = null
} catch (error) {
ElMessage.error('获取数据集详情失败')
}
}
//
const startAnalysis = async () => {
if (!analysisForm.value.dataset_id) {
ElMessage.warning('请先选择数据集')
return
}
analyzing.value = true
try {
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
analysisResult.value = response.data
console.log('Analysis completed, result:', analysisResult.value)
} catch (error) {
ElMessage.error('特征分析失败')
console.error('Analysis error:', error)
} finally {
analyzing.value = false
}
}
//
const renderCharts = () => {
console.log('Starting to render charts')
//
if (importanceChart.value) {
importanceChart.value.dispose()
}
if (correlationChart.value) {
correlationChart.value.dispose()
}
// DOM
if (!importanceChartRef.value || !correlationChartRef.value) {
console.log('Chart DOM elements not ready')
return
}
try {
//
importanceChart.value = echarts.init(importanceChartRef.value)
correlationChart.value = echarts.init(correlationChartRef.value)
//
importanceChart.value.setOption({
title: { text: '特征重要性排序' },
tooltip: {},
xAxis: {
type: 'value',
name: '重要性得分'
},
yAxis: {
type: 'category',
data: analysisResult.value.important_features.map(f => f.name)
},
series: [{
type: 'bar',
data: analysisResult.value.important_features.map(f => f.importance)
}]
})
correlationChart.value.setOption({
title: { text: '特征相关性热力图' },
tooltip: {
position: 'top',
formatter: function (params) {
const value = params.data[2].toFixed(2)
const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]]
const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]]
return `${feature1}${feature2} 的相关性: ${value}`
}
},
grid: {
height: '50%',
top: '10%'
},
xAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true },
axisLabel: {
interval: 0,
rotate: 45
}
},
yAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true }
},
visualMap: {
min: -1,
max: 1,
calculable: true,
orient: 'horizontal',
left: 'center',
bottom: '15%',
color: ['#cc3333', '#eeeeee', '#00007f']
},
series: [{
type: 'heatmap',
data: analysisResult.value.correlation_analysis.matrix,
label: {
show: true,
formatter: function(params) {
return params.data[2].toFixed(2)
}
}
}]
})
//
window.addEventListener('resize', () => {
importanceChart.value?.resize()
correlationChart.value?.resize()
})
console.log('Charts rendered successfully')
} catch (error) {
console.error('Error rendering charts:', error)
}
}
//
onMounted(() => {
//
})
//
onUnmounted(() => {
importanceChart.value?.dispose()
correlationChart.value?.dispose()
})
</script>
<style lang="scss" scoped>
.analysis-page {
padding: 20px;
.analysis-card {
.header-content {
h2 {
margin: 0;
}
}
.dataset-section {
margin-bottom: 20px;
}
.action-section {
margin: 20px 0;
text-align: center;
}
.result-section {
h3 {
margin: 20px 0;
}
.chart-container {
margin: 20px 0;
border: 1px solid #ebeef5;
border-radius: 4px;
}
}
}
}
</style>

View File

@ -1,725 +0,0 @@
<template>
<div class="data-page">
<el-card class="data-card">
<template #header>
<div class="header-content">
<h2>数据管理</h2>
<div class="header-buttons">
<el-upload
:action="null"
:auto-upload="false"
:show-file-list="false"
accept=".xls,.xlsx"
@change="handleFileChange"
>
<el-button type="primary">导入数据</el-button>
</el-upload>
<el-button type="primary" @click="downloadTemplate">下载模板</el-button>
</div>
</div>
</template>
<!-- 数据列表 -->
<el-tabs v-model="activeTab" @tab-click="handleTabClick">
<el-tab-pane label="火箭炮数据" name="rocket">
<!-- 搜索和过滤 -->
<div class="filter-section">
<el-input
v-model="searchQuery"
placeholder="搜索装备名称或制造商"
style="width: 200px; margin-right: 10px;"
/>
<el-select v-model="filterManufacturer" placeholder="制造商" clearable>
<el-option
v-for="item in manufacturers"
:key="item"
:label="item"
:value="item"
/>
</el-select>
</div>
<!-- 数据表格 -->
<el-table :data="filteredRocketData" border style="width: 100%">
<el-table-column prop="name" label="名称" sortable></el-table-column>
<el-table-column prop="manufacturer" label="制造商" sortable></el-table-column>
<el-table-column prop="length_m" label="总长(m)" sortable></el-table-column>
<el-table-column prop="weight_kg" label="重量(kg)" sortable></el-table-column>
<el-table-column prop="max_range_km" label="最大射程(km)" sortable></el-table-column>
<el-table-column prop="rocket_diameter_mm" label="口径(mm)" sortable></el-table-column>
<el-table-column prop="rate_of_fire" label="射速(发/分)" sortable></el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDetails(scope.row)">
详情
</el-button>
<el-button size="small" type="primary" @click="editData(scope.row)">
编辑
</el-button>
<el-button size="small" type="danger" @click="deleteData(scope.row)">
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="巡飞弹数据" name="missile">
<!-- 搜索和过滤 -->
<div class="filter-section">
<el-input
v-model="searchQuery"
placeholder="搜索装备名称或制造商"
style="width: 200px; margin-right: 10px;"
/>
<el-select v-model="filterManufacturer" placeholder="制造商" clearable>
<el-option
v-for="item in manufacturers"
:key="item"
:label="item"
:value="item"
/>
</el-select>
</div>
<!-- 数据表格 -->
<el-table :data="filteredMissileData" border style="width: 100%">
<el-table-column prop="name" label="名称" sortable></el-table-column>
<el-table-column prop="manufacturer" label="制造" sortable></el-table-column>
<el-table-column prop="length_m" label="弹长(m)" sortable></el-table-column>
<el-table-column prop="weight_kg" label="重量(kg)" sortable></el-table-column>
<el-table-column prop="max_range_km" label="最大射程(km)" sortable></el-table-column>
<el-table-column prop="max_speed_ms" label="最大速度(m/s)" sortable></el-table-column>
<el-table-column prop="flight_time_min" label="巡飞时间(min)" sortable></el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDetails(scope.row)">
详情
</el-button>
<el-button size="small" type="primary" @click="editData(scope.row)">
编辑
</el-button>
<el-button size="small" type="danger" @click="deleteData(scope.row)">
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
</el-tabs>
<!-- 详情对话框 -->
<el-dialog v-model="detailsVisible" title="装备详情" width="70%">
<el-descriptions :column="2" border>
<!-- 基本信息 -->
<template>
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">基本信息</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="名称">{{ selectedData?.name }}</el-descriptions-item>
<el-descriptions-item label="制造商">{{ selectedData?.manufacturer }}</el-descriptions-item>
<!-- 通用参数 -->
<el-descriptions-item label="总长(m)">{{ formatNumber(selectedData?.length_m) }}</el-descriptions-item>
<el-descriptions-item label="宽度(m)">{{ formatNumber(selectedData?.width_m) }}</el-descriptions-item>
<el-descriptions-item label="高度(m)">{{ formatNumber(selectedData?.height_m) }}</el-descriptions-item>
<el-descriptions-item label="重量(kg)">{{ formatNumber(selectedData?.weight_kg) }}</el-descriptions-item>
<el-descriptions-item label="最大射程(km)">{{ formatNumber(selectedData?.max_range_km) }}</el-descriptions-item>
</template>
<!-- 火箭炮特有参数 -->
<template v-if="selectedData?.type === '火箭炮'">
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">火箭炮参数</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="口径(mm)">{{ formatNumber(selectedData?.rocket_diameter_mm) }}</el-descriptions-item>
<el-descriptions-item label="射速(发/分)">{{ formatNumber(selectedData?.rate_of_fire) }}</el-descriptions-item>
<el-descriptions-item label="火箭弹长度(m)">{{ formatNumber(selectedData?.rocket_length_m) }}</el-descriptions-item>
<el-descriptions-item label="火箭弹重量(kg)">{{ formatNumber(selectedData?.rocket_weight_kg) }}</el-descriptions-item>
<el-descriptions-item label="方向射界(度)">{{ formatNumber(selectedData?.firing_angle_horizontal) }}</el-descriptions-item>
<el-descriptions-item label="高低射界(度)">{{ formatNumber(selectedData?.firing_angle_vertical) }}</el-descriptions-item>
<el-descriptions-item label="机动方式">{{ selectedData?.mobility_type }}</el-descriptions-item>
<el-descriptions-item label="结构布局">{{ selectedData?.structure_layout }}</el-descriptions-item>
<el-descriptions-item label="发动机型号">{{ selectedData?.engine_model }}</el-descriptions-item>
<el-descriptions-item label="发动机参数">{{ selectedData?.engine_params }}</el-descriptions-item>
<el-descriptions-item label="功率(hp)">{{ formatNumber(selectedData?.power_hp) }}</el-descriptions-item>
<el-descriptions-item label="行程(km)">{{ formatNumber(selectedData?.travel_range_km) }}</el-descriptions-item>
</template>
<!-- 巡飞弹特有参数 -->
<template v-if="selectedData?.type === '巡飞弹'">
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">巡飞弹参数</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="最大速度(m/s)">{{ formatNumber(selectedData?.max_speed_ms) }}</el-descriptions-item>
<el-descriptions-item label="巡航速度(km/h)">{{ formatNumber(selectedData?.cruise_speed_kmh) }}</el-descriptions-item>
<el-descriptions-item label="巡飞时间(min)">{{ formatNumber(selectedData?.flight_time_min) }}</el-descriptions-item>
<el-descriptions-item label="战斗部类型">{{ selectedData?.warhead_type }}</el-descriptions-item>
<el-descriptions-item label="发射方式">{{ selectedData?.launch_mode }}</el-descriptions-item>
<el-descriptions-item label="动力装置">{{ selectedData?.power_system }}</el-descriptions-item>
<el-descriptions-item label="制导体制">{{ selectedData?.guidance_system }}</el-descriptions-item>
</template>
<!-- 特殊参数 -->
<template v-if="selectedData?.custom_params?.length > 0">
{{ console.log('Rendering custom params:', selectedData.custom_params) }}
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">特殊参数</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item
v-for="param in selectedData.custom_params"
:key="param.id"
:label="param.param_name"
>
{{ formatCustomParamValue(param) }}
</el-descriptions-item>
</template>
<!-- 成本信息 -->
<template>
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">成本信息</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="实际成本(元)">
{{ formatMoney(selectedData?.actual_cost) }}
</el-descriptions-item>
<el-descriptions-item label="预测成本(元)">
{{ formatMoney(selectedData?.predicted_cost) }}
</el-descriptions-item>
<el-descriptions-item label="成本估算时间">
{{ selectedData?.cost_estimate_date || '-' }}
</el-descriptions-item>
</template>
</el-descriptions>
</el-dialog>
<!-- 编辑对话框 -->
<el-dialog v-model="editVisible" title="编辑装备数据" width="70%">
<el-form :model="editForm" label-width="120px">
<!-- 基本信息 -->
<template>
<el-divider content-position="left">基本信息</el-divider>
<el-form-item label="名称">
<el-input v-model="editForm.name"></el-input>
</el-form-item>
<el-form-item label="制造商">
<el-input v-model="editForm.manufacturer"></el-input>
</el-form-item>
<!-- 通用参数 -->
<el-form-item label="总长(m)">
<el-input-number v-model="editForm.length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="宽度(m)">
<el-input-number v-model="editForm.width_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="高度(m)">
<el-input-number v-model="editForm.height_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="重量(kg)">
<el-input-number v-model="editForm.weight_kg" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="最大射程(km)">
<el-input-number v-model="editForm.max_range_km" :precision="2"></el-input-number>
</el-form-item>
</template>
<!-- 火箭炮特有参数 -->
<template v-if="editForm.type === '火箭炮'">
<el-divider content-position="left">火箭炮参数</el-divider>
<el-form-item label="口径(mm)">
<el-input-number v-model="editForm.rocket_diameter_mm" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="射速(发/分)">
<el-input-number v-model="editForm.rate_of_fire" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹长度(m)">
<el-input-number v-model="editForm.rocket_length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹重量(kg)">
<el-input-number v-model="editForm.rocket_weight_kg" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="方向射界(度)">
<el-input-number v-model="editForm.firing_angle_horizontal" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="低射界(度)">
<el-input-number v-model="editForm.firing_angle_vertical" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="机动方式">
<el-select v-model="editForm.mobility_type">
<el-option v-for="option in getSelectOptions('mobility_type')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="结构布局">
<el-input v-model="editForm.structure_layout"></el-input>
</el-form-item>
<el-form-item label="发动型号">
<el-input v-model="editForm.engine_model"></el-input>
</el-form-item>
<el-form-item label="发动机参数">
<el-input v-model="editForm.engine_params"></el-input>
</el-form-item>
<el-form-item label="功率(hp)">
<el-input-number v-model="editForm.power_hp" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="行程(km)">
<el-input-number v-model="editForm.travel_range_km" :precision="2"></el-input-number>
</el-form-item>
</template>
<!-- 巡飞弹特有参数 -->
<template v-if="editForm.type === '巡飞弹'">
<el-divider content-position="left">巡飞弹参数</el-divider>
<el-form-item label="最大速度(m/s)">
<el-input-number v-model="editForm.max_speed_ms" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡航速度(km/h)">
<el-input-number v-model="editForm.cruise_speed_kmh" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡飞时间(min)">
<el-input-number v-model="editForm.flight_time_min" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="战斗部类型">
<el-select v-model="editForm.warhead_type">
<el-option v-for="option in getSelectOptions('warhead_type')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="发射方式">
<el-select v-model="editForm.launch_mode">
<el-option v-for="option in getSelectOptions('launch_mode')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="动力装置">
<el-select v-model="editForm.power_system">
<el-option v-for="option in getSelectOptions('power_system')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="制导体制">
<el-select v-model="editForm.guidance_system">
<el-option v-for="option in getSelectOptions('guidance_system')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
</template>
<!-- 特殊参数 -->
<template v-if="editForm.custom_params?.length">
<el-divider content-position="left">特殊参数</el-divider>
<el-form-item
v-for="(param, index) in editForm.custom_params"
:key="param.id"
:label="param.param_name"
>
<el-input-number
v-if="isNumericParam(param)"
v-model="editForm.custom_params[index].param_value"
:precision="2"
:step="0.1"
>
<template #append v-if="param.param_unit">
{{ param.param_unit }}
</template>
</el-input-number>
<el-input
v-else
v-model="editForm.custom_params[index].param_value"
>
<template #append v-if="param.param_unit">
{{ param.param_unit }}
</template>
</el-input>
</el-form-item>
</template>
<!-- 成本信息 -->
<el-divider content-position="left">成本信息</el-divider>
<el-form-item label="实际成本(元)">
<el-input-number
v-model="editForm.actual_cost"
:precision="0"
:min="0"
:step="1000"
:controls="true"
style="width: 200px"
></el-input-number>
</el-form-item>
<el-form-item label="预测成本(元)">
<el-input-number
v-model="editForm.predicted_cost"
:precision="0"
disabled
style="width: 200px"
></el-input-number>
</el-form-item>
<el-form-item label="成本估算时间">
<el-date-picker
v-model="editForm.cost_estimate_date"
type="datetime"
placeholder="选择日期时间"
format="YYYY-MM-DD HH:mm:ss"
value-format="YYYY-MM-DD HH:mm:ss"
disabled
></el-date-picker>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" @click="saveEdit">保存</el-button>
</span>
</template>
</el-dialog>
</el-card>
</div>
</template>
<script setup>
import { ref, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
const activeTab = ref('rocket')
const rocketData = ref([])
const missileData = ref([])
const detailsVisible = ref(false)
const editVisible = ref(false)
const selectedData = ref(null)
const editForm = ref(null)
const searchQuery = ref('')
const filterManufacturer = ref('')
//
const manufacturers = computed(() => {
const data = activeTab.value === 'rocket' ? rocketData.value : missileData.value
return [...new Set(data.map(item => item.manufacturer))]
})
//
const filteredRocketData = computed(() => {
return filterData(rocketData.value)
})
const filteredMissileData = computed(() => {
return filterData(missileData.value)
})
const filterData = (data) => {
return data.filter(item => {
const matchesSearch = searchQuery.value ?
(item.name.toLowerCase().includes(searchQuery.value.toLowerCase()) ||
item.manufacturer.toLowerCase().includes(searchQuery.value.toLowerCase())) : true
const matchesManufacturer = filterManufacturer.value ?
item.manufacturer === filterManufacturer.value : true
return matchesSearch && matchesManufacturer
})
}
//
const loadData = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/data`)
if (response.data.error) {
throw new Error(response.data.error)
}
//
rocketData.value = response.data.rocket_artillery.map(item => ({
...item,
custom_params: item.custom_params || []
}))
//
missileData.value = response.data.loitering_munition.map(item => ({
...item,
custom_params: item.custom_params || []
}))
} catch (error) {
ElMessage.error('加载数据失败')
console.error('Error loading data:', error)
}
}
//
const handleFileChange = async (file) => {
try {
const formData = new FormData()
formData.append('file', file.raw)
const response = await axios.post(
`${API_BASE_URL}/data/import`,
formData,
{
headers: {
'Content-Type': 'multipart/form-data'
}
}
)
if (response.data.success) {
ElMessage.success('数据导入成功')
loadData() //
} else {
throw new Error(response.data.error)
}
} catch (error) {
ElMessage.error(error.message || '数据导入失败')
}
}
//
const downloadTemplate = () => {
window.open(`${API_BASE_URL}/data/template`, '_blank')
}
//
const viewDetails = async (row) => {
try {
console.log('Requesting details for row:', row)
const response = await axios.get(`${API_BASE_URL}/data/details/${row.id}`)
if (response.data.error) {
throw new Error(response.data.error)
}
console.log('Details response:', response.data)
// custom_params
if (typeof response.data.custom_params === 'string') {
response.data.custom_params = JSON.parse(response.data.custom_params)
}
console.log('Parsed custom params:', response.data.custom_params)
selectedData.value = response.data
console.log('Selected data:', selectedData.value)
detailsVisible.value = true
} catch (error) {
ElMessage.error('获取详情失败')
console.error('Error getting details:', error)
}
}
//
const editData = async (row) => {
try {
console.log('Editing row:', row)
const response = await axios.get(`${API_BASE_URL}/data/details/${row.id}`)
if (response.data.error) {
throw new Error(response.data.error)
}
console.log('Edit data response:', response.data)
// custom_params
if (typeof response.data.custom_params === 'string') {
response.data.custom_params = JSON.parse(response.data.custom_params)
}
//
const data = { ...response.data }
Object.keys(data).forEach(key => {
if (isNumberInput(key) && data[key] !== null && data[key] !== undefined) {
data[key] = Number(data[key])
}
})
//
if (data.custom_params) {
data.custom_params = data.custom_params.map(param => ({
...param,
param_value: !isNaN(param.param_value) ? Number(param.param_value) : param.param_value
}))
}
console.log('Parsed custom params:', data.custom_params)
editForm.value = data
console.log('Edit form data:', editForm.value)
editVisible.value = true
} catch (error) {
ElMessage.error('获取编辑数据失败')
console.error('Error getting edit data:', error)
}
}
//
const saveEdit = async () => {
try {
//
const saveData = {
...editForm.value,
custom_params: editForm.value.custom_params.map(param => ({
id: param.id,
param_name: param.param_name,
param_value: param.param_value,
param_unit: param.param_unit
}))
}
const response = await axios.put(
`${API_BASE_URL}/data/${editForm.value.id}`,
saveData
)
if (response.data.success) {
ElMessage.success('保存成功')
editVisible.value = false
loadData() //
} else {
throw new Error(response.data.error)
}
} catch (error) {
ElMessage.error(error.message || '保存失败')
}
}
//
const deleteData = async (row) => {
try {
await ElMessageBox.confirm('确定要删除这条数据吗?', '警告', {
type: 'warning'
})
const response = await axios.delete(`${API_BASE_URL}/data/${row.id}`)
if (response.data.success) {
ElMessage.success('删除成功')
loadData() //
} else {
throw new Error(response.data.error)
}
} catch (error) {
if (error !== 'cancel') {
ElMessage.error(error.message || '删除失败')
}
}
}
//
const formatNumber = (value) => {
if (value === null || value === undefined) return '-'
return Number(value).toFixed(2)
}
//
const formatMoney = (value) => {
if (value === null || value === undefined) return '-'
return new Intl.NumberFormat('zh-CN', {
style: 'currency',
currency: 'CNY'
}).format(value)
}
//
const formatCustomParamValue = (param) => {
if (!param || !param.param_value) return '-'
let value = param.param_value
//
if (!isNaN(value)) {
value = Number(value).toFixed(2)
}
//
if (param.param_unit) {
value = `${value} ${param.param_unit}`
}
return value
}
//
const isNumericParam = (param) => {
return !isNaN(param.param_value) && param.param_unit !== undefined
}
// handleTabClick
const handleTabClick = () => {
//
searchQuery.value = ''
filterManufacturer.value = ''
}
//
onMounted(() => {
loadData()
})
//
const isNumberInput = (key) => {
const numberFields = [
'length_m',
'width_m',
'height_m',
'weight_kg',
'max_range_km',
'firing_angle_horizontal',
'firing_angle_vertical',
'rocket_length_m',
'rocket_diameter_mm',
'rocket_weight_kg',
'rate_of_fire',
'combat_weight_kg',
'speed_kmh',
'min_range_km',
'power_hp',
'travel_range_km',
'max_speed_ms',
'cruise_speed_kmh',
'flight_time_min',
'folded_length_mm',
'folded_width_mm',
'folded_height_mm',
'actual_cost',
'predicted_cost'
]
return numberFields.includes(key)
}
</script>
<style lang="scss" scoped>
.data-page {
padding: 20px;
.data-card {
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
h2 {
margin: 0;
}
.header-buttons {
display: flex;
gap: 10px;
}
}
}
.filter-section {
margin-bottom: 20px;
display: flex;
gap: 10px;
}
.el-table {
margin-top: 20px;
}
}
</style>

View File

@ -1,322 +0,0 @@
<template>
<div class="dataset-page">
<el-card class="dataset-card">
<template #header>
<div class="header-content">
<h2>数据集管理</h2>
<div class="header-buttons">
<el-button type="primary" @click="createDataset">创建数据集</el-button>
</div>
</div>
</template>
<!-- 数据集列表 -->
<el-table :data="datasets" border style="width: 100%">
<el-table-column prop="name" label="数据集名称"></el-table-column>
<el-table-column prop="equipment_type" label="装备类型"></el-table-column>
<el-table-column prop="purpose" label="用途"></el-table-column>
<el-table-column prop="description" label="描述"></el-table-column>
<el-table-column prop="equipment_count" label="装备数量"></el-table-column>
<el-table-column prop="created_at" label="创建时间">
<template #default="scope">
{{ formatDateTime(scope.row.created_at) }}
</template>
</el-table-column>
<el-table-column prop="updated_at" label="修改时间">
<template #default="scope">
{{ formatDateTime(scope.row.updated_at) }}
</template>
</el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDataset(scope.row)">查看</el-button>
<el-button size="small" type="primary" @click="editDataset(scope.row)">编辑</el-button>
<el-button size="small" type="danger" @click="deleteDataset(scope.row)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-card>
<!-- 数据集详情对话框 -->
<el-dialog v-model="detailsVisible" :title="selectedDataset?.name" width="70%">
<el-descriptions :column="2" border>
<el-descriptions-item label="装备类型">{{ selectedDataset?.equipment_type }}</el-descriptions-item>
<el-descriptions-item label="用途">{{ selectedDataset?.purpose }}</el-descriptions-item>
<el-descriptions-item label="描述" :span="2">{{ selectedDataset?.description }}</el-descriptions-item>
</el-descriptions>
<!-- 数据集统计信息 -->
<div style="margin-top: 20px">
<el-divider content-position="left">统计信息</el-divider>
<el-descriptions :column="2" border>
<el-descriptions-item label="装备数量">{{ selectedDataset?.statistics?.equipment_count || 0 }}</el-descriptions-item>
<el-descriptions-item label="平均成本">{{ formatMoney(selectedDataset?.statistics?.average_cost) }}</el-descriptions-item>
<el-descriptions-item label="总成本">{{ formatMoney(selectedDataset?.statistics?.total_cost) }}</el-descriptions-item>
</el-descriptions>
</div>
<!-- 包含的装备列表 -->
<div style="margin-top: 20px">
<el-divider content-position="left">包含装备</el-divider>
<el-table :data="selectedDataset?.equipment" border>
<el-table-column prop="name" label="名称"></el-table-column>
<el-table-column prop="manufacturer" label="制造商"></el-table-column>
<el-table-column prop="actual_cost" label="成本(元)">
<template #default="scope">
{{ formatMoney(scope.row.actual_cost) }}
</template>
</el-table-column>
<el-table-column label="操作" width="100">
<template #default="scope">
<el-button size="small" @click="viewEquipment(scope.row)">详情</el-button>
</template>
</el-table-column>
</el-table>
</div>
</el-dialog>
<!-- 创建/编辑数据集对话框 -->
<el-dialog v-model="editVisible" :title="datasetForm.id ? '编辑数据集' : '创建数据集'" width="70%">
<el-form :model="datasetForm" label-width="120px">
<el-form-item label="数据集名称" required>
<el-input v-model="datasetForm.name"></el-input>
</el-form-item>
<el-form-item label="装备类型" required>
<el-select v-model="datasetForm.equipment_type" @change="handleEquipmentTypeChange">
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<el-form-item label="用途" required>
<el-select v-model="datasetForm.purpose">
<el-option label="训练" value="训练"></el-option>
<el-option label="验证" value="验证"></el-option>
</el-select>
</el-form-item>
<el-form-item label="描述">
<el-input type="textarea" v-model="datasetForm.description"></el-input>
</el-form-item>
<!-- 选择装备数据 -->
<el-form-item label="选择装备" required>
<el-table
:data="availableEquipment"
border
@selection-change="handleSelectionChange"
:max-height="400"
>
<el-table-column type="selection" width="55"></el-table-column>
<el-table-column prop="name" label="名称"></el-table-column>
<el-table-column prop="manufacturer" label="制造商"></el-table-column>
<el-table-column prop="actual_cost" label="成本(元)">
<template #default="scope">
{{ formatMoney(scope.row.actual_cost) }}
</template>
</el-table-column>
</el-table>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" @click="saveDataset">保存</el-button>
</span>
</template>
</el-dialog>
</div>
</template>
<script setup>
import { ref, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
//
const datasets = ref([]) //
const selectedDataset = ref(null) //
const detailsVisible = ref(false) //
const editVisible = ref(false) //
const availableEquipment = ref([]) //
const selectedEquipment = ref([]) //
//
const datasetForm = ref({
id: null,
name: '',
equipment_type: '',
purpose: '',
description: '',
selected_equipment: []
})
//
const loadDatasets = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets`)
datasets.value = response.data
} catch (error) {
ElMessage.error('获取数据集列表失败')
}
}
//
const createDataset = () => {
datasetForm.value = {
id: null,
name: '',
equipment_type: '',
purpose: '',
description: '',
selected_equipment: []
}
editVisible.value = true
}
//
const editDataset = async (dataset) => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets/${dataset.id}`)
datasetForm.value = response.data
await loadAvailableEquipment()
editVisible.value = true
} catch (error) {
ElMessage.error('获取数据集详情失败')
}
}
//
const viewDataset = async (dataset) => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets/${dataset.id}`)
selectedDataset.value = response.data
detailsVisible.value = true
} catch (error) {
ElMessage.error('获取数据集详情失败')
}
}
//
const deleteDataset = async (dataset) => {
try {
await ElMessageBox.confirm('确定要删除这个数据集吗?', '警告', {
type: 'warning'
})
await axios.delete(`${API_BASE_URL}/datasets/${dataset.id}`)
ElMessage.success('删除成功')
loadDatasets()
} catch (error) {
if (error !== 'cancel') {
ElMessage.error('删除失败')
}
}
}
//
const loadAvailableEquipment = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/data`)
availableEquipment.value = datasetForm.value.equipment_type === '火箭炮'
? response.data.rocket_artillery
: response.data.loitering_munition
} catch (error) {
ElMessage.error('获取装备列表失败')
}
}
//
const handleEquipmentTypeChange = () => {
selectedEquipment.value = [] //
loadAvailableEquipment() //
}
//
const handleSelectionChange = (selection) => {
selectedEquipment.value = selection
}
//
const saveDataset = async () => {
try {
const data = {
...datasetForm.value,
equipment_ids: selectedEquipment.value.map(item => item.id)
}
if (data.id) {
await axios.put(`${API_BASE_URL}/datasets/${data.id}`, data)
} else {
await axios.post(`${API_BASE_URL}/datasets`, data)
}
ElMessage.success('保存成功')
editVisible.value = false
loadDatasets()
} catch (error) {
ElMessage.error('保存失败')
}
}
//
const viewEquipment = (equipment) => {
//
console.log('View equipment:', equipment)
}
//
const formatMoney = (value) => {
if (value === null || value === undefined) return '-'
return new Intl.NumberFormat('zh-CN', {
style: 'currency',
currency: 'CNY'
}).format(value)
}
//
const formatDateTime = (value) => {
if (!value) return '-'
const date = new Date(value)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
})
}
//
onMounted(() => {
loadDatasets()
})
</script>
<style lang="scss" scoped>
.dataset-page {
padding: 20px;
.dataset-card {
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
h2 {
margin: 0;
}
.header-buttons {
display: flex;
gap: 10px;
}
}
}
.el-table {
margin-top: 20px;
}
}
</style>

View File

@ -1,101 +0,0 @@
<template>
<div class="home-page">
<el-card class="welcome-card">
<template #header>
<h2>装备成本估算系统</h2>
</template>
<el-row :gutter="20">
<el-col :span="8">
<el-card @click="$router.push('/predict')">
<el-icon><Money /></el-icon>
<h3>成本预测</h3>
<p>基于机器学习和 PLS 回归模型的成本预测</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/analysis')">
<el-icon><DataAnalysis /></el-icon>
<h3>特征分析</h3>
<p>分析参数对成本的影响</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/training')">
<el-icon><Monitor /></el-icon>
<h3>模型训练</h3>
<p>训练和优化预测模型</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/models')">
<el-icon><Management /></el-icon>
<h3>模型管理</h3>
<p>管理训练好的模型</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/datasets')">
<el-icon><Collection /></el-icon>
<h3>数据集管理</h3>
<p>管理训练和验证数据集</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/data')">
<el-icon><Management /></el-icon>
<h3>数据管理</h3>
<p>管理装备数据和成本数据</p>
</el-card>
</el-col>
</el-row>
</el-card>
</div>
</template>
<script setup>
import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-plus/icons-vue'
</script>
<style lang="scss" scoped>
.home-page {
padding: 20px;
.welcome-card {
max-width: 1200px;
margin: 0 auto;
h2 {
text-align: center;
margin: 0;
}
}
.el-card {
text-align: center;
cursor: pointer;
transition: all 0.3s;
margin-bottom: 20px;
&:hover {
transform: translateY(-5px);
box-shadow: 0 2px 12px 0 rgba(0,0,0,.1);
}
.el-icon {
font-size: 48px;
color: #409EFF;
margin: 20px 0;
}
h3 {
margin: 10px 0;
font-size: 18px;
}
p {
color: #909399;
font-size: 14px;
}
}
}
</style>

View File

@ -1,279 +0,0 @@
<template>
<div class="model-page">
<el-card class="model-card">
<template #header>
<div class="header-content">
<h2>模型管理</h2>
</div>
</template>
<!-- 模型列表 -->
<el-table :data="modelList" border style="width: 100%">
<el-table-column prop="model_type" label="模型类型">
<template #default="scope">
{{ getModelName(scope.row.model_type) }}
</template>
</el-table-column>
<el-table-column prop="model_name" label="模型名称"></el-table-column>
<el-table-column prop="equipment_type" label="装备类型"></el-table-column>
<el-table-column prop="r2_score" label="R²分数">
<template #default="scope">
{{ scope.row.r2_score.toFixed(4) }}
</template>
</el-table-column>
<el-table-column prop="mae" label="MAE (元)">
<template #default="scope">
{{ scope.row.mae.toFixed(2) }}
</template>
</el-table-column>
<el-table-column prop="rmse" label="RMSE (元)">
<template #default="scope">
{{ scope.row.rmse.toFixed(2) }}
</template>
</el-table-column>
<el-table-column prop="training_date" label="训练时间">
<template #default="scope">
{{ formatDateTime(scope.row.training_date) }}
</template>
</el-table-column>
<el-table-column prop="is_active" label="状态">
<template #default="scope">
<el-tag :type="scope.row.is_active ? 'success' : 'info'">
{{ scope.row.is_active ? '使用中' : '未使用' }}
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button
size="small"
type="primary"
:disabled="scope.row.is_active"
@click="activateModel(scope.row)"
>
激活
</el-button>
<el-button
size="small"
@click="viewDetails(scope.row)"
>
详情
</el-button>
<el-button
size="small"
type="danger"
:disabled="scope.row.is_active"
@click="deleteModel(scope.row)"
>
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-card>
<!-- 模型详情对话框 -->
<el-dialog v-model="detailsVisible" title="模型详情" width="70%">
<el-descriptions :column="2" border>
<el-descriptions-item label="模型名称">{{ selectedModel?.model_name }}</el-descriptions-item>
<el-descriptions-item label="模型类型">{{ formatModelType(selectedModel?.model_type) }}</el-descriptions-item>
<el-descriptions-item label="装备类型">{{ selectedModel?.equipment_type }}</el-descriptions-item>
<el-descriptions-item label="训练数据量">{{ selectedModel?.training_data_size }}</el-descriptions-item>
<el-descriptions-item label="训练时间">{{ formatDateTime(selectedModel?.training_date) }}</el-descriptions-item>
<el-descriptions-item label="状态">
<el-tag :type="selectedModel?.is_active ? 'success' : 'info'">
{{ selectedModel?.is_active ? '使用中' : '未使用' }}
</el-tag>
</el-descriptions-item>
</el-descriptions>
<!-- 评估指标 -->
<div style="margin-top: 20px">
<el-divider content-position="left">评估指标</el-divider>
<el-descriptions :column="3" border>
<el-descriptions-item label="R²分数">{{ selectedModel?.r2_score.toFixed(4) }}</el-descriptions-item>
<el-descriptions-item label="MAE">{{ selectedModel?.mae.toFixed(2) }} </el-descriptions-item>
<el-descriptions-item label="RMSE">{{ selectedModel?.rmse.toFixed(2) }} </el-descriptions-item>
</el-descriptions>
</div>
<!-- 特征重要性 -->
<div v-if="selectedModel?.feature_importance" style="margin-top: 20px">
<el-divider content-position="left">特征重要性</el-divider>
<div ref="importanceChartRef" style="width: 100%; height: 400px"></div>
</div>
</el-dialog>
</div>
</template>
<script setup>
import { ref, onMounted, watch, nextTick, onUnmounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
import * as echarts from 'echarts'
//
const modelList = ref([])
const selectedModel = ref(null)
const detailsVisible = ref(false)
const importanceChartRef = ref(null)
const importanceChart = ref(null)
//
const loadModels = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/models`)
modelList.value = response.data
} catch (error) {
ElMessage.error('获取模型列表失败')
}
}
//
const activateModel = async (model) => {
try {
await ElMessageBox.confirm('确定要激活这个模型吗?', '提示', {
type: 'warning'
})
await axios.post(`${API_BASE_URL}/models/${model.id}/activate`)
ElMessage.success('模型激活成功')
loadModels()
} catch (error) {
if (error !== 'cancel') {
ElMessage.error('模型激活失败')
}
}
}
//
const deleteModel = async (model) => {
try {
await ElMessageBox.confirm('确定要删除这个模型吗?', '警告', {
type: 'warning'
})
await axios.delete(`${API_BASE_URL}/models/${model.id}`)
ElMessage.success('删除成功')
loadModels()
} catch (error) {
if (error !== 'cancel') {
ElMessage.error('删除失败')
}
}
}
//
const viewDetails = async (model) => {
selectedModel.value = model
detailsVisible.value = true
}
//
watch(() => detailsVisible.value, async (visible) => {
if (visible && selectedModel.value?.feature_importance) {
await nextTick()
renderImportanceChart()
}
})
//
const renderImportanceChart = () => {
if (importanceChart.value) {
importanceChart.value.dispose()
}
importanceChart.value = echarts.init(importanceChartRef.value)
const featureImportance = JSON.parse(selectedModel.value.feature_importance)
const features = Object.keys(featureImportance)
const values = Object.values(featureImportance)
importanceChart.value.setOption({
title: { text: '特征重要性' },
tooltip: {},
xAxis: {
type: 'value',
name: '重要性得分'
},
yAxis: {
type: 'category',
data: features
},
series: [{
type: 'bar',
data: values
}]
})
}
//
const formatModelType = (type) => {
const typeMap = {
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbdt': 'GBDT',
'rf': 'Random Forest'
}
return typeMap[type] || type
}
//
const formatDateTime = (value) => {
if (!value) return '-'
const date = new Date(value)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
})
}
//
onUnmounted(() => {
importanceChart.value?.dispose()
})
//
onMounted(() => {
loadModels()
})
const getModelName = (modelType) => {
const modelNames = {
'pls': 'PLS回归',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
}
return modelNames[modelType] || modelType
}
</script>
<style lang="scss" scoped>
.model-page {
padding: 20px;
.model-card {
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
h2 {
margin: 0;
}
}
}
.el-table {
margin-top: 20px;
}
}
</style>

View File

@ -1,321 +0,0 @@
<template>
<div class="predict-page">
<el-card class="predict-card">
<template #header>
<h2>成本预测</h2>
</template>
<el-form :model="formData" label-width="120px">
<!-- 装备类型选择 -->
<el-form-item label="装备类型">
<el-select v-model="formData.type" @change="handleTypeChange">
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<!-- 通用参数 -->
<el-form-item label="总长(m)">
<el-input-number v-model="formData.length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="宽度(m)">
<el-input-number v-model="formData.width_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="高度(m)">
<el-input-number v-model="formData.height_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="重量(kg)">
<el-input-number v-model="formData.weight_kg"></el-input-number>
</el-form-item>
<el-form-item label="最大射程(km)">
<el-input-number v-model="formData.max_range_km"></el-input-number>
</el-form-item>
<!-- 火箭炮特有参数 -->
<template v-if="formData.type === '火箭炮'">
<el-form-item label="方向射界(度)">
<el-input-number v-model="formData.firing_angle_horizontal"></el-input-number>
</el-form-item>
<el-form-item label="高低射界(度)">
<el-input-number v-model="formData.firing_angle_vertical"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹长度(m)">
<el-input-number v-model="formData.rocket_length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="弹体直径(mm)">
<el-input-number v-model="formData.rocket_diameter_mm"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹重量(kg)">
<el-input-number v-model="formData.rocket_weight_kg"></el-input-number>
</el-form-item>
<el-form-item label="射速(发/分钟)">
<el-input-number v-model="formData.rate_of_fire"></el-input-number>
</el-form-item>
</template>
<!-- 巡飞弹特有参数 -->
<template v-if="formData.type === '巡飞弹'">
<el-form-item label="最大速度(km/h)">
<el-input-number v-model="formData.max_speed_kmh"></el-input-number>
</el-form-item>
<el-form-item label="巡航速度(km/h)">
<el-input-number v-model="formData.cruise_speed_kmh"></el-input-number>
</el-form-item>
<el-form-item label="巡飞时间(min)">
<el-input-number v-model="formData.flight_time_min"></el-input-number>
</el-form-item>
<el-form-item label="战斗部类型">
<el-select v-model="formData.warhead_type">
<el-option label="破片杀伤战斗部" value="破片杀伤战斗部"></el-option>
<el-option label="破甲战斗部" value="破甲战斗部"></el-option>
<el-option label="高爆战斗部" value="高爆战斗部"></el-option>
</el-select>
</el-form-item>
<el-form-item label="发射方式">
<el-select v-model="formData.launch_mode">
<el-option label="箱式发射" value="箱式发射"></el-option>
<el-option label="凭自身动力起飞" value="凭自身动力起飞"></el-option>
</el-select>
</el-form-item>
<el-form-item label="折叠长度(mm)">
<el-input-number v-model="formData.folded_length_mm"></el-input-number>
</el-form-item>
<el-form-item label="折叠宽度(mm)">
<el-input-number v-model="formData.folded_width_mm"></el-input-number>
</el-form-item>
<el-form-item label="折叠高度(mm)">
<el-input-number v-model="formData.folded_height_mm"></el-input-number>
</el-form-item>
</template>
<el-form-item>
<el-button type="primary" @click="submitForm">预测成本</el-button>
<el-button @click="resetForm">重置</el-button>
</el-form-item>
</el-form>
<!-- 预测结果 -->
<div v-if="predictionResults" class="prediction-results">
<h3>预测结果</h3>
<!-- 机器学习模型预测结果 -->
<div class="ml-prediction">
<h4>机器学习模型预测</h4>
<el-descriptions border>
<el-descriptions-item label="模型类型">
{{ getModelName(mlPrediction.model_info.type) }}
</el-descriptions-item>
<el-descriptions-item label="模型名称">
{{ mlPrediction.model_info.name }}
</el-descriptions-item>
<el-descriptions-item label="预测成本">
{{ formatMoney(mlPrediction.predicted_cost) }}
</el-descriptions-item>
<el-descriptions-item label="置信区间">
{{ formatMoney(mlPrediction.confidence_interval.lower) }} ~
{{ formatMoney(mlPrediction.confidence_interval.upper) }}
</el-descriptions-item>
</el-descriptions>
</div>
<!-- PLS回归预测结果 -->
<div class="pls-prediction">
<h4>PLS回归预测</h4>
<el-descriptions border>
<el-descriptions-item label="模型类型">
{{ getModelName(plsPrediction.model_info.type) }}
</el-descriptions-item>
<el-descriptions-item label="模型名称">
{{ plsPrediction.model_info.name }}
</el-descriptions-item>
<el-descriptions-item label="预测成本">
{{ formatMoney(plsPrediction.predicted_cost) }}
</el-descriptions-item>
<el-descriptions-item label="置信区间">
{{ formatMoney(plsPrediction.confidence_interval.lower) }} ~
{{ formatMoney(plsPrediction.confidence_interval.upper) }}
</el-descriptions-item>
</el-descriptions>
</div>
</div>
</el-card>
</div>
</template>
<script setup>
import { ref, reactive } from 'vue'
import { ElMessage } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
const formData = reactive({
type: '',
length_m: null,
width_m: null,
height_m: null,
weight_kg: null,
max_range_km: null
})
const predictionResults = ref(null)
const mlPrediction = ref(null)
const plsPrediction = ref(null)
const handleTypeChange = () => {
//
if (formData.type === '火箭炮') {
formData.firing_angle_horizontal = null
formData.firing_angle_vertical = null
formData.rocket_length_m = null
formData.rocket_diameter_mm = null
formData.rocket_weight_kg = null
formData.rate_of_fire = null
} else if (formData.type === '巡飞弹') {
formData.max_speed_kmh = null
formData.cruise_speed_kmh = null
formData.flight_time_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.folded_length_mm = null
formData.folded_width_mm = null
formData.folded_height_mm = null
}
}
const submitForm = async () => {
try {
//
if (!formData.type) {
throw new Error('请选择装备类型')
}
//
const commonFields = ['length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km']
for (const field of commonFields) {
if (!formData[field]) {
throw new Error(`请输入${formatFieldName(field)}`)
}
}
//
if (formData.type === '火箭炮') {
const rocketFields = [
'firing_angle_horizontal', 'firing_angle_vertical',
'rocket_length_m', 'rocket_diameter_mm', 'rocket_weight_kg', 'rate_of_fire'
]
for (const field of rocketFields) {
if (!formData[field]) {
throw new Error(`请输入${formatFieldName(field)}`)
}
}
} else if (formData.type === '巡飞弹') {
const missileFields = [
'max_speed_kmh', 'cruise_speed_kmh', 'flight_time_min',
'folded_length_mm', 'folded_width_mm', 'folded_height_mm'
]
for (const field of missileFields) {
if (!formData[field]) {
throw new Error(`请输入${formatFieldName(field)}`)
}
}
}
//
const [mlResponse, plsResponse] = await Promise.all([
axios.post(`${API_BASE_URL}/predict`, formData),
axios.post(`${API_BASE_URL}/pls/predict`, formData)
])
mlPrediction.value = mlResponse.data
plsPrediction.value = plsResponse.data
predictionResults.value = true
} catch (error) {
ElMessage.error(error.message || '预测失败')
}
}
const resetForm = () => {
formData.type = ''
formData.length_m = null
formData.width_m = null
formData.height_m = null
formData.weight_kg = null
formData.max_range_km = null
predictionResults.value = null
mlPrediction.value = null
plsPrediction.value = null
}
const formatFieldName = (field) => {
const nameMap = {
'length_m': '总长',
'width_m': '宽度',
'height_m': '高度',
'weight_kg': '重量',
'max_range_km': '最大射程',
'firing_angle_horizontal': '方向射界',
'firing_angle_vertical': '高低射界',
'rocket_length_m': '火箭弹长度',
'rocket_diameter_mm': '弹体直径',
'rocket_weight_kg': '火箭弹重量',
'rate_of_fire': '射速',
'max_speed_kmh': '最大速度',
'cruise_speed_kmh': '巡航速度',
'flight_time_min': '巡飞时间',
'folded_length_mm': '折叠长度',
'folded_width_mm': '折叠宽度',
'folded_height_mm': '折叠高度'
}
return nameMap[field] || field
}
const formatMoney = (value) => {
return new Intl.NumberFormat('zh-CN', {
style: 'currency',
currency: 'CNY'
}).format(value)
}
const getModelName = (modelType) => {
const modelNames = {
'pls': 'PLS回归',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
}
return modelNames[modelType] || modelType
}
</script>
<style scoped>
.predict-page {
padding: 20px;
}
.predict-card {
max-width: 800px;
margin: 0 auto;
}
.prediction-results {
margin-top: 20px;
.ml-prediction, .pls-prediction {
margin-top: 20px;
padding: 20px;
background-color: #f5f7fa;
border-radius: 4px;
}
h4 {
margin-top: 0;
margin-bottom: 15px;
}
}
.el-descriptions {
margin-top: 10px;
}
</style>

View File

@ -1,370 +0,0 @@
<template>
<div class="training-page">
<el-card class="training-card">
<template #header>
<h2>模型训练</h2>
</template>
<!-- 训练配置 -->
<el-form :model="trainingConfig" label-width="120px">
<el-form-item label="装备类型">
<el-select v-model="trainingConfig.type" placeholder="选择装备类型">
<el-option label="火箭炮" value="火箭炮" />
<el-option label="巡飞弹" value="巡飞弹" />
</el-select>
</el-form-item>
<el-form-item label="训练数据集">
<el-select v-model="trainingConfig.train_dataset_id" placeholder="选择训练数据集">
<el-option
v-for="dataset in trainingDatasets"
:key="dataset.id"
:label="dataset.name"
:value="dataset.id"
/>
</el-select>
</el-form-item>
<el-form-item label="验证数据集">
<el-select v-model="trainingConfig.validation_dataset_id" placeholder="选择验证数据集">
<el-option
v-for="dataset in validationDatasets"
:key="dataset.id"
:label="dataset.name"
:value="dataset.id"
/>
</el-select>
</el-form-item>
<el-form-item label="选择模型">
<el-checkbox-group v-model="trainingConfig.models">
<el-checkbox label="pls" disabled>PLS回归</el-checkbox>
<el-checkbox label="xgboost" checked>XGBoost</el-checkbox>
<el-checkbox label="lightgbm" checked>LightGBM</el-checkbox>
<el-checkbox label="gbm" checked>GBM</el-checkbox>
<el-checkbox label="rf" checked>Random Forest</el-checkbox>
</el-checkbox-group>
</el-form-item>
<el-form-item>
<el-button type="primary" @click="startTraining" :loading="isTraining">
开始训练
</el-button>
</el-form-item>
</el-form>
<!-- 训练结果 -->
<div v-if="trainingResult" class="training-result">
<h3>训练结果</h3>
<!-- 最佳模型信息 -->
<div class="best-model-info" v-if="trainingResult.best_model">
<h4>最佳模型: {{ getModelName(trainingResult.best_model.type) }}</h4>
<p>R²分数: {{ formatNumber(trainingResult.best_model.r2) }}</p>
<p>MAE: {{ formatNumber(trainingResult.best_model.mae) }} </p>
<p>RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} </p>
</div>
<!-- 所有模型评估结果 -->
<el-table :data="modelResults" border style="width: 100%; margin-top: 20px;">
<el-table-column prop="model" label="模型" width="120">
<template #default="scope">
{{ getModelName(scope.row.model) }}
</template>
</el-table-column>
<!-- 训练集评估 -->
<el-table-column label="训练集评估">
<el-table-column prop="train.r2" label="R²分数" width="120">
<template #default="scope">
{{ formatNumber(scope.row.train.r2) }}
</template>
</el-table-column>
<el-table-column prop="train.mae" label="MAE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.train.mae) }}
</template>
</el-table-column>
<el-table-column prop="train.rmse" label="RMSE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.train.rmse) }}
</template>
</el-table-column>
</el-table-column>
<!-- 验证集评估 -->
<el-table-column label="验证集评估">
<el-table-column prop="validation.r2" label="R²分数" width="120">
<template #default="scope">
{{ formatNumber(scope.row.validation.r2) }}
</template>
</el-table-column>
<el-table-column prop="validation.mae" label="MAE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.validation.mae) }}
</template>
</el-table-column>
<el-table-column prop="validation.rmse" label="RMSE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.validation.rmse) }}
</template>
</el-table-column>
</el-table-column>
</el-table>
<!-- 特征重要性 -->
<div v-if="trainingResult.feature_importance" class="feature-importance">
<h4>特征重要性</h4>
<el-table
:data="featureImportanceData"
border
style="width: 100%; margin-top: 10px;"
>
<el-table-column prop="feature" label="特征" width="180" />
<el-table-column prop="importance" label="重要性" width="120">
<template #default="scope">
{{ formatNumber(scope.row.importance) }}
</template>
</el-table-column>
</el-table>
</div>
</div>
</el-card>
</div>
</template>
<script setup>
import { ref, computed, onMounted, watch } from 'vue'
import { ElMessage } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
//
const trainingConfig = ref({
type: '',
train_dataset_id: null,
validation_dataset_id: null,
models: ['pls']
})
//
const trainingDatasets = ref([])
const validationDatasets = ref([])
//
const isTraining = ref(false)
const trainingResult = ref(null)
//
const loadDatasets = async () => {
try {
//
const trainResponse = await axios.get(
`${API_BASE_URL}/datasets`,
{ params: { equipment_type: trainingConfig.value.type, purpose: '训练' } }
)
trainingDatasets.value = trainResponse.data
//
const valResponse = await axios.get(
`${API_BASE_URL}/datasets`,
{ params: { equipment_type: trainingConfig.value.type, purpose: '验证' } }
)
validationDatasets.value = valResponse.data
} catch (error) {
ElMessage.error('加载数据集失败')
console.error('Error loading datasets:', error)
}
}
//
watch(() => trainingConfig.value.type, (newType) => {
if (newType) {
loadDatasets()
}
})
//
const startTraining = async () => {
try {
//
if (!trainingConfig.value.type) {
ElMessage.warning('请选择装备类型')
return
}
if (!trainingConfig.value.train_dataset_id) {
ElMessage.warning('请选择训练数据集')
return
}
if (!trainingConfig.value.validation_dataset_id) {
ElMessage.warning('请选择验证数据集')
return
}
if (trainingConfig.value.models.length === 0) {
ElMessage.warning('请至少选择一个模型')
return
}
isTraining.value = true
//
const response = await axios.post(`${API_BASE_URL}/train`, trainingConfig.value)
if (response.data.error) {
throw new Error(response.data.error)
}
trainingResult.value = response.data
ElMessage.success('训练完成')
} catch (error) {
ElMessage.error(error.message || '训练失败')
console.error('Training error:', error)
} finally {
isTraining.value = false
}
}
//
const formatNumber = (value) => {
if (value === null || value === undefined) return '-'
if (typeof value === 'number') {
if (Math.abs(value) >= 1000) {
return value.toLocaleString('zh-CN', { maximumFractionDigits: 2 })
}
return value.toFixed(4)
}
return value
}
//
const getModelName = (modelType) => {
const modelNames = {
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
}
return modelNames[modelType] || modelType
}
//
const modelResults = computed(() => {
if (!trainingResult.value?.metrics) return []
return Object.entries(trainingResult.value.metrics).map(([model, metrics]) => ({
model,
train: metrics.train,
validation: metrics.validation
}))
})
//
const featureNameMap = {
//
'length_m': '总长(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
//
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
'rocket_diameter_mm': '口径(mm)',
'rocket_weight_kg': '火箭弹重量(kg)',
'rate_of_fire': '射速(发/分)',
'combat_weight_kg': '战斗重量(kg)',
'speed_kmh': '速度(km/h)',
'min_range_km': '最小射程(km)',
'power_hp': '功率(hp)',
//
'fire_density': '火力密度',
'mobility_index': '机动性指标',
'range_ratio': '射程比',
'power_weight_ratio': '功重比',
'volume_density': '体积密度',
//
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
'flight_time_min': '巡飞时间(min)',
'folded_length_mm': '折叠长度(mm)',
'folded_width_mm': '折叠宽度(mm)',
'folded_height_mm': '折叠高度(mm)',
//
'warhead_ratio': '战斗部比重',
'speed_ratio': '速度比',
'range_time_ratio': '射程时间比',
'aspect_ratio': '展弦比'
}
//
const featureImportanceData = computed(() => {
if (!trainingResult.value?.feature_importance || !trainingResult.value?.feature_names) return []
//
const data = trainingResult.value.feature_importance
.map((importance, index) => ({
feature: featureNameMap[trainingResult.value.feature_names[index]] || trainingResult.value.feature_names[index],
importance
}))
// 0
.filter(item => item.importance > 0)
//
.sort((a, b) => b.importance - a.importance)
return data
})
//
onMounted(() => {
if (trainingConfig.value.type) {
loadDatasets()
}
})
</script>
<style lang="scss" scoped>
.training-page {
padding: 20px;
.training-card {
.training-result {
margin-top: 20px;
.best-model-info {
background-color: #f5f7fa;
padding: 15px;
border-radius: 4px;
margin-bottom: 20px;
}
.feature-importance {
margin-top: 20px;
.importance-bar {
width: 100%;
background-color: #f5f7fa;
border-radius: 4px;
.importance-value {
background-color: #409eff;
color: white;
padding: 4px 8px;
border-radius: 4px;
text-align: right;
transition: width 0.3s ease;
}
}
}
}
}
}
</style>

View File

@ -1,5 +0,0 @@
module.exports = {
devServer: {
port: 8080
}
}

View File

@ -1,12 +0,0 @@
flask==2.0.1
flask-cors==3.0.10
sqlalchemy==1.4.23
pymysql==1.0.2
cryptography==3.4.7 # MySQL 8.0+ 认证需要
numpy==1.21.2
pandas==1.3.3
scikit-learn==0.24.2
tensorflow==2.6.0
urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
openpyxl==3.1.2 # 用于读取 .xlsx 文件
xlrd==2.0.1 # 用于读取 .xls 文件

View File

@ -1,60 +0,0 @@
#!/bin/bash
echo "开始安装装备成本估算系统..."
# 检查Python版本
python3 -V || {
echo "错误: 需要 Python 3.8+"
exit 1
}
# 检查Node.js版本
node -v || {
echo "错误: 需要 Node.js 14+"
exit 1
}
# 创建必要的目录
echo "创建系统目录..."
mkdir -p {logs,data,models}
# 安装后端依赖
echo "安装后端依赖..."
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
# 安装前端依赖
echo "安装前端依赖..."
cd frontend
npm install
npm run build
cd ..
# 配置文件
if [ ! -f config/.env ]; then
echo "创建配置文件..."
cp config/.env.template config/.env
echo "请修改 config/.env 中的配置"
fi
# 初始化数据库
echo "初始化数据库..."
read -p "请输入MySQL root密码: " mysqlpass
mysql -u root -p$mysqlpass < src/schema.sql
# 导入测试数据(可选)
read -p "是否导入测试数据?(y/n) " import_test_data
if [ "$import_test_data" = "y" ]; then
mysql -u root -p$mysqlpass equipment_cost_db < src/init_data.sql
fi
# 设置权限
echo "设置文件权限..."
chmod +x scripts/*.sh
chmod 755 logs models data
chmod 600 config/.env
echo "安装完成!"
echo "请检查并修改 config/.env 中的配置"
echo "使用 ./scripts/start.sh 启动服务"

View File

@ -1,51 +0,0 @@
#!/bin/bash
echo "启动装备成本估算系统..."
# 检查配置文件
if [ ! -f config/.env ]; then
echo "错误: 配置文件不存在"
echo "请先运行 install.sh"
exit 1
fi
# 检查日志目录
if [ ! -d logs ]; then
mkdir -p logs
fi
# 激活虚拟环境
source venv/bin/activate
# 导出环境变量
export $(cat config/.env | xargs)
# 启动后端服务
echo "启动后端服务..."
gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" \
--daemon \
--pid gunicorn.pid \
--access-logfile logs/access.log \
--error-logfile logs/error.log
# 等待后端服务启动
sleep 2
# 检查后端服务是否成功启动
if ! curl -s http://localhost:5001/api/ > /dev/null; then
echo "错误: 后端服务启动失败"
exit 1
fi
# 启动前端服务
echo "启动前端服务..."
cd frontend
npm run serve -- --port 8080 --host 0.0.0.0 &
echo $! > ../frontend.pid
cd ..
echo "服务已启动!"
echo "后端API: http://localhost:5001"
echo "前端界面: http://localhost:8080"
echo "查看后端日志: tail -f logs/access.log"
echo "查看前端日志: tail -f logs/frontend.log"

View File

@ -1,44 +0,0 @@
#!/bin/bash
echo "停止装备成本估算系统..."
# 停止后端服务
if [ -f gunicorn.pid ]; then
echo "停止后端服务..."
pid=$(cat gunicorn.pid)
kill -TERM $pid
rm gunicorn.pid
echo "后端服务已停止 (PID: $pid)"
else
echo "未找到后端服务PID文件"
# 尝试查找并停止所有gunicorn进程
pkill -f gunicorn
echo "已尝试停止所有gunicorn进程"
fi
# 停止前端服务
if [ -f frontend.pid ]; then
echo "停止前端服务..."
pid=$(cat frontend.pid)
kill -TERM $pid
rm frontend.pid
echo "前端服务已停止 (PID: $pid)"
else
echo "未找到前端服务PID文件"
# 尝试查找并停止前端服务进程
pkill -f "vite preview"
echo "已尝试停止所有前端服务进程"
fi
# 检查是否还有相关进程在运行
if pgrep -f gunicorn > /dev/null; then
echo "警告: 仍有gunicorn进程在运行"
ps aux | grep gunicorn | grep -v grep
fi
if pgrep -f "vite preview" > /dev/null; then
echo "警告: 仍有前端服务进程在运行"
ps aux | grep "vite preview" | grep -v grep
fi
echo "所有服务已停止"

View File

@ -1 +0,0 @@
# 这个文件可以为空,但必须存在

View File

@ -1,50 +0,0 @@
from flask import Flask
from flask_cors import CORS
from .routes import api_bp
from .logger import setup_logger
import os
# 获取logger
logger = setup_logger(__name__)
def create_app():
"""
创建并配置Flask应用
"""
try:
# 创建必要的目录
os.makedirs('logs', exist_ok=True)
os.makedirs('data', exist_ok=True)
os.makedirs('models', exist_ok=True)
logger.info("=== Server Starting ===")
logger.info("Initializing directories...")
# 创建Flask应用
app = Flask(__name__)
# 配置CORS
CORS(app)
logger.info("CORS enabled")
# 注册API蓝图
app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered")
# 配置数据库连接
app.config['MYSQL_HOST'] = 'localhost'
app.config['MYSQL_USER'] = 'root'
app.config['MYSQL_PASSWORD'] = '123456'
app.config['MYSQL_DB'] = 'equipment_cost_db'
logger.info("Starting server...")
return app
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
raise
if __name__ == '__main__':
app = create_app()
app.run(host='localhost', port=5001)

View File

@ -1,342 +0,0 @@
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from scipy import stats
import joblib
import os
import pandas as pd
from .feature_analysis import FeatureAnalysis
import logging
from src.model_trainer import ModelTrainer
from src.database import get_db_connection
from .logger import setup_logger
logger = setup_logger(__name__)
class CostPredictor:
def __init__(self):
self.scaler_X = StandardScaler()
self.scaler_y = StandardScaler()
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
# 添加 TensorFlow 配置
tf.config.run_functions_eagerly(False) # 启用图执行模式
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
self.load_model()
def load_model(self):
"""
加载预训练型和标准化器
"""
try:
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 创建默认模型
self._create_default_model()
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
def _create_default_model(self):
"""
创建默认模型并进行初始化训练
"""
# 创建输入层
inputs = tf.keras.Input(shape=(11,))
# 创建隐藏层
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
# 创建输出层
outputs = tf.keras.layers.Dense(1)(x)
# 创建模型
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 编译模型
self.model.compile(
optimizer='adam',
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_absolute_error]
)
# 创建示例数据
example_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 训练标准化器
self.scaler_X.fit(example_data)
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
# 设置默认装备类型
self.equipment_type = '火箭炮'
def _create_example_data(self):
"""
创建示例数据来训练标准化器
"""
# 火箭炮示例数据
rocket_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 巡飞弹示例数据
missile_data = pd.DataFrame({
'length_m': [1.3, 2.5],
'width_m': [0.23, 0.6],
'height_m': [0.23, 0.6],
'weight_kg': [12.5, 135],
'max_range_km': [40, 250],
'max_speed_kmh': [180, 185],
'cruise_speed_kmh': [100, 110],
'flight_time_min': [60, 120],
'folded_length_mm': [1300, 2500],
'folded_width_mm': [230, 600],
'folded_height_mm': [230, 600]
})
# 训练标准化器
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
# 设置默认装备类型
self.equipment_type = '火箭炮'
def predict(self, data):
"""
使用训练好的最优模型进行预测
"""
try:
logger.info(f"Starting prediction for {data.get('type')}")
equipment_type = data.get('type')
# 加载已训练的最优模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type):
raise ValueError(f"No trained model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 获取模型类型
model_type = trainer.get_model_type()
return {
'predicted_cost': float(y_pred[0]),
'model_type': model_type, # 返回使用的模型类型
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise
def _calculate_confidence_interval(self, prediction, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
lower = max(1000, interval[0]) # 最小值设为1000元
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
return [lower, upper]
except Exception as e:
logging.error(f"Error calculating confidence interval: {str(e)}")
# 如果计算失败返回基于20%的简单区间
lower = max(1000, prediction * 0.8)
upper = prediction * 1.2
return [lower, upper]
def evaluate(self, y_true, y_pred):
"""
模型评估
"""
return {
'mae': float(mean_absolute_error(y_true, y_pred)),
'mse': float(mean_squared_error(y_true, y_pred)),
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
'r2': float(r2_score(y_true, y_pred))
}
def predict_pls(self, data):
"""
使用 PLS <EFBFBD><EFBFBD><EFBFBD>型预测成本
"""
try:
logger.info(f"Starting PLS prediction for {data.get('type')}")
equipment_type = data.get('type')
# 加载 PLS 模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
raise ValueError(f"No trained PLS model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
return {
'predicted_cost': float(y_pred[0]),
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"PLS prediction error: {str(e)}")
raise
def predict_all(self, data):
"""
使用所有可用模型进行预测
"""
try:
logger.info(f"Starting multi-model prediction for {data.get('type')}")
equipment_type = data.get('type')
results = {}
# 1. 获取所有激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT id, model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
""", (equipment_type,))
active_models = cursor.fetchall()
if not active_models:
raise ValueError(f"No active models found for {equipment_type}")
# 2. 使用每个模型进行预测
trainer = ModelTrainer()
for model_info in active_models:
try:
# 加载特定模型
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
logger.warning(f"Failed to load model: {model_info['model_name']}")
continue
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 保存结果
results[model_info['model_type']] = {
'predicted_cost': float(y_pred[0]),
'model_info': {
'name': model_info['model_name'],
'type': model_info['model_type'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
continue
if not results:
raise ValueError("No successful predictions from any model")
# 3. 计算综合预测结果
all_predictions = [result['predicted_cost'] for result in results.values()]
ensemble_prediction = float(np.mean(all_predictions))
prediction_std = float(np.std(all_predictions))
# 4. 返回所有结果
return {
'individual_predictions': results,
'ensemble_prediction': {
'predicted_cost': ensemble_prediction,
'standard_deviation': prediction_std,
'confidence_interval': {
'lower': float(ensemble_prediction - 1.96 * prediction_std),
'upper': float(ensemble_prediction + 1.96 * prediction_std)
}
}
}
except Exception as e:
logger.error(f"Error in multi-model prediction: {str(e)}")
raise

View File

@ -1,155 +0,0 @@
import pandas as pd
import openpyxl
from openpyxl.styles import PatternFill, Font, Alignment
from openpyxl.worksheet.datavalidation import DataValidation
import os
from .logger import setup_logger
logger = setup_logger(__name__)
def create_excel_template():
"""
创建数据模板
"""
try:
# 确保data目录存在
data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
os.makedirs(data_dir, exist_ok=True)
# 创建完整的文件路径
template_path = os.path.join(data_dir, 'equipment_data_template.xlsx')
# 创建 Excel 写入器
writer = pd.ExcelWriter(template_path, engine='openpyxl')
# 火箭炮基本参数表
rocket_artillery_columns = [
'名称', '类型', '制造商', '口径_mm',
'反射管数量', '乘员数', '总长_m',
'宽度_m', '高度_m', '重量_kg',
'战斗重_kg', '速度_km/h', '最大射程_km',
'最小射程_km', '方向射界_度', '高低射界_度',
'火箭弹长度_m', '火箭弹重量_kg',
'火箭弹最大速度_m/s', '射速_发',
'战斗部重量_kg', '行走方式',
'结构布局', '发动机型号', '发动机参数',
'功率_hp', '行程_km', '成本_元'
]
# 巡飞弹基本参数表
loitering_munition_columns = [
'名称', '类型', '制造商', '目标类型',
'弹长_m', '弹径_mm', '翼展_m',
'重量_kg', '战斗部重量_kg',
'最大射程_km', '最大速度_m/s',
'巡航速度_kmh', '巡飞时间_min',
'战斗部类型', '发射方式',
'折叠长度_mm', '折叠宽度_mm',
'折叠高度_mm', '动力装置',
'制导体制', '成本_元'
]
# 特殊参数表
special_params_columns = [
'装备名称', # 关联字段
'参数名称',
'参数值',
'参数单位',
'参数说明'
]
# 创建工作表
pd.DataFrame(columns=rocket_artillery_columns).to_excel(
writer, sheet_name='火箭炮', index=False
)
pd.DataFrame(columns=loitering_munition_columns).to_excel(
writer, sheet_name='巡飞弹', index=False
)
pd.DataFrame(columns=special_params_columns).to_excel(
writer, sheet_name='特殊参数', index=False
)
# 获取工作簿
workbook = writer.book
# 设置火箭炮工作表格式
rocket_sheet = workbook['火箭炮']
for col in range(1, len(rocket_artillery_columns) + 1):
cell = rocket_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 设置巡飞弹工作表格式
missile_sheet = workbook['巡飞弹']
for col in range(1, len(loitering_munition_columns) + 1):
cell = missile_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 设置特殊参数工作表格式
special_sheet = workbook['特殊参数']
for col in range(1, len(special_params_columns) + 1):
cell = special_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 添加数据验证
for sheet in [rocket_sheet, missile_sheet]:
# 数值验证
number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0")
number_validation.error = "请输入大于0的数值"
number_validation.errorTitle = "输入错误"
# 应用到相应列
for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
number_validation.add(f"{col}2:{col}1000")
sheet.add_data_validation(number_validation)
# 添加说明
rocket_sheet['AD1'] = "填写说明:"
rocket_sheet['AD2'] = "1. 所有数值必须大于0"
rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写"
rocket_sheet['AD4'] = "3. 成本单位为元"
missile_sheet['V1'] = "填写说明:"
missile_sheet['V2'] = "1. 所有数值必须大于0"
missile_sheet['V3'] = "2. 单位必须按照表头要求填写"
missile_sheet['V4'] = "3. 成本单位为元"
special_sheet['G1'] = "填写说明:"
special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致"
special_sheet['G3'] = "2. 参数值必须包含单位"
special_sheet['G4'] = "3. 参数说明应简明扼要"
# 调整列宽
for sheet in [rocket_sheet, missile_sheet, special_sheet]:
for col in sheet.columns:
max_length = 0
column = col[0].column_letter
for cell in col:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = (max_length + 2)
sheet.column_dimensions[column].width = adjusted_width
# 保存文件
writer.close()
return template_path
except Exception as e:
raise Exception(f"创建模板文件失败: {str(e)}")
if __name__ == "__main__":
try:
template_path = create_excel_template()
print(f"模板文件已创建: {template_path}")
except Exception as e:
print(f"错误: {str(e)}")

View File

@ -1,233 +0,0 @@
from sklearn.preprocessing import StandardScaler
from datetime import datetime
import os
import joblib
import pandas as pd
import numpy as np
from src.feature_analysis import FeatureAnalysis
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.model_selection import cross_val_score, LeaveOneOut
import json
import logging
from src.database.db_connection import get_db_connection
from sklearn.metrics import mean_absolute_error, mean_squared_error
from .logger import setup_logger
logger = setup_logger(__name__)
class DataPreparation:
def __init__(self):
self.feature_analyzer = FeatureAnalysis()
self.feature_scaler = StandardScaler()
self.target_scaler = StandardScaler() # 添加目标值标准化器
def prepare_training_data(self, equipment_data, equipment_type):
"""
准备训练数据
"""
try:
logger.info(f"Preparing training data for {equipment_type}")
logger.info(f"Raw data size: {len(equipment_data)}")
# 如果输入已经是 numpy 数组,直接返回
if isinstance(equipment_data, np.ndarray):
X = equipment_data
logger.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
return {
'X': X,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
# 从原始数据中提取特征和目标值
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
for item in equipment_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 提取目标值(成本)
try:
cost = float(item['actual_cost'])
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logger.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError, KeyError):
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录原始数据范围
logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
# 标准化特征和目标值
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
# 记录标准化后的数据范围
logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
# 记录标准化器参数
logger.info("Feature scaler params:")
logger.info(f"Mean: {self.feature_scaler.mean_}")
logger.info(f"Scale: {self.feature_scaler.scale_}")
logger.info("Target scaler params:")
logger.info(f"Mean: {self.target_scaler.mean_}")
logger.info(f"Scale: {self.target_scaler.scale_}")
return {
'X': X_scaled,
'y': y_scaled,
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
except Exception as e:
logger.error(f"Error in data preparation: {str(e)}")
raise Exception(f"Training error: {str(e)}")
def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None):
"""
准备验证数据
"""
try:
logger.info(f"Preparing validation data for {equipment_type}")
logger.info(f"Raw validation data size: {len(validation_data)}")
# 如果输入已经是 numpy 数组,直接使用
if isinstance(validation_data, np.ndarray):
X = validation_data
logger.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 使用训练数据的标准化器
if scalers and 'feature_scaler' in scalers:
X_scaled = scalers['feature_scaler'].transform(X)
else:
X_scaled = X
logger.info(f"Preprocessed data shape: {X_scaled.shape}")
logger.info(f"Validation features shape: {X_scaled.shape}")
logger.info(f"Validation features type: {X_scaled.dtype}")
return {
'X': X_scaled,
'y': None # 验证数据可能没有标签
}
# 从原始数据中提取特征和目标值
if not feature_names:
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
for item in validation_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 提取目标值(成本)并验证范围
try:
cost = float(item['actual_cost'])
logger.info(f"Raw cost value: {cost}")
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logger.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError):
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录数据范围
logger.info(f"Features range: min={X.min()}, max={X.max()}")
logger.info(f"Targets range: min={y.min()}, max={y.max()}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 使用训练数据的标准化器
if scalers and 'feature_scaler' in scalers:
X_scaled = scalers['feature_scaler'].transform(X)
if 'target_scaler' in scalers:
y_scaled = scalers['target_scaler'].transform(y.reshape(-1, 1)).ravel()
else:
y_scaled = y
else:
X_scaled = X
y_scaled = y
logger.info(f"Preprocessed data shape: {X_scaled.shape}")
logger.info(f"Validation features shape: {X_scaled.shape}")
logger.info(f"Validation features type: {X_scaled.dtype}")
# 记录标准化后的数据范围
logger.info(f"Scaled validation X range: min={X_scaled.min()}, max={X_scaled.max()}")
logger.info(f"Scaled validation y range: min={y_scaled.min()}, max={y_scaled.max()}")
# 确保特征维度一致
if not feature_names:
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
logger.info(f"Expected features: {len(feature_names)}")
logger.info(f"Actual features: {X_scaled.shape[1]}")
if X_scaled.shape[1] != len(feature_names):
raise ValueError(f"Feature dimension mismatch: expected {len(feature_names)}, got {X_scaled.shape[1]}")
return {
'X': X_scaled,
'y': y_scaled # 返回标准化后的成本值
}
except Exception as e:
logger.error(f"Error in validation data preparation: {str(e)}")
logger.error(f"Feature names: {feature_names}")
logger.error(f"Equipment type: {equipment_type}")
raise Exception(f"Validation error: {str(e)}")
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
return self.feature_analyzer.calculate_derived_features(data, equipment_type)
except Exception as e:
logger.error(f"Error calculating derived features: {str(e)}")
raise Exception(f"Feature calculation error: {str(e)}")

View File

@ -1 +0,0 @@
from .db_connection import get_db_connection

View File

@ -1,37 +0,0 @@
import mysql.connector
from mysql.connector import Error
from contextlib import contextmanager
import os
from dotenv import load_dotenv
from ..logger import setup_logger
# 获取logger
logger = setup_logger(__name__)
# 加载环境变量
load_dotenv()
@contextmanager
def get_db_connection():
"""
数据库连接上下文管理器
"""
connection = None
try:
connection = mysql.connector.connect(
host=os.getenv('MYSQL_HOST', 'localhost'),
user=os.getenv('MYSQL_USER', 'root'),
password=os.getenv('MYSQL_PASSWORD', '123456'),
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
)
logger.info("Database connection established")
yield connection
except Error as e:
logger.error(f"Error connecting to MySQL: {str(e)}")
raise
finally:
if connection and connection.is_connected():
connection.close()
logger.info("Database connection closed")

View File

@ -1,269 +0,0 @@
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
import logging
from .logger import setup_logger
logger = setup_logger(__name__)
class FeatureAnalysis:
def __init__(self):
self.scaler = StandardScaler()
self.important_features = []
# 添加特征名称映射
self.feature_names_map = {
# 通用参数
'length_m': '总长(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
# 火箭炮特有参数
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
'rocket_diameter_mm': '口径(mm)',
'rocket_weight_kg': '火箭弹重量(kg)',
'rate_of_fire': '射速(发/分)',
'combat_weight_kg': '战斗重量(kg)',
'speed_kmh': '速度(km/h)',
'min_range_km': '最小射程(km)',
'power_hp': '功率(hp)',
# 火箭炮衍生特征
'fire_density': '火力密度',
'mobility_index': '机动性指标',
'range_ratio': '射程比',
'power_weight_ratio': '功重比',
'volume_density': '体积密度',
# 巡飞弹特有参数
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
'flight_time_min': '巡飞时间(min)',
'folded_length_mm': '折叠长度(mm)',
'folded_width_mm': '折叠宽度(mm)',
'folded_height_mm': '折叠高度(mm)',
# 巡飞弹衍生特征
'warhead_ratio': '战斗部比重',
'speed_ratio': '速度比',
'range_time_ratio': '射程时间比',
'aspect_ratio': '展弦比',
'volume_density': '体积密度'
}
def get_equipment_specific_features(self, equipment_type):
"""
获取特定装备类型的特征列表
"""
# 通用参数
common_features = [
'length_m', # 总长(m)
'width_m', # 宽度(m)
'height_m', # 高度(m)
'weight_kg', # 重量(kg)
'max_range_km' # 最大射程(km)
]
if equipment_type == '火箭炮':
# 火箭炮特有参数
specific_features = [
'firing_angle_horizontal', # 方向射界(度)
'firing_angle_vertical', # 高低射界(度)
'rocket_length_m', # 火箭弹长度(m)
'rocket_diameter_mm', # 口径(mm)
'rocket_weight_kg', # 火箭弹重量(kg)
'rate_of_fire', # 射速(发/分)
'combat_weight_kg', # 战斗重量(kg)
'speed_kmh', # 速度(km/h)
'min_range_km', # 最小射程(km)
'power_hp' # 功率(hp)
]
# 火箭炮衍生特征
derived_features = [
'fire_density', # 火力密度 = 射速 * 火箭弹重量
'mobility_index', # 机动性指标 = 速度 / 战斗重量
'range_ratio', # 射程比 = 最大射程 / 最小射程
'power_weight_ratio', # 功重比 = 功率 / 战斗重量
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
else: # 巡飞弹
# 巡飞弹特有参数
specific_features = [
'wingspan_m', # 翼展(m)
'warhead_weight_kg', # 战斗部重量(kg)
'max_speed_ms', # 最大速度(m/s)
'cruise_speed_kmh', # 巡航速度(km/h)
'flight_time_min', # 巡飞时间(min)
'folded_length_mm', # 折叠长度(mm)
'folded_width_mm', # 折叠宽度(mm)
'folded_height_mm' # 折叠高度(mm)
]
# 巡飞弹衍生特征
derived_features = [
'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
'speed_ratio', # 速度比 = 巡航速度 / 最大速度
'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
if equipment_type == '火箭炮':
# 火箭炮衍生特征计算
if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
else:
data['fire_density'] = 0 # 或者其他默认值
if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
else:
data['mobility_index'] = 0
if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
data['range_ratio'] = data['max_range_km'] / data['min_range_km']
else:
data['range_ratio'] = 0
if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
else:
data['power_weight_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
else: # 巡飞弹
# 巡飞弹衍生特征计算
if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
else:
data['warhead_ratio'] = 0
if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
else:
data['speed_ratio'] = 0
if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
else:
data['range_time_ratio'] = 0
if 'wingspan_m' in data.columns and 'length_m' in data.columns:
data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
else:
data['aspect_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
return data
except Exception as e:
logger.error(f"Error calculating derived features: {str(e)}")
raise
def analyze_features(self, features, target, feature_names):
"""
分析特征重要性和相关性
"""
try:
# 转换为numpy数组
X = np.array(features)
y = np.array(target)
# 数据标准化
X_scaled = self.scaler.fit_transform(X)
# 特征重要性分析
rf = RandomForestRegressor(n_estimators=100, random_state=42)
rf.fit(X_scaled, y)
importances = rf.feature_importances_
# 按重要性排序,使用中文特征名
importance_indices = np.argsort(importances)[::-1]
important_features = [
{
'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
'importance': float(importances[i])
}
for i in importance_indices
]
# 相关性分析
df = pd.DataFrame(X_scaled, columns=feature_names)
correlation_matrix = df.corr().values
# 生成相关性分析数据保留2位小数
correlation_data = []
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
for i in range(len(feature_names)):
for j in range(len(feature_names)):
correlation_data.append([
i, j,
round(correlation_matrix[i][j], 2) # 修改为保留2位小数
])
return {
'important_features': important_features,
'correlation_analysis': {
'features': chinese_feature_names, # 使用中文特征名
'matrix': correlation_data
}
}
except Exception as e:
logger.error(f"Error in feature analysis: {str(e)}")
raise
def preprocess_features(self, equipment_data, equipment_type):
"""
预处理特征数据
"""
try:
# 转换为 DataFrame
df = pd.DataFrame(equipment_data)
# 计算衍生特征
df = self.calculate_derived_features(df, equipment_type)
# 处理缺失值
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
# 转换为数值类型
df[col] = pd.to_numeric(df[col], errors='coerce')
# 使用新的方式填充缺失值
mean_value = df[col].mean()
df[col] = df[col].fillna(mean_value)
logger.info(f"Preprocessed data shape: {df.shape}")
return df
except Exception as e:
logger.error(f"Error preprocessing features: {str(e)}")
raise Exception(f"Feature preprocessing error: {str(e)}")

View File

@ -1,255 +0,0 @@
import pandas as pd
from .logger import setup_logger
from src.database.db_connection import get_db_connection
logger = setup_logger(__name__)
def import_training_data(excel_file):
"""
从Excel导入训练数据到数据库
"""
try:
# 读取所有sheet
rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮')
missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹')
special_df = pd.read_excel(excel_file, sheet_name='特殊参数')
# 记录所有装备名称,用于后续检查
equipment_names = set()
with get_db_connection() as conn:
cursor = conn.cursor()
# 1. 先导入火箭炮数据
logger.info("开始导入火箭炮数据...")
for _, row in rocket_df.iterrows():
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '火箭炮'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
if existing_equipment:
logger.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入")
continue
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (row['名称'], '火箭炮', row['制造商']))
equipment_id = cursor.lastrowid
# 插入通用参数
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
equipment_id,
row['总长_m'] if pd.notna(row['总长_m']) else None,
row['宽度_m'] if pd.notna(row['宽度_m']) else None,
row['高度_m'] if pd.notna(row['高度_m']) else None,
row['重量_kg'] if pd.notna(row['重量_kg']) else None,
row['最大射程_km'] if pd.notna(row['最大射程_km']) else None
))
# 插入火箭炮特有参数
cursor.execute("""
INSERT INTO rocket_artillery_params
(equipment_id, firing_angle_horizontal, firing_angle_vertical,
rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire,
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
structure_layout, engine_model, engine_params, power_hp,
travel_range_km)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
equipment_id,
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
row['高低射界_度'] if pd.notna(row['高低射界_度']) else None,
row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None,
row['口径_mm'] if pd.notna(row['口径_mm']) else None,
row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None,
row['射速_发'] if pd.notna(row['射速_发']) else None,
row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None,
row['速度_km/h'] if pd.notna(row['速度_km/h']) else None,
row['最小射程_km'] if pd.notna(row['最小射程_km']) else None,
row['行走方式'] if pd.notna(row['行走方式']) else None,
row['结构布局'] if pd.notna(row['结构布局']) else None,
row['发动机型号'] if pd.notna(row['发动机型号']) else None,
row['发动机参数'] if pd.notna(row['发动机参数']) else None,
row['功率_hp'] if pd.notna(row['功率_hp']) else None,
row['行程_km'] if pd.notna(row['行程_km']) else None
))
# 插入成本数据
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
""", (equipment_id, row['成本_元']))
logger.info("火箭炮数据导入完成")
# 2. 导入巡飞弹数据
logger.info("开始导入巡飞弹数据...")
for index, row in missile_df.iterrows():
# 记录每行数据的空值情况
null_values = row[row.isna()].index.tolist()
if null_values:
logger.info(f"{index + 2} 中的空值字段: {null_values}")
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '巡飞弹'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
if existing_equipment:
logger.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入")
continue
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (
row['名称'],
'巡飞弹',
row['制造商'] if pd.notna(row['制造商']) else None
))
equipment_id = cursor.lastrowid
# 插入通用参数
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
equipment_id,
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
float(row['重量_kg']) if pd.notna(row['重量_kg']) else None,
float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None
))
# 插入巡飞弹特有参数
cursor.execute("""
INSERT INTO loitering_munition_params
(equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms,
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
folded_length_mm, folded_width_mm, folded_height_mm,
power_system, guidance_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
equipment_id,
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None,
float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None,
float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None,
float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None,
str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None,
str(row['发射方式']) if pd.notna(row['发射方式']) else None,
float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None,
float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None,
float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None,
str(row['动力装置']) if pd.notna(row['动力装置']) else None,
str(row['制导体制']) if pd.notna(row['制导体制']) else None
))
# 插入成本数据
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
""", (equipment_id, float(row['成本_元'])))
logger.info("巡飞弹数据导入完成")
# 提交之前的更改并关闭原有游标
cursor.close()
conn.commit()
# 3. 导入特殊参数
logger.info("开始导入特殊参数...")
for index, row in special_df.iterrows():
equipment_name = row['装备名称']
param_name = row['参数名称']
logger.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'")
if equipment_name not in equipment_names:
logger.warning(f"未找到装备: {equipment_name},请检查名称是否正确")
continue
# 获取装备ID - 使用新的游标
logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
SELECT id FROM equipment WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()
if not result:
logger.warning(f"未找到装备: {equipment_name}")
continue
equipment_id = result[0]
logger.debug(f"找到装备ID: {equipment_id}")
# 检查参数是否存在 - 使用新的游标
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
with conn.cursor() as check_cursor:
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = %s AND param_name = %s
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
if exists:
logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
continue
# 插入新的参数 - 使用新的游标
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
with conn.cursor() as insert_cursor:
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (%s, %s, %s, %s, %s)
""", (
equipment_id,
param_name,
param_value,
param_unit,
param_desc
))
logger.debug(f"成功插入参数记录")
# 最终提交
conn.commit()
logger.info("特殊参数导入完成")
logger.info("所有数据导入成功")
return True
except Exception as e:
logger.error(f"Error importing data: {str(e)}")
raise
if __name__ == "__main__":
try:
excel_file = 'data/equipment_data_20241108.xlsx'
import_training_data(excel_file)
logger.info("All data imported successfully")
except Exception as e:
logger.error(f"Import failed: {str(e)}")

View File

@ -1,319 +0,0 @@
/*
使
1.
2.
3.
*/
-- 插入装备基本信息
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-- 插入巡飞弹技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES (
1, -- 终结者巡飞弹
0.56,
0.15,
0.20,
2.72,
160.93,
96.56,
24,
15,
'破片杀伤战斗部',
'凭自身动力起飞',
560,
150,
200
);
-- 插入火箭炮技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES (
2, -- 胜利-2火箭炮
10,
2.5,
3.34,
15000,
23
);
-- 插入成本数据(示例数据)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 1000000), -- 终结者巡飞弹成本
(2, 5000000); -- 胜利-2火箭炮成本
-- 插入更多巡飞弹变体数据用于训练
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-- 插入变体技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 终结者-A稍大型号
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
-- 终结者-B稍小型号
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
-- 终结者-C标准型号的改进版
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-- 插入变体成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(3, 1100000), -- 终结者-A成本较高
(4, 900000), -- 终结者-B成本较低
(5, 1050000); -- 终结者-C成本中等
-- 添加更多巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-- 添加它们的技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 哈比
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
-- 游荡者
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
-- 凤凰
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
-- 弹簧刀
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
-- 彩虹-4
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-- 添加成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(6, 800000), -- 哈比
(7, 500000), -- 游荡者
(8, 450000), -- 凤凰
(9, 480000), -- 弹簧刀
(10, 1500000); -- 彩虹-4
-- 火箭炮数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('BM-21', '火箭炮', '俄罗斯'),
('SR5', '火箭炮', '中国'),
('HIMARS', '火箭炮', '美国'),
('LAR-160', '火箭炮', '以色列'),
('T-122', '火箭炮', '土耳其'),
('RM-70', '火箭炮', '捷克'),
('ASTROS II', '火箭炮', '巴西');
-- 火箭炮通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- BM-21
(1, 7.35, 2.4, 3.1, 13700, 20.4),
-- SR5
(2, 10.2, 2.8, 3.2, 28500, 70),
-- HIMARS
(3, 7.0, 2.4, 3.2, 16250, 70),
-- LAR-160
(4, 6.7, 2.5, 2.8, 15000, 45),
-- T-122
(5, 7.2, 2.5, 2.9, 18000, 40),
-- RM-70
(6, 7.5, 2.5, 3.0, 17200, 20.3),
-- ASTROS II
(7, 8.0, 2.7, 3.1, 24500, 90);
-- 火箭炮特有参数
INSERT INTO rocket_artillery_params (
equipment_id,
firing_angle_horizontal,
firing_angle_vertical,
rocket_length_m,
rocket_diameter_mm,
rocket_weight_kg,
rate_of_fire,
combat_weight_kg,
speed_kmh,
min_range_km,
mobility_type,
structure_layout,
engine_model,
engine_params,
power_hp,
travel_range_km
) VALUES
-- BM-21
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
-- SR5
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
-- HIMARS
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
-- LAR-160
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
-- T-122
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
-- RM-70
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
-- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-- 巡飞弹数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('Hero-120', '巡飞弹', '以色列'),
('Switchblade 600', '巡飞弹', '美国'),
('Warmate', '巡飞弹', '波兰'),
('CH-901', '巡飞弹', '中国'),
('HAROP', '巡飞弹', '以色列'),
('Coyote', '巡飞弹', '美国'),
('WS-43', '巡飞弹', '中国');
-- 巡飞弹通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- Hero-120
(8, 1.3, 0.23, 0.23, 12.5, 40),
-- Switchblade 600
(9, 1.3, 0.22, 0.22, 15.0, 40),
-- Warmate
(10, 1.1, 0.15, 0.15, 5.7, 15),
-- CH-901
(11, 1.2, 0.18, 0.18, 9.0, 20),
-- HAROP
(12, 2.5, 0.43, 0.43, 135, 1000),
-- Coyote
(13, 0.9, 0.12, 0.12, 5.9, 20),
-- WS-43
(14, 1.8, 0.35, 0.35, 20, 60);
-- 巡飞弹特有参数
INSERT INTO loitering_munition_params (
equipment_id,
wingspan_m,
warhead_weight_kg,
max_speed_ms,
cruise_speed_kmh,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm,
power_system,
guidance_system
) VALUES
-- Hero-120
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
-- Switchblade 600
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
-- Warmate
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
-- CH-901
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
-- HAROP
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
-- Coyote
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
-- WS-43
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-- 火箭炮
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000), -- ASTROS II
-- 巡飞弹
(8, 150000), -- Hero-120
(9, 180000), -- Switchblade 600
(10, 80000), -- Warmate
(11, 100000), -- CH-901
(12, 850000), -- HAROP
(13, 75000), -- Coyote
(14, 120000); -- WS-43
-- 创建初始数据集
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-- 关联装备到数据集
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 火箭炮训练集
(1, 1), (1, 2), (1, 3), (1, 4),
-- 巡飞弹训练集
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
-- 火箭炮验证集
(3, 5), (3, 6), (3, 7),
-- 巡飞弹验证集
(4, 13), (4, 14);

View File

@ -1,33 +0,0 @@
import logging
import os
from datetime import datetime
def setup_logger(name):
"""
创建并配置logger
"""
# 创建logger
logger = logging.getLogger(name)
# 如果logger已经有处理器直接返回
if logger.handlers:
return logger
# 设置日志级别
logger.setLevel(logging.INFO)
# 确保日志目录存在
os.makedirs('logs', exist_ok=True)
# 创建文件处理器
file_handler = logging.FileHandler('logs/api.log')
file_handler.setLevel(logging.INFO)
# 创建格式化器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
return logger

View File

@ -1,612 +0,0 @@
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.impute import SimpleImputer
import xgboost as xgb
import lightgbm as lgb
import logging
import joblib
import os
from src.feature_analysis import FeatureAnalysis
from datetime import datetime
import json
from src.database import get_db_connection
from src.data_preparation import DataPreparation
from sklearn.cross_decomposition import PLSRegression
from .logger import setup_logger
logger = setup_logger(__name__)
class ModelTrainer:
def __init__(self):
"""
初始化 ModelTrainer
"""
self.models = {
'xgboost': self._create_xgboost_model(),
'lightgbm': self._create_lightgbm_model(),
'gbm': self._create_gbm_model(),
'rf': self._create_rf_model(),
'pls': self._create_pls_model()
}
self.best_model = None
self.imputer = SimpleImputer(strategy='mean')
self.feature_scaler = None
self.target_scaler = None
self.equipment_type = None
self.feature_analyzer = FeatureAnalysis()
def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
"""
训练模型并返回评估结果
"""
try:
self.equipment_type = equipment_type
logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
results = {}
best_score = -float('inf')
best_model_info = None
# 首先训练 PLS 模型
logger.info("Training pls...")
pls_model = self.models['pls']
pls_model.fit(X_train, y_train)
pls_metrics = self._calculate_metrics(
pls_model,
X_train, y_train,
X_val, y_val
)
results['pls'] = pls_metrics
# 训练其他机器学习模型
for model_name in model_names:
if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
continue
if model_name not in self.models:
logger.warning(f"Unknown model: {model_name}")
continue
logger.info(f"Training {model_name}...")
model = self.models[model_name]
# 训练模型
model.fit(X_train, y_train)
# 计算评估指标
metrics = self._calculate_metrics(
model,
X_train, y_train,
X_val, y_val
)
results[model_name] = metrics
# 更新最佳模型(只在机器学习模型中比较)
if metrics['validation']['r2'] > best_score:
best_score = metrics['validation']['r2']
best_model_info = {
'type': model_name,
'r2': metrics['validation']['r2'],
'mae': metrics['validation']['mae'],
'rmse': metrics['validation']['rmse']
}
self.best_model = model
# 保存最佳模型和 PLS 模型
if equipment_type and best_model_info:
self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
return {
'metrics': results,
'best_model': best_model_info
}
except Exception as e:
logger.error(f"Error in model training: {str(e)}")
raise
def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
"""
计算模型评估指标
"""
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
# 训练集评估
train_pred = model.predict(X_train)
# 如果使用了标准化,需要转换回原始范围
if hasattr(self, 'target_scaler'):
train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
else:
y_train_orig = y_train
# 记录预测范围
logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
train_metrics = {
'r2': r2_score(y_train_orig, train_pred),
'mae': mean_absolute_error(y_train_orig, train_pred),
'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
}
# 验证集评估
if X_val is not None and y_val is not None:
val_pred = model.predict(X_val)
# 如果使用了标准化,需要转换回原始范围
if hasattr(self, 'target_scaler'):
val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
else:
y_val_orig = y_val
# 记录预测范围
logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
val_metrics = {
'r2': r2_score(y_val_orig, val_pred),
'mae': mean_absolute_error(y_val_orig, val_pred),
'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
}
else:
# 使用交叉验证
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
val_metrics = {
'r2': cv_scores.mean(),
'mae': None,
'rmse': None
}
return {
'train': train_metrics,
'validation': val_metrics
}
def _create_xgboost_model(self):
"""
创建 XGBoost 模型增强正则化
"""
return xgb.XGBRegressor(
n_estimators=50, # 减少树的数量
learning_rate=0.05, # 学习率
max_depth=3, # 减小树的深
min_child_weight=3, # 增加节点权重
subsample=0.7, # 减小样本采样比例
colsample_bytree=0.7, # 减小特征采样比例
reg_alpha=0.1, # L1 正则化
reg_lambda=1, # L2 正则化
random_state=42
)
def _create_lightgbm_model(self):
"""
创建 LightGBM 模型增强正则化
"""
return lgb.LGBMRegressor(
n_estimators=50,
learning_rate=0.05,
max_depth=3,
num_leaves=7,
min_data_in_leaf=3,
min_sum_hessian_in_leaf=1e-3,
subsample=0.7,
colsample_bytree=0.7,
reg_alpha=0.1,
reg_lambda=1,
random_state=42,
verbose=-1
)
def _create_gbm_model(self):
"""
创建 GBM 模型增强正则化以减轻过拟合
"""
return GradientBoostingRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42,
subsample=0.8,
min_samples_split=3,
min_samples_leaf=2
)
def _create_rf_model(self):
"""
创建随机森林模型针对小样本数据调整参数
"""
return RandomForestRegressor(
n_estimators=100,
max_depth=3,
random_state=42,
min_samples_split=3,
min_samples_leaf=2
)
def _create_pls_model(self):
"""
创建 PLS 模型优化参数配置
"""
return PLSRegression(
n_components=2, # 减少主成分数量从5减到2
scale=True, # 保持数据标准化
max_iter=500, # 减少最大迭代次数,避免过拟合
tol=1e-6 # 降低收敛精度,避免过拟合
)
def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
"""
保存最佳模型和 PLS 模型
"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 1. 保存最佳机器学习模型
model_path = f'{model_dir}/{equipment_type}_{timestamp}'
if isinstance(self.best_model, xgb.XGBRegressor):
self.best_model.save_model(f'{model_path}.json')
model_format = 'json'
else:
joblib.dump(self.best_model, f'{model_path}.joblib')
model_format = 'joblib'
# 2. 保存 PLS 模型
pls_model = self.models['pls']
pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
joblib.dump(pls_model, pls_path)
# 3. 保存标准化器
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
joblib.dump({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
logger.info(f"Saved best model to {model_path}.{model_format}")
logger.info(f"Saved PLS model to {pls_path}")
logger.info(f"Saved scalers to {scaler_path}")
# 4. 更新数据库中的模型记录
with get_db_connection() as conn:
cursor = conn.cursor()
# 将所有模型设置为非激活
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s
""", (equipment_type,))
# 获取 PLS 模型的评估指标
pls_metrics = self._calculate_metrics(
self.models['pls'],
X_train,
y_train,
X_val,
y_val
)
# 保存最佳机器学习模型记录
self.best_model.equipment_type = equipment_type # 设置装备类型
ml_feature_importance = self._get_feature_importance(self.best_model)
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path, scaler_path,
r2_score, mae, rmse, feature_importance, training_data_size,
training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
f"{equipment_type}_{timestamp}", # model_name
best_model_info['type'], # model_type
equipment_type, # equipment_type
f"{model_path}.{model_format}", # model_path
scaler_path, # scaler_path
best_model_info['r2'], # r2_score
best_model_info['mae'], # mae
best_model_info['rmse'], # rmse
json.dumps(ml_feature_importance), # feature_importance
len(X_train), # training_data_size
'system' # created_by
))
# 保存 PLS 模型记录
pls_model.equipment_type = equipment_type # 设置装备类型
pls_feature_importance = self._get_feature_importance(pls_model)
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path, scaler_path,
r2_score, mae, rmse, feature_importance, training_data_size,
training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
f"{equipment_type}_{timestamp}_pls", # model_name
'pls', # model_type
equipment_type, # equipment_type
pls_path, # model_path
scaler_path, # scaler_path
float(pls_metrics['validation']['r2']), # r2_score
float(pls_metrics['validation']['mae']), # mae
float(pls_metrics['validation']['rmse']), # rmse
json.dumps(pls_feature_importance), # feature_importance
len(X_train), # training_data_size
'system' # created_by
))
conn.commit()
except Exception as e:
logger.error(f"Error saving models: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
def load_model(self, equipment_type, model_type='ml'):
"""
加载已训练的模型
"""
try:
logger.info(f"Loading {model_type} model for {equipment_type}")
# 从数据库获取激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 构建查询语句
if model_type == 'pls':
query = """
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type = 'pls'
AND is_active = TRUE
LIMIT 1
"""
params = (equipment_type,)
else:
query = """
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type != 'pls'
AND is_active = TRUE
LIMIT 1
"""
params = (equipment_type,)
# 记录查询信息
logger.info(f"Executing query: {query}")
logger.info(f"Query params: {params}")
cursor.execute(query, params)
model_record = cursor.fetchone()
# 记录查询结果
if model_record:
logger.info(f"Found model record: {model_record}")
else:
logger.warning(f"No active model found for type {model_type}")
return False
# 检查文件是否存在
logger.info(f"Checking model file: {model_record['model_path']}")
logger.info(f"Checking scaler file: {model_record['scaler_path']}")
if not os.path.exists(model_record['model_path']):
logger.error(f"Model file not found: {model_record['model_path']}")
raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
if not os.path.exists(model_record['scaler_path']):
logger.error(f"Scaler file not found: {model_record['scaler_path']}")
raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
# 加载模型文件
logger.info(f"Loading model from {model_record['model_path']}")
if model_type == 'pls':
self.best_model = joblib.load(model_record['model_path'])
logger.info("Loaded PLS model")
else:
if model_record['model_type'] == 'xgboost':
self.best_model = xgb.XGBRegressor()
self.best_model.load_model(model_record['model_path'])
logger.info("Loaded XGBoost model")
else:
self.best_model = joblib.load(model_record['model_path'])
logger.info(f"Loaded {model_record['model_type']} model")
# 加载标准化器
logger.info(f"Loading scalers from {model_record['scaler_path']}")
scalers = joblib.load(model_record['scaler_path'])
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
logger.info("Loaded scalers successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.error(f"Detailed traceback:", exc_info=True)
return False
def predict(self, features):
"""
使用加载的模型进行预测
"""
try:
if not self.best_model:
raise ValueError("No model loaded")
if not self.feature_scaler:
raise ValueError("Feature scaler not loaded")
if not self.target_scaler:
raise ValueError("Target scaler not loaded")
logger.info("Starting prediction")
logger.info(f"Input features shape: {features.shape}")
logger.info(f"Input features: \n{features}")
# 处理缺失值
features_filled = np.array(features, dtype=float)
features_filled[np.isnan(features_filled)] = 0
features_filled = np.nan_to_num(features_filled, 0)
logger.info(f"Filled features: \n{features_filled}")
# 标准化特征
X = self.feature_scaler.transform(features_filled)
logger.info(f"Transformed features shape: {X.shape}")
logger.info(f"Transformed features: \n{X}")
# 预测
y_pred_scaled = self.best_model.predict(X)
logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
logger.info(f"Scaled prediction: {y_pred_scaled}")
# <20><>标准化
y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
logger.info(f"Final prediction shape: {y_pred.shape}")
logger.info(f"Final prediction: {y_pred}")
# 记录标准化器的参数
logger.info("Target scaler params:")
logger.info(f"Mean: {self.target_scaler.mean_}")
logger.info(f"Scale: {self.target_scaler.scale_}")
return y_pred.ravel()
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise
def _get_feature_importance(self, model):
"""
获取特征重要性
"""
try:
if not model:
return {}
# 获取特征名称
feature_analyzer = FeatureAnalysis()
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
# 获取特<E58F96><E789B9><EFBFBD>重要性
if hasattr(model, 'feature_importances_'):
importances = model.feature_importances_
elif hasattr(model, 'coef_'):
if len(model.coef_.shape) > 1: # 如果是二维数组
importances = np.abs(model.coef_[0]) # 取第一行
else:
importances = np.abs(model.coef_)
else:
return {}
# 创建特征重要性字典
importance_dict = {}
for name, importance in zip(feature_names, importances):
importance_dict[name] = float(importance) # 确保转换为 Python 标量
# 按重要性降序排序
sorted_dict = dict(sorted(
importance_dict.items(),
key=lambda x: x[1],
reverse=True
))
# 过滤掉重要性为0的特征
return {k: v for k, v in sorted_dict.items() if v > 0}
except Exception as e:
logger.error(f"Error getting feature importance: {str(e)}")
return {}
def _calculate_confidence_interval(self, prediction, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
lower = max(1000, interval[0]) # 最小值设为1000元
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
return [lower, upper]
except Exception as e:
logger.error(f"Error calculating confidence interval: {str(e)}")
# 如果计算失败返回基于20%的简单区间
lower = max(1000, prediction * 0.8)
upper = prediction * 1.2
return [lower, upper]
def get_model_type(self):
"""
获取当前模型的类型
"""
if isinstance(self.best_model, xgb.XGBRegressor):
return 'xgboost'
elif isinstance(self.best_model, lgb.LGBMRegressor):
return 'lightgbm'
elif isinstance(self.best_model, GradientBoostingRegressor):
return 'gbm'
elif isinstance(self.best_model, RandomForestRegressor):
return 'rf'
else:
return 'unknown'
def _get_pls_feature_importance(self):
"""
获取 PLS 模型的特征重要性
"""
try:
if not self.models['pls']:
return {}
# 获取特征名称
feature_analyzer = FeatureAnalysis()
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
# 获取 PLS 模型的系数作为特征重要性
pls_model = self.models['pls']
if hasattr(pls_model, 'coef_'):
# 使用绝对值作为重要性指标
importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
else:
return {}
# 创建特征重要性字典
importance_dict = {}
for name, importance in zip(feature_names, importances):
importance_dict[name] = float(importance) # 确保转换为 Python 标量
# 按重要性降序排序
sorted_dict = dict(sorted(
importance_dict.items(),
key=lambda x: x[1],
reverse=True
))
# 过滤掉重要性为0的特征
return {k: v for k, v in sorted_dict.items() if v > 0}
except Exception as e:
logger.error(f"Error getting PLS feature importance: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return {}

File diff suppressed because it is too large Load Diff

View File

@ -1,140 +0,0 @@
-- 如果数据库已存在则删除
DROP DATABASE IF EXISTS equipment_cost_db;
-- 创建数据库
CREATE DATABASE equipment_cost_db
DEFAULT CHARACTER SET utf8mb4
COLLATE utf8mb4_unicode_ci;
-- 使用数据库
USE equipment_cost_db;
-- 装备基本信息表
CREATE TABLE equipment (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100), -- 名称
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
manufacturer VARCHAR(100), -- 制造商
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 通用参数表
CREATE TABLE common_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
length_m FLOAT, -- 总长(m)
width_m FLOAT, -- 宽度(m)
height_m FLOAT, -- 高度(m)
weight_kg FLOAT, -- 重量(kg)
max_range_km FLOAT, -- 最大射程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 火箭炮特有参数表
CREATE TABLE rocket_artillery_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
firing_angle_horizontal FLOAT, -- 方向射界(度)
firing_angle_vertical FLOAT, -- 高低射界(度)
rocket_length_m FLOAT, -- 火箭弹长度(m)
rocket_diameter_mm FLOAT, -- 弹体直径(mm)
rocket_weight_kg FLOAT, -- 火箭弹重量(kg)
rate_of_fire FLOAT, -- 射速(发/分钟)
combat_weight_kg FLOAT, -- 战斗重量(kg)
speed_kmh FLOAT, -- 速度(km/h)
min_range_km FLOAT, -- 最小射程(km)
mobility_type VARCHAR(50), -- 行走方式
structure_layout VARCHAR(100), -- 结构布局
engine_model VARCHAR(100), -- 发动机型号
engine_params TEXT, -- 发动机参数
power_hp FLOAT, -- 功率(hp)
travel_range_km FLOAT, -- 行程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 巡飞弹特有参数表
CREATE TABLE loitering_munition_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
wingspan_m FLOAT, -- 翼展(m)
warhead_weight_kg FLOAT, -- 战斗部重量(kg)
max_speed_ms FLOAT, -- 最大速度(m/s)
cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
flight_time_min FLOAT, -- 巡飞时间(min)
warhead_type VARCHAR(50), -- 战斗部类型
launch_mode VARCHAR(50), -- 发射方式
folded_length_mm FLOAT, -- 折叠长度(mm)
folded_width_mm FLOAT, -- 折叠宽度(mm)
folded_height_mm FLOAT, -- 折叠高度(mm)
power_system VARCHAR(100), -- 动力装置
guidance_system VARCHAR(100), -- 制导体制
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 成本数据表
CREATE TABLE cost_data (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
actual_cost DECIMAL(15,2), -- 实际成本(元)
predicted_cost DECIMAL(15,2), -- 预测成本(元)
prediction_date TIMESTAMP, -- 预测日期
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 特殊参数表
CREATE TABLE custom_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
param_name VARCHAR(100), -- 参数名称
param_value VARCHAR(255), -- 参数值
param_unit VARCHAR(50), -- 参数单位
description TEXT, -- 参数说明
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_equipment_type ON equipment(type);
CREATE INDEX idx_equipment_name ON equipment(name);
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-- 数据集表
CREATE TABLE datasets (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL, -- 数据集名称
description TEXT, -- 数据集描述
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
-- 数据集-装备关联表
CREATE TABLE dataset_equipment (
dataset_id INT NOT NULL,
equipment_id INT NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
);
-- 训练模型表
CREATE TABLE trained_models (
id INT AUTO_INCREMENT PRIMARY KEY,
model_name VARCHAR(100) NOT NULL, -- 模型名称
model_type VARCHAR(50) NOT NULL, -- 模型类型
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
model_path VARCHAR(255) NOT NULL, -- 模型文件路径
scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径
r2_score FLOAT, -- R²分数
mae FLOAT, -- 平均绝对误差
rmse FLOAT, -- 均方根误差
feature_importance JSON, -- 特征重要性
training_data_size INT, -- 训练数据量
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
created_by VARCHAR(50) -- 创建者
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX idx_model_active ON trained_models(is_active);

View File

@ -1,191 +0,0 @@
import requests
import json
import logging
from datetime import datetime
import os
import sys
# 添加项目根目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 配置基本日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/test_api.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def test_api_endpoints():
"""
测试 API 各个端点
"""
base_url = 'http://localhost:5001/api'
try:
# 1. 测试根路由
logger.info("\n1. 测试 API 根路由")
response = requests.get(f'{base_url}/')
print_response(response, "API 根路由")
# 2. 测试机器学习预测接口
logger.info("\n2. 测试机器学习预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post(
f'{base_url}/predict',
json=predict_data
)
print_response(response, "机器学习预测")
# 3. 测试 PLS 预测接口
logger.info("\n3. 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print_response(response, "PLS 预测")
# 4. 测试特征分析接口
logger.info("\n4. 测试特征分析接口")
analysis_data = {
"dataset_id": 1, # 假设数据集 ID 为 1
"equipment_type": "巡飞弹"
}
response = requests.post(
f'{base_url}/analyze-features',
json=analysis_data
)
print_response(response, "特征分析")
# 5. 测试机器学习模型训练接口
logger.info("\n5. 测试机器学习模型训练接口")
training_data = {
"type": "巡飞弹",
"train_dataset_id": 1, # 训练数据集 ID
"validation_dataset_id": 2, # 验证数据集 ID
"models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型
}
response = requests.post(
f'{base_url}/train',
json=training_data
)
print_response(response, "模型训练")
# 6. 测试数据集相关接口
logger.info("\n6. 测试数据集相关接口")
# 6.1 获取数据集列表
response = requests.get(f'{base_url}/datasets')
print_response(response, "获取数据集列表")
# 6.2 获取可用的装备列表
response = requests.get(f'{base_url}/data')
equipment_data = response.json()
# 获取巡飞弹类型的装备ID
available_equipment_ids = []
if 'loitering_munition' in equipment_data:
available_equipment_ids = [
item['id']
for item in equipment_data['loitering_munition']
if item['id'] is not None
][:3] # 取前3个可用的ID
if not available_equipment_ids:
logger.warning("没有找到可用的装备ID跳过创建数据集测试")
else:
# 6.3 创建新数据集
new_dataset = {
"name": "测试数据集",
"description": "用于测试的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"equipment_ids": available_equipment_ids
}
logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}")
response = requests.post(
f'{base_url}/datasets',
json=new_dataset
)
print_response(response, "创建数据集")
# 7. 测试模型相关接口
logger.info("\n7. 测试模型相关接口")
# 7.1 获取模型列表
response = requests.get(f'{base_url}/models')
print_response(response, "获取模型列表")
# 7.2 获取最新模型
response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型")
# 8. 测试多模型预测接口
logger.info("\n8. 测试多模型预测接口")
response = requests.post(
f'{base_url}/predict/all',
json=predict_data
)
print_response(response, "多模型预测")
logger.info("所有测试完成")
except requests.exceptions.RequestException as e:
logger.error(f"API 请求错误: {str(e)}")
except Exception as e:
logger.error(f"测试过程中出现错误: {str(e)}")
def print_response(response, endpoint_name):
"""
打印响应结果
"""
try:
logger.info(f"\n=== {endpoint_name} 测试结果 ===")
logger.info(f"状态码: {response.status_code}")
if response.status_code == 200:
result = response.json()
logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}")
else:
logger.error(f"错误响应:\n{response.text}")
except Exception as e:
logger.error(f"处理响应时出错: {str(e)}")
if __name__ == "__main__":
try:
# 确保日志目录存在
os.makedirs('logs', exist_ok=True)
logger.info(f"=== API 测试开始 - {datetime.now()} ===")
test_api_endpoints()
logger.info(f"=== API 测试结束 - {datetime.now()} ===")
except Exception as e:
logger.error(f"测试执行失败: {str(e)}")

View File

@ -1,34 +0,0 @@
# 安装说明
## 1. 系统要求
- Linux 服务器 (推荐 Ubuntu 20.04+)
- Python 3.8+
- MySQL 8.0+
## 2. 安装部署
```bash
# 解压部署包
tar -xzf equipment_cost_prediction.tar.gz
cd equipment_cost_prediction
# 运行安装脚本
bash scripts/install.sh
# 修改配置文件
vim config/.env
# 启动服务
bash scripts/start.sh
```
## 3. 验证部署
```bash
# 检查服务状态
curl http://localhost:5001/api/
# 检查日志
tail -f logs/api.log
```

View File

@ -643,3 +643,11 @@ trainingResult.value = null
- Feature max_speed_ms missing rate: 77.78%
- Feature cruise_speed_kmh missing rate: 61.11%
- Feature flight_time_min missing rate: 33.33%
## 前端特征分析页面未正确显示相关性分析数据(常见问题)
- 确保相关性分析数据的正确格式化和返回
- 添加了详细的日志记录
- 增加了数据验证步骤
- 处理了可能的NaN值
- 确保所有数值都转换为Python原生类型使用float()

View File

@ -11,14 +11,23 @@
<div class="dataset-section">
<el-form :model="analysisForm" label-width="120px">
<el-form-item label="装备类型" required>
<el-select v-model="analysisForm.equipment_type" @change="handleEquipmentTypeChange">
<el-select
v-model="analysisForm.equipment_type"
@change="handleEquipmentTypeChange"
placeholder="请选择装备类型"
>
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<el-form-item label="选择数据集" required>
<el-select v-model="analysisForm.dataset_id" @change="handleDatasetChange">
<el-select
v-model="analysisForm.dataset_id"
@change="handleDatasetChange"
placeholder="请选择数据集"
:disabled="!analysisForm.equipment_type"
>
<el-option
v-for="dataset in availableDatasets"
:key="dataset.id"
@ -51,14 +60,29 @@
<!-- 特征重要性 -->
<h3>特征重要性</h3>
<div class="chart-container">
<div ref="importanceChartRef" style="width: 100%; height: 400px"></div>
<div ref="importanceChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 相关性分析 -->
<h3>相关性分析</h3>
<div class="chart-container">
<div ref="correlationChartRef" style="width: 100%; height: 500px"></div>
<div ref="correlationChartRef" style="width: 100%; height: 800px"></div>
</div>
<!-- 巡飞弹特有的图表 -->
<template v-if="analysisForm.equipment_type === '巡飞弹'">
<!-- 特征工程参数分析 -->
<h3>特征工程参数分析</h3>
<div class="chart-container">
<div ref="newFeatureChartRef" style="width: 100%; height: 600px"></div>
</div>
<!-- 发动机性能分析 -->
<h3>发动机性能与作战参数分析</h3>
<div class="chart-container">
<div ref="engineChartRef" style="width: 100%; height: 600px"></div>
</div>
</template>
</div>
</el-card>
</div>
@ -82,6 +106,10 @@ const analyzing = ref(false)
const analysisResult = ref(null)
const importanceChartRef = ref(null)
const correlationChartRef = ref(null)
const newFeatureChartRef = ref(null)
const engineChartRef = ref(null)
const newFeatureChart = ref(null)
const engineChart = ref(null)
//
const importanceChart = ref(null)
@ -103,31 +131,74 @@ watch(() => analysisResult.value, async (newResult) => {
//
const loadDatasets = async (type) => {
try {
// type
if (!type) {
availableDatasets.value = []
return
}
console.log('Loading datasets for type:', type)
const response = await axios.get(`${API_BASE_URL}/datasets`, {
params: { equipment_type: type, purpose: '训练' }
})
//
if (!response.data) {
console.warn('No datasets returned from API')
availableDatasets.value = []
return
}
console.log('Datasets loaded:', response.data)
availableDatasets.value = response.data
} catch (error) {
console.error('Error loading datasets:', error)
ElMessage.error('获取数据集列表失败')
availableDatasets.value = []
}
}
//
const handleEquipmentTypeChange = () => {
console.log('Equipment type changed to:', analysisForm.value.equipment_type)
//
analysisForm.value.dataset_id = null
selectedDataset.value = null
analysisResult.value = null
loadDatasets(analysisForm.value.equipment_type)
//
if (analysisForm.value.equipment_type) {
loadDatasets(analysisForm.value.equipment_type)
} else {
availableDatasets.value = []
}
}
//
const handleDatasetChange = async () => {
try {
//
if (!analysisForm.value.dataset_id) {
selectedDataset.value = null
analysisResult.value = null
return
}
console.log('Dataset changed to:', analysisForm.value.dataset_id)
const response = await axios.get(`${API_BASE_URL}/datasets/${analysisForm.value.dataset_id}`)
//
if (!response.data) {
throw new Error('获取数据集详情失败:服务器返回空数据')
}
selectedDataset.value = response.data
analysisResult.value = null
} catch (error) {
ElMessage.error('获取数据集详情失败')
console.error('Error getting dataset details:', error)
ElMessage.error(error.message || '获取数据集详情失败')
selectedDataset.value = null
}
}
@ -140,14 +211,69 @@ const startAnalysis = async () => {
analyzing.value = true
try {
//
console.log('Analysis request params:', {
dataset_id: analysisForm.value.dataset_id,
equipment_type: analysisForm.value.equipment_type
})
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
//
console.log('Raw API response:', response)
console.log('Response data type:', typeof response.data)
console.log('Response data:', response.data)
//
if (!response.data) {
throw new Error('API返回的数据为空')
}
//
analysisResult.value = response.data
console.log('Analysis completed, result:', analysisResult.value)
//
console.log('Analysis result after assignment:', {
value: analysisResult.value,
important_features: analysisResult.value?.important_features,
correlation_analysis: analysisResult.value?.correlation_analysis,
equipment_names: analysisResult.value?.equipment_names,
length_width_ratio: analysisResult.value?.length_width_ratio
})
//
if (analysisForm.value.equipment_type === '巡飞弹') {
const missileData = {
equipment_names: analysisResult.value?.equipment_names || [],
length_width_ratio: analysisResult.value?.length_width_ratio || [],
engine_power_kw: analysisResult.value?.engine_power_kw || [],
guidance_system_score: analysisResult.value?.guidance_system_score || [],
warhead_power_score: analysisResult.value?.warhead_power_score || []
}
console.log('Missile specific data:', missileData)
//
const missingFields = Object.entries(missileData)
.filter(([key, value]) => !Array.isArray(value) || value.length === 0)
.map(([key]) => key)
if (missingFields.length > 0) {
console.warn('Missing or empty missile data fields:', missingFields)
ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
}
}
} catch (error) {
ElMessage.error('特征分析失败')
console.error('Analysis error:', error)
console.error('Error details:', {
message: error.message,
response: error.response?.data,
status: error.response?.status
})
ElMessage.error(error.message || '特征析失败')
} finally {
analyzing.value = false
}
@ -166,6 +292,12 @@ const createResizeHandler = () => {
if (correlationChart.value && !correlationChart.value.isDisposed()) {
correlationChart.value.resize()
}
if (newFeatureChart.value && !newFeatureChart.value.isDisposed()) {
newFeatureChart.value.resize()
}
if (engineChart.value && !engineChart.value.isDisposed()) {
engineChart.value.resize()
}
} catch (error) {
console.error('Error in resize handler:', error)
}
@ -180,6 +312,11 @@ onMounted(() => {
// resize
resizeHandler.value = createResizeHandler()
window.addEventListener('resize', resizeHandler.value)
//
if (analysisForm.value.equipment_type) {
loadDatasets(analysisForm.value.equipment_type)
}
})
//
@ -190,131 +327,353 @@ onUnmounted(() => {
resizeHandler.value = null
}
//
try {
if (importanceChart.value) {
importanceChart.value.dispose()
importanceChart.value = null
//
[importanceChart, correlationChart, newFeatureChart, engineChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
try {
chart.value.dispose()
} catch (e) {
console.error('Error disposing chart:', e)
}
chart.value = null
}
if (correlationChart.value) {
correlationChart.value.dispose()
correlationChart.value = null
}
} catch (error) {
console.error('Error disposing charts:', error)
}
})
})
//
const renderCharts = () => {
console.log('Starting to render charts')
//
if (!analysisResult.value ||
!analysisResult.value.important_features ||
!analysisResult.value.correlation_analysis) {
console.log('Analysis result not ready')
if (!analysisResult.value) {
console.error('No analysis result available')
return
}
// DOM
if (!importanceChartRef.value || !correlationChartRef.value) {
console.log('Chart DOM elements not ready')
return
}
try {
//
if (importanceChart.value) {
importanceChart.value.dispose()
importanceChart.value = null
}
if (correlationChart.value) {
correlationChart.value.dispose()
correlationChart.value = null
}
//
importanceChart.value = echarts.init(importanceChartRef.value)
correlationChart.value = echarts.init(correlationChartRef.value)
//
const importanceOption = {
title: { text: '特征重要性排序' },
tooltip: {},
xAxis: {
type: 'value',
name: '重要性得分'
},
yAxis: {
type: 'category',
data: analysisResult.value.important_features.map(f => f.name)
},
series: [{
name: '重要性',
type: 'bar',
data: analysisResult.value.important_features.map(f => f.importance)
}]
}
const correlationOption = {
title: { text: '特征相关性热力图' },
tooltip: {
position: 'top',
formatter: function (params) {
const value = params.data[2].toFixed(2)
const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]]
const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]]
return `${feature1}${feature2} 的相关性: ${value}`
}
},
grid: {
height: '50%',
top: '10%'
},
xAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true },
axisLabel: {
interval: 0,
rotate: 45
}
},
yAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true }
},
visualMap: {
min: -1,
max: 1,
calculable: true,
orient: 'horizontal',
left: 'center',
bottom: '15%',
color: ['#cc3333', '#eeeeee', '#00007f']
},
series: [{
name: '相关性',
type: 'heatmap',
data: analysisResult.value.correlation_analysis.matrix,
label: {
show: true,
formatter: function(params) {
return params.data[2].toFixed(2)
//
[importanceChart, correlationChart, newFeatureChart, engineChart].forEach(chart => {
if (chart.value && !chart.value.isDisposed()) {
chart.value.dispose()
chart.value = null
}
})
// DOM
nextTick(() => {
try {
//
if (importanceChartRef.value && correlationChartRef.value) {
importanceChart.value = echarts.init(importanceChartRef.value)
correlationChart.value = echarts.init(correlationChartRef.value)
//
const importanceOption = {
title: { text: '特征重要性排序' },
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'shadow'
},
formatter: function(params) {
const data = params[0]
return `${data.name}: ${data.value.toFixed(4)}`
}
},
xAxis: {
type: 'value',
name: '重要性得分'
},
yAxis: {
type: 'category',
data: analysisResult.value.important_features.map(f => f.name)
},
series: [{
name: '重要性',
type: 'bar',
data: analysisResult.value.important_features.map(f => f.importance),
itemStyle: {
color: '#3a5fcd'
}
}]
}
const correlationOption = {
title: { text: '特征相关性热力图' },
tooltip: {
position: 'top',
trigger: 'item',
formatter: function (params) {
if (!params.data) return ''
const value = params.data[2].toFixed(2)
const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]]
const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]]
return `${feature1}${feature2}<br/>相关性: ${value}`
}
},
grid: {
height: '75%',
top: '10%',
bottom: '15%',
left: '10%',
right: '10%',
containLabel: true
},
xAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true },
axisLabel: {
interval: 0,
rotate: 45,
margin: 15
}
},
yAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true },
axisLabel: {
interval: 0,
margin: 15
}
},
visualMap: {
min: -1,
max: 1,
calculable: true,
orient: 'horizontal',
left: 'center',
bottom: '5%',
inRange: {
color: ['#cc3333', '#eeeeee', '#00007f']
}
},
series: [{
name: '相关性',
type: 'heatmap',
data: analysisResult.value.correlation_analysis.matrix,
emphasis: {
itemStyle: {
shadowBlur: 10,
shadowColor: 'rgba(0, 0, 0, 0.5)'
}
},
label: {
show: true,
formatter: function(params) {
return params.data[2].toFixed(2)
}
},
itemStyle: {
borderWidth: 1,
borderColor: '#fff'
}
}]
}
// 使 clear
importanceChart.value.clear()
correlationChart.value.clear()
//
importanceChart.value.setOption(importanceOption, { notMerge: true })
correlationChart.value.setOption(correlationOption, { notMerge: true })
}
}]
}
//
importanceChart.value.setOption(importanceOption)
correlationChart.value.setOption(correlationOption)
console.log('Charts rendered successfully')
//
if (analysisForm.value.equipment_type === '巡飞弹' &&
newFeatureChartRef.value &&
engineChartRef.value) {
//
newFeatureChart.value = echarts.init(newFeatureChartRef.value)
engineChart.value = echarts.init(engineChartRef.value)
//
const chartData = {
names: analysisResult.value.equipment_names || [],
lengthWidthRatio: analysisResult.value.length_width_ratio || [],
weightRangeRatio: analysisResult.value.weight_range_ratio || [],
speedWeightRatio: analysisResult.value.speed_weight_ratio || [],
guidanceSystemScore: analysisResult.value.guidance_system_score || [],
warheadPowerScore: analysisResult.value.warhead_power_score || [],
enginePowerKw: analysisResult.value.engine_power_kw || [],
engineThrustN: analysisResult.value.engine_thrust_n || [],
minAltitudeM: analysisResult.value.min_altitude_m || [],
maxAltitudeM: analysisResult.value.max_altitude_m || []
}
//
const newFeatureOption = {
animation: false, //
title: {
text: '特征工程参数分析',
left: 'center'
},
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'cross'
},
formatter: function(params) {
//
const equipmentName = chartData.names[params[0].dataIndex]
let result = `${equipmentName}<br/>`
//
params.forEach(param => {
result += `${param.seriesName}: ${param.value.toFixed(2)}<br/>`
})
return result
}
},
legend: {
top: 30,
data: ['长宽比', '重量射程比', '速度重量比', '制导系统评分', '战斗部威力评分']
},
grid: {
top: 80,
bottom: 50,
containLabel: true
},
xAxis: {
type: 'category',
data: chartData.names
},
yAxis: [
{
type: 'value',
name: '比率',
position: 'left'
},
{
type: 'value',
name: '评分',
position: 'right',
min: 0,
max: 10
}
],
series: [
{
name: '长宽比',
type: 'line',
data: chartData.lengthWidthRatio
},
{
name: '重量射程比',
type: 'line',
data: chartData.weightRangeRatio
},
{
name: '速度重量比',
type: 'line',
data: chartData.speedWeightRatio
},
{
name: '制导系统评分',
type: 'bar',
yAxisIndex: 1,
data: chartData.guidanceSystemScore
},
{
name: '战斗部威力评分',
type: 'bar',
yAxisIndex: 1,
data: chartData.warheadPowerScore
}
]
}
//
const engineOption = {
animation: false, //
title: {
text: '发动机性能与作战参数分析',
left: 'center'
},
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'cross'
},
formatter: function(params) {
//
const equipmentName = chartData.names[params[0].dataIndex]
let result = `${equipmentName}<br/>`
//
params.forEach(param => {
result += `${param.seriesName}: ${param.value.toFixed(2)}<br/>`
})
return result
}
},
legend: {
top: 30,
data: ['发动机功率(kw)', '发动机推力(N)', '最小作战高度(m)', '最大作战高度(m)']
},
grid: {
top: 80,
bottom: 50,
containLabel: true
},
xAxis: {
type: 'category',
data: chartData.names
},
yAxis: [
{
type: 'value',
name: '功率/推力',
position: 'left'
},
{
type: 'value',
name: '高度(m)',
position: 'right'
}
],
series: [
{
name: '发动机功率(kw)',
type: 'bar',
data: chartData.enginePowerKw
},
{
name: '发动机推力(N)',
type: 'bar',
data: chartData.engineThrustN
},
{
name: '最小作战高度(m)',
type: 'line',
yAxisIndex: 1,
data: chartData.minAltitudeM
},
{
name: '最大作战高度(m)',
type: 'line',
yAxisIndex: 1,
data: chartData.maxAltitudeM
}
]
}
//
newFeatureChart.value.clear()
engineChart.value.clear()
//
newFeatureChart.value.setOption(newFeatureOption, { notMerge: true })
engineChart.value.setOption(engineOption, { notMerge: true })
}
console.log('Charts rendered successfully')
} catch (error) {
console.error('Error in chart rendering:', error)
}
})
} catch (error) {
console.error('Error rendering charts:', error)
console.error('Error in renderCharts:', error)
}
}

View File

@ -85,23 +85,24 @@
<!-- 数据表格 -->
<el-table :data="filteredMissileData" border style="width: 100%">
<el-table-column prop="name" label="名称" sortable></el-table-column>
<el-table-column prop="manufacturer" label="制造" sortable></el-table-column>
<el-table-column prop="manufacturer" label="制造" sortable></el-table-column>
<el-table-column prop="length_m" label="弹长(m)" sortable></el-table-column>
<el-table-column prop="weight_kg" label="重量(kg)" sortable></el-table-column>
<el-table-column prop="max_range_km" label="最大射程(km)" sortable></el-table-column>
<el-table-column prop="max_speed_ms" label="最大速度(m/s)" sortable></el-table-column>
<el-table-column prop="flight_time_min" label="巡飞时间(min)" sortable></el-table-column>
<el-table-column prop="endurance_min" label="续航时间(min)" sortable></el-table-column>
<el-table-column prop="warhead_weight_kg" label="战斗部重量(kg)" sortable></el-table-column>
<el-table-column prop="guidance_system" label="制导系统"></el-table-column>
<el-table-column prop="actual_cost" label="成本(元)" sortable>
<template #default="scope">
{{ formatMoney(scope.row.actual_cost) }}
</template>
</el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDetails(scope.row)">
详情
</el-button>
<el-button size="small" type="primary" @click="editData(scope.row)">
编辑
</el-button>
<el-button size="small" type="danger" @click="deleteData(scope.row)">
删除
</el-button>
<el-button size="small" @click="viewDetails(scope.row)">详情</el-button>
<el-button size="small" type="primary" @click="editData(scope.row)">编辑</el-button>
<el-button size="small" type="danger" @click="deleteData(scope.row)">删除</el-button>
</template>
</el-table-column>
</el-table>
@ -162,7 +163,7 @@
<el-descriptions-item label="最大速度(m/s)">{{ formatNumber(selectedData?.max_speed_ms) }}</el-descriptions-item>
<el-descriptions-item label="巡航速度(km/h)">{{ formatNumber(selectedData?.cruise_speed_kmh) }}</el-descriptions-item>
<el-descriptions-item label="巡飞时间(min)">{{ formatNumber(selectedData?.flight_time_min) }}</el-descriptions-item>
<el-descriptions-item label="续航时间(min)">{{ formatNumber(selectedData?.endurance_min) }}</el-descriptions-item>
<el-descriptions-item label="战斗部类型">{{ selectedData?.warhead_type }}</el-descriptions-item>
<el-descriptions-item label="发射方式">{{ selectedData?.launch_mode }}</el-descriptions-item>
<el-descriptions-item label="动力装置">{{ selectedData?.power_system }}</el-descriptions-item>
@ -291,8 +292,8 @@
<el-form-item label="巡航速度(km/h)">
<el-input-number v-model="editForm.cruise_speed_kmh" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡飞时间(min)">
<el-input-number v-model="editForm.flight_time_min" :precision="0"></el-input-number>
<el-form-item label="续航时间(min)">
<el-input-number v-model="editForm.endurance_min" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="战斗部类型">
<el-select v-model="editForm.warhead_type">

View File

@ -55,36 +55,56 @@
<!-- 巡飞弹特有参数 -->
<template v-if="formData.type === '巡飞弹'">
<el-form-item label="最大速度(km/h)">
<el-input-number v-model="formData.max_speed_kmh"></el-input-number>
<el-form-item label="翼展(m)">
<el-input-number v-model="formData.wingspan_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="战斗部重量(kg)">
<el-input-number v-model="formData.warhead_weight_kg" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="最大速度(m/s)">
<el-input-number v-model="formData.max_speed_ms" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡航速度(km/h)">
<el-input-number v-model="formData.cruise_speed_kmh"></el-input-number>
<el-input-number v-model="formData.cruise_speed_kmh" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡飞时间(min)">
<el-input-number v-model="formData.flight_time_min"></el-input-number>
<el-form-item label="续航时间(min)">
<el-input-number v-model="formData.endurance_min" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="战斗部类型">
<el-select v-model="formData.warhead_type">
<el-option label="破片杀伤战斗部" value="破片杀伤战斗部"></el-option>
<el-option label="破甲战斗部" value="破甲战斗部"></el-option>
<el-option label="高爆战斗部" value="高爆战斗部"></el-option>
<el-option label="破片杀伤/破甲双用战斗部" value="破片杀伤/破甲双用战斗部"></el-option>
<el-option label="模块化战斗部" value="模块化战斗部"></el-option>
</el-select>
</el-form-item>
<el-form-item label="发射方式">
<el-select v-model="formData.launch_mode">
<el-option label="箱式发射" value="箱式发射"></el-option>
<el-option label="凭自身动力起飞" value="凭自身动力起飞"></el-option>
<el-option label="弹射式发射" value="弹射式发射"></el-option>
<el-option label="垂直起降" value="垂直起降"></el-option>
<el-option label="单兵发射管" value="单兵发射管"></el-option>
<el-option label="箱式发射/弹射式" value="箱式发射/弹射式"></el-option>
<el-option label="箱式发射/空中发射" value="箱式发射/空中发射"></el-option>
</el-select>
</el-form-item>
<el-form-item label="折叠长度(mm)">
<el-input-number v-model="formData.folded_length_mm"></el-input-number>
<el-form-item label="动力装置">
<el-select v-model="formData.power_system">
<el-option label="电动机" value="电动机"></el-option>
<el-option label="活塞发动机" value="活塞发动机"></el-option>
</el-select>
</el-form-item>
<el-form-item label="折叠宽度(mm)">
<el-input-number v-model="formData.folded_width_mm"></el-input-number>
</el-form-item>
<el-form-item label="折叠高度(mm)">
<el-input-number v-model="formData.folded_height_mm"></el-input-number>
<el-form-item label="制导系统">
<el-select v-model="formData.guidance_system">
<el-option label="GPS/INS" value="GPS/INS"></el-option>
<el-option label="GPS/INS/光电" value="GPS/INS/光电"></el-option>
<el-option label="GPS/INS/光电/数据链" value="GPS/INS/光电/数据链"></el-option>
<el-option label="GPS/INS/光电/AI识别" value="GPS/INS/光电/AI识别"></el-option>
<el-option label="GPS/INS/光电/数据链/AI辅助" value="GPS/INS/光电/数据链/AI辅助"></el-option>
<el-option label="GPS/INS/光电/数据链/AI辅助/红外" value="GPS/INS/光电/数据链/AI辅助/红外"></el-option>
<el-option label="GPS/INS/光电/数据链/AI辅助/卫通" value="GPS/INS/光电/数据链/AI辅助/卫通"></el-option>
</el-select>
</el-form-item>
</template>
@ -154,7 +174,16 @@ const formData = reactive({
width_m: null,
height_m: null,
weight_kg: null,
max_range_km: null
max_range_km: null,
wingspan_m: null,
warhead_weight_kg: null,
max_speed_ms: null,
cruise_speed_kmh: null,
endurance_min: null,
warhead_type: '',
launch_mode: '',
power_system: '',
guidance_system: ''
})
const predictionResults = ref(null)
@ -171,14 +200,15 @@ const handleTypeChange = () => {
formData.rocket_weight_kg = null
formData.rate_of_fire = null
} else if (formData.type === '巡飞弹') {
formData.max_speed_kmh = null
formData.wingspan_m = null
formData.warhead_weight_kg = null
formData.max_speed_ms = null
formData.cruise_speed_kmh = null
formData.flight_time_min = null
formData.endurance_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.folded_length_mm = null
formData.folded_width_mm = null
formData.folded_height_mm = null
formData.power_system = ''
formData.guidance_system = ''
}
}
@ -210,8 +240,8 @@ const submitForm = async () => {
}
} else if (formData.type === '巡飞弹') {
const missileFields = [
'max_speed_kmh', 'cruise_speed_kmh', 'flight_time_min',
'folded_length_mm', 'folded_width_mm', 'folded_height_mm'
'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
'endurance_min', 'warhead_type', 'launch_mode', 'power_system', 'guidance_system'
]
for (const field of missileFields) {
if (!formData[field]) {

View File

@ -1,269 +1,192 @@
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
import logging
from .logger import setup_logger
logger = setup_logger(__name__)
logger = logging.getLogger(__name__)
class FeatureAnalysis:
def __init__(self):
self.scaler = StandardScaler()
self.important_features = []
# 添加特征名称映射
self.feature_names_map = {
# 通用参数
'length_m': '总长(m)',
self.feature_name_map = {
'length_m': '长度(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
# 火箭炮特有参数
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
'rocket_diameter_mm': '口径(mm)',
'rocket_weight_kg': '火箭弹重量(kg)',
'rate_of_fire': '射速(发/分)',
'combat_weight_kg': '战斗重量(kg)',
'speed_kmh': '速度(km/h)',
'min_range_km': '最小射程(km)',
'power_hp': '功率(hp)',
# 火箭炮衍生特征
'fire_density': '火力密度',
'mobility_index': '机动性指标',
'range_ratio': '射程比',
'power_weight_ratio': '功重比',
'volume_density': '体积密度',
# 巡飞弹特有参数
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'warhead_weight_kg': '弹头重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
'flight_time_min': '巡飞时间(min)',
'folded_length_mm': '折叠长度(mm)',
'folded_width_mm': '折叠宽度(mm)',
'folded_height_mm': '折叠高度(mm)',
# 巡飞弹衍生特征
'warhead_ratio': '战斗部比重',
'speed_ratio': '速度比',
'range_time_ratio': '射程时间比',
'aspect_ratio': '展弦比',
'volume_density': '体积密度'
'endurance_min': '续航时间(min)',
'payload_weight_kg': '载荷重量(kg)',
'min_combat_radius_km': '最小作战半径(km)',
'engine_power_kw': '发动机功率(kw)',
'engine_thrust_n': '发动机推力(N)',
'datalink_range_km': '数据链距离(km)',
'guidance_accuracy_m': '制导精度(m)',
'min_altitude_m': '最小飞行高度(m)',
'max_altitude_m': '最大飞行高度(m)',
'length_width_ratio': '长宽比',
'weight_range_ratio': '重量射程比',
'speed_weight_ratio': '速度重量比',
'guidance_system_score': '制导系统评分',
'warhead_power_score': '战斗部威力评分',
'firing_angle_horizontal': '水平射角(°)',
'firing_angle_vertical': '垂直射角(°)',
'rocket_length_m': '火箭长度(m)',
'rocket_diameter_mm': '火箭直径(mm)',
'rocket_weight_kg': '火箭重量(kg)'
}
def get_equipment_specific_features(self, equipment_type):
"""
获取特定装备类型的特征列表
"""
# 通用参数
"""获取特定装备类型的特征列表"""
common_features = [
'length_m', # 总长(m)
'width_m', # 宽度(m)
'height_m', # 高度(m)
'weight_kg', # 重量(kg)
'max_range_km' # 最大射程(km)
'length_m', 'width_m', 'height_m',
'weight_kg', 'max_range_km'
]
if equipment_type == '火箭炮':
# 火箭炮特有参数
specific_features = [
'firing_angle_horizontal', # 方向射界(度)
'firing_angle_vertical', # 高低射界(度)
'rocket_length_m', # 火箭弹长度(m)
'rocket_diameter_mm', # 口径(mm)
'rocket_weight_kg', # 火箭弹重量(kg)
'rate_of_fire', # 射速(发/分)
'combat_weight_kg', # 战斗重量(kg)
'speed_kmh', # 速度(km/h)
'min_range_km', # 最小射程(km)
'power_hp' # 功率(hp)
return common_features + [
'firing_angle_horizontal',
'firing_angle_vertical',
'rocket_length_m',
'rocket_diameter_mm',
'rocket_weight_kg'
]
# 火箭炮衍生特征
derived_features = [
'fire_density', # 火力密度 = 射速 * 火箭弹重量
'mobility_index', # 机动性指标 = 速度 / 战斗重量
'range_ratio', # 射程比 = 最大射程 / 最小射程
'power_weight_ratio', # 功重比 = 功率 / 战斗重量
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
elif equipment_type == '巡飞弹':
return common_features + [
'wingspan_m',
'warhead_weight_kg',
'max_speed_ms',
'cruise_speed_kmh',
'endurance_min',
'payload_weight_kg',
'min_combat_radius_km',
'engine_power_kw',
'engine_thrust_n',
'datalink_range_km',
'guidance_accuracy_m',
'min_altitude_m',
'max_altitude_m',
'length_width_ratio',
'weight_range_ratio',
'speed_weight_ratio',
'guidance_system_score',
'warhead_power_score'
]
return common_features + specific_features + derived_features
else: # 巡飞弹
# 巡飞弹特有参数
specific_features = [
'wingspan_m', # 翼展(m)
'warhead_weight_kg', # 战斗部重量(kg)
'max_speed_ms', # 最大速度(m/s)
'cruise_speed_kmh', # 巡航速度(km/h)
'flight_time_min', # 巡飞时间(min)
'folded_length_mm', # 折叠长度(mm)
'folded_width_mm', # 折叠宽度(mm)
'folded_height_mm' # 折叠高度(mm)
]
# 巡飞弹衍生特征
derived_features = [
'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
'speed_ratio', # 速度比 = 巡航速度 / 最大速度
'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
if equipment_type == '火箭炮':
# 火箭炮衍生特征计算
if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
else:
data['fire_density'] = 0 # 或者其他默认值
if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
else:
data['mobility_index'] = 0
if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
data['range_ratio'] = data['max_range_km'] / data['min_range_km']
else:
data['range_ratio'] = 0
if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
else:
data['power_weight_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
else: # 巡飞弹
# 巡飞弹衍生特征计算
if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
else:
data['warhead_ratio'] = 0
if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
else:
data['speed_ratio'] = 0
if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
else:
data['range_time_ratio'] = 0
if 'wingspan_m' in data.columns and 'length_m' in data.columns:
data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
else:
data['aspect_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
return data
except Exception as e:
logger.error(f"Error calculating derived features: {str(e)}")
raise
return common_features
def analyze_features(self, features, target, feature_names):
"""
分析特征重要性和相关性
"""
"""分析特征重要性和相关性"""
try:
# 转换为numpy数组
X = np.array(features)
y = np.array(target)
X = np.array(features, dtype=np.float64) # 明确指定数据类型
y = np.array(target, dtype=np.float64)
# 数据标准化
# 打印原始数据的统计信息
logger.info("Feature statistics before scaling:")
for i, name in enumerate(feature_names):
feature_data = X[:, i]
logger.info(f"{self.feature_name_map.get(name, name)}: "
f"min={np.min(feature_data)}, "
f"max={np.max(feature_data)}, "
f"mean={np.mean(feature_data)}, "
f"null_count={np.sum(np.isnan(feature_data))}")
# 处理可能的无穷大和NaN值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 标准化特征
X_scaled = self.scaler.fit_transform(X)
# 特征重要性分析
rf = RandomForestRegressor(n_estimators=100, random_state=42)
rf.fit(X_scaled, y)
importances = rf.feature_importances_
selector = SelectKBest(score_func=f_regression, k='all')
selector.fit(X_scaled, y)
importance_scores = selector.scores_
# 按重要性排序,使用中文特征名
importance_indices = np.argsort(importances)[::-1]
important_features = [
{
'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
'importance': float(importances[i])
}
for i in importance_indices
]
# 计算相关性矩阵前检查特征的方差
X = np.array(features, dtype=np.float64)
feature_std = np.std(X, axis=0)
constant_features = []
# 相关性分析
df = pd.DataFrame(X_scaled, columns=feature_names)
correlation_matrix = df.corr().values
# 记录标准差为0的特征
for i, (name, std) in enumerate(zip(feature_names, feature_std)):
if std == 0:
logger.warning(f"Feature '{self.feature_name_map.get(name, name)}' has zero standard deviation "
f"(constant value: {X[0, i]})")
constant_features.append(name)
# 生成相关性分析数据保留2位小数
# 计算相关性矩阵
correlation_matrix = np.corrcoef(X.T)
# 处理相关性矩阵中的无效值
correlation_data = []
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
chinese_feature_names = [self.feature_name_map.get(name, name) for name in feature_names]
for i in range(len(feature_names)):
for j in range(len(feature_names)):
correlation_data.append([
i, j,
round(correlation_matrix[i][j], 2) # 修改为保留2位小数
])
corr_value = correlation_matrix[i, j]
if np.isnan(corr_value):
# 如果是常量特征,设置相关系数
if feature_names[i] in constant_features or feature_names[j] in constant_features:
if i == j:
# 自身相关性设为1
corr_value = 1.0
else:
# 与其他特征的相关性设为0
corr_value = 0.0
logger.info(f"Setting correlation for constant feature: "
f"{chinese_feature_names[i]} vs {chinese_feature_names[j]} = {corr_value}")
correlation_data.append([i, j, float(corr_value)])
return {
# 记录数据形状
logger.info(f"Features shape: {X.shape}")
logger.info(f"Target shape: {y.shape}")
logger.info(f"Correlation matrix shape: {correlation_matrix.shape}")
# 创建特征重要性列表(使用中文名称)
important_features = []
for idx, (name, score) in enumerate(zip(feature_names, importance_scores)):
if not np.isnan(score):
important_features.append({
'name': self.feature_name_map.get(name, name), # 使用中文名称
'importance': float(score)
})
# 按重要性排序
important_features.sort(key=lambda x: x['importance'], reverse=True)
# 返回结果
result = {
'important_features': important_features,
'correlation_analysis': {
'features': chinese_feature_names, # 使用中文特征名
'features': chinese_feature_names, # 使用中文名称
'matrix': correlation_data
}
}
except Exception as e:
logger.error(f"Error in feature analysis: {str(e)}")
raise
def preprocess_features(self, equipment_data, equipment_type):
"""
预处理特征数据
"""
try:
# 转换为 DataFrame
df = pd.DataFrame(equipment_data)
# 添加数据验证
logger.info("Correlation data validation:")
expected_pairs = len(feature_names) * len(feature_names)
actual_pairs = len(correlation_data)
logger.info(f"Expected correlation pairs: {expected_pairs}")
logger.info(f"Actual correlation pairs: {actual_pairs}")
if expected_pairs != actual_pairs:
logger.warning("Missing correlation pairs detected!")
# 计算衍生特征
df = self.calculate_derived_features(df, equipment_type)
# 验证返回的数据
logger.info("Validation of return data:")
logger.info(f"Has important_features: {bool(result['important_features'])}")
logger.info(f"Important features count: {len(result['important_features'])}")
logger.info(f"Has correlation_analysis: {bool(result['correlation_analysis'])}")
logger.info(f"Correlation features count: {len(result['correlation_analysis']['features'])}")
logger.info(f"Correlation matrix size: {len(result['correlation_analysis']['matrix'])}")
# 处理缺失值
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
# 转换为数值类型
df[col] = pd.to_numeric(df[col], errors='coerce')
# 使用新的方式填充缺失值
mean_value = df[col].mean()
df[col] = df[col].fillna(mean_value)
logger.info(f"Preprocessed data shape: {df.shape}")
return df
return result
except Exception as e:
logger.error(f"Error preprocessing features: {str(e)}")
raise Exception(f"Feature preprocessing error: {str(e)}")
logger.error(f"Error in analyze_features: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise

485
src/real_data.sql Normal file
View File

@ -0,0 +1,485 @@
-- 清空现有数据
SET FOREIGN_KEY_CHECKS=0;
TRUNCATE TABLE dataset_equipment;
TRUNCATE TABLE datasets;
TRUNCATE TABLE cost_data;
TRUNCATE TABLE loitering_munition_params;
TRUNCATE TABLE common_params;
TRUNCATE TABLE equipment;
SET FOREIGN_KEY_CHECKS=1;
-- 按系列插入装备数据确保ID连续
-- 1. HAROP/Harpy 系列 (ID: 1-3)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(1, 'IAI Harop', '巡飞弹', '以色列'),
(2, 'IAI Harpy', '巡飞弹', '以色列'),
(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
-- 2. Hero 系列 (ID: 4-9)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
(9, 'Hero-900', '巡飞弹', '以色列 UVision');
-- 3. Switchblade 系列 (ID: 10-13)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
-- 4. Warmate 系列 (ID: 14-18)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
-- 5. CH-901/902 系列 (ID: 19-23)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(19, 'CH-901', '巡飞弹', '中国航天科工'),
(20, 'CH-901A', '巡飞弹', '中国航天科工'),
(21, 'CH-901H', '巡飞弹', '中国航天科工'),
(22, 'CH-902', '巡飞弹', '中国航天科工'),
(23, 'CH-902A', '巡飞弹', '中国航天科工');
-- 6. WS-43/61 系列 (ID: 24-28)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(24, 'WS-43', '巡飞弹', '中国航天科工'),
(25, 'WS-43A', '巡飞弹', '中国航天科工'),
(26, 'WS-43B', '巡飞弹', '中国航天科工'),
(27, 'WS-61', '巡飞弹', '中国航天科工'),
(28, 'WS-61A', '巡飞弹', '中国航天科工');
-- 7. Kargu/Alpagu 系列 (ID: 29-33)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(29, 'Kargu', '巡飞弹', '土耳其 STM'),
(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
-- 8. Shahed 系列 (ID: 34-38)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(34, 'Shahed-131', '巡飞弹', '伊朗'),
(35, 'Shahed-131B', '巡飞弹', '伊朗'),
(36, 'Shahed-136', '巡飞弹', '伊朗'),
(37, 'Shahed-136B', '巡飞弹', '伊朗'),
(38, 'Shahed-136C', '巡飞弹', '伊朗');
-- 9. Green Dragon 系列 (ID: 39-43)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
-- 10. Phoenix Ghost 系列 (ID: 44-48)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
-- 11. ZALA Lancet 系列 (ID: 49-52)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
-- 12. Rotem L 系列 (ID: 53-56)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
-- 13. KUB-BLA 系列 (ID: 57-60)
INSERT INTO equipment (id, name, type, manufacturer) VALUES
(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
-- 插入通用参数
INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
(10, 0.58, 0.12, 0.12, 2.5, 10),
(11, 1.30, 0.22, 0.22, 15.0, 40),
(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
(14, 0.68, 0.12, 0.12, 2.5, 10),
(15, 1.30, 0.22, 0.22, 15.0, 40),
(16, 0.68, 0.12, 0.12, 2.5, 10),
(17, 1.30, 0.22, 0.22, 15.0, 40),
(18, 0.68, 0.12, 0.12, 2.5, 10),
(19, 1.2, 0.18, 0.18, 9.0, 20),
(20, 1.2, 0.18, 0.18, 9.3, 25),
(21, 1.2, 0.18, 0.18, 9.5, 20),
(22, 1.4, 0.22, 0.22, 15.0, 30),
(23, 1.4, 0.22, 0.22, 15.5, 35),
(24, 1.8, 0.35, 0.35, 20, 60),
(25, 1.8, 0.35, 0.35, 21, 70),
(26, 1.9, 0.35, 0.35, 22, 80),
(27, 2.2, 0.40, 0.40, 35, 100),
(28, 2.2, 0.40, 0.40, 37, 120),
(29, 0.6, 0.35, 0.35, 7.0, 10),
(30, 0.6, 0.35, 0.35, 7.2, 15),
(31, 1.0, 0.23, 0.23, 3.7, 5),
(32, 1.0, 0.23, 0.23, 3.9, 8),
(33, 0.6, 0.35, 0.35, 7.5, 15),
(34, 2.6, 0.34, 0.34, 135, 900),
(35, 2.6, 0.34, 0.34, 140, 1000),
(36, 3.5, 0.42, 0.42, 200, 2000),
(37, 3.5, 0.42, 0.42, 210, 2200),
(38, 3.5, 0.42, 0.42, 215, 2500),
(39, 1.5, 0.20, 0.20, 15, 40),
(40, 1.6, 0.20, 0.20, 16, 50),
(41, 1.5, 0.20, 0.20, 15.5, 45),
(42, 1.5, 0.20, 0.20, 15.8, 40),
(43, 1.2, 0.18, 0.18, 12, 30),
(44, 1.5, 0.25, 0.25, 14.0, 30),
(45, 1.5, 0.25, 0.25, 14.5, 35),
(46, 1.6, 0.26, 0.26, 15.0, 40),
(47, 1.5, 0.25, 0.25, 14.8, 30),
(48, 1.7, 0.27, 0.27, 16.0, 50),
(49, 1.0, 0.20, 0.20, 5.0, 40),
(50, 1.65, 0.35, 0.35, 12.0, 70),
(51, 1.65, 0.35, 0.35, 12.5, 80),
(52, 1.80, 0.40, 0.40, 15.0, 100),
(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
-- 插入特有参数
INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
endurance_min,
warhead_type,
launch_mode,
power_system,
guidance_system
) VALUES
-- HAROP/Harpy系列
(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
-- Hero系列
(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
-- Switchblade系列
(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
-- Warmate系列
(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
-- CH-901/902系列
(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
-- 插入成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 800000), -- IAI Harop
(2, 700000), -- IAI Harpy
(3, 350000), -- IAI Mini Harpy
(4, 70000), -- Hero-30
(5, 120000), -- Hero-70
(6, 150000), -- Hero-120
(7, 300000), -- Hero-250
(8, 400000), -- Hero-400EC
(9, 650000), -- Hero-900
(10, 60000), -- Switchblade 300
(11, 180000), -- Switchblade 600
(12, 75000), -- Switchblade 300 Block 10
(13, 200000), -- Switchblade 600 Extended Range
(14, 60000), -- Warmate 1.0
(15, 180000), -- Warmate 2.0
(16, 60000), -- Warmate-V
(17, 180000), -- Warmate-L
(18, 60000), -- Warmate 3.0
(19, 100000), -- CH-901
(20, 120000), -- CH-901A
(21, 130000), -- CH-901H
(22, 180000), -- CH-902
(23, 200000), -- CH-902A
(24, 120000), -- WS-43
(25, 150000), -- WS-43A
(26, 180000), -- WS-43B
(27, 300000), -- WS-61
(28, 350000), -- WS-61A
(29, 70000), -- Kargu
(30, 85000), -- Kargu-2
(31, 45000), -- Alpagu
(32, 55000), -- Alpagu Block-II
(33, 95000), -- Kargu Autonomous
(34, 20000), -- Shahed-131
(35, 25000), -- Shahed-131B
(36, 40000), -- Shahed-136
(37, 45000), -- Shahed-136B
(38, 50000), -- Shahed-136C
(39, 160000), -- Green Dragon
(40, 200000), -- Green Dragon Extended Range
(41, 180000), -- Green Dragon Block 2
(42, 190000), -- Green Dragon Maritime
(43, 140000), -- Green Dragon-S
(44, 150000), -- Phoenix Ghost
(45, 180000), -- Phoenix Ghost Block I
(46, 220000), -- Phoenix Ghost Block II
(47, 190000), -- Phoenix Ghost Maritime
(48, 250000), -- Phoenix Ghost-ER
(49, 80000), -- Lancet-1
(50, 150000), -- Lancet-3
(51, 180000), -- Lancet-3M
(52, 250000), -- Lancet-4
(53, 65000), -- Rotem L
(54, 85000), -- Rotem L-X
(55, 75000), -- Rotem L-M
(56, 95000), -- Rotem L-ER
(57, 95000), -- KUB-BLA
(58, 120000), -- KUB-BLA-E
(59, 110000), -- KUB-BLA-M
(60, 150000); -- KUB-BLA-ER
-- 创建数据集
INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
-- 关联装备到数据集(按照制造商和型号分配)
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 训练集约80%的数据48个型号
-- 以色列系列
(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
-- 美国系列
(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
-- 中国系列
(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
-- 波兰和土耳其系列
(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
-- 俄罗斯系列
(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
-- 验证集约20%的数据12个型号
-- 混合系列
(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
-- 添加分类特征编码
INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
-- 战斗部类型编码
('warhead_type', '破片杀伤战斗部', 1),
('warhead_type', '破甲战斗部', 2),
('warhead_type', '高爆战斗部', 3),
('warhead_type', '破片杀伤/破甲双用战斗部', 4),
('warhead_type', '模块化战斗部', 5),
-- 发射方式编码
('launch_mode', '箱式发射', 1),
('launch_mode', '弹射式发射', 2),
('launch_mode', '垂直起降', 3),
('launch_mode', '单兵发射管', 4),
('launch_mode', '箱式发射/弹射式', 5),
('launch_mode', '箱式发射/空中发射', 6),
-- 动力装置编码(按复杂度递增)
('power_system', '电动机', 1),
('power_system', '活塞发动机', 2),
-- 制导系统编码(按复杂度递增)
('guidance_system', 'GPS/INS', 1),
('guidance_system', 'GPS/INS/光电', 2),
('guidance_system', 'GPS/INS/光电/数据链', 3),
('guidance_system', 'GPS/INS/光电/AI识别', 4),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
UPDATE loitering_munition_params l
JOIN common_params c ON l.equipment_id = c.equipment_id
SET
-- 新增关键参数
l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
l.engine_power_kw =
CASE
WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
END,
l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
l.guidance_accuracy_m =
CASE
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
ELSE 3.0
END,
l.min_altitude_m = -- 最小作战高度
CASE
-- 大型巡飞弹(体型大、重量大)
WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
-- 中型巡飞弹
WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
-- 中小型巡飞弹
WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
-- 小型巡飞弹
WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
-- 超小型巡飞弹
WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
-- 其他型号使用默认值
ELSE 50
END,
l.max_altitude_m =
CASE
WHEN c.max_range_km > 500 THEN 5000
WHEN c.max_range_km > 100 THEN 3000
ELSE 1500
END,
-- 特征工程字段
l.length_width_ratio = c.length_m / c.width_m,
l.weight_range_ratio = c.weight_kg / c.max_range_km,
l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
l.guidance_system_score =
CASE
WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
ELSE 2
END,
l.warhead_power_score =
CASE
WHEN l.warhead_type = '模块化战斗部' THEN 10
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
WHEN l.warhead_type = '高爆战斗部' THEN 7
WHEN l.warhead_type = '破甲战斗部' THEN 6
WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
ELSE 4
END,
-- 分类特征编码
l.warhead_type_code =
CASE
WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
WHEN l.warhead_type = '破甲战斗部' THEN 2
WHEN l.warhead_type = '高爆战斗部' THEN 3
WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
WHEN l.warhead_type = '模块化战斗部' THEN 5
ELSE 0
END,
l.launch_mode_code =
CASE
WHEN l.launch_mode = '箱式发射' THEN 1
WHEN l.launch_mode = '弹射式发射' THEN 2
WHEN l.launch_mode = '垂直起降' THEN 3
WHEN l.launch_mode = '单兵发射管' THEN 4
WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
ELSE 0
END,
l.power_system_code =
CASE
WHEN l.power_system = '电动机' THEN 1
WHEN l.power_system = '活塞发动机' THEN 2
ELSE 0
END,
l.guidance_system_code =
CASE
WHEN l.guidance_system = 'GPS/INS' THEN 1
WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
ELSE 0
END;

View File

@ -135,7 +135,6 @@ def analyze_features():
logger.info(f"Dataset info: {dataset}")
# 创建特征分析实例
from src.feature_analysis import FeatureAnalysis
analyzer = FeatureAnalysis()
# 获取特征列表
@ -143,20 +142,46 @@ def analyze_features():
logger.info(f"Feature names: {feature_names}")
# 获取数据集中的装备数据
if dataset['equipment_type'] == '火箭炮':
if dataset['equipment_type'] == '巡飞弹':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
SELECT
e.name,
e.*,
cp.*,
lmp.*,
cd.actual_cost,
lmp.length_width_ratio,
lmp.weight_range_ratio,
lmp.speed_weight_ratio,
lmp.guidance_system_score,
lmp.warhead_power_score,
lmp.engine_power_kw,
lmp.engine_thrust_n,
lmp.min_altitude_m,
lmp.max_altitude_m
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (dataset_id,))
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
SELECT e.name,
e.*, cp.*, lmp.*,
cp.max_range_km,
lmp.length_width_ratio,
lmp.weight_range_ratio,
lmp.speed_weight_ratio,
lmp.guidance_system_score,
lmp.warhead_power_score,
lmp.engine_power_kw,
lmp.engine_thrust_n,
lmp.min_altitude_m,
lmp.max_altitude_m,
cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
@ -173,61 +198,52 @@ def analyze_features():
logger.warning("No valid equipment data found in dataset")
return jsonify({'error': '数据集没有有效的成本数据'}), 400
# 统计每个特征的缺失率
missing_rates = {}
for name in feature_names:
missing_count = sum(1 for item in equipment_data if item.get(name) is None)
missing_rate = missing_count / len(equipment_data)
missing_rates[name] = missing_rate
logger.info(f"Feature {name} missing rate: {missing_rate:.2%}")
# 过滤掉缺失率过高的特征
valid_features = [name for name in feature_names if missing_rates[name] < 0.7]
logger.info(f"Valid features after filtering: {valid_features}")
if len(valid_features) < 3: # 至少需要3个特征
return jsonify({'error': '有效特征数量不足'}), 400
# 计算每个特征的均值
feature_means = {}
for name in valid_features:
values = [float(item[name]) for item in equipment_data if item.get(name) is not None]
feature_means[name] = sum(values) / len(values) if values else 0
logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}")
# 准备特征和目标值
features = []
target = []
equipment_names = [] # 新增:存储装备名称
# 提取特征和目标值,使用均值填充缺失值
# 提取特征和目标值
for item in equipment_data:
feature_values = []
for name in valid_features:
equipment_names.append(item['name']) # 保存装备名称
for name in feature_names:
value = item.get(name)
try:
# 确保数值类型转换正确
feature_values.append(float(value) if value is not None else feature_means[name])
feature_values.append(float(value) if value is not None else 0)
except (ValueError, TypeError) as e:
logger.error(f"Error converting value for feature {name}: {value}")
logger.error(f"Error details: {str(e)}")
return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400
features.append(feature_values)
# 确保成本值是值类型
try:
target.append(float(item['actual_cost']))
except (ValueError, TypeError) as e:
logger.error(f"Error converting actual_cost: {item['actual_cost']}")
logger.error(f"Error details: {str(e)}")
return jsonify({'error': '成本值无法换为数值'}), 400
logger.info(f"Prepared {len(features)} feature vectors")
logger.info(f"First feature vector: {features[0] if features else None}")
logger.info(f"First target value: {target[0] if target else None}")
features.append(feature_values)
target.append(float(item['actual_cost']))
# 调用特征分析方法
result = analyzer.analyze_features(features, target, valid_features)
logger.info("Analysis completed successfully")
result = analyzer.analyze_features(features, target, feature_names)
# 如果是巡飞弹类型,添加额外的数据
if dataset['equipment_type'] == '巡飞弹':
missile_data = {
'equipment_names': equipment_names,
'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data],
'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data],
'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data],
'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data],
'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data],
'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data],
'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data],
'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data],
'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data]
}
# 验证数据完整性
for key, value in missile_data.items():
logger.info(f"{key} data length: {len(value)}")
logger.info(f"{key} sample data: {value[:3]}")
# 更新结果
result.update(missile_data)
return jsonify(result)
@ -428,148 +444,39 @@ def get_equipment_data():
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute('SET SESSION group_concat_max_len = 1000000')
# 先测试特殊参数查询
cursor.execute("""
SELECT equipment_id, param_name, param_value, param_unit
FROM custom_params
WHERE param_name IS NOT NULL
AND param_value IS NOT NULL
LIMIT 5
""")
test_params = cursor.fetchall()
logger.info(f"Test custom params: {test_params}")
# 获取火箭炮数据
logger.info("Fetching rocket artillery data...")
cursor.execute("""
SELECT
e.id,
e.name,
e.type,
e.manufacturer,
e.created_at,
cp.length_m,
cp.width_m,
cp.height_m,
cp.weight_kg,
cp.max_range_km,
rap.firing_angle_horizontal,
rap.firing_angle_vertical,
rap.rocket_length_m,
rap.rocket_diameter_mm,
rap.rocket_weight_kg,
rap.rate_of_fire,
rap.combat_weight_kg,
rap.speed_kmh,
rap.min_range_km,
rap.mobility_type,
rap.structure_layout,
rap.engine_model,
rap.engine_params,
rap.power_hp,
rap.travel_range_km,
cd.actual_cost,
(
SELECT COALESCE(
JSON_ARRAYAGG(
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
),
'[]'
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
SELECT e.*, cp.*, rap.*, cd.actual_cost, cd.predicted_cost
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.type = '火箭炮'
""")
rocket_artillery = cursor.fetchall()
logger.info(f"Found {len(rocket_artillery)} rocket artillery records")
if rocket_artillery:
logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}")
logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}")
rocket_data = cursor.fetchall()
logger.info(f"Found {len(rocket_data)} rocket artillery records")
# 获取巡飞弹数据
logger.info("Fetching missile data...")
cursor.execute("""
SELECT
e.id,
e.name,
e.type,
e.manufacturer,
e.created_at,
cp.length_m,
cp.width_m,
cp.height_m,
cp.weight_kg,
cp.max_range_km,
lmp.wingspan_m,
lmp.warhead_weight_kg,
lmp.max_speed_ms,
lmp.cruise_speed_kmh,
lmp.flight_time_min,
lmp.warhead_type,
lmp.launch_mode,
lmp.folded_length_mm,
lmp.folded_width_mm,
lmp.folded_height_mm,
lmp.power_system,
lmp.guidance_system,
cd.actual_cost,
(
SELECT COALESCE(
JSON_ARRAYAGG(
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
),
'[]'
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
SELECT e.*, cp.*, lmp.*, cd.actual_cost, cd.predicted_cost,
lmp.wingspan_m, lmp.warhead_weight_kg, lmp.max_speed_ms,
lmp.cruise_speed_kmh, lmp.endurance_min, lmp.warhead_type,
lmp.launch_mode, lmp.power_system, lmp.guidance_system
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.type = '巡飞弹'
""")
loitering_munition = cursor.fetchall()
logger.info(f"Found {len(loitering_munition)} missile records")
if loitering_munition:
logger.info(f"First missile: {loitering_munition[0]['name']}")
logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}")
missile_data = cursor.fetchall()
logger.info(f"Found {len(missile_data)} missile records")
# 处理 custom_params保为 NULL
for item in rocket_artillery + loitering_munition:
if item['custom_params'] is None:
item['custom_params'] = []
logger.debug(f"Set empty custom_params for equipment {item['id']}")
else:
logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params")
logger.info("Data fetching completed")
return jsonify({
'rocket_artillery': rocket_artillery,
'loitering_munition': loitering_munition
'rocket_artillery': rocket_data,
'loitering_munition': missile_data
})
except Exception as e:

View File

@ -60,17 +60,47 @@ CREATE TABLE loitering_munition_params (
warhead_weight_kg FLOAT, -- 战斗部重量(kg)
max_speed_ms FLOAT, -- 最大速度(m/s)
cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
flight_time_min FLOAT, -- 巡飞时间(min)
endurance_min FLOAT, -- 续航时间(min)
warhead_type VARCHAR(50), -- 战斗部类型
launch_mode VARCHAR(50), -- 发射方式
folded_length_mm FLOAT, -- 折叠长度(mm)
folded_width_mm FLOAT, -- 折叠宽度(mm)
folded_height_mm FLOAT, -- 折叠高度(mm)
power_system VARCHAR(100), -- 动力装置
guidance_system VARCHAR(100), -- 制导体制
-- 新增关键参数
payload_weight_kg FLOAT, -- 有效载荷重量(kg)
min_combat_radius_km FLOAT, -- 最小作战半径(km)
engine_power_kw FLOAT, -- 发动机功率(kw)
engine_thrust_n FLOAT, -- 发动机推力(N)
datalink_range_km FLOAT, -- 通信链路距离(km)
guidance_accuracy_m FLOAT, -- 制导精度(m)
min_altitude_m FLOAT, -- 最小作战高度(m)
max_altitude_m FLOAT, -- 最大作战高度(m)
-- 特征工程字段
length_width_ratio FLOAT, -- 长宽比
weight_range_ratio FLOAT, -- 重量/射程比
speed_weight_ratio FLOAT, -- 速度/重量比
guidance_system_score INT, -- 制导系统复杂度评分(1-10)
warhead_power_score INT, -- 战斗部威力评分(1-10)
-- 分类特征编码
warhead_type_code INT, -- 战斗部类型编码
launch_mode_code INT, -- 发射方式编码
power_system_code INT, -- 动力装置编码
guidance_system_code INT, -- 制导系统编码
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 分类特征编码表
CREATE TABLE feature_encoding (
id INT AUTO_INCREMENT PRIMARY KEY,
feature_type VARCHAR(50), -- 特征类型(warhead_type/launch_mode/power_system/guidance_system)
feature_value VARCHAR(100), -- 特征值
code INT, -- 编码值
UNIQUE KEY unique_feature (feature_type, feature_value)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 成本数据表
CREATE TABLE cost_data (
id INT AUTO_INCREMENT PRIMARY KEY,