增加一些文件

This commit is contained in:
Tian jianyong 2024-11-08 23:43:57 +08:00
commit 865c93c811
39 changed files with 5609 additions and 0 deletions

82
.cursorrules Normal file
View File

@ -0,0 +1,82 @@
# 开发流程
First ensure basic functionality works
Implement core functionality using the simplest direct approach
Ensure data flow is working correctly
Verify results are accurate
Then gradually add additional features
Add error handling
Add data validation
Add format conversion
Add logging
Improve user experience
Avoid premature optimization
Don't do complex data validation at the start
Don't worry about performance optimization early
Don't over-engineer
This development flow:
Quickly validates if core functionality works
Identifies and fixes fundamental issues early
Avoids wasting time on unnecessary optimizations
Makes code easier to maintain and debug
These principles should guide all code responses, focusing on getting the basics working first before adding complexity.
# 代码修改最佳实践
1. 修改前的准备
- 检查相关文件和依赖关系
- 确保命名一致性
- 添加必要的日志记录
- 准备回滚方案
2. 修改过程中
- 遵循统一的代码风格
- 添加适当的错误处理
- 保持代码的可读性
- 避免重复代码
3. 修改后的验证
- 验证主要功能
- 测试边界条件
- 检查错误处理
- 验证性能影响
4. 文档更新
- 更新相关文档
- 添加注释说明
- 记录重要修改
- 更新调试信息
5. 代码审查要点
- 检查命名规范
- 验证错误处理
- 确认日志完整性
- 评估代码质量
6. 调试建议
- 添加详细日志
- 使用断点调试
- 验证数据流
- 检查状态变化
7. 性能考虑
- 避免过早优化
- 关注关键路径
- 合理使用缓存
- 优化数据库查询
8. 安全性检查
- 验证输入数据
- 处理异常情况
- 保护敏感信息
- 添加访问控制
These practices help maintain code quality and reduce potential issues.

27
.gitignore vendored Normal file
View File

@ -0,0 +1,27 @@
.DS_Store
node_modules
/dist
/models
/logs
/uploads
/data
# local env files
.env.local
.env.*.local
# Log files
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
# Editor directories and files
.idea
.vscode
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
/frontend/node_modules

23
README.md Normal file
View File

@ -0,0 +1,23 @@
# 数据库配置说明
本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要:
1. 创建数据库用户
```sql
CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password';
GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost';
FLUSH PRIVILEGES;
```
2. 配置数据库字符集
确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置:
```ini
[mysqld]
character-set-server=utf8mb4
collation-server=utf8mb4_unicode_ci
[client]
default-character-set=utf8mb4
```

8
app.py Normal file
View File

@ -0,0 +1,8 @@
import logging
# 配置日志
logging.basicConfig(
filename='logs/api.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)

32
config.py Normal file
View File

@ -0,0 +1,32 @@
import os
import secrets
# 数据库配置
DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
# 安全密钥配置(自动生成随机密钥)
SECRET_KEY = secrets.token_hex(16)
# 环境配置
DEBUG = True
ENV = 'development'
# 文件上传配置
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 跨域配置
CORS_ORIGINS = [
"http://localhost:8080",
"http://127.0.0.1:8080",
]
# 日志配置
LOG_LEVEL = 'DEBUG'
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')

BIN
data/.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

616
docs/debug.md Normal file
View File

@ -0,0 +1,616 @@
# 调试记录
## 特殊参数显示问题
### 问题描述
在数据管理页面中,装备详情对话框的特殊参数部分显示为空行或不显示。
### 调试步骤
1. 后端数据查询
```sql
# 测试特殊参数查询
SELECT equipment_id, param_name, param_value, param_unit
FROM custom_params
WHERE param_name IS NOT NULL
AND param_value IS NOT NULL
LIMIT 5
```
2. 日志记录
```python
logging.info(f"Getting details for equipment ID: {id}")
logging.info(f"Equipment type: {equipment_type}")
logging.info(f"Found equipment details: {result['name']}")
logging.info(f"Custom params: {result.get('custom_params')}")
```
3. 前端调试
```javascript
console.log('Requesting details for row:', row)
console.log('Details response:', response.data)
console.log('Custom params:', response.data.custom_params)
console.log('Selected data:', selectedData.value)
```
### 关键发现
1. 数据库查询
- 特殊参数表中有数据
- JSON_ARRAYAGG 返回的格式需要处理
- 需要过滤掉 NULL 值
2. 数据格式
- 后端返回的特殊参数是 JSON 字符串
- 需要在前端解析为数组
- 确保数组不为空
3. 前端渲染
- 条件判断需要更严格
- 需要确保数据类型正确
- 需要正确格式化显示值
### 解决方案
1. 后端查询优化
```sql
(
SELECT JSON_ARRAYAGG(
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
```
2. 前端数据处理
```javascript
// 确保 custom_params 是数组
if (typeof response.data.custom_params === 'string') {
response.data.custom_params = JSON.parse(response.data.custom_params)
}
```
3. 渲染条件优化
```vue
<template v-if="selectedData?.custom_params && Array.isArray(selectedData.custom_params) && selectedData.custom_params.length > 0">
```
### 最佳实践
1. 数据库查询
- 使用子查询而不是 JOIN 获取特殊参数
- 确保返回格式统一
- 过滤无效数据
2. 数据处理
- 统一数据格式
- 处理空值和异常
- 保持类型一致
3. 前端显示
- 严格的条件判断
- 类型检查
- 格式化显示
4. 调试方法
- 使用日志跟踪数据流
- 检查数据格式和类型
- 验证每个环节的数据
## 编辑对话框问题
### 问题描述
在数据管理页面中,编辑对话框的成本信息分区和特殊参数分区显示不正确。
### 调试步骤
1. 检查数据流
```javascript
console.log('Editing row:', row)
console.log('Edit data response:', response.data)
console.log('Parsed custom params:', data.custom_params)
console.log('Edit form data:', editForm.value)
```
2. 检查模板结构
```vue
<!-- 错误的嵌套结构 -->
<el-form>
<template>
<el-divider>成本信息</el-divider>
</template>
</el-form>
<!-- 正确的结构 -->
<el-divider>成本信息</el-divider>
<el-form>
<!-- 表单项 -->
</el-form>
```
### 关键发现
1. 模板结构问题
- el-divider 不应该嵌套在 template 中
- 每个分区需要独立的 el-form
- 避免不必要的 template 嵌套
2. 数据类型问题
- 后端返回的数值是字符串类型
- el-input-number 组件需要数值类型
- 需要在前端进行类型转换
3. 条件渲染问题
- v-if 条件过于严格可能导致内容不显示
- 某些字段应该始终显示
- 某些字段只在有值时显示
### 解决方案
1. 修改模板结构
```vue
<!-- 成本信息 -->
<el-divider content-position="left">成本信息</el-divider>
<el-form :model="editForm" label-width="120px">
<el-form-item label="实际成本(元)">
<el-input-number v-model="editForm.actual_cost"></el-input-number>
</el-form-item>
</el-form>
```
2. 数据类型转换
```javascript
// 转换所有数值类型字段
Object.keys(data).forEach(key => {
if (isNumberInput(key) && data[key] !== null && data[key] !== undefined) {
data[key] = Number(data[key])
}
})
```
3. 优化条件渲染
```vue
<!-- 始终显示必要字段 -->
<el-form-item label="实际成本(元)">
<el-input-number v-model="editForm.actual_cost"></el-input-number>
</el-form-item>
<!-- 只在有值时显示可选字段 -->
<el-form-item label="预测成本(元)" v-if="editForm.predicted_cost">
<el-input-number v-model="editForm.predicted_cost" disabled></el-input-number>
</el-form-item>
```
### 最佳实践
1. 模板结构
- 保持清晰的分区结构
- 避免不必要的嵌套
- 使用合适的组件层级
2. 数据处理
- 在获取数据后立即进行类型转换
- 确保数据类型与组件要求匹配
- 处理好空值和未定义值
3. 条件渲染
- 合理使用 v-if 和 v-show
- 必要字段始终显示
- 可选字段根据条件显示
4. 调试方法
- 使用 console.log 跟踪数据流
- 检查组件的属性要求
- 验证数据类型和结构
## 特征分析功能问题
### 问题描述
特征分析页面中,第一次点击分析按钮时,图表不显示,只有标题栏。第二次点击才能正常显示图表。
### 调试步骤
1. 检查数据流
```javascript
console.log('Analysis result:', analysisResult.value)
console.log('Charts not ready:', {
importanceChartRef: !!importanceChartRef.value,
correlationChartRef: !!correlationChartRef.value,
analysisResult: !!analysisResult.value
})
```
2. 检查渲染时机
```javascript
// 使用 watch 监听分析结果变化
watch(() => analysisResult.value, async (newResult) => {
if (newResult) {
await nextTick()
setTimeout(() => {
renderCharts()
}, 100)
}
}, { deep: true })
```
3. 检查图表实例管理
```javascript
// 销毁旧的图表实例
if (importanceChart.value) {
importanceChart.value.dispose()
}
if (correlationChart.value) {
correlationChart.value.dispose()
}
```
### 关键发现
1. 渲染时机问题
- DOM 元素可能还未准备好
- 数据更新后需要等待 DOM 更新
- 需要正确管理图表实例
2. 图表实例管理
- 需要保存图表实例的引用
- 重新渲染前需要销毁旧实例
- 组件卸载时需要清理实例
3. 数据格式问题
- 特征名称需要中文映射
- 相关性数据需要保留2位小数
- 需要正确处理缺失值
### 解决方案
1. 优化渲染逻辑
```javascript
// 使用 nextTick 和延时确保 DOM 已更新
await nextTick()
setTimeout(() => {
renderCharts()
}, 100)
```
2. 完善图表实例管理
```javascript
// 保存图表实例的引用
const importanceChart = ref(null)
const correlationChart = ref(null)
// 组件卸载时清理
onUnmounted(() => {
importanceChart.value?.dispose()
correlationChart.value?.dispose()
})
```
3. 优化数据处理
```javascript
// 使用中文特征名
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
// 保留2位小数
correlation_data.append([
i, j,
round(correlation_matrix[i][j], 2)
])
```
### 最佳实践
1. 渲染控制
- 使用 watch 监听数据变化
- 使用 nextTick 等待 DOM 更新
- 添加适当的延时确保渲染
2. 实例管理
- 保存图表实例引用
- 及时销毁旧实例
- 组件卸载时清理
3. 数据处理
- 统一使用中文特征名
- 控制数值精度
- 处理好缺失值
4. 调试方法
- 添加详细的日志记录
- 检查 DOM 元素状态
- 验证数据格式
## 页面状态保持问题
### 问题描述
特征分析<EFBFBD><EFBFBD><EFBFBD>面切换到其他页面后再返回页面状态分析结果和图表会丢失需要重新分析。
### 调试步骤
1. 检查路由配置
```javascript
// 错误的配置
<keep-alive include="AnalysisPage">
<router-view></router-view>
</keep-alive>
// 正确的配置
<router-view v-slot="{ Component }">
<keep-alive>
<component :is="Component" :key="$route.fullPath" />
</keep-alive>
</router-view>
```
2. 检查组件定义
```javascript
// 错误的组件名称定义
<script setup name="AnalysisPage">
// 正确的组件名称定义
const __name = 'AnalysisPage'
```
### 关键发现
1. keep-alive 配置问题
- 需要使用 v-slot API
- 需要使用动态组件
- 需要添加 key 属性
2. 组件定义问题
- setup 语法糖不支持直接添加 name
- 需要使用 __name 或 defineOptions
3. 缓存范围问题
- 不需要指定 include 属性
- 缓存所有路由组件更简单
- 避免组件名称不匹配的问题
### 解决方案
1. 修改路由视图配置
```vue
<router-view v-slot="{ Component }">
<keep-alive>
<component :is="Component" :key="$route.fullPath" />
</keep-alive>
</router-view>
```
2. 修改组件定义
```javascript
const __name = 'AnalysisPage'
```
### 最佳实践
1. 路由配置
- 使用 Vue3 的新 API
- 保持配置简单清晰
- 避免不必要的限制
2. 组件定义
- 使用推荐的方式定义组件名
- 避免使用已废弃的语法
- 保持代码一致性
3. 状态管理
- 合理使用 keep-alive
- 正确处理组件生命周期
- 注意清理工作
4. 调试方法
- 检查组件是否被缓存
- 验证状态是否保持
- 确认生命周期钩子的执行
## 代码修改最佳实践
### 1. 修改前的准备
1. 检查相关文件:
```
前端组件修改时检查:
- 相关的路由配置
- 父子组件关系
- 共用的组件和函数
- API调用
后端接口修改时检查:
- 路由定义
- 数据库查询
- 相关的工具类和函数
- 错误处理
```
2. 保持命名一致性:
```
- 类名ModelTrainer 而不是 ModelTraining
- 文件名train_model.py 对应 ModelTrainer
- 变量名:保持前后端一致的命名规范
```
3. 添加日志记录:
```python
# 在关键节点添加日志
logging.info(f"Starting model training for {equipment_type}")
logging.info(f"Training dataset: {train_dataset_id}")
logging.error(f"Error in model training: {str(e)}")
```
### 2. 修改过程中
1. 错误处理:
```python
try:
# 主要逻辑
except Exception as e:
logging.error(f"Error: {str(e)}")
logging.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
```
2. 数据验证:
```python
# 验证输入
if not formData.value.type:
throw new Error('请选择装备类型')
if not formData.value.train_dataset_id:
throw new Error('请选择训练数据集')
```
3. 状态管理:
```javascript
// 重置状态
formData.value.train_dataset_id = null
formData.value.validation_dataset_id = null
trainingResult.value = null
```
### 3. 修改后的验证
1. 功能测试:
```
- 验证主要功能
- 测试边界条件
- 检查错误处理
```
2. 性能检查:
```
- 检查数据库查询性能
- 验证前端渲染性能
- 确认内存使用情况
```
3. 代码质量:
```
- 检查代码风格
- 确保注释完整
- 验证类型定义
```
### 4. 文档更新
1. 更新调试文档:
```
- 记录问题原因
- 描述解决方案
- 添加最佳实践
```
2. 更新设计文档:
```
- 更新接口定义
- 修改数据结构
- 补充新功能说明
```
3. 更新注释:
```
- 添加函数说明
- 说明参数用途
- 解释复杂逻辑
```
## 模型训练结果
从最新的训练结果来看:
1. XGBoost 表现最好:
- 训练集 R² = 0.4346,没有过拟合
- 验证集 R² = 0.3625,表现最稳定
- MAE = 0.60RMSE = 0.61,预测误差较小
2. LightGBM 表现次之:
- 训练集 R² = 0.5277,轻微过拟合
- 验证集 R² = 0.1101,泛化能力一般
- MAE = 0.55RMSE = 0.72,预测误差适中
3. Random Forest
- 训练集 R² = 0.7756,存在过拟合
- 验证集 R² = 0.3189,泛化能力还可以
- MAE = 0.47RMSE = 0.63,预测误差较小
4. GBDT 过拟合严重:
- 训练集 R² = 0.9700,严重过拟合
- 验证集 R² = -1.3133,泛化能力很差
- MAE = 0.96RMSE = 1.17,预测误差大
### 建议
1. 使用 XGBoost 作为主要模型
2. 可以考虑集成 XGBoost 和 Random Forest
3. 继续调整 LightGBM 的参数
4. 暂时不使用 GBDT

433
docs/design.md Normal file
View File

@ -0,0 +1,433 @@
# 装备成本估算系统设计方案
## 一、系统概述
本系统旨在通过装备的技术参数,利用机器学习方法对装备成本进行估算。系统采用前后端分离架构,主要包含数据预处理、特征分析、模型训练和成本预测等模块。
## 二、系统架构
### 1. 技术架构
- 前端Vue.js + Element UI
- 后端Python Flask
- 数据库: Mysql
- 机器学习框架TensorFlow, Scikit-learn
### 2. 系统模块
```mermaid
graph TD
A[装备成本估算系统] --> B[数据预处理模块]
A --> C[特征分析模块]
A --> D[模型训练模块]
A --> E[成本预测模块]
A --> F[系统管理模块]
```
## 三、数据模型设计
### 1. 数据库表结构
#### 装备基本信息表(equipment)
- id: 主键
- name: 装备名称
- type: 装备类型
- manufacturer: 制造商
- created_at: 创建时间
#### 技术参数表(technical_params)
##### 尺寸参数
- length_m: 总长(m)
- width_m: 宽度(m)
- height_m: 高度(m)
- weight_standard_kg: 标准重量(kg)
- weight_combat_kg: 战斗重量(kg)
##### 火力参数
- firing_angle_horizontal: 方向射界(度)
- firing_angle_vertical: 高低射界(度)
- rocket_length_m: 火箭弹长度(m)
- rocket_diameter_mm: 弹体直径(mm)
- rocket_weight_kg: 火箭弹重量(kg)
##### 性能参数
- max_speed_ms: 最大速度(m/s)
- max_range_km: 最大射程(km)
- warhead_weight_kg: 战斗部重量(kg)
##### 其他参数
- mobility_type: 机动性类型
- wheel_arrangement: 轮式布局
- amphibious: 两栖能力
#### 成本数据表(cost_data)
- equipment_id: 装备ID
- actual_cost: 实际成本
- predicted_cost: 预测成本
- prediction_date: 预测日期
## 四、核心功能模块设计
### 1. 特征工程模块
- 数据预处理
- 数据清洗
- 缺失值处理
- 异常值检测
- 数据标准化
- 特征衍生
- 功重比(power_weight_ratio)
- 体积(volume)
- 射程速度比(range_speed_ratio)
- 战斗部比例(warhead_rocket_ratio)
- 特征选择
- 方差分析
- 相关性分析
- 互信息分析
- 一致性分析
- rwg值计算组内一致性
- ICC值计算组内相关系数
### 2. 模型训练模块
#### 2.1 集成学习模型
1. XGBoost (eXtreme Gradient Boosting)
- 特点:
- 使用二阶导数进行优化,收敛更快
- 支持自定义损失函数
- 内置正则化,防止过拟合
- 支持特征重要性评估
- 配置:
```python
XGBRegressor(
n_estimators=100, # 树的数量适中,避免过拟合
learning_rate=0.1, # 较大的学习率,加快收敛
max_depth=3, # 较小的树深度,防止过拟合
min_child_weight=3, # 控制过拟合
subsample=0.8, # 随机采样,增加模型鲁棒性
colsample_bytree=0.8, # 特征采样,防止过拟合
objective='reg:squarederror' # 回归任务
)
```
2. LightGBM (Light Gradient Boosting Machine)
- 特点:
- 基于直方图的决策树算法,训练速度快
- 支持类别特征的直接输入
- 叶子优先的生长策略
- 内存占用小
- 配置:
```python
LGBMRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
num_leaves=8, # 控制树的复杂度
subsample=0.8,
colsample_bytree=0.8,
objective='regression'
)
```
3. GBDT (Gradient Boosting Decision Tree)
- 特点:
- 经典的梯度提升算法
- 较好的可解释性
- 对异常值不敏感
- 预测稳定性好
- 配置:
```python
GradientBoostingRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
subsample=0.8
)
```
4. Random Forest (随机森林)
- 特点:
- Bagging集成方法
- 训练过程可并行化
- 不易过拟合
- 对特征尺度不敏感
- 配置:
```python
RandomForestRegressor(
n_estimators=100,
max_depth=3,
min_samples_split=3,
min_samples_leaf=2
)
```
#### 2.2 模型选择策略
1. 交叉验证
- 小样本数据集使用留一法LOO交叉验证
- 大样本数据集使用5折交叉验证
2. 模型评估指标
- R²分数评估拟合优度
- 标准差:评估预测稳定性
- MAE评估预测误差
- MSE惩罚较大误差
3. 自动模型选择
- 对每个模型进行交叉验证评估
- 选择R²分数最高的模型
- 保存模型训练历史记录
#### 2.3 特征工程与模型优化
1. 特征选择
- 使用模型的特征重要性评分
- 基于互信息的特征筛选
- 相关性分析去除冗余特征
2. 参数优化
- 针对小样本:
- 减小模型复杂度(树的深度和数量)
- 增加正则化强度
- 使用较大的学习率
- 针对大样本:
- 增加模型复杂度
- 减小学习率
- 使用特征采样
3. 集成策略
- 模型投票:综合多个模型的预测结果
- 置信区间使用bootstrap方法估计预测不确定性
- 异常检测:识别并处理异常预测值
### 3. 成本预测模块
- 数据标准化处理
- 模型预测流程
- 置信区间计算
- 预测评估指标
- MAE(平均绝对误差)
- MSE(均方误差)
- RMSE(均方根误差)
- R²(决定系数)
### 4. 数据管理模块
#### 4.1 数据查询优化
1. 特殊参数查询策略
- 使用子查询替代 JOIN 操作,避免数据重复和空值问题
- 使用 COALESCE 确保返回空数组而不是 NULL
- 在 WHERE 子句中添加条件确保只获取有效参数
- 简化查询结构,移除不必要的 GROUP BY 子句
2. 查询性能优化
- 数据结构清晰,每个装备保持单一记录
- 特殊参数组织在 JSON 数组中
- 避免空值和无效数据的显示
- 优化查询性能,减少复杂的 JOIN 和 GROUP BY 操作
#### 4.2 数据展示
1. 基本信息展示
- 装备基本参数表格展示
- 支持搜索和过滤功能
- 分类显示不同类型装备数据
2. 特殊参数展示
- 在详情页面显示装备特有参数
- 支持特殊参数的格式化显示
- 根据参数类型自动添加单位
- 区分数值类型和文本类型参数
3. 数据编辑功能
- 支持基本参数和特殊参数的编辑
- 根据参数类型提供不同的编辑控件
- 数值类型参数支持精度控制
- 保存时进行数据验证
#### 4.3 数据维护
1. 数据导入导出
- 支持 Excel 模板导入
- 提供数据模板下载
- 导入时进行数据验证和错误提示
2. 数据删除
- 级联删除相关数据
- 删除前进行确认
- 删除后自动刷新数据列表
3. 数据更新
- 支持单条记录更新
- 保存时进行数据完整性验证
- 更新后实时刷新显示
#### 4.4 扩展性设计
1. 参数管理
- 支持动态添加特殊参数
- 参数定义包含名称、单位、说明等
- 支持参数值类型定义
2. 数据结构
- 采用灵活的数据模型设计
- 支持不同类型装备的差异化参数
- 便于后续功能扩展
3. 接口设计
- RESTful API 设计
- 统一的响应格式
- 完善的错误处理机制
## 五、API接口设计
### 1. 成本预测接口
- 路径POST /api/predict
- 必要参数:
- length_m: 总长
- width_m: 宽度
- height_m: 高度
- weight_standard_kg: 标准重量
- weight_combat_kg: 战斗重量
- max_range_km: 最大射程
- max_speed_ms: 最大速度
- 响应数据:
- predicted_cost: 预测成本
- confidence_interval: 置信区间(上下限)
### 2. 特征分析接口
- 路径POST /api/analyze-features
- 响应数据:
- important_features: 重要特征列表
- correlation_analysis: 相关性分析结果
## 六、部署要求
### 1. 环境要求
- Python 3.8+
- PostgreSQL 12+
- Node.js 14+
### 2. 依赖包
- Flask
- NumPy
- Pandas
- Scikit-learn
- TensorFlow
- SQLAlchemy
### 3. 部署步骤
1. 数据库初始化
- 创建数据库
- 执行表结构脚本
- 导入基础数据
2. 后端服务部署
- 安装依赖包
- 配置环境变量
- 启动服务
3. 前端部署
- 安装依赖
- 构建生产版本
- 配置Nginx
## 七、后续优化建议
### 1. 特征工程优化
- 增加更多领域特征
- 火力密度指标
- 机动性综合指标
- 作战效能指标
- 优化特征选择算法
- 增强数据清洗能力
### 2. 模型优化
- 引入集成学习方法
- Random Forest
- XGBoost
- LightGBM
- 实现模型自动调优
- 增加在线学习能力
### 3. 系统优化
- 增加批量处理能力
- 实现模型版本管理
- 提升预测结果可解释性
- 优化系统性能
- 数据库索引优化
- 缓存策略
- API性能优化
### 4. 安全性优化
- 数据加密存储
- API访问认证
- 操作日志记录
- 数据备份策略
## 八、项目进度规划
### 第一阶段基础功能开发4周
1. 数据库设计和实现
2. 特征工程模块开发
3. 基础API实现
### 第二阶段模型开发4周
1. 数据预处理
2. 模型训练与优化
3. 预测功能实现
### 第三阶段系统集成3周
1. 前端开发
2. 系统集成测试
3. 性能优化
### 第四阶段系统测试与部署2周
1. 系统测试
2. 部署上线
3. 文档完善

136
docs/nodejs_install.md Normal file
View File

@ -0,0 +1,136 @@
# Node.js 安装指南
## Windows 安装方法
### 1. 使用安装包
1. 访问 Node.js 官网 <https://nodejs.org/>
2. 下载 14.x LTS 版本安装包
3. 运行安装包,按提示完成安装
4. 验证安装:
```bash
node --version
npm --version
```
### 2. 使用 nvm-windows推荐
1. 下载 nvm-windows<https://github.com/coreybutler/nvm-windows/releases>
2. 安装 nvm-windows
3. 安装 Node.js
```bash
nvm install 14.21.3
nvm use 14.21.3
```
## Linux 安装方法
### 1. 使用 nvm推荐
```bash
# 安装 nvm
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
# 重新加载配置
source ~/.bashrc
# 安装 Node.js 14
nvm install 14
nvm use 14
```
### 2. 使用包管理器
#### Ubuntu/Debian
```bash
# 添加 NodeSource 仓库
curl -fsSL https://deb.nodesource.com/setup_14.x | sudo -E bash -
# 安装 Node.js
sudo apt-get install -y nodejs
```
#### CentOS/RHEL
```bash
# 添加 NodeSource 仓库
curl -fsSL https://rpm.nodesource.com/setup_14.x | sudo bash -
# 安装 Node.js
sudo yum install -y nodejs
```
## macOS 安装方法
### 1. 使用 Homebrew推荐
```bash
# 安装 Homebrew如果未安装
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
# 安装 Node.js 14
brew install node@14
# 添加环境变量
echo 'export PATH="/usr/local/opt/node@14/bin:$PATH"' >> ~/.zshrc
source ~/.zshrc
```
### 2. 使用 nvm
```bash
# 安装 nvm
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
# 重新加载配置
source ~/.zshrc
# 安装 Node.js 14
nvm install 14
nvm use 14
```
## 验证安装
安装完成后,运行以下命令验证:
```bash
# 检查 Node.js 版本
node --version # 应显示 v14.x.x
# 检查 npm 版本
npm --version # 应显示 6.x.x 或更高
```
## 常见问题
### 1. 权限问题
如果遇到权限错误,可以:
```bash
# Linux/macOS
sudo chown -R $USER /usr/local/lib/node_modules
```
### 2. 版本切换
如果需要在不同版本间切换:
```bash
# 使用 nvm
nvm list # 查看已安装版本
nvm use 14 # 切换到 14.x 版本
```
### 3. npm 配置
建议配置国内镜像源:
```bash
# 使用淘宝镜像
npm config set registry https://registry.npmmirror.com
```

39
docs/requirements.md Normal file
View File

@ -0,0 +1,39 @@
# 项目需求
## 项目名称
- 装备成本估算系统
## 系统功能
1. 根据装备技术参数进行成本估算
2. 采用神经网络模型,使用现有数据进行训练,然后对新装备进行成本估算
## 技术要求
1. 数据相关性分析:
1方差分析采用按照不同的标签类别将特征划分为不同的总体然后判断总体之间均值是否相同 (或者是否有显著性差异)
2线性相关分析对于特征和标签皆为连续值的回归问题要检测二者的相关性最直接的做法就是求相关系数rxy本质是建立协方差矩阵分析数据和成本之间相关关系的类型和程度筛选出影响特征
3互信息 (mutual information) 用于特征选择,可以从两个角度进行解释:(1)、基于 KL 散度和 (2)、基于信息增益。
2. 数据一致性分析:对特征数据分层分组,计算组内一致性,目标是选择比较合适的一组数据,以此产生一个进行成本估算和分析的虚拟量.大部分的研究中报告的三个数据rwg、ICC(1)、ICC(2)要符合3个条件rwg>0.7、ICC(1)>0.05、ICC(2)>0.5
RWG值打分一致性
ICC1组内一致性
ICC2组间一致性。
3. 回归模型:偏最小二乘回归(partial Least SquaresPLS)
4. 神经网络模型:采用 BP 网络
### 数据准备
建议补充数据的优先级(火箭弹):
1. 第一优先级(射击性能相关):
- rate_of_fire (射速)
- rocket_weight_kg (火箭弹重量)
- max_range_km (最大射程)
2. 第二优先级(火力性能相关):
- firing_angle_horizontal (方向射界)
firing_angle_vertical (高低射界)
- rocket_length_m (火箭弹长度)
3. 第三优先级(机动性能相关):
- min_range_km (最小射程)
- power_hp (功率)

163
docs/run.md Normal file
View File

@ -0,0 +1,163 @@
# 系统运行说明
## 一、环境准备
### 1. 安装必要软件
```bash
# 安装 Python 3.8+
# 安装 MySQL 8.0+
# 安装 Node.js 14+
```
### 2. 安装 Python 依赖
```bash
pip install -r requirements.txt
```
### 3. 安装前端依赖
```bash
cd frontend
npm install
```
## 二、数据库配置
### 1. 创建数据库
```sql
CREATE DATABASE equipment_cost_db DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
```
### 2. 初始化数据库结构
```bash
# 执行数据库结构初始化脚本
mysql -u username -p equipment_cost_db < src/schema.sql
# 导入示例数据
mysql -u username -p equipment_cost_db < src/init_data.sql
# 导入真实数据
mysql -u username -p equipment_cost_db < src/real_data.sql
```
## 三、配置文件
### 1. 后端配置
创建 `config.py` 文件:
```python
# config.py
DATABASE_URI = "mysql+pymysql://username:password@localhost:3306/equipment_cost_db"
SECRET_KEY = "your-secret-key"
DEBUG = True
```
### 2. 前端配置
修改 `frontend/src/config.js`
```javascript
export const API_BASE_URL = 'http://localhost:5001/api';
```
## 四、启动系统
### 1. 启动后端服务
```bash
# 开发环境
python run.py # 服务将在 http://localhost:5001 启动
# 生产环境
gunicorn -w 4 -b 0.0.0.0:5001 run:app
```
### 2. 启动前端服务
```bash
# 开发环境
cd frontend
npm run serve # 前端将在 http://localhost:8080 启动
# 生产环境
npm run build
```
## 五、访问系统
- 后端API<http://localhost:5001/api>
- 前端界面:<http://localhost:8080>
## 六、常见问题
### 1. 数据库连接问题
- 检查 MySQL 服务是否启动
- 验证数据库用户名和密码
- 确认数据库端口是否正确
### 2. 模型训练
```bash
# 训练模型
python src/train_model.py
# 查看训练日志
tail -f logs/training.log
```
### 3. 系统监控
```bash
# 查看系统日志
tail -f logs/app.log
# 监控API请求
tail -f logs/access.log
```
## 七、开发调试
### 1. 后端调试
```bash
# 启动调试模式
python run.py --debug
# 运行测试
python -m pytest tests/
```
### 2. 前端调试
```bash
# 启动开发服务器
npm run serve
# 运行测试
npm run test
```
## 八、部署建议
### 1. 使用 Docker 部署
```bash
# 构建镜像
docker-compose build
# 启动服务
docker-compose up -d
```
### 2. 生产环境配置
- 使用 Nginx 作为反向代理
- 配置 SSL 证书
- 设置适当的防火墙规则
- 启用数据库备份

1
frontend Submodule

@ -0,0 +1 @@
Subproject commit 96445d75411a5f9ace114085af0872cfbc116515

13
loiteringmunitions.md Normal file
View File

@ -0,0 +1,13 @@
# 巡飞弹技术参数示例
## 美国“终结者”单兵巡飞弹
目标: 静止和移动的人员和轻型装甲车辆
外形尺寸: 560mm×150mm×200mm收起时
弹重: <2.72kg
射程: >24km
巡飞时间: 15min
巡飞速度: 96.56km/h
最大飞行速度: >160.93km/h
战斗部类型: 破片杀伤战斗部、发烟战斗部、温压战斗部
发射方式: 凭自身动力起飞

12
package.json Normal file
View File

@ -0,0 +1,12 @@
{
"name": "frontend",
"version": "1.0.0",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"description": ""
}

12
requirements.txt Normal file
View File

@ -0,0 +1,12 @@
flask==2.0.1
flask-cors==3.0.10
sqlalchemy==1.4.23
pymysql==1.0.2
cryptography==3.4.7 # MySQL 8.0+ 认证需要
numpy==1.21.2
pandas==1.3.3
scikit-learn==0.24.2
tensorflow==2.6.0
urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
openpyxl==3.1.2 # 用于读取 .xlsx 文件
xlrd==2.0.1 # 用于读取 .xls 文件

29
rocketparameters.md Normal file
View File

@ -0,0 +1,29 @@
# 火箭炮系统技术参数示例
## 伊朗“胜利”-2 240mm 12管火箭炮系统
产品类别: 多管火箭炮
型号: “胜利”-2 240mm   多管火箭炮
尺寸与重量
总长: 10m(393.7in)
宽(行军状态): 2.5m(98.4in)
高(行军状态) 3.34m131.5in
标准重: 15000kg(33069 lb)(15.0t)
战斗重: 19900kg(43871 lb)(19.9t)
机动性
行走装置: 轮式
布局: 6×6
两栖: 无
火力
方向射界: 100º(1778mils)(左/90°右
高低射界(武器前方): 57°(1013mils)
型号: “胜利”2火箭弹
尺寸与重量
总长: 3.550m(11ft)
弹体直径: 512mm(20.16in)(尾翼展开)
发射(重量): 275kg(606 lb)
性能
速度(最大速度): 1302kt(2412km/h;1499mph;670m/s)
最大射程 12.4n miles(23km;14.3miles)
武器组成
战斗部: 85kg(187 lb)

61
run.py Normal file
View File

@ -0,0 +1,61 @@
import os
import logging
from src.app import app
# 确保必要的目录存在
def ensure_directories():
"""
确保所有必要的目录都存在
"""
directories = [
'logs',
'data',
'models',
'uploads'
]
for directory in directories:
os.makedirs(directory, exist_ok=True)
# 配置日志
def setup_logging():
"""
配置日志系统
"""
logging.basicConfig(
filename='logs/server.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# 同时输出到控制台
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logging.getLogger('').addHandler(console_handler)
if __name__ == "__main__":
try:
# 初始化目录
ensure_directories()
# 设置日志
setup_logging()
# 记录启动信息
logging.info("=== Server Starting ===")
logging.info("Initializing directories...")
logging.info("Setting up logging system...")
# 启动服务器
app.run(
host='localhost',
port=5001,
debug=True,
use_reloader=False # 禁用重载器以避免模型重复加载
)
except Exception as e:
logging.error(f"Server failed to start: {str(e)}")
raise

1
src/__init__.py Normal file
View File

@ -0,0 +1 @@
# 这个文件可以为空,但必须存在

68
src/api.py Normal file
View File

@ -0,0 +1,68 @@
from flask import Flask, request, jsonify
from .model_training import ModelTrainer
from .cost_prediction import CostPredictor
from .feature_analysis import FeatureAnalysis
import pandas as pd
app = Flask(__name__)
@app.route('/api/predict', methods=['POST'])
def predict_cost():
"""
成本预测接口
"""
try:
data = request.get_json()
# 验证必要参数
required_params = [
'length_m', 'width_m', 'height_m', 'weight_standard_kg',
'weight_combat_kg', 'max_range_km', 'max_speed_ms'
]
for param in required_params:
if param not in data:
return jsonify({'error': f'Missing parameter: {param}'}), 400
predictor = CostPredictor()
result = predictor.predict(data)
return jsonify({
'predicted_cost': float(result['predicted_cost']),
'confidence_interval': {
'lower': float(result['confidence_intervals'][0]),
'upper': float(result['confidence_intervals'][1])
}
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/analyze-features', methods=['POST'])
def analyze_features():
"""
特征分析接口
"""
try:
data = request.get_json()
analyzer = FeatureAnalysis()
# 数据预处理
processed_data = analyzer.preprocess_features(pd.DataFrame(data))
# 特征重要性分析
important_features = analyzer.select_features(
processed_data,
data['cost']
)
return jsonify({
'important_features': important_features,
'correlation_analysis': analyzer.correlation_analysis(
processed_data,
data['cost']
).to_dict()
})
except Exception as e:
return jsonify({'error': str(e)}), 500

68
src/app.py Normal file
View File

@ -0,0 +1,68 @@
from flask import Flask
from flask_cors import CORS
import logging
import os
from .routes import api_bp
def create_app():
"""
创建并配置Flask应用
"""
app = Flask(__name__)
# 配置跨域
CORS(app)
# 配置日志
setup_logging()
# 注册蓝图
app.register_blueprint(api_bp, url_prefix='/api')
# 错误处理
@app.errorhandler(404)
def not_found_error(error):
logging.error(f"404 error: {error}")
return {'error': 'Resource not found'}, 404
@app.errorhandler(500)
def internal_error(error):
logging.error(f"500 error: {error}")
return {'error': 'Internal server error'}, 500
@app.errorhandler(Exception)
def handle_exception(error):
logging.error(f"Unhandled exception: {error}", exc_info=True)
return {'error': str(error)}, 500
return app
def setup_logging():
"""
配置日志系统
"""
# 确保日志目录存在
os.makedirs('logs', exist_ok=True)
# 配置日志格式
logging.basicConfig(
filename='logs/api.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# 同时输出到控制台
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logging.getLogger('').addHandler(console_handler)
app = create_app()
@app.route('/health')
def health_check():
"""
健康检查端点
"""
return {'status': 'ok'}

218
src/cost_prediction.py Normal file
View File

@ -0,0 +1,218 @@
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from scipy import stats
import joblib
import os
import pandas as pd
from .feature_analysis import FeatureAnalysis
import logging
from src.model_trainer import ModelTrainer
from src.database import get_db_connection
class CostPredictor:
def __init__(self):
self.scaler_X = StandardScaler()
self.scaler_y = StandardScaler()
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
# 添加 TensorFlow 配置
tf.config.run_functions_eagerly(False) # 启用图执行模式
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
self.load_model()
def load_model(self):
"""
加载预训练模型和标准化器
"""
try:
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 创建默认模型
self._create_default_model()
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
def _create_default_model(self):
"""
创建默认模型并进行初始化训练
"""
# 创建输入层
inputs = tf.keras.Input(shape=(11,))
# 创建隐藏层
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
# 创建输出层
outputs = tf.keras.layers.Dense(1)(x)
# 创建模型
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 编译模型
self.model.compile(
optimizer='adam',
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_absolute_error]
)
# 创建示例数据
example_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 训练标准化器
self.scaler_X.fit(example_data)
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
# 设置默认装备类型
self.equipment_type = '火箭炮'
def _create_example_data(self):
"""
创建示例数据来训练标准化器
"""
# 火箭炮示例数据
rocket_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 巡飞弹示例数据
missile_data = pd.DataFrame({
'length_m': [1.3, 2.5],
'width_m': [0.23, 0.6],
'height_m': [0.23, 0.6],
'weight_kg': [12.5, 135],
'max_range_km': [40, 250],
'max_speed_kmh': [180, 185],
'cruise_speed_kmh': [100, 110],
'flight_time_min': [60, 120],
'folded_length_mm': [1300, 2500],
'folded_width_mm': [230, 600],
'folded_height_mm': [230, 600]
})
# 训练标准化器
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
# 设置默认装备类型
self.equipment_type = '火箭炮'
def predict(self, data):
"""
预测成本
"""
try:
equipment_type = data.get('type')
# 加载模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type):
raise ValueError(f"No trained model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = self._calculate_confidence_interval(y_pred[0])
# 确保预测值和置信区间都是正数且合理的范围
predicted_cost = max(1000, float(y_pred[0])) # 最小值设为1000元
lower_bound = max(1000, float(confidence_interval[0]))
upper_bound = max(predicted_cost * 1.2, float(confidence_interval[1])) # 至少比预测值大20%
return {
'predicted_cost': predicted_cost,
'confidence_interval': {
'lower': lower_bound,
'upper': upper_bound
}
}
except Exception as e:
logging.error(f"Prediction error: {str(e)}")
raise
def _calculate_confidence_interval(self, prediction, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
lower = max(1000, interval[0]) # 最小值设为1000元
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
return [lower, upper]
except Exception as e:
logging.error(f"Error calculating confidence interval: {str(e)}")
# 如果计算失败返回基于20%的简单区间
lower = max(1000, prediction * 0.8)
upper = prediction * 1.2
return [lower, upper]
def evaluate(self, y_true, y_pred):
"""
模型评估
"""
return {
'mae': float(mean_absolute_error(y_true, y_pred)),
'mse': float(mean_squared_error(y_true, y_pred)),
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
'r2': float(r2_score(y_true, y_pred))
}

152
src/create_template.py Normal file
View File

@ -0,0 +1,152 @@
import pandas as pd
import openpyxl
from openpyxl.styles import PatternFill, Font, Alignment
from openpyxl.worksheet.datavalidation import DataValidation
import os
def create_excel_template():
"""
创建数据模板
"""
try:
# 确保data目录存在
data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
os.makedirs(data_dir, exist_ok=True)
# 创建完整的文件路径
template_path = os.path.join(data_dir, 'equipment_data_template.xlsx')
# 创建 Excel 写入器
writer = pd.ExcelWriter(template_path, engine='openpyxl')
# 火箭炮基本参数表
rocket_artillery_columns = [
'名称', '类型', '制造商', '口径_mm',
'反射管数量', '乘员数', '总长_m',
'宽度_m', '高度_m', '重量_kg',
'战斗重_kg', '速度_km/h', '最大射程_km',
'最小射程_km', '方向射界_度', '高低射界_度',
'火箭弹长度_m', '火箭弹重量_kg',
'火箭弹最大速度_m/s', '射速_发',
'战斗部重量_kg', '行走方式',
'结构布局', '发动机型号', '发动机参数',
'功率_hp', '行程_km', '成本_元'
]
# 巡飞弹基本参数表
loitering_munition_columns = [
'名称', '类型', '制造商', '目标类型',
'弹长_m', '弹径_mm', '翼展_m',
'重量_kg', '战斗部重量_kg',
'最大射程_km', '最大速度_m/s',
'巡航速度_kmh', '巡飞时间_min',
'战斗部类型', '发射方式',
'折叠长度_mm', '折叠宽度_mm',
'折叠高度_mm', '动力装置',
'制导体制', '成本_元'
]
# 特殊参数表
special_params_columns = [
'装备名称', # 关联字段
'参数名称',
'参数值',
'参数单位',
'参数说明'
]
# 创建工作表
pd.DataFrame(columns=rocket_artillery_columns).to_excel(
writer, sheet_name='火箭炮', index=False
)
pd.DataFrame(columns=loitering_munition_columns).to_excel(
writer, sheet_name='巡飞弹', index=False
)
pd.DataFrame(columns=special_params_columns).to_excel(
writer, sheet_name='特殊参数', index=False
)
# 获取工作簿
workbook = writer.book
# 设置火箭炮工作表格式
rocket_sheet = workbook['火箭炮']
for col in range(1, len(rocket_artillery_columns) + 1):
cell = rocket_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 设置巡飞弹工作表格式
missile_sheet = workbook['巡飞弹']
for col in range(1, len(loitering_munition_columns) + 1):
cell = missile_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 设置特殊参数工作表格式
special_sheet = workbook['特殊参数']
for col in range(1, len(special_params_columns) + 1):
cell = special_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 添加数据验证
for sheet in [rocket_sheet, missile_sheet]:
# 数值验证
number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0")
number_validation.error = "请输入大于0的数值"
number_validation.errorTitle = "输入错误"
# 应用到相应列
for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
number_validation.add(f"{col}2:{col}1000")
sheet.add_data_validation(number_validation)
# 添加说明
rocket_sheet['AD1'] = "填写说明:"
rocket_sheet['AD2'] = "1. 所有数值必须大于0"
rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写"
rocket_sheet['AD4'] = "3. 成本单位为元"
missile_sheet['V1'] = "填写说明:"
missile_sheet['V2'] = "1. 所有数值必须大于0"
missile_sheet['V3'] = "2. 单位必须按照表头要求填写"
missile_sheet['V4'] = "3. 成本单位为元"
special_sheet['G1'] = "填写说明:"
special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致"
special_sheet['G3'] = "2. 参数值必须包含单位"
special_sheet['G4'] = "3. 参数说明应简明扼要"
# 调整列宽
for sheet in [rocket_sheet, missile_sheet, special_sheet]:
for col in sheet.columns:
max_length = 0
column = col[0].column_letter
for cell in col:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = (max_length + 2)
sheet.column_dimensions[column].width = adjusted_width
# 保存文件
writer.close()
return template_path
except Exception as e:
raise Exception(f"创建模板文件失败: {str(e)}")
if __name__ == "__main__":
try:
template_path = create_excel_template()
print(f"模板文件已创建: {template_path}")
except Exception as e:
print(f"错误: {str(e)}")

199
src/data_preparation.py Normal file
View File

@ -0,0 +1,199 @@
from sklearn.preprocessing import StandardScaler
from datetime import datetime
import os
import joblib
import pandas as pd
import numpy as np
from src.feature_analysis import FeatureAnalysis
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.model_selection import cross_val_score, LeaveOneOut
import json
import logging
from src.database.db_connection import get_db_connection
from sklearn.metrics import mean_absolute_error, mean_squared_error
class DataPreparation:
def __init__(self):
self.feature_analyzer = FeatureAnalysis()
self.feature_scaler = StandardScaler()
self.target_scaler = StandardScaler() # 添加目标值标准化器
def prepare_training_data(self, equipment_data, equipment_type):
"""
准备训练数据
"""
try:
logging.info(f"Preparing training data for {equipment_type}")
logging.info(f"Raw data size: {len(equipment_data)}")
# 如果输入已经是 numpy 数组,直接返回
if isinstance(equipment_data, np.ndarray):
X = equipment_data
logging.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
return {
'X': X,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
# 从原始数据中提取特征和目标值
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
for item in equipment_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 提取目标值(成本)
try:
cost = float(item['actual_cost'])
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logging.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError, KeyError):
logging.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录原始数据范围
logging.info(f"Original X range: min={X.min()}, max={X.max()}")
logging.info(f"Original y range: min={y.min()}, max={y.max()}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 标准化特征和目标值
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
# 记录标准化后的数据范围
logging.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
logging.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
return {
'X': X_scaled,
'y': y_scaled,
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
except Exception as e:
logging.error(f"Error in data preparation: {str(e)}")
raise Exception(f"Training error: {str(e)}")
def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None):
"""
准备验证数据
"""
try:
logging.info(f"Preparing validation data for {equipment_type}")
logging.info(f"Raw validation data size: {len(validation_data)}")
# 如果输入已经是 numpy 数组,直接使用
if isinstance(validation_data, np.ndarray):
X = validation_data
logging.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 使用训练数据的标准化器
if scalers and 'feature_scaler' in scalers:
X_scaled = scalers['feature_scaler'].transform(X)
else:
# 如果没有提供标准化器,直接返回处理后的数组
X_scaled = X
logging.info(f"Preprocessed data shape: {X_scaled.shape}")
logging.info(f"Validation features shape: {X_scaled.shape}")
logging.info(f"Validation features type: {X_scaled.dtype}")
return {
'X': X_scaled,
'y': None # 验证数据可能没有标签
}
# 否则,从原始数据中提取特征
if not feature_names:
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
# 提取特征和目标值
features = []
targets = []
for item in validation_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0) # 使用0替代NaN
features.append(feature_values)
# 提取目标值(成本)
try:
targets.append(float(item['actual_cost']))
except (ValueError, TypeError):
logging.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 使用训练数据的标准化器
if scalers and 'feature_scaler' in scalers:
X_scaled = scalers['feature_scaler'].transform(X)
else:
# 如果没有提供标准化器,直接返回处理后的数组
X_scaled = X
logging.info(f"Preprocessed data shape: {X_scaled.shape}")
logging.info(f"Validation features shape: {X_scaled.shape}")
logging.info(f"Validation features type: {X_scaled.dtype}")
return {
'X': X_scaled,
'y': y # 返回原始成本值
}
except Exception as e:
logging.error(f"Error in validation data preparation: {str(e)}")
logging.error(f"Feature names: {feature_names}")
logging.error(f"Equipment type: {equipment_type}")
raise Exception(f"Validation error: {str(e)}")
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
return self.feature_analyzer.calculate_derived_features(data, equipment_type)
except Exception as e:
logging.error(f"Error calculating derived features: {str(e)}")
raise Exception(f"Feature calculation error: {str(e)}")

1
src/database/__init__.py Normal file
View File

@ -0,0 +1 @@
from .db_connection import get_db_connection

View File

@ -0,0 +1,28 @@
import mysql.connector
from mysql.connector import Error
import logging
from contextlib import contextmanager
# 数据库配置
DB_CONFIG = {
'host': 'localhost',
'user': 'root',
'password': '123456',
'database': 'equipment_cost_db'
}
@contextmanager
def get_db_connection():
"""
数据库连接上下文管理器
"""
conn = None
try:
conn = mysql.connector.connect(**DB_CONFIG)
yield conn
except Error as e:
logging.error(f"Error connecting to MySQL: {str(e)}")
raise
finally:
if conn and conn.is_connected():
conn.close()

266
src/feature_analysis.py Normal file
View File

@ -0,0 +1,266 @@
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
import logging
class FeatureAnalysis:
def __init__(self):
self.scaler = StandardScaler()
self.important_features = []
# 添加特征名称映射
self.feature_names_map = {
# 通用参数
'length_m': '总长(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
# 火箭炮特有参数
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
'rocket_diameter_mm': '口径(mm)',
'rocket_weight_kg': '火箭弹重量(kg)',
'rate_of_fire': '射速(发/分)',
'combat_weight_kg': '战斗重量(kg)',
'speed_kmh': '速度(km/h)',
'min_range_km': '最小射程(km)',
'power_hp': '功率(hp)',
# 火箭炮衍生特征
'fire_density': '火力密度',
'mobility_index': '机动性指标',
'range_ratio': '射程比',
'power_weight_ratio': '功重比',
'volume_density': '体积密度',
# 巡飞弹特有参数
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
'flight_time_min': '巡飞时间(min)',
'folded_length_mm': '折叠长度(mm)',
'folded_width_mm': '折叠宽度(mm)',
'folded_height_mm': '折叠高度(mm)',
# 巡飞弹衍生特征
'warhead_ratio': '战斗部比重',
'speed_ratio': '速度比',
'range_time_ratio': '射程时间比',
'aspect_ratio': '展弦比',
'volume_density': '体积密度'
}
def get_equipment_specific_features(self, equipment_type):
"""
获取特定装备类型的特征列表
"""
# 通用参数
common_features = [
'length_m', # 总长(m)
'width_m', # 宽度(m)
'height_m', # 高度(m)
'weight_kg', # 重量(kg)
'max_range_km' # 最大射程(km)
]
if equipment_type == '火箭炮':
# 火箭炮特有参数
specific_features = [
'firing_angle_horizontal', # 方向射界(度)
'firing_angle_vertical', # 高低射界(度)
'rocket_length_m', # 火箭弹长度(m)
'rocket_diameter_mm', # 口径(mm)
'rocket_weight_kg', # 火箭弹重量(kg)
'rate_of_fire', # 射速(发/分)
'combat_weight_kg', # 战斗重量(kg)
'speed_kmh', # 速度(km/h)
'min_range_km', # 最小射程(km)
'power_hp' # 功率(hp)
]
# 火箭炮衍生特征
derived_features = [
'fire_density', # 火力密度 = 射速 * 火箭弹重量
'mobility_index', # 机动性指标 = 速度 / 战斗重量
'range_ratio', # 射程比 = 最大射程 / 最小射程
'power_weight_ratio', # 功重比 = 功率 / 战斗重量
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
else: # 巡飞弹
# 巡飞弹特有参数
specific_features = [
'wingspan_m', # 翼展(m)
'warhead_weight_kg', # 战斗部重量(kg)
'max_speed_ms', # 最大速度(m/s)
'cruise_speed_kmh', # 巡航速度(km/h)
'flight_time_min', # 巡飞时间(min)
'folded_length_mm', # 折叠长度(mm)
'folded_width_mm', # 折叠宽度(mm)
'folded_height_mm' # 折叠高度(mm)
]
# 巡飞弹衍生特征
derived_features = [
'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
'speed_ratio', # 速度比 = 巡航速度 / 最大速度
'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
if equipment_type == '火箭炮':
# 火箭炮衍生特征计算
if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
else:
data['fire_density'] = 0 # 或者其他默认值
if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
else:
data['mobility_index'] = 0
if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
data['range_ratio'] = data['max_range_km'] / data['min_range_km']
else:
data['range_ratio'] = 0
if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
else:
data['power_weight_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
else: # 巡飞弹
# 巡飞弹衍生特征计算
if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
else:
data['warhead_ratio'] = 0
if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
else:
data['speed_ratio'] = 0
if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
else:
data['range_time_ratio'] = 0
if 'wingspan_m' in data.columns and 'length_m' in data.columns:
data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
else:
data['aspect_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
return data
except Exception as e:
logging.error(f"Error calculating derived features: {str(e)}")
raise
def analyze_features(self, features, target, feature_names):
"""
分析特征重要性和相关性
"""
try:
# 转换为numpy数组
X = np.array(features)
y = np.array(target)
# 数据标准化
X_scaled = self.scaler.fit_transform(X)
# 特征重要性分析
rf = RandomForestRegressor(n_estimators=100, random_state=42)
rf.fit(X_scaled, y)
importances = rf.feature_importances_
# 按重要性排序,使用中文特征名
importance_indices = np.argsort(importances)[::-1]
important_features = [
{
'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
'importance': float(importances[i])
}
for i in importance_indices
]
# 相关性分析
df = pd.DataFrame(X_scaled, columns=feature_names)
correlation_matrix = df.corr().values
# 生成相关性分析数据保留2位小数
correlation_data = []
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
for i in range(len(feature_names)):
for j in range(len(feature_names)):
correlation_data.append([
i, j,
round(correlation_matrix[i][j], 2) # 修改为保留2位小数
])
return {
'important_features': important_features,
'correlation_analysis': {
'features': chinese_feature_names, # 使用中文特征名
'matrix': correlation_data
}
}
except Exception as e:
print(f"Error in feature analysis: {str(e)}")
raise
def preprocess_features(self, equipment_data, equipment_type):
"""
预处理特征数据
"""
try:
# 转换为 DataFrame
df = pd.DataFrame(equipment_data)
# 计算衍生特征
df = self.calculate_derived_features(df, equipment_type)
# 处理缺失值
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
# 转换为数值类型
df[col] = pd.to_numeric(df[col], errors='coerce')
# 使用新的方式填充缺失值
mean_value = df[col].mean()
df[col] = df[col].fillna(mean_value)
logging.info(f"Preprocessed data shape: {df.shape}")
return df
except Exception as e:
logging.error(f"Error preprocessing features: {str(e)}")
raise Exception(f"Feature preprocessing error: {str(e)}")

259
src/import_data.py Normal file
View File

@ -0,0 +1,259 @@
import pandas as pd
import logging
from src.database.db_connection import get_db_connection
logging.basicConfig(
filename='logs/import.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def import_training_data(excel_file):
"""
从Excel导入训练数据到数据库
"""
try:
# 读取所有sheet
rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮')
missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹')
special_df = pd.read_excel(excel_file, sheet_name='特殊参数')
# 记录所有装备名称,用于后续检查
equipment_names = set()
with get_db_connection() as conn:
cursor = conn.cursor()
# 1. 先导入火箭炮数据
logging.info("开始导入火箭炮数据...")
for _, row in rocket_df.iterrows():
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '火箭炮'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
if existing_equipment:
logging.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入")
continue
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (row['名称'], '火箭炮', row['制造商']))
equipment_id = cursor.lastrowid
# 插入通用参数
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
equipment_id,
row['总长_m'] if pd.notna(row['总长_m']) else None,
row['宽度_m'] if pd.notna(row['宽度_m']) else None,
row['高度_m'] if pd.notna(row['高度_m']) else None,
row['重量_kg'] if pd.notna(row['重量_kg']) else None,
row['最大射程_km'] if pd.notna(row['最大射程_km']) else None
))
# 插入火箭炮特有参数
cursor.execute("""
INSERT INTO rocket_artillery_params
(equipment_id, firing_angle_horizontal, firing_angle_vertical,
rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire,
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
structure_layout, engine_model, engine_params, power_hp,
travel_range_km)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
equipment_id,
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
row['高低射界_度'] if pd.notna(row['高低射界_度']) else None,
row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None,
row['口径_mm'] if pd.notna(row['口径_mm']) else None,
row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None,
row['射速_发'] if pd.notna(row['射速_发']) else None,
row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None,
row['速度_km/h'] if pd.notna(row['速度_km/h']) else None,
row['最小射程_km'] if pd.notna(row['最小射程_km']) else None,
row['行走方式'] if pd.notna(row['行走方式']) else None,
row['结构布局'] if pd.notna(row['结构布局']) else None,
row['发动机型号'] if pd.notna(row['发动机型号']) else None,
row['发动机参数'] if pd.notna(row['发动机参数']) else None,
row['功率_hp'] if pd.notna(row['功率_hp']) else None,
row['行程_km'] if pd.notna(row['行程_km']) else None
))
# 插入成本数据
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
""", (equipment_id, row['成本_元']))
logging.info("火箭炮数据导入完成")
# 2. 导入巡飞弹数据
logging.info("开始导入巡飞弹数据...")
for index, row in missile_df.iterrows():
# 记录每行数据的空值情况
null_values = row[row.isna()].index.tolist()
if null_values:
logging.info(f"{index + 2} 中的空值字段: {null_values}")
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '巡飞弹'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
if existing_equipment:
logging.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入")
continue
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (
row['名称'],
'巡飞弹',
row['制造商'] if pd.notna(row['制造商']) else None
))
equipment_id = cursor.lastrowid
# 插入通用参数
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
equipment_id,
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
float(row['重量_kg']) if pd.notna(row['重量_kg']) else None,
float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None
))
# 插入巡飞弹特有参数
cursor.execute("""
INSERT INTO loitering_munition_params
(equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms,
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
folded_length_mm, folded_width_mm, folded_height_mm,
power_system, guidance_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
equipment_id,
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None,
float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None,
float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None,
float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None,
str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None,
str(row['发射方式']) if pd.notna(row['发射方式']) else None,
float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None,
float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None,
float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None,
str(row['动力装置']) if pd.notna(row['动力装置']) else None,
str(row['制导体制']) if pd.notna(row['制导体制']) else None
))
# 插入成本数据
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
""", (equipment_id, float(row['成本_元'])))
logging.info("巡飞弹数据导入完成")
# 提交之前的更改并关闭原有游标
cursor.close()
conn.commit()
# 3. 导入特殊参数
logging.info("开始导入特殊参数...")
for index, row in special_df.iterrows():
equipment_name = row['装备名称']
param_name = row['参数名称']
logging.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'")
if equipment_name not in equipment_names:
logging.warning(f"未找到装备: {equipment_name},请检查名称是否正确")
continue
# 获取装备ID - 使用新的游标
logging.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
SELECT id FROM equipment WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()
if not result:
logging.warning(f"未找到装备: {equipment_name}")
continue
equipment_id = result[0]
logging.debug(f"找到装备ID: {equipment_id}")
# 检查参数是否存在 - 使用新的游标
logging.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
with conn.cursor() as check_cursor:
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = %s AND param_name = %s
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
if exists:
logging.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
continue
# 插入新的参数 - 使用新的游标
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
logging.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
with conn.cursor() as insert_cursor:
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (%s, %s, %s, %s, %s)
""", (
equipment_id,
param_name,
param_value,
param_unit,
param_desc
))
logging.debug(f"成功插入参数记录")
# 最终提交
conn.commit()
logging.info("特殊参数导入完成")
logging.info("所有数据导入成功")
return True
except Exception as e:
logging.error(f"Error importing data: {str(e)}")
raise
if __name__ == "__main__":
try:
excel_file = 'data/equipment_data_20241108.xlsx'
import_training_data(excel_file)
logging.info("All data imported successfully")
except Exception as e:
logging.error(f"Import failed: {str(e)}")

277
src/init_data.sql Normal file
View File

@ -0,0 +1,277 @@
-- 插入装备基本信息
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-- 插入巡飞弹技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES (
1, -- 终结者巡飞弹
0.56,
0.15,
0.20,
2.72,
160.93,
96.56,
24,
15,
'破片杀伤战斗部',
'凭自身动力起飞',
560,
150,
200
);
-- 插入火箭炮技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES (
2, -- 胜利-2火箭炮
10,
2.5,
3.34,
15000,
23
);
-- 插入成本数据(示例数据)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 1000000), -- 终结者巡飞弹成本
(2, 5000000); -- 胜利-2火箭炮成本
-- 插入更多巡飞弹变体数据用于训练
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-- 插入变体技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 终结者-A稍大型号
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
-- 终结者-B稍小型号
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
-- 终结者-C标准型号的改进版
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-- 插入变体成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(3, 1100000), -- 终结者-A成本较高
(4, 900000), -- 终结者-B成本较低
(5, 1050000); -- 终结者-C成本中等
-- 添加更多巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-- 添加它们的技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 哈比
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
-- 游荡者
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
-- 凤凰
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
-- 弹簧刀
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
-- 彩虹-4
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-- 添加成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(6, 800000), -- 哈比
(7, 500000), -- 游荡者
(8, 450000), -- 凤凰
(9, 480000), -- 弹簧刀
(10, 1500000); -- 彩虹-4
-- 火箭炮数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('BM-21', '火箭炮', '俄罗斯', '面目标'),
('SR5', '火箭炮', '中国', '面目标和点目标'),
('HIMARS', '火箭炮', '美国', '战术目标'),
('LAR-160', '火箭炮', '以色列', '面目标'),
('T-122', '火箭炮', '土耳其', '面目标'),
('RM-70', '火箭炮', '捷克', '面目标'),
('ASTROS II', '火箭炮', '巴西', '面目标和点目标');
-- 插入火箭炮参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- BM-21
(1, 7.35, 2.4, 3.1, 13700, 20.4),
-- SR5
(2, 10.2, 2.8, 3.2, 28500, 70),
-- HIMARS
(3, 7.0, 2.4, 3.2, 16250, 70),
-- LAR-160
(4, 6.7, 2.5, 2.8, 15000, 45),
-- T-122
(5, 7.2, 2.5, 2.9, 18000, 40),
-- RM-70
(6, 7.5, 2.5, 3.0, 17200, 20.3),
-- ASTROS II
(7, 8.0, 2.7, 3.1, 24500, 90);
INSERT INTO rocket_artillery_params (
equipment_id,
firing_angle_horizontal,
firing_angle_vertical,
rocket_length_m,
rocket_diameter_mm,
rocket_weight_kg,
rate_of_fire
) VALUES
-- BM-21
(1, 102, 55, 2.87, 122, 66.6, 40),
-- SR5
(2, 110, 60, 4.1, 220, 150, 60),
-- HIMARS
(3, 90, 65, 3.94, 227, 301, 6),
-- LAR-160
(4, 100, 58, 3.3, 160, 110, 18),
-- T-122
(5, 110, 65, 2.95, 122, 65.5, 40),
-- RM-70
(6, 100, 50, 2.87, 122, 66.6, 40),
-- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30);
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000); -- ASTROS II
-- 巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('Hero-120', '巡飞弹', '以色列', '装甲目标'),
('Switchblade 600', '巡飞弹', '美国', '装甲车辆'),
('Warmate', '巡飞弹', '波兰', '轻型装甲目标'),
('CH-901', '巡飞弹', '中国', '人员和轻型装甲车辆'),
('HAROP', '巡飞弹', '以色列', '防空系统'),
('Coyote', '巡飞弹', '美国', '无人机和巡飞弹'),
('WS-43', '巡飞弹', '中国', '固定目标和装甲目标');
-- 插入巡飞弹参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- Hero-120
(8, 1.3, 0.23, 0.23, 12.5, 40),
-- Switchblade 600
(9, 1.3, 0.22, 0.22, 15.0, 40),
-- Warmate
(10, 1.1, 0.15, 0.15, 5.7, 15),
-- CH-901
(11, 1.2, 0.18, 0.18, 9.0, 20),
-- HAROP
(12, 2.5, 0.43, 0.43, 135, 1000),
-- Coyote
(13, 0.9, 0.12, 0.12, 5.9, 20),
-- WS-43
(14, 1.8, 0.35, 0.35, 20, 60);
INSERT INTO loitering_munition_params (
equipment_id,
max_speed_kmh,
cruise_speed_kmh,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- Hero-120
(8, 180, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230),
-- Switchblade 600
(9, 185, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220),
-- Warmate
(10, 150, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150),
-- CH-901
(11, 160, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180),
-- HAROP
(12, 185, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430),
-- Coyote
(13, 150, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120),
-- WS-43
(14, 170, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350);
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(8, 150000), -- Hero-120
(9, 180000), -- Switchblade 600
(10, 80000), -- Warmate
(11, 100000), -- CH-901
(12, 850000), -- HAROP
(13, 75000), -- Coyote
(14, 120000); -- WS-43

398
src/model_trainer.py Normal file
View File

@ -0,0 +1,398 @@
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.impute import SimpleImputer
import xgboost as xgb
import lightgbm as lgb
import logging
import joblib
import os
from src.feature_analysis import FeatureAnalysis
from datetime import datetime
import json
from src.database import get_db_connection
from src.data_preparation import DataPreparation
class ModelTrainer:
def __init__(self):
self.models = {
'xgboost': self._create_xgboost_model(),
'lightgbm': self._create_lightgbm_model(),
'gbdt': self._create_gbdt_model(),
'rf': self._create_rf_model()
}
self.best_model = None
self.imputer = SimpleImputer(strategy='mean')
self.feature_scaler = None
self.target_scaler = None
def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
"""
训练模型并返回评估结果
"""
try:
# 记录数据范围
logging.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
logging.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
results = {}
best_score = -float('inf')
best_model_info = None
for model_name in model_names:
if model_name not in self.models:
logging.warning(f"Unknown model: {model_name}")
continue
logging.info(f"Training {model_name}...")
model = self.models[model_name]
# 训练模型
model.fit(X_train, y_train)
# 计算评估指标
metrics = self._calculate_metrics(
model,
X_train, y_train,
X_val, y_val
)
# 更新最佳模型
if metrics['validation']['r2'] > best_score:
best_score = metrics['validation']['r2']
self.best_model = model
best_model_info = {
'type': model_name,
'r2': float(metrics['validation']['r2']),
'mae': float(metrics['validation']['mae']) if metrics['validation']['mae'] is not None else None,
'rmse': float(metrics['validation']['rmse']) if metrics['validation']['rmse'] is not None else None
}
# 转换 numpy 数据类型为 Python 原生类型
results[model_name] = {
'train': {
'r2': float(metrics['train']['r2']),
'mae': float(metrics['train']['mae']),
'rmse': float(metrics['train']['rmse'])
},
'validation': {
'r2': float(metrics['validation']['r2']),
'mae': float(metrics['validation']['mae']) if metrics['validation']['mae'] is not None else None,
'rmse': float(metrics['validation']['rmse']) if metrics['validation']['rmse'] is not None else None
}
}
# 保存最佳模型
if equipment_type and best_model_info:
self._save_best_model(equipment_type, best_model_info, X_train)
# 转换特征重要性为列表
feature_importance = None
if self.best_model and hasattr(self.best_model, 'feature_importances_'):
feature_importance = [float(x) for x in self.best_model.feature_importances_]
return {
'metrics': results,
'best_model': best_model_info,
'feature_importance': feature_importance
}
except Exception as e:
logging.error(f"Error in model training: {str(e)}")
raise
def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
"""
计算模型评估指标
"""
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
# 训练集评估
train_pred = model.predict(X_train)
train_metrics = {
'r2': r2_score(y_train, train_pred),
'mae': mean_absolute_error(y_train, train_pred),
'rmse': np.sqrt(mean_squared_error(y_train, train_pred))
}
# 验证集评估
if X_val is not None and y_val is not None:
val_pred = model.predict(X_val)
val_metrics = {
'r2': r2_score(y_val, val_pred),
'mae': mean_absolute_error(y_val, val_pred),
'rmse': np.sqrt(mean_squared_error(y_val, val_pred))
}
else:
# 使用交叉验证
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
val_metrics = {
'r2': cv_scores.mean(),
'mae': None,
'rmse': None
}
return {
'train': train_metrics,
'validation': val_metrics
}
def _create_xgboost_model(self):
"""
创建 XGBoost 模型增强正则化
"""
return xgb.XGBRegressor(
n_estimators=50, # 减少树的数量
learning_rate=0.05, # 减小学习率
max_depth=3, # 减小树的深度
min_child_weight=3, # 增加最小子节点权重
subsample=0.7, # 减小样本采样比例
colsample_bytree=0.7, # 减小特征采样比例
reg_alpha=0.1, # L1 正则化
reg_lambda=1, # L2 正则化
random_state=42
)
def _create_lightgbm_model(self):
"""
创建 LightGBM 模型增强正则化
"""
return lgb.LGBMRegressor(
n_estimators=50,
learning_rate=0.05,
max_depth=3,
num_leaves=7,
min_data_in_leaf=3,
min_sum_hessian_in_leaf=1e-3,
subsample=0.7,
colsample_bytree=0.7,
reg_alpha=0.1,
reg_lambda=1,
random_state=42,
verbose=-1
)
def _create_gbdt_model(self):
"""
创建 GBDT 模型增强正则化以减轻过拟合
"""
return GradientBoostingRegressor(
n_estimators=20, # 减少树的数量
learning_rate=0.01, # 减小学习率
max_depth=2, # 减小树的深度
min_samples_split=4, # 增加分裂所需的最小样本数
min_samples_leaf=3, # 增加叶子节点最小样本数
subsample=0.5, # 减小样本采样比例
random_state=42,
validation_fraction=0.2 # 使用部分训练数据作为验证集
)
def _create_rf_model(self):
"""
创建随机森林模型针对小样本数据调整参数
"""
return RandomForestRegressor(
n_estimators=100, # 增加树的数量
max_depth=4, # 限制树的深度
min_samples_split=2, # 减小分需的最小样本数
min_samples_leaf=1, # 减小叶子节点最小样本数
max_features='sqrt', # 特征采样
bootstrap=True, # 使用 bootstrap 采样
oob_score=True, # 计算袋外分数
random_state=42
)
def _save_best_model(self, equipment_type, best_model_info, X_train):
"""
保存最佳模型
"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 保存模型文件
model_path = f'{model_dir}/{equipment_type}_{timestamp}'
if isinstance(self.best_model, xgb.XGBRegressor):
self.best_model.save_model(f'{model_path}.json')
model_format = 'json'
else:
joblib.dump(self.best_model, f'{model_path}.joblib')
model_format = 'joblib'
# 验证标准化器
if not isinstance(self.feature_scaler, StandardScaler):
raise ValueError("Invalid feature scaler")
if not isinstance(self.target_scaler, StandardScaler):
raise ValueError("Invalid target scaler")
# 保存标准化器
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
joblib.dump({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
logging.info(f"Saved model to {model_path}.{model_format}")
logging.info(f"Saved scalers to {scaler_path}")
# 更新数据库中的模型记录
with get_db_connection() as conn:
cursor = conn.cursor()
# 将之前的激活模型设置为非激活
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s
""", (equipment_type,))
# 插入新的模型记录
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path,
scaler_path, r2_score, mae, rmse, feature_importance,
training_data_size, training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, 'system')
""", (
f'{best_model_info["type"]}_{timestamp}',
best_model_info["type"],
equipment_type,
f'{model_path}.{model_format}',
scaler_path,
best_model_info["r2"],
best_model_info["mae"],
best_model_info["rmse"],
json.dumps(self.feature_importance) if hasattr(self, 'feature_importance') else None,
len(X_train)
))
conn.commit()
logging.info(f"Best model saved: {model_path}")
return True
except Exception as e:
logging.error(f"Error saving best model: {str(e)}")
return False
def load_model(self, equipment_type):
"""
加载已训练的模型
"""
try:
logging.info(f"Loading model for {equipment_type}")
# 从数据库获<E5BA93><E88EB7>最新的激活模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
model_record = cursor.fetchone()
if not model_record:
raise ValueError(f"No active model found for {equipment_type}")
logging.info(f"Found model: {model_record['model_name']}")
logging.info(f"Model path: {model_record['model_path']}")
logging.info(f"Scaler path: {model_record['scaler_path']}")
# 检查文件是否存在
if not os.path.exists(model_record['model_path']):
raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
if not os.path.exists(model_record['scaler_path']):
raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
# 加载模型文件
if model_record['model_type'] == 'xgboost':
self.best_model = xgb.XGBRegressor()
self.best_model.load_model(model_record['model_path'])
else:
self.best_model = joblib.load(model_record['model_path'])
# 加载标准化器
try:
scalers = joblib.load(model_record['scaler_path'])
logging.info(f"Loaded scalers: {scalers.keys()}")
if 'feature_scaler' not in scalers or 'target_scaler' not in scalers:
raise ValueError("Missing scalers in saved file")
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
# 验证标准化器
if not hasattr(self.feature_scaler, 'transform') or not hasattr(self.target_scaler, 'transform'):
raise ValueError("Invalid scaler objects")
logging.info("Model and scalers loaded successfully")
logging.info(f"Feature scaler type: {type(self.feature_scaler)}")
logging.info(f"Target scaler type: {type(self.target_scaler)}")
except Exception as e:
logging.error(f"Error loading scalers: {str(e)}")
logging.error(f"Scaler file content: {scalers if 'scalers' in locals() else 'Not loaded'}")
raise ValueError(f"Failed to load scalers: {str(e)}")
return True
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
logging.error("Detailed traceback:", exc_info=True)
return False
def predict(self, features):
"""
使用加载的模型进行预测
"""
try:
if not self.best_model:
raise ValueError("No model loaded")
if not self.feature_scaler:
raise ValueError("Feature scaler not loaded")
if not self.target_scaler:
raise ValueError("Target scaler not loaded")
logging.info("Starting prediction")
logging.info(f"Input features shape: {features.shape}")
logging.info(f"Input features: \n{features}")
# 处理缺失值
features_filled = np.array(features, dtype=float)
features_filled[np.isnan(features_filled)] = 0
features_filled = np.nan_to_num(features_filled, 0)
logging.info(f"Filled features: \n{features_filled}")
# 标准化特征
X = self.feature_scaler.transform(features_filled)
logging.info(f"Transformed features shape: {X.shape}")
logging.info(f"Transformed features: \n{X}")
# 预测
y_pred_scaled = self.best_model.predict(X)
logging.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
logging.info(f"Scaled prediction: {y_pred_scaled}")
# 反标准化
y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
logging.info(f"Final prediction shape: {y_pred.shape}")
logging.info(f"Final prediction: {y_pred}")
# 记录标准化器的参数
logging.info("Target scaler params:")
logging.info(f"Mean: {self.target_scaler.mean_}")
logging.info(f"Scale: {self.target_scaler.scale_}")
return y_pred.ravel()
except Exception as e:
logging.error(f"Error in prediction: {str(e)}")
raise

313
src/pls_regression.py Normal file
View File

@ -0,0 +1,313 @@
# -*- coding: utf-8 -*-
from sklearn.cross_decomposition import PLSRegression
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
import logging
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.model_selection import LeaveOneOut
import os
from datetime import datetime
import joblib
from src.database.db_connection import get_db_connection
class PLSPredictor:
def __init__(self, n_components=2):
"""
初始化PLS回归模型
"""
self.model = PLSRegression(
n_components=n_components,
scale=True,
max_iter=500,
tol=1e-6
)
self.scaler_X = StandardScaler()
self.scaler_y = StandardScaler()
self.feature_names = None
self.model_path = None
# 尝试加载已训练的模型
self.load_model()
# 初始化示例数据
self._initialize_scalers()
def _initialize_scalers(self):
"""
使用示例数据初始化标准化器
"""
# 创建示例数据
example_data = pd.DataFrame({
'length_m': [0.56, 0.58, 0.54],
'width_m': [0.15, 0.16, 0.14],
'height_m': [0.20, 0.21, 0.19],
'weight_kg': [2.72, 2.85, 2.60],
'max_range_km': [24, 26, 22],
'max_speed_kmh': [160.93, 170, 155],
'cruise_speed_kmh': [96.56, 100, 93],
'flight_time_min': [15, 16, 14],
'folded_length_mm': [560, 580, 540],
'folded_width_mm': [150, 160, 140],
'folded_height_mm': [200, 210, 190]
})
# 初始化特征标准化器
self.scaler_X.fit(example_data)
# 初始化目标变量标准化器
example_costs = np.array([[1000000], [1100000], [900000]])
self.scaler_y.fit(example_costs)
def predict(self, features):
"""
使用PLS模型进行预测
"""
try:
# 转换输入数据为DataFrame
if not isinstance(features, pd.DataFrame):
features = pd.DataFrame([features])
# 选择数值特征
numeric_features = features.select_dtypes(include=[np.number]).columns
X = features[numeric_features]
# 标准化特征
X_scaled = self.scaler_X.transform(X)
# 预测
y_pred_scaled = self.model.predict(X_scaled)
y_pred = self.scaler_y.inverse_transform(y_pred_scaled)
# 计算置信区间
ci = self._calculate_confidence_intervals(y_pred)
return {
'predicted_cost': float(abs(y_pred[0][0])),
'confidence_interval': {
'lower': float(abs(ci['lower'])),
'upper': float(abs(ci['upper']))
}
}
except Exception as e:
logging.error(f"Error in PLS prediction: {str(e)}")
raise Exception(f"PLS prediction error: {str(e)}")
def fit(self, X, y):
"""
训练PLS模型
"""
try:
logging.info("=== PLS Training Start ===")
# 1. 检查输入数据
logging.info(f"Input X type: {type(X)}, shape: {X.shape if hasattr(X, 'shape') else 'no shape'}")
logging.info(f"Input y type: {type(y)}, shape: {y.shape if hasattr(y, 'shape') else 'no shape'}")
logging.info(f"X data:\n{X}")
logging.info(f"y data:\n{y}")
# 2. 转换为numpy数组
if isinstance(X, pd.DataFrame):
# 保存特征名称
self.feature_names = X.columns.tolist()
X = X.values
X = np.array(X, dtype=float)
y = np.array(y, dtype=float)
# 3. 标准化数据
logging.info("Standardizing data...")
X_scaled = self.scaler_X.fit_transform(X)
y_scaled = self.scaler_y.fit_transform(y.reshape(-1, 1))
logging.info(f"X_scaled shape: {X_scaled.shape}")
logging.info(f"y_scaled shape: {y_scaled.shape}")
# 4. 训练模型
logging.info("Training PLS model...")
self.model.fit(X_scaled, y_scaled.ravel())
logging.info("PLS model training completed")
# 5. 计算R²分数
logging.info("Calculating R² score...")
y_pred = self.model.predict(X_scaled)
y_pred = self.scaler_y.inverse_transform(y_pred.reshape(-1, 1))
r2 = r2_score(y.reshape(-1, 1), y_pred)
logging.info(f"R² score: {r2}")
result = {
'r2_score': float(r2),
'n_components': int(self.model.n_components),
'feature_importance': None
}
logging.info(f"Final result: {result}")
logging.info("=== PLS Training End ===")
# 保存训练好的模型
equipment_type = 'missile' # 或者从参数中获取
self.save_model(equipment_type)
return result
except Exception as e:
logging.error(f"Error in PLS training: {str(e)}")
logging.error(f"Error traceback:", exc_info=True)
raise Exception(f"PLS training error: {str(e)}")
def _calculate_confidence_intervals(self, predictions, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用 bootstrap 方法计算置信区间
n_predictions = 1000
bootstrap_predictions = []
for _ in range(n_predictions):
# 添加随机噪声
noise = np.random.normal(0, predictions.mean() * 0.05, predictions.shape)
noisy_pred = predictions + noise
bootstrap_predictions.append(noisy_pred)
bootstrap_predictions = np.array(bootstrap_predictions).flatten()
# 计算置信区间
lower = np.percentile(bootstrap_predictions, ((1 - confidence) / 2) * 100)
upper = np.percentile(bootstrap_predictions, (1 - (1 - confidence) / 2) * 100)
return {
'lower': float(lower),
'upper': float(upper)
}
except Exception as e:
logging.error(f"Error calculating confidence intervals: {str(e)}")
# 如果计算失败返回基于10%标准差的区间
mean_pred = np.mean(predictions)
return {
'lower': float(mean_pred * 0.9),
'upper': float(mean_pred * 1.1)
}
def _get_feature_importance(self):
"""
计算特征重要性
"""
try:
if not hasattr(self.model, 'x_weights_'):
return {}
# 获取 VIP 分数
t = self.model.x_scores_
w = self.model.x_weights_
q = self.model.y_loadings_
# 计算每个特征的 VIP 分数
m, p = w.shape
vips = np.zeros((p,))
s = np.diag(t.T @ t @ q.T @ q).reshape(m, -1)
total_s = np.sum(s)
for i in range(p):
weight = np.array([(w[j,i] / np.linalg.norm(w[:,i]))**2 for j in range(m)])
vips[i] = np.sqrt(p*(s.T @ weight)/total_s)
# 创建特征重要性字典
feature_importance = {}
for i, score in enumerate(vips):
feature_name = f"feature_{i}" if self.feature_names is None else self.feature_names[i]
feature_importance[feature_name] = float(score)
# 按重要性排序
return dict(sorted(feature_importance.items(), key=lambda x: x[1], reverse=True))
except Exception as e:
logging.error(f"Error calculating feature importance: {str(e)}")
return {}
def save_model(self, equipment_type):
"""
保存模型和标准化器
"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 保存模型文件
model_path = f'{model_dir}/pls_{equipment_type}_{timestamp}'
joblib.dump({
'model': self.model,
'scaler_X': self.scaler_X,
'scaler_y': self.scaler_y,
'feature_names': self.feature_names
}, f'{model_path}.joblib')
# 更新数据库中的模型记录
with get_db_connection() as conn:
cursor = conn.cursor()
# 将之前的激活模型设置为非激活
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type = 'pls'
""", (equipment_type,))
# 插入新的模型记录
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path,
r2_score, training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, 'system')
""", (
f'PLS_{timestamp}',
'pls',
equipment_type,
f'{model_path}.joblib',
float(self.r2_score_)
))
conn.commit()
self.model_path = f'{model_path}.joblib'
logging.info(f"Model saved to {self.model_path}")
except Exception as e:
logging.error(f"Error saving model: {str(e)}")
raise Exception(f"Failed to save model: {str(e)}")
def load_model(self):
"""
加载最新的激活模型
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取最新的激活模型
cursor.execute("""
SELECT * FROM trained_models
WHERE model_type = 'pls' AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""")
model_record = cursor.fetchone()
if model_record and os.path.exists(model_record['model_path']):
# 加载模型文件
saved_data = joblib.load(model_record['model_path'])
self.model = saved_data['model']
self.scaler_X = saved_data['scaler_X']
self.scaler_y = saved_data['scaler_y']
self.feature_names = saved_data['feature_names']
self.model_path = model_record['model_path']
logging.info(f"Loaded model from {self.model_path}")
return True
return False
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
return False

9
src/real_data.sql Normal file
View File

@ -0,0 +1,9 @@
-- 火箭炮数据13种
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('BM-21', '火箭炮', '俄罗斯', '面目标'),
-- ... 其他12种火箭炮数据
-- 巡飞弹数据18种
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('Hero-120', '巡飞弹', '以色列', '装甲目标'),
-- ... 其他17种巡飞弹数据

1298
src/routes.py Normal file

File diff suppressed because it is too large Load Diff

28
src/run.py Normal file
View File

@ -0,0 +1,28 @@
import os
import logging
from src.app import app
# 确保必要的目录存在
os.makedirs('logs', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('data', exist_ok=True)
# 配置日志
logging.basicConfig(
filename='logs/server.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
if __name__ == "__main__":
try:
logging.info("Starting server...")
app.run(
host='localhost',
port=5001,
debug=True, # 启用调试模式
use_reloader=True # 启用自动重载
)
except Exception as e:
logging.error(f"Server failed to start: {str(e)}")
raise

129
src/schema.sql Normal file
View File

@ -0,0 +1,129 @@
-- 装备基本信息表
CREATE TABLE equipment (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100), -- 名称
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
manufacturer VARCHAR(100), -- 制造商
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 通用参数表
CREATE TABLE common_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
length_m FLOAT, -- 总长(m)
width_m FLOAT, -- 宽度(m)
height_m FLOAT, -- 高度(m)
weight_kg FLOAT, -- 重量(kg)
max_range_km FLOAT, -- 最大射程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 火箭炮特有参数表
CREATE TABLE rocket_artillery_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
firing_angle_horizontal FLOAT, -- 方向射界(度)
firing_angle_vertical FLOAT, -- 高低射界(度)
rocket_length_m FLOAT, -- 火箭弹长度(m)
rocket_diameter_mm FLOAT, -- 弹体直径(mm)
rocket_weight_kg FLOAT, -- 火箭弹重量(kg)
rate_of_fire FLOAT, -- 射速(发/分钟)
combat_weight_kg FLOAT, -- 战斗重量(kg)
speed_kmh FLOAT, -- 速度(km/h)
min_range_km FLOAT, -- 最小射程(km)
mobility_type VARCHAR(50), -- 行走方式
structure_layout VARCHAR(100), -- 结构布局
engine_model VARCHAR(100), -- 发动机型号
engine_params TEXT, -- 发动机参数
power_hp FLOAT, -- 功率(hp)
travel_range_km FLOAT, -- 行程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 巡飞弹特有参数表
CREATE TABLE loitering_munition_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
wingspan_m FLOAT, -- 翼展(m)
warhead_weight_kg FLOAT, -- 战斗部重量(kg)
max_speed_ms FLOAT, -- 最大速度(m/s)
cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
flight_time_min FLOAT, -- 巡飞时间(min)
warhead_type VARCHAR(50), -- 战斗部类型
launch_mode VARCHAR(50), -- 发射方式
folded_length_mm FLOAT, -- 折叠长度(mm)
folded_width_mm FLOAT, -- 折叠宽度(mm)
folded_height_mm FLOAT, -- 折叠高度(mm)
power_system VARCHAR(100), -- 动力装置
guidance_system VARCHAR(100), -- 制导体制
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 成本数据表
CREATE TABLE cost_data (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
actual_cost DECIMAL(15,2), -- 实际成本(元)
predicted_cost DECIMAL(15,2), -- 预测成本(元)
prediction_date TIMESTAMP, -- 预测日期
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 特殊参数表
CREATE TABLE custom_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
param_name VARCHAR(100), -- 参数名称
param_value VARCHAR(255), -- 参数值
param_unit VARCHAR(50), -- 参数单位
description TEXT, -- 参数说明
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_equipment_type ON equipment(type);
CREATE INDEX idx_equipment_name ON equipment(name);
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-- 数据集表
CREATE TABLE datasets (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL, -- 数据集名称
description TEXT, -- 数据集描述
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
-- 数据集-装备关联表
CREATE TABLE dataset_equipment (
dataset_id INT NOT NULL,
equipment_id INT NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
);
-- 训练模型表
CREATE TABLE trained_models (
id INT AUTO_INCREMENT PRIMARY KEY,
model_name VARCHAR(100) NOT NULL, -- 模型名称
model_type VARCHAR(50) NOT NULL, -- 模型类型
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
model_path VARCHAR(255) NOT NULL, -- 模型文件路径
scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径
r2_score FLOAT, -- R²分数
mae FLOAT, -- 平均绝对误差
rmse FLOAT, -- 均方根误差
feature_importance JSON, -- 特征重要性
training_data_size INT, -- 训练数据量
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
created_by VARCHAR(50) -- 创建者
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX idx_model_active ON trained_models(is_active);

192
src/test_api.py Normal file
View File

@ -0,0 +1,192 @@
import requests
import json
def test_api_endpoints():
"""
测试API各个端点
"""
base_url = 'http://localhost:5001/api'
# 1. 测试根路由
print("\n1. 测试 API 根路由")
response = requests.get(f'{base_url}/')
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 2. 测试机器学习预测接口
print("\n2. 测试机器学习预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 0.56,
"width_m": 0.15,
"height_m": 0.20,
"weight_kg": 2.72,
"max_range_km": 24,
"max_speed_kmh": 160.93,
"cruise_speed_kmh": 96.56,
"flight_time_min": 15,
"folded_length_mm": 560,
"folded_width_mm": 150,
"folded_height_mm": 200,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post(
f'{base_url}/predict',
json=predict_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 3. 测试 PLS 预测接口
print("\n3. 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 4. 测试特征分析接口
print("\n4. 测试特征分析接口")
analysis_data = {
"data": [{
"type": "巡飞弹",
"length_m": 0.56,
"width_m": 0.15,
"height_m": 0.20,
"weight_kg": 2.72,
"max_range_km": 24,
"max_speed_kmh": 160.93,
"cruise_speed_kmh": 96.56,
"flight_time_min": 15,
"folded_length_mm": 560,
"folded_width_mm": 150,
"folded_height_mm": 200
}],
"cost": [1000000]
}
response = requests.post(
f'{base_url}/analyze-features',
json=analysis_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 5. 测试机器学习模型训练接口
print("\n5. 测试机器学习模型训练接口")
training_data = {
"training_data": [
{
"type": "巡飞弹",
"length_m": 0.56,
"width_m": 0.15,
"height_m": 0.20,
"weight_kg": 2.72,
"max_range_km": 24,
"max_speed_kmh": 160.93,
"cruise_speed_kmh": 96.56,
"flight_time_min": 15,
"folded_length_mm": 560,
"folded_width_mm": 150,
"folded_height_mm": 200,
"cost": 1000000
},
{
"type": "巡飞弹",
"length_m": 0.58,
"width_m": 0.16,
"height_m": 0.21,
"weight_kg": 2.85,
"max_range_km": 26,
"max_speed_kmh": 170,
"cruise_speed_kmh": 100,
"flight_time_min": 16,
"folded_length_mm": 580,
"folded_width_mm": 160,
"folded_height_mm": 210,
"cost": 1100000
},
{
"type": "巡飞弹",
"length_m": 0.54,
"width_m": 0.14,
"height_m": 0.19,
"weight_kg": 2.60,
"max_range_km": 22,
"max_speed_kmh": 155,
"cruise_speed_kmh": 93,
"flight_time_min": 14,
"folded_length_mm": 540,
"folded_width_mm": 140,
"folded_height_mm": 190,
"cost": 900000
}
],
"equipment_type": "巡飞弹"
}
response = requests.post(
f'{base_url}/train',
json=training_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 6. 测试 PLS 模型训练接口
print("\n6. 测试 PLS 模型训练接口")
# 使用真实的训练数据
training_data = {
"training_data": [
{
"length_m": 1.3, # 哈比
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_kmh": 180,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230
},
{
"length_m": 2.5, # HAROP
"width_m": 0.43,
"height_m": 0.43,
"weight_kg": 135,
"max_range_km": 1000,
"max_speed_kmh": 185,
"cruise_speed_kmh": 110,
"flight_time_min": 360,
"folded_length_mm": 2500,
"folded_width_mm": 430,
"folded_height_mm": 430
},
{
"length_m": 1.1, # Warmate
"width_m": 0.15,
"height_m": 0.15,
"weight_kg": 5.7,
"max_range_km": 15,
"max_speed_kmh": 150,
"cruise_speed_kmh": 90,
"flight_time_min": 30,
"folded_length_mm": 1100,
"folded_width_mm": 150,
"folded_height_mm": 150
}
],
"actual_costs": [150000, 850000, 80000] # 对应的实际成本
}
response = requests.post(
f'{base_url}/pls/train',
json=training_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
if __name__ == "__main__":
try:
test_api_endpoints()
except Exception as e:
print(f"测试过程中出现错误: {str(e)}")

18
vite.config.js Normal file
View File

@ -0,0 +1,18 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import path from 'path'
// https://vitejs.dev/config/
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': path.resolve(__dirname, 'src'),
}
},
define: {
__VUE_OPTIONS_API__: true,
__VUE_PROD_DEVTOOLS__: false,
__VUE_PROD_HYDRATION_MISMATCH_DETAILS__: false
}
})