增加一些文件
This commit is contained in:
commit
865c93c811
82
.cursorrules
Normal file
82
.cursorrules
Normal file
@ -0,0 +1,82 @@
|
||||
# 开发流程
|
||||
|
||||
First ensure basic functionality works
|
||||
Implement core functionality using the simplest direct approach
|
||||
Ensure data flow is working correctly
|
||||
Verify results are accurate
|
||||
Then gradually add additional features
|
||||
Add error handling
|
||||
Add data validation
|
||||
Add format conversion
|
||||
Add logging
|
||||
Improve user experience
|
||||
Avoid premature optimization
|
||||
Don't do complex data validation at the start
|
||||
Don't worry about performance optimization early
|
||||
Don't over-engineer
|
||||
This development flow:
|
||||
Quickly validates if core functionality works
|
||||
Identifies and fixes fundamental issues early
|
||||
Avoids wasting time on unnecessary optimizations
|
||||
Makes code easier to maintain and debug
|
||||
These principles should guide all code responses, focusing on getting the basics working first before adding complexity.
|
||||
|
||||
# 代码修改最佳实践
|
||||
|
||||
1. 修改前的准备
|
||||
|
||||
- 检查相关文件和依赖关系
|
||||
- 确保命名一致性
|
||||
- 添加必要的日志记录
|
||||
- 准备回滚方案
|
||||
|
||||
2. 修改过程中
|
||||
|
||||
- 遵循统一的代码风格
|
||||
- 添加适当的错误处理
|
||||
- 保持代码的可读性
|
||||
- 避免重复代码
|
||||
|
||||
3. 修改后的验证
|
||||
|
||||
- 验证主要功能
|
||||
- 测试边界条件
|
||||
- 检查错误处理
|
||||
- 验证性能影响
|
||||
|
||||
4. 文档更新
|
||||
|
||||
- 更新相关文档
|
||||
- 添加注释说明
|
||||
- 记录重要修改
|
||||
- 更新调试信息
|
||||
|
||||
5. 代码审查要点
|
||||
|
||||
- 检查命名规范
|
||||
- 验证错误处理
|
||||
- 确认日志完整性
|
||||
- 评估代码质量
|
||||
|
||||
6. 调试建议
|
||||
|
||||
- 添加详细日志
|
||||
- 使用断点调试
|
||||
- 验证数据流
|
||||
- 检查状态变化
|
||||
|
||||
7. 性能考虑
|
||||
|
||||
- 避免过早优化
|
||||
- 关注关键路径
|
||||
- 合理使用缓存
|
||||
- 优化数据库查询
|
||||
|
||||
8. 安全性检查
|
||||
|
||||
- 验证输入数据
|
||||
- 处理异常情况
|
||||
- 保护敏感信息
|
||||
- 添加访问控制
|
||||
|
||||
These practices help maintain code quality and reduce potential issues.
|
||||
27
.gitignore
vendored
Normal file
27
.gitignore
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
.DS_Store
|
||||
node_modules
|
||||
/dist
|
||||
/models
|
||||
/logs
|
||||
/uploads
|
||||
/data
|
||||
|
||||
# local env files
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Log files
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
|
||||
# Editor directories and files
|
||||
.idea
|
||||
.vscode
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
/frontend/node_modules
|
||||
23
README.md
Normal file
23
README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# 数据库配置说明
|
||||
|
||||
本系统使用 MySQL 8.0+ 作为数据库。在安装 MySQL 后,需要:
|
||||
|
||||
1. 创建数据库用户
|
||||
|
||||
```sql
|
||||
CREATE USER 'equipment_user'@'localhost' IDENTIFIED BY 'your_password';
|
||||
GRANT ALL PRIVILEGES ON equipment_cost_db.* TO 'equipment_user'@'localhost';
|
||||
FLUSH PRIVILEGES;
|
||||
```
|
||||
|
||||
2. 配置数据库字符集
|
||||
确保 MySQL 配置文件(my.cnf 或 my.ini)包含以下设置:
|
||||
|
||||
```ini
|
||||
[mysqld]
|
||||
character-set-server=utf8mb4
|
||||
collation-server=utf8mb4_unicode_ci
|
||||
|
||||
[client]
|
||||
default-character-set=utf8mb4
|
||||
```
|
||||
8
app.py
Normal file
8
app.py
Normal file
@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
filename='logs/api.log',
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
32
config.py
Normal file
32
config.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
import secrets
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
|
||||
|
||||
# 安全密钥配置(自动生成随机密钥)
|
||||
SECRET_KEY = secrets.token_hex(16)
|
||||
|
||||
# 环境配置
|
||||
DEBUG = True
|
||||
ENV = 'development'
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
||||
ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'}
|
||||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制
|
||||
|
||||
# API配置
|
||||
API_VERSION = 'v1'
|
||||
API_PREFIX = f'/api/{API_VERSION}'
|
||||
|
||||
# 跨域配置
|
||||
CORS_ORIGINS = [
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
]
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL = 'DEBUG'
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log')
|
||||
BIN
data/.DS_Store
vendored
Normal file
BIN
data/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
data/equipment_data_20241108.xlsx
Normal file
BIN
data/equipment_data_20241108.xlsx
Normal file
Binary file not shown.
BIN
data/equipment_data_20241108_training.xlsx
Normal file
BIN
data/equipment_data_20241108_training.xlsx
Normal file
Binary file not shown.
BIN
data/equipment_data_20241108_verify.xlsx
Normal file
BIN
data/equipment_data_20241108_verify.xlsx
Normal file
Binary file not shown.
616
docs/debug.md
Normal file
616
docs/debug.md
Normal file
@ -0,0 +1,616 @@
|
||||
# 调试记录
|
||||
|
||||
## 特殊参数显示问题
|
||||
|
||||
### 问题描述
|
||||
|
||||
在数据管理页面中,装备详情对话框的特殊参数部分显示为空行或不显示。
|
||||
|
||||
### 调试步骤
|
||||
|
||||
1. 后端数据查询
|
||||
|
||||
```sql
|
||||
# 测试特殊参数查询
|
||||
SELECT equipment_id, param_name, param_value, param_unit
|
||||
FROM custom_params
|
||||
WHERE param_name IS NOT NULL
|
||||
AND param_value IS NOT NULL
|
||||
LIMIT 5
|
||||
```
|
||||
|
||||
2. 日志记录
|
||||
|
||||
```python
|
||||
logging.info(f"Getting details for equipment ID: {id}")
|
||||
logging.info(f"Equipment type: {equipment_type}")
|
||||
logging.info(f"Found equipment details: {result['name']}")
|
||||
logging.info(f"Custom params: {result.get('custom_params')}")
|
||||
```
|
||||
|
||||
3. 前端调试
|
||||
|
||||
```javascript
|
||||
console.log('Requesting details for row:', row)
|
||||
console.log('Details response:', response.data)
|
||||
console.log('Custom params:', response.data.custom_params)
|
||||
console.log('Selected data:', selectedData.value)
|
||||
```
|
||||
|
||||
### 关键发现
|
||||
|
||||
1. 数据库查询
|
||||
|
||||
- 特殊参数表中有数据
|
||||
- JSON_ARRAYAGG 返回的格式需要处理
|
||||
- 需要过滤掉 NULL 值
|
||||
|
||||
2. 数据格式
|
||||
|
||||
- 后端返回的特殊参数是 JSON 字符串
|
||||
- 需要在前端解析为数组
|
||||
- 确保数组不为空
|
||||
|
||||
3. 前端渲染
|
||||
|
||||
- 条件判断需要更严格
|
||||
- 需要确保数据类型正确
|
||||
- 需要正确格式化显示值
|
||||
|
||||
### 解决方案
|
||||
|
||||
1. 后端查询优化
|
||||
|
||||
```sql
|
||||
(
|
||||
SELECT JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'id', csp.id,
|
||||
'param_name', csp.param_name,
|
||||
'param_value', csp.param_value,
|
||||
'param_unit', csp.param_unit,
|
||||
'description', csp.description
|
||||
)
|
||||
)
|
||||
FROM custom_params csp
|
||||
WHERE csp.equipment_id = e.id
|
||||
AND csp.param_name IS NOT NULL
|
||||
AND csp.param_value IS NOT NULL
|
||||
) as custom_params
|
||||
```
|
||||
|
||||
2. 前端数据处理
|
||||
|
||||
```javascript
|
||||
// 确保 custom_params 是数组
|
||||
if (typeof response.data.custom_params === 'string') {
|
||||
response.data.custom_params = JSON.parse(response.data.custom_params)
|
||||
}
|
||||
```
|
||||
|
||||
3. 渲染条件优化
|
||||
|
||||
```vue
|
||||
<template v-if="selectedData?.custom_params && Array.isArray(selectedData.custom_params) && selectedData.custom_params.length > 0">
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. 数据库查询
|
||||
|
||||
- 使用子查询而不是 JOIN 获取特殊参数
|
||||
- 确保返回格式统一
|
||||
- 过滤无效数据
|
||||
|
||||
2. 数据处理
|
||||
|
||||
- 统一数据格式
|
||||
- 处理空值和异常
|
||||
- 保持类型一致
|
||||
|
||||
3. 前端显示
|
||||
|
||||
- 严格的条件判断
|
||||
- 类型检查
|
||||
- 格式化显示
|
||||
|
||||
4. 调试方法
|
||||
|
||||
- 使用日志跟踪数据流
|
||||
- 检查数据格式和类型
|
||||
- 验证每个环节的数据
|
||||
|
||||
## 编辑对话框问题
|
||||
|
||||
### 问题描述
|
||||
|
||||
在数据管理页面中,编辑对话框的成本信息分区和特殊参数分区显示不正确。
|
||||
|
||||
### 调试步骤
|
||||
|
||||
1. 检查数据流
|
||||
|
||||
```javascript
|
||||
console.log('Editing row:', row)
|
||||
console.log('Edit data response:', response.data)
|
||||
console.log('Parsed custom params:', data.custom_params)
|
||||
console.log('Edit form data:', editForm.value)
|
||||
```
|
||||
|
||||
2. 检查模板结构
|
||||
|
||||
```vue
|
||||
<!-- 错误的嵌套结构 -->
|
||||
<el-form>
|
||||
<template>
|
||||
<el-divider>成本信息</el-divider>
|
||||
</template>
|
||||
</el-form>
|
||||
|
||||
<!-- 正确的结构 -->
|
||||
<el-divider>成本信息</el-divider>
|
||||
<el-form>
|
||||
<!-- 表单项 -->
|
||||
</el-form>
|
||||
```
|
||||
|
||||
### 关键发现
|
||||
|
||||
1. 模板结构问题
|
||||
|
||||
- el-divider 不应该嵌套在 template 中
|
||||
- 每个分区需要独立的 el-form
|
||||
- 避免不必要的 template 嵌套
|
||||
|
||||
2. 数据类型问题
|
||||
|
||||
- 后端返回的数值是字符串类型
|
||||
- el-input-number 组件需要数值类型
|
||||
- 需要在前端进行类型转换
|
||||
|
||||
3. 条件渲染问题
|
||||
|
||||
- v-if 条件过于严格可能导致内容不显示
|
||||
- 某些字段应该始终显示
|
||||
- 某些字段只在有值时显示
|
||||
|
||||
### 解决方案
|
||||
|
||||
1. 修改模板结构
|
||||
|
||||
```vue
|
||||
<!-- 成本信息 -->
|
||||
<el-divider content-position="left">成本信息</el-divider>
|
||||
<el-form :model="editForm" label-width="120px">
|
||||
<el-form-item label="实际成本(元)">
|
||||
<el-input-number v-model="editForm.actual_cost"></el-input-number>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
```
|
||||
|
||||
2. 数据类型转换
|
||||
|
||||
```javascript
|
||||
// 转换所有数值类型字段
|
||||
Object.keys(data).forEach(key => {
|
||||
if (isNumberInput(key) && data[key] !== null && data[key] !== undefined) {
|
||||
data[key] = Number(data[key])
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
3. 优化条件渲染
|
||||
|
||||
```vue
|
||||
<!-- 始终显示必要字段 -->
|
||||
<el-form-item label="实际成本(元)">
|
||||
<el-input-number v-model="editForm.actual_cost"></el-input-number>
|
||||
</el-form-item>
|
||||
|
||||
<!-- 只在有值时显示可选字段 -->
|
||||
<el-form-item label="预测成本(元)" v-if="editForm.predicted_cost">
|
||||
<el-input-number v-model="editForm.predicted_cost" disabled></el-input-number>
|
||||
</el-form-item>
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. 模板结构
|
||||
|
||||
- 保持清晰的分区结构
|
||||
- 避免不必要的嵌套
|
||||
- 使用合适的组件层级
|
||||
|
||||
2. 数据处理
|
||||
|
||||
- 在获取数据后立即进行类型转换
|
||||
- 确保数据类型与组件要求匹配
|
||||
- 处理好空值和未定义值
|
||||
|
||||
3. 条件渲染
|
||||
|
||||
- 合理使用 v-if 和 v-show
|
||||
- 必要字段始终显示
|
||||
- 可选字段根据条件显示
|
||||
|
||||
4. 调试方法
|
||||
|
||||
- 使用 console.log 跟踪数据流
|
||||
- 检查组件的属性要求
|
||||
- 验证数据类型和结构
|
||||
|
||||
## 特征分析功能问题
|
||||
|
||||
### 问题描述
|
||||
|
||||
特征分析页面中,第一次点击分析按钮时,图表不显示,只有标题栏。第二次点击才能正常显示图表。
|
||||
|
||||
### 调试步骤
|
||||
|
||||
1. 检查数据流
|
||||
|
||||
```javascript
|
||||
console.log('Analysis result:', analysisResult.value)
|
||||
console.log('Charts not ready:', {
|
||||
importanceChartRef: !!importanceChartRef.value,
|
||||
correlationChartRef: !!correlationChartRef.value,
|
||||
analysisResult: !!analysisResult.value
|
||||
})
|
||||
```
|
||||
|
||||
2. 检查渲染时机
|
||||
|
||||
```javascript
|
||||
// 使用 watch 监听分析结果变化
|
||||
watch(() => analysisResult.value, async (newResult) => {
|
||||
if (newResult) {
|
||||
await nextTick()
|
||||
setTimeout(() => {
|
||||
renderCharts()
|
||||
}, 100)
|
||||
}
|
||||
}, { deep: true })
|
||||
```
|
||||
|
||||
3. 检查图表实例管理
|
||||
|
||||
```javascript
|
||||
// 销毁旧的图表实例
|
||||
if (importanceChart.value) {
|
||||
importanceChart.value.dispose()
|
||||
}
|
||||
if (correlationChart.value) {
|
||||
correlationChart.value.dispose()
|
||||
}
|
||||
```
|
||||
|
||||
### 关键发现
|
||||
|
||||
1. 渲染时机问题
|
||||
|
||||
- DOM 元素可能还未准备好
|
||||
- 数据更新后需要等待 DOM 更新
|
||||
- 需要正确管理图表实例
|
||||
|
||||
2. 图表实例管理
|
||||
|
||||
- 需要保存图表实例的引用
|
||||
- 重新渲染前需要销毁旧实例
|
||||
- 组件卸载时需要清理实例
|
||||
|
||||
3. 数据格式问题
|
||||
|
||||
- 特征名称需要中文映射
|
||||
- 相关性数据需要保留2位小数
|
||||
- 需要正确处理缺失值
|
||||
|
||||
### 解决方案
|
||||
|
||||
1. 优化渲染逻辑
|
||||
|
||||
```javascript
|
||||
// 使用 nextTick 和延时确保 DOM 已更新
|
||||
await nextTick()
|
||||
setTimeout(() => {
|
||||
renderCharts()
|
||||
}, 100)
|
||||
```
|
||||
|
||||
2. 完善图表实例管理
|
||||
|
||||
```javascript
|
||||
// 保存图表实例的引用
|
||||
const importanceChart = ref(null)
|
||||
const correlationChart = ref(null)
|
||||
|
||||
// 组件卸载时清理
|
||||
onUnmounted(() => {
|
||||
importanceChart.value?.dispose()
|
||||
correlationChart.value?.dispose()
|
||||
})
|
||||
```
|
||||
|
||||
3. 优化数据处理
|
||||
|
||||
```javascript
|
||||
// 使用中文特征名
|
||||
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
|
||||
|
||||
// 保留2位小数
|
||||
correlation_data.append([
|
||||
i, j,
|
||||
round(correlation_matrix[i][j], 2)
|
||||
])
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. 渲染控制
|
||||
|
||||
- 使用 watch 监听数据变化
|
||||
- 使用 nextTick 等待 DOM 更新
|
||||
- 添加适当的延时确保渲染
|
||||
|
||||
2. 实例管理
|
||||
|
||||
- 保存图表实例引用
|
||||
- 及时销毁旧实例
|
||||
- 组件卸载时清理
|
||||
|
||||
3. 数据处理
|
||||
|
||||
- 统一使用中文特征名
|
||||
- 控制数值精度
|
||||
- 处理好缺失值
|
||||
|
||||
4. 调试方法
|
||||
|
||||
- 添加详细的日志记录
|
||||
- 检查 DOM 元素状态
|
||||
- 验证数据格式
|
||||
|
||||
## 页面状态保持问题
|
||||
|
||||
### 问题描述
|
||||
|
||||
特征分析<EFBFBD><EFBFBD><EFBFBD>面切换到其他页面后再返回,页面状态(分析结果和图表)会丢失,需要重新分析。
|
||||
|
||||
### 调试步骤
|
||||
|
||||
1. 检查路由配置
|
||||
|
||||
```javascript
|
||||
// 错误的配置
|
||||
<keep-alive include="AnalysisPage">
|
||||
<router-view></router-view>
|
||||
</keep-alive>
|
||||
|
||||
// 正确的配置
|
||||
<router-view v-slot="{ Component }">
|
||||
<keep-alive>
|
||||
<component :is="Component" :key="$route.fullPath" />
|
||||
</keep-alive>
|
||||
</router-view>
|
||||
```
|
||||
|
||||
2. 检查组件定义
|
||||
|
||||
```javascript
|
||||
// 错误的组件名称定义
|
||||
<script setup name="AnalysisPage">
|
||||
|
||||
// 正确的组件名称定义
|
||||
const __name = 'AnalysisPage'
|
||||
```
|
||||
|
||||
### 关键发现
|
||||
|
||||
1. keep-alive 配置问题
|
||||
|
||||
- 需要使用 v-slot API
|
||||
- 需要使用动态组件
|
||||
- 需要添加 key 属性
|
||||
|
||||
2. 组件定义问题
|
||||
|
||||
- setup 语法糖不支持直接添加 name
|
||||
- 需要使用 __name 或 defineOptions
|
||||
|
||||
3. 缓存范围问题
|
||||
|
||||
- 不需要指定 include 属性
|
||||
- 缓存所有路由组件更简单
|
||||
- 避免组件名称不匹配的问题
|
||||
|
||||
### 解决方案
|
||||
|
||||
1. 修改路由视图配置
|
||||
|
||||
```vue
|
||||
<router-view v-slot="{ Component }">
|
||||
<keep-alive>
|
||||
<component :is="Component" :key="$route.fullPath" />
|
||||
</keep-alive>
|
||||
</router-view>
|
||||
```
|
||||
|
||||
2. 修改组件定义
|
||||
|
||||
```javascript
|
||||
const __name = 'AnalysisPage'
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. 路由配置
|
||||
|
||||
- 使用 Vue3 的新 API
|
||||
- 保持配置简单清晰
|
||||
- 避免不必要的限制
|
||||
|
||||
2. 组件定义
|
||||
|
||||
- 使用推荐的方式定义组件名
|
||||
- 避免使用已废弃的语法
|
||||
- 保持代码一致性
|
||||
|
||||
3. 状态管理
|
||||
|
||||
- 合理使用 keep-alive
|
||||
- 正确处理组件生命周期
|
||||
- 注意清理工作
|
||||
|
||||
4. 调试方法
|
||||
|
||||
- 检查组件是否被缓存
|
||||
- 验证状态是否保持
|
||||
- 确认生命周期钩子的执行
|
||||
|
||||
## 代码修改最佳实践
|
||||
|
||||
### 1. 修改前的准备
|
||||
|
||||
1. 检查相关文件:
|
||||
|
||||
```
|
||||
前端组件修改时检查:
|
||||
- 相关的路由配置
|
||||
- 父子组件关系
|
||||
- 共用的组件和函数
|
||||
- API调用
|
||||
|
||||
后端接口修改时检查:
|
||||
- 路由定义
|
||||
- 数据库查询
|
||||
- 相关的工具类和函数
|
||||
- 错误处理
|
||||
```
|
||||
|
||||
2. 保持命名一致性:
|
||||
|
||||
```
|
||||
- 类名:ModelTrainer 而不是 ModelTraining
|
||||
- 文件名:train_model.py 对应 ModelTrainer
|
||||
- 变量名:保持前后端一致的命名规范
|
||||
```
|
||||
|
||||
3. 添加日志记录:
|
||||
|
||||
```python
|
||||
# 在关键节点添加日志
|
||||
logging.info(f"Starting model training for {equipment_type}")
|
||||
logging.info(f"Training dataset: {train_dataset_id}")
|
||||
logging.error(f"Error in model training: {str(e)}")
|
||||
```
|
||||
|
||||
### 2. 修改过程中
|
||||
|
||||
1. 错误处理:
|
||||
|
||||
```python
|
||||
try:
|
||||
# 主要逻辑
|
||||
except Exception as e:
|
||||
logging.error(f"Error: {str(e)}")
|
||||
logging.error("Detailed traceback:", exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
```
|
||||
|
||||
2. 数据验证:
|
||||
|
||||
```python
|
||||
# 验证输入
|
||||
if not formData.value.type:
|
||||
throw new Error('请选择装备类型')
|
||||
if not formData.value.train_dataset_id:
|
||||
throw new Error('请选择训练数据集')
|
||||
```
|
||||
|
||||
3. 状态管理:
|
||||
|
||||
```javascript
|
||||
// 重置状态
|
||||
formData.value.train_dataset_id = null
|
||||
formData.value.validation_dataset_id = null
|
||||
trainingResult.value = null
|
||||
```
|
||||
|
||||
### 3. 修改后的验证
|
||||
|
||||
1. 功能测试:
|
||||
|
||||
```
|
||||
- 验证主要功能
|
||||
- 测试边界条件
|
||||
- 检查错误处理
|
||||
```
|
||||
|
||||
2. 性能检查:
|
||||
|
||||
```
|
||||
- 检查数据库查询性能
|
||||
- 验证前端渲染性能
|
||||
- 确认内存使用情况
|
||||
```
|
||||
|
||||
3. 代码质量:
|
||||
|
||||
```
|
||||
- 检查代码风格
|
||||
- 确保注释完整
|
||||
- 验证类型定义
|
||||
```
|
||||
|
||||
### 4. 文档更新
|
||||
|
||||
1. 更新调试文档:
|
||||
|
||||
```
|
||||
- 记录问题原因
|
||||
- 描述解决方案
|
||||
- 添加最佳实践
|
||||
```
|
||||
|
||||
2. 更新设计文档:
|
||||
|
||||
```
|
||||
- 更新接口定义
|
||||
- 修改数据结构
|
||||
- 补充新功能说明
|
||||
```
|
||||
|
||||
3. 更新注释:
|
||||
|
||||
```
|
||||
- 添加函数说明
|
||||
- 说明参数用途
|
||||
- 解释复杂逻辑
|
||||
```
|
||||
|
||||
## 模型训练结果
|
||||
|
||||
从最新的训练结果来看:
|
||||
|
||||
1. XGBoost 表现最好:
|
||||
- 训练集 R² = 0.4346,没有过拟合
|
||||
- 验证集 R² = 0.3625,表现最稳定
|
||||
- MAE = 0.60,RMSE = 0.61,预测误差较小
|
||||
2. LightGBM 表现次之:
|
||||
- 训练集 R² = 0.5277,轻微过拟合
|
||||
- 验证集 R² = 0.1101,泛化能力一般
|
||||
- MAE = 0.55,RMSE = 0.72,预测误差适中
|
||||
3. Random Forest:
|
||||
- 训练集 R² = 0.7756,存在过拟合
|
||||
- 验证集 R² = 0.3189,泛化能力还可以
|
||||
- MAE = 0.47,RMSE = 0.63,预测误差较小
|
||||
4. GBDT 过拟合严重:
|
||||
- 训练集 R² = 0.9700,严重过拟合
|
||||
- 验证集 R² = -1.3133,泛化能力很差
|
||||
- MAE = 0.96,RMSE = 1.17,预测误差大
|
||||
|
||||
### 建议
|
||||
|
||||
1. 使用 XGBoost 作为主要模型
|
||||
2. 可以考虑集成 XGBoost 和 Random Forest
|
||||
3. 继续调整 LightGBM 的参数
|
||||
4. 暂时不使用 GBDT
|
||||
433
docs/design.md
Normal file
433
docs/design.md
Normal file
@ -0,0 +1,433 @@
|
||||
# 装备成本估算系统设计方案
|
||||
|
||||
## 一、系统概述
|
||||
|
||||
本系统旨在通过装备的技术参数,利用机器学习方法对装备成本进行估算。系统采用前后端分离架构,主要包含数据预处理、特征分析、模型训练和成本预测等模块。
|
||||
|
||||
## 二、系统架构
|
||||
|
||||
### 1. 技术架构
|
||||
|
||||
- 前端:Vue.js + Element UI
|
||||
- 后端:Python Flask
|
||||
- 数据库: Mysql
|
||||
- 机器学习框架:TensorFlow, Scikit-learn
|
||||
|
||||
### 2. 系统模块
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[装备成本估算系统] --> B[数据预处理模块]
|
||||
A --> C[特征分析模块]
|
||||
A --> D[模型训练模块]
|
||||
A --> E[成本预测模块]
|
||||
A --> F[系统管理模块]
|
||||
```
|
||||
|
||||
## 三、数据模型设计
|
||||
|
||||
### 1. 数据库表结构
|
||||
|
||||
#### 装备基本信息表(equipment)
|
||||
|
||||
- id: 主键
|
||||
- name: 装备名称
|
||||
- type: 装备类型
|
||||
- manufacturer: 制造商
|
||||
- created_at: 创建时间
|
||||
|
||||
#### 技术参数表(technical_params)
|
||||
|
||||
##### 尺寸参数
|
||||
|
||||
- length_m: 总长(m)
|
||||
- width_m: 宽度(m)
|
||||
- height_m: 高度(m)
|
||||
- weight_standard_kg: 标准重量(kg)
|
||||
- weight_combat_kg: 战斗重量(kg)
|
||||
|
||||
##### 火力参数
|
||||
|
||||
- firing_angle_horizontal: 方向射界(度)
|
||||
- firing_angle_vertical: 高低射界(度)
|
||||
- rocket_length_m: 火箭弹长度(m)
|
||||
- rocket_diameter_mm: 弹体直径(mm)
|
||||
- rocket_weight_kg: 火箭弹重量(kg)
|
||||
|
||||
##### 性能参数
|
||||
|
||||
- max_speed_ms: 最大速度(m/s)
|
||||
- max_range_km: 最大射程(km)
|
||||
- warhead_weight_kg: 战斗部重量(kg)
|
||||
|
||||
##### 其他参数
|
||||
|
||||
- mobility_type: 机动性类型
|
||||
- wheel_arrangement: 轮式布局
|
||||
- amphibious: 两栖能力
|
||||
|
||||
#### 成本数据表(cost_data)
|
||||
|
||||
- equipment_id: 装备ID
|
||||
- actual_cost: 实际成本
|
||||
- predicted_cost: 预测成本
|
||||
- prediction_date: 预测日期
|
||||
|
||||
## 四、核心功能模块设计
|
||||
|
||||
### 1. 特征工程模块
|
||||
|
||||
- 数据预处理
|
||||
- 数据清洗
|
||||
- 缺失值处理
|
||||
- 异常值检测
|
||||
- 数据标准化
|
||||
- 特征衍生
|
||||
- 功重比(power_weight_ratio)
|
||||
- 体积(volume)
|
||||
- 射程速度比(range_speed_ratio)
|
||||
- 战斗部比例(warhead_rocket_ratio)
|
||||
- 特征选择
|
||||
- 方差分析
|
||||
- 相关性分析
|
||||
- 互信息分析
|
||||
- 一致性分析
|
||||
- rwg值计算(组内一致性)
|
||||
- ICC值计算(组内相关系数)
|
||||
|
||||
### 2. 模型训练模块
|
||||
|
||||
#### 2.1 集成学习模型
|
||||
|
||||
1. XGBoost (eXtreme Gradient Boosting)
|
||||
|
||||
- 特点:
|
||||
- 使用二阶导数进行优化,收敛更快
|
||||
- 支持自定义损失函数
|
||||
- 内置正则化,防止过拟合
|
||||
- 支持特征重要性评估
|
||||
- 配置:
|
||||
|
||||
```python
|
||||
XGBRegressor(
|
||||
n_estimators=100, # 树的数量适中,避免过拟合
|
||||
learning_rate=0.1, # 较大的学习率,加快收敛
|
||||
max_depth=3, # 较小的树深度,防止过拟合
|
||||
min_child_weight=3, # 控制过拟合
|
||||
subsample=0.8, # 随机采样,增加模型鲁棒性
|
||||
colsample_bytree=0.8, # 特征采样,防止过拟合
|
||||
objective='reg:squarederror' # 回归任务
|
||||
)
|
||||
```
|
||||
|
||||
2. LightGBM (Light Gradient Boosting Machine)
|
||||
|
||||
- 特点:
|
||||
- 基于直方图的决策树算法,训练速度快
|
||||
- 支持类别特征的直接输入
|
||||
- 叶子优先的生长策略
|
||||
- 内存占用小
|
||||
- 配置:
|
||||
|
||||
```python
|
||||
LGBMRegressor(
|
||||
n_estimators=100,
|
||||
learning_rate=0.1,
|
||||
max_depth=3,
|
||||
num_leaves=8, # 控制树的复杂度
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
objective='regression'
|
||||
)
|
||||
```
|
||||
|
||||
3. GBDT (Gradient Boosting Decision Tree)
|
||||
|
||||
- 特点:
|
||||
- 经典的梯度提升算法
|
||||
- 较好的可解释性
|
||||
- 对异常值不敏感
|
||||
- 预测稳定性好
|
||||
- 配置:
|
||||
|
||||
```python
|
||||
GradientBoostingRegressor(
|
||||
n_estimators=100,
|
||||
learning_rate=0.1,
|
||||
max_depth=3,
|
||||
subsample=0.8
|
||||
)
|
||||
```
|
||||
|
||||
4. Random Forest (随机森林)
|
||||
|
||||
- 特点:
|
||||
- Bagging集成方法
|
||||
- 训练过程可并行化
|
||||
- 不易过拟合
|
||||
- 对特征尺度不敏感
|
||||
- 配置:
|
||||
|
||||
```python
|
||||
RandomForestRegressor(
|
||||
n_estimators=100,
|
||||
max_depth=3,
|
||||
min_samples_split=3,
|
||||
min_samples_leaf=2
|
||||
)
|
||||
```
|
||||
|
||||
#### 2.2 模型选择策略
|
||||
|
||||
1. 交叉验证
|
||||
|
||||
- 小样本数据集使用留一法(LOO)交叉验证
|
||||
- 大样本数据集使用5折交叉验证
|
||||
|
||||
2. 模型评估指标
|
||||
|
||||
- R²分数:评估拟合优度
|
||||
- 标准差:评估预测稳定性
|
||||
- MAE:评估预测误差
|
||||
- MSE:惩罚较大误差
|
||||
|
||||
3. 自动模型选择
|
||||
|
||||
- 对每个模型进行交叉验证评估
|
||||
- 选择R²分数最高的模型
|
||||
- 保存模型训练历史记录
|
||||
|
||||
#### 2.3 特征工程与模型优化
|
||||
|
||||
1. 特征选择
|
||||
|
||||
- 使用模型的特征重要性评分
|
||||
- 基于互信息的特征筛选
|
||||
- 相关性分析去除冗余特征
|
||||
|
||||
2. 参数优化
|
||||
|
||||
- 针对小样本:
|
||||
- 减小模型复杂度(树的深度和数量)
|
||||
- 增加正则化强度
|
||||
- 使用较大的学习率
|
||||
- 针对大样本:
|
||||
- 增加模型复杂度
|
||||
- 减小学习率
|
||||
- 使用特征采样
|
||||
|
||||
3. 集成策略
|
||||
|
||||
- 模型投票:综合多个模型的预测结果
|
||||
- 置信区间:使用bootstrap方法估计预测不确定性
|
||||
- 异常检测:识别并处理异常预测值
|
||||
|
||||
### 3. 成本预测模块
|
||||
|
||||
- 数据标准化处理
|
||||
- 模型预测流程
|
||||
- 置信区间计算
|
||||
- 预测评估指标
|
||||
- MAE(平均绝对误差)
|
||||
- MSE(均方误差)
|
||||
- RMSE(均方根误差)
|
||||
- R²(决定系数)
|
||||
|
||||
### 4. 数据管理模块
|
||||
|
||||
#### 4.1 数据查询优化
|
||||
|
||||
1. 特殊参数查询策略
|
||||
|
||||
- 使用子查询替代 JOIN 操作,避免数据重复和空值问题
|
||||
- 使用 COALESCE 确保返回空数组而不是 NULL
|
||||
- 在 WHERE 子句中添加条件确保只获取有效参数
|
||||
- 简化查询结构,移除不必要的 GROUP BY 子句
|
||||
|
||||
2. 查询性能优化
|
||||
|
||||
- 数据结构清晰,每个装备保持单一记录
|
||||
- 特殊参数组织在 JSON 数组中
|
||||
- 避免空值和无效数据的显示
|
||||
- 优化查询性能,减少复杂的 JOIN 和 GROUP BY 操作
|
||||
|
||||
#### 4.2 数据展示
|
||||
|
||||
1. 基本信息展示
|
||||
|
||||
- 装备基本参数表格展示
|
||||
- 支持搜索和过滤功能
|
||||
- 分类显示不同类型装备数据
|
||||
|
||||
2. 特殊参数展示
|
||||
|
||||
- 在详情页面显示装备特有参数
|
||||
- 支持特殊参数的格式化显示
|
||||
- 根据参数类型自动添加单位
|
||||
- 区分数值类型和文本类型参数
|
||||
|
||||
3. 数据编辑功能
|
||||
|
||||
- 支持基本参数和特殊参数的编辑
|
||||
- 根据参数类型提供不同的编辑控件
|
||||
- 数值类型参数支持精度控制
|
||||
- 保存时进行数据验证
|
||||
|
||||
#### 4.3 数据维护
|
||||
|
||||
1. 数据导入导出
|
||||
|
||||
- 支持 Excel 模板导入
|
||||
- 提供数据模板下载
|
||||
- 导入时进行数据验证和错误提示
|
||||
|
||||
2. 数据删除
|
||||
|
||||
- 级联删除相关数据
|
||||
- 删除前进行确认
|
||||
- 删除后自动刷新数据列表
|
||||
|
||||
3. 数据更新
|
||||
|
||||
- 支持单条记录更新
|
||||
- 保存时进行数据完整性验证
|
||||
- 更新后实时刷新显示
|
||||
|
||||
#### 4.4 扩展性设计
|
||||
|
||||
1. 参数管理
|
||||
|
||||
- 支持动态添加特殊参数
|
||||
- 参数定义包含名称、单位、说明等
|
||||
- 支持参数值类型定义
|
||||
|
||||
2. 数据结构
|
||||
|
||||
- 采用灵活的数据模型设计
|
||||
- 支持不同类型装备的差异化参数
|
||||
- 便于后续功能扩展
|
||||
|
||||
3. 接口设计
|
||||
|
||||
- RESTful API 设计
|
||||
- 统一的响应格式
|
||||
- 完善的错误处理机制
|
||||
|
||||
## 五、API接口设计
|
||||
|
||||
### 1. 成本预测接口
|
||||
|
||||
- 路径:POST /api/predict
|
||||
- 必要参数:
|
||||
- length_m: 总长
|
||||
- width_m: 宽度
|
||||
- height_m: 高度
|
||||
- weight_standard_kg: 标准重量
|
||||
- weight_combat_kg: 战斗重量
|
||||
- max_range_km: 最大射程
|
||||
- max_speed_ms: 最大速度
|
||||
- 响应数据:
|
||||
- predicted_cost: 预测成本
|
||||
- confidence_interval: 置信区间(上下限)
|
||||
|
||||
### 2. 特征分析接口
|
||||
|
||||
- 路径:POST /api/analyze-features
|
||||
- 响应数据:
|
||||
- important_features: 重要特征列表
|
||||
- correlation_analysis: 相关性分析结果
|
||||
|
||||
## 六、部署要求
|
||||
|
||||
### 1. 环境要求
|
||||
|
||||
- Python 3.8+
|
||||
- PostgreSQL 12+
|
||||
- Node.js 14+
|
||||
|
||||
### 2. 依赖包
|
||||
|
||||
- Flask
|
||||
- NumPy
|
||||
- Pandas
|
||||
- Scikit-learn
|
||||
- TensorFlow
|
||||
- SQLAlchemy
|
||||
|
||||
### 3. 部署步骤
|
||||
|
||||
1. 数据库初始化
|
||||
- 创建数据库
|
||||
- 执行表结构脚本
|
||||
- 导入基础数据
|
||||
2. 后端服务部署
|
||||
- 安装依赖包
|
||||
- 配置环境变量
|
||||
- 启动服务
|
||||
3. 前端部署
|
||||
- 安装依赖
|
||||
- 构建生产版本
|
||||
- 配置Nginx
|
||||
|
||||
## 七、后续优化建议
|
||||
|
||||
### 1. 特征工程优化
|
||||
|
||||
- 增加更多领域特征
|
||||
- 火力密度指标
|
||||
- 机动性综合指标
|
||||
- 作战效能指标
|
||||
- 优化特征选择算法
|
||||
- 增强数据清洗能力
|
||||
|
||||
### 2. 模型优化
|
||||
|
||||
- 引入集成学习方法
|
||||
- Random Forest
|
||||
- XGBoost
|
||||
- LightGBM
|
||||
- 实现模型自动调优
|
||||
- 增加在线学习能力
|
||||
|
||||
### 3. 系统优化
|
||||
|
||||
- 增加批量处理能力
|
||||
- 实现模型版本管理
|
||||
- 提升预测结果可解释性
|
||||
- 优化系统性能
|
||||
- 数据库索引优化
|
||||
- 缓存策略
|
||||
- API性能优化
|
||||
|
||||
### 4. 安全性优化
|
||||
|
||||
- 数据加密存储
|
||||
- API访问认证
|
||||
- 操作日志记录
|
||||
- 数据备份策略
|
||||
|
||||
## 八、项目进度规划
|
||||
|
||||
### 第一阶段:基础功能开发(4周)
|
||||
|
||||
1. 数据库设计和实现
|
||||
2. 特征工程模块开发
|
||||
3. 基础API实现
|
||||
|
||||
### 第二阶段:模型开发(4周)
|
||||
|
||||
1. 数据预处理
|
||||
2. 模型训练与优化
|
||||
3. 预测功能实现
|
||||
|
||||
### 第三阶段:系统集成(3周)
|
||||
|
||||
1. 前端开发
|
||||
2. 系统集成测试
|
||||
3. 性能优化
|
||||
|
||||
### 第四阶段:系统测试与部署(2周)
|
||||
|
||||
1. 系统测试
|
||||
2. 部署上线
|
||||
3. 文档完善
|
||||
136
docs/nodejs_install.md
Normal file
136
docs/nodejs_install.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Node.js 安装指南
|
||||
|
||||
## Windows 安装方法
|
||||
|
||||
### 1. 使用安装包
|
||||
|
||||
1. 访问 Node.js 官网 <https://nodejs.org/>
|
||||
2. 下载 14.x LTS 版本安装包
|
||||
3. 运行安装包,按提示完成安装
|
||||
4. 验证安装:
|
||||
|
||||
```bash
|
||||
node --version
|
||||
npm --version
|
||||
```
|
||||
|
||||
### 2. 使用 nvm-windows(推荐)
|
||||
|
||||
1. 下载 nvm-windows:<https://github.com/coreybutler/nvm-windows/releases>
|
||||
2. 安装 nvm-windows
|
||||
3. 安装 Node.js:
|
||||
|
||||
```bash
|
||||
nvm install 14.21.3
|
||||
nvm use 14.21.3
|
||||
```
|
||||
|
||||
## Linux 安装方法
|
||||
|
||||
### 1. 使用 nvm(推荐)
|
||||
|
||||
```bash
|
||||
# 安装 nvm
|
||||
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
|
||||
|
||||
# 重新加载配置
|
||||
source ~/.bashrc
|
||||
|
||||
# 安装 Node.js 14
|
||||
nvm install 14
|
||||
nvm use 14
|
||||
```
|
||||
|
||||
### 2. 使用包管理器
|
||||
|
||||
#### Ubuntu/Debian
|
||||
|
||||
```bash
|
||||
# 添加 NodeSource 仓库
|
||||
curl -fsSL https://deb.nodesource.com/setup_14.x | sudo -E bash -
|
||||
|
||||
# 安装 Node.js
|
||||
sudo apt-get install -y nodejs
|
||||
```
|
||||
|
||||
#### CentOS/RHEL
|
||||
|
||||
```bash
|
||||
# 添加 NodeSource 仓库
|
||||
curl -fsSL https://rpm.nodesource.com/setup_14.x | sudo bash -
|
||||
|
||||
# 安装 Node.js
|
||||
sudo yum install -y nodejs
|
||||
```
|
||||
|
||||
## macOS 安装方法
|
||||
|
||||
### 1. 使用 Homebrew(推荐)
|
||||
|
||||
```bash
|
||||
# 安装 Homebrew(如果未安装)
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# 安装 Node.js 14
|
||||
brew install node@14
|
||||
|
||||
# 添加环境变量
|
||||
echo 'export PATH="/usr/local/opt/node@14/bin:$PATH"' >> ~/.zshrc
|
||||
source ~/.zshrc
|
||||
```
|
||||
|
||||
### 2. 使用 nvm
|
||||
|
||||
```bash
|
||||
# 安装 nvm
|
||||
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
|
||||
|
||||
# 重新加载配置
|
||||
source ~/.zshrc
|
||||
|
||||
# 安装 Node.js 14
|
||||
nvm install 14
|
||||
nvm use 14
|
||||
```
|
||||
|
||||
## 验证安装
|
||||
|
||||
安装完成后,运行以下命令验证:
|
||||
|
||||
```bash
|
||||
# 检查 Node.js 版本
|
||||
node --version # 应显示 v14.x.x
|
||||
|
||||
# 检查 npm 版本
|
||||
npm --version # 应显示 6.x.x 或更高
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### 1. 权限问题
|
||||
|
||||
如果遇到权限错误,可以:
|
||||
|
||||
```bash
|
||||
# Linux/macOS
|
||||
sudo chown -R $USER /usr/local/lib/node_modules
|
||||
```
|
||||
|
||||
### 2. 版本切换
|
||||
|
||||
如果需要在不同版本间切换:
|
||||
|
||||
```bash
|
||||
# 使用 nvm
|
||||
nvm list # 查看已安装版本
|
||||
nvm use 14 # 切换到 14.x 版本
|
||||
```
|
||||
|
||||
### 3. npm 配置
|
||||
|
||||
建议配置国内镜像源:
|
||||
|
||||
```bash
|
||||
# 使用淘宝镜像
|
||||
npm config set registry https://registry.npmmirror.com
|
||||
```
|
||||
39
docs/requirements.md
Normal file
39
docs/requirements.md
Normal file
@ -0,0 +1,39 @@
|
||||
# 项目需求
|
||||
|
||||
## 项目名称
|
||||
|
||||
- 装备成本估算系统
|
||||
|
||||
## 系统功能
|
||||
|
||||
1. 根据装备技术参数进行成本估算
|
||||
2. 采用神经网络模型,使用现有数据进行训练,然后对新装备进行成本估算
|
||||
|
||||
## 技术要求
|
||||
|
||||
1. 数据相关性分析:
|
||||
(1)方差分析:采用按照不同的标签类别将特征划分为不同的总体,然后判断总体之间均值是否相同 (或者是否有显著性差异)
|
||||
(2)线性相关分析:对于特征和标签皆为连续值的回归问题,要检测二者的相关性,最直接的做法就是求相关系数rxy,本质是建立协方差矩阵,分析数据和成本之间相关关系的类型和程度,筛选出影响特征
|
||||
(3)互信息 (mutual information): 用于特征选择,可以从两个角度进行解释:(1)、基于 KL 散度和 (2)、基于信息增益。
|
||||
2. 数据一致性分析:对特征数据分层分组,计算组内一致性,目标是选择比较合适的一组数据,以此产生一个进行成本估算和分析的虚拟量.大部分的研究中报告的三个数据:rwg、ICC(1)、ICC(2),要符合3个条件,rwg>0.7、ICC(1)>0.05、ICC(2)>0.5
|
||||
RWG值:打分一致性;
|
||||
ICC1:组内一致性;
|
||||
ICC2:组间一致性。
|
||||
3. 回归模型:偏最小二乘回归(partial Least Squares,PLS)
|
||||
4. 神经网络模型:采用 BP 网络
|
||||
|
||||
### 数据准备
|
||||
|
||||
建议补充数据的优先级(火箭弹):
|
||||
|
||||
1. 第一优先级(射击性能相关):
|
||||
- rate_of_fire (射速)
|
||||
- rocket_weight_kg (火箭弹重量)
|
||||
- max_range_km (最大射程)
|
||||
2. 第二优先级(火力性能相关):
|
||||
- firing_angle_horizontal (方向射界)
|
||||
firing_angle_vertical (高低射界)
|
||||
- rocket_length_m (火箭弹长度)
|
||||
3. 第三优先级(机动性能相关):
|
||||
- min_range_km (最小射程)
|
||||
- power_hp (功率)
|
||||
163
docs/run.md
Normal file
163
docs/run.md
Normal file
@ -0,0 +1,163 @@
|
||||
# 系统运行说明
|
||||
|
||||
## 一、环境准备
|
||||
|
||||
### 1. 安装必要软件
|
||||
|
||||
```bash
|
||||
# 安装 Python 3.8+
|
||||
# 安装 MySQL 8.0+
|
||||
# 安装 Node.js 14+
|
||||
```
|
||||
|
||||
### 2. 安装 Python 依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 3. 安装前端依赖
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
```
|
||||
|
||||
## 二、数据库配置
|
||||
|
||||
### 1. 创建数据库
|
||||
|
||||
```sql
|
||||
CREATE DATABASE equipment_cost_db DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
```
|
||||
|
||||
### 2. 初始化数据库结构
|
||||
|
||||
```bash
|
||||
# 执行数据库结构初始化脚本
|
||||
mysql -u username -p equipment_cost_db < src/schema.sql
|
||||
|
||||
# 导入示例数据
|
||||
mysql -u username -p equipment_cost_db < src/init_data.sql
|
||||
|
||||
# 导入真实数据
|
||||
mysql -u username -p equipment_cost_db < src/real_data.sql
|
||||
```
|
||||
|
||||
## 三、配置文件
|
||||
|
||||
### 1. 后端配置
|
||||
|
||||
创建 `config.py` 文件:
|
||||
|
||||
```python
|
||||
# config.py
|
||||
DATABASE_URI = "mysql+pymysql://username:password@localhost:3306/equipment_cost_db"
|
||||
SECRET_KEY = "your-secret-key"
|
||||
DEBUG = True
|
||||
```
|
||||
|
||||
### 2. 前端配置
|
||||
|
||||
修改 `frontend/src/config.js`:
|
||||
|
||||
```javascript
|
||||
export const API_BASE_URL = 'http://localhost:5001/api';
|
||||
```
|
||||
|
||||
## 四、启动系统
|
||||
|
||||
### 1. 启动后端服务
|
||||
|
||||
```bash
|
||||
# 开发环境
|
||||
python run.py # 服务将在 http://localhost:5001 启动
|
||||
|
||||
# 生产环境
|
||||
gunicorn -w 4 -b 0.0.0.0:5001 run:app
|
||||
```
|
||||
|
||||
### 2. 启动前端服务
|
||||
|
||||
```bash
|
||||
# 开发环境
|
||||
cd frontend
|
||||
npm run serve # 前端将在 http://localhost:8080 启动
|
||||
|
||||
# 生产环境
|
||||
npm run build
|
||||
```
|
||||
|
||||
## 五、访问系统
|
||||
|
||||
- 后端API:<http://localhost:5001/api>
|
||||
- 前端界面:<http://localhost:8080>
|
||||
|
||||
## 六、常见问题
|
||||
|
||||
### 1. 数据库连接问题
|
||||
|
||||
- 检查 MySQL 服务是否启动
|
||||
- 验证数据库用户名和密码
|
||||
- 确认数据库端口是否正确
|
||||
|
||||
### 2. 模型训练
|
||||
|
||||
```bash
|
||||
# 训练模型
|
||||
python src/train_model.py
|
||||
|
||||
# 查看训练日志
|
||||
tail -f logs/training.log
|
||||
```
|
||||
|
||||
### 3. 系统监控
|
||||
|
||||
```bash
|
||||
# 查看系统日志
|
||||
tail -f logs/app.log
|
||||
|
||||
# 监控API请求
|
||||
tail -f logs/access.log
|
||||
```
|
||||
|
||||
## 七、开发调试
|
||||
|
||||
### 1. 后端调试
|
||||
|
||||
```bash
|
||||
# 启动调试模式
|
||||
python run.py --debug
|
||||
|
||||
# 运行测试
|
||||
python -m pytest tests/
|
||||
```
|
||||
|
||||
### 2. 前端调试
|
||||
|
||||
```bash
|
||||
# 启动开发服务器
|
||||
npm run serve
|
||||
|
||||
# 运行测试
|
||||
npm run test
|
||||
```
|
||||
|
||||
## 八、部署建议
|
||||
|
||||
### 1. 使用 Docker 部署
|
||||
|
||||
```bash
|
||||
# 构建镜像
|
||||
docker-compose build
|
||||
|
||||
# 启动服务
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### 2. 生产环境配置
|
||||
|
||||
- 使用 Nginx 作为反向代理
|
||||
- 配置 SSL 证书
|
||||
- 设置适当的防火墙规则
|
||||
- 启用数据库备份
|
||||
1
frontend
Submodule
1
frontend
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 96445d75411a5f9ace114085af0872cfbc116515
|
||||
13
loiteringmunitions.md
Normal file
13
loiteringmunitions.md
Normal file
@ -0,0 +1,13 @@
|
||||
# 巡飞弹技术参数示例
|
||||
|
||||
## 美国“终结者”单兵巡飞弹
|
||||
|
||||
目标: 静止和移动的人员和轻型装甲车辆
|
||||
外形尺寸: 560mm×150mm×200mm(收起时)
|
||||
弹重: <2.72kg
|
||||
射程: >24km
|
||||
巡飞时间: 15min
|
||||
巡飞速度: 96.56km/h
|
||||
最大飞行速度: >160.93km/h
|
||||
战斗部类型: 破片杀伤战斗部、发烟战斗部、温压战斗部
|
||||
发射方式: 凭自身动力起飞
|
||||
12
package.json
Normal file
12
package.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"version": "1.0.0",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"test": "echo \"Error: no test specified\" && exit 1"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"description": ""
|
||||
}
|
||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
flask==2.0.1
|
||||
flask-cors==3.0.10
|
||||
sqlalchemy==1.4.23
|
||||
pymysql==1.0.2
|
||||
cryptography==3.4.7 # MySQL 8.0+ 认证需要
|
||||
numpy==1.21.2
|
||||
pandas==1.3.3
|
||||
scikit-learn==0.24.2
|
||||
tensorflow==2.6.0
|
||||
urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
|
||||
openpyxl==3.1.2 # 用于读取 .xlsx 文件
|
||||
xlrd==2.0.1 # 用于读取 .xls 文件
|
||||
29
rocketparameters.md
Normal file
29
rocketparameters.md
Normal file
@ -0,0 +1,29 @@
|
||||
# 火箭炮系统技术参数示例
|
||||
|
||||
## 伊朗“胜利”-2 240mm (12管)火箭炮系统
|
||||
|
||||
产品类别: 多管火箭炮
|
||||
型号: “胜利”-2 240mm 多管火箭炮
|
||||
尺寸与重量
|
||||
总长: 10m(393.7in)
|
||||
宽(行军状态): 2.5m(98.4in)
|
||||
高(行军状态) 3.34m(131.5in)
|
||||
标准重: 15000kg(33069 lb)(15.0t)
|
||||
战斗重: 19900kg(43871 lb)(19.9t)
|
||||
机动性
|
||||
行走装置: 轮式
|
||||
布局: 6×6
|
||||
两栖: 无
|
||||
火力
|
||||
方向射界: 100º(1778mils)(左/90°右)
|
||||
高低射界(武器前方): 57°(1013mils)
|
||||
型号: “胜利”2火箭弹
|
||||
尺寸与重量
|
||||
总长: 3.550m(11ft)
|
||||
弹体直径: 512mm(20.16in)(尾翼展开)
|
||||
发射(重量): 275kg(606 lb)
|
||||
性能
|
||||
速度(最大速度): 1302kt(2412km/h;1499mph;670m/s)
|
||||
最大射程 12.4n miles(23km;14.3miles)
|
||||
武器组成
|
||||
战斗部: 85kg(187 lb)
|
||||
61
run.py
Normal file
61
run.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import logging
|
||||
from src.app import app
|
||||
|
||||
# 确保必要的目录存在
|
||||
def ensure_directories():
|
||||
"""
|
||||
确保所有必要的目录都存在
|
||||
"""
|
||||
directories = [
|
||||
'logs',
|
||||
'data',
|
||||
'models',
|
||||
'uploads'
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
# 配置日志
|
||||
def setup_logging():
|
||||
"""
|
||||
配置日志系统
|
||||
"""
|
||||
logging.basicConfig(
|
||||
filename='logs/server.log',
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 同时输出到控制台
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger('').addHandler(console_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 初始化目录
|
||||
ensure_directories()
|
||||
|
||||
# 设置日志
|
||||
setup_logging()
|
||||
|
||||
# 记录启动信息
|
||||
logging.info("=== Server Starting ===")
|
||||
logging.info("Initializing directories...")
|
||||
logging.info("Setting up logging system...")
|
||||
|
||||
# 启动服务器
|
||||
app.run(
|
||||
host='localhost',
|
||||
port=5001,
|
||||
debug=True,
|
||||
use_reloader=False # 禁用重载器以避免模型重复加载
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Server failed to start: {str(e)}")
|
||||
raise
|
||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# 这个文件可以为空,但必须存在
|
||||
68
src/api.py
Normal file
68
src/api.py
Normal file
@ -0,0 +1,68 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from .model_training import ModelTrainer
|
||||
from .cost_prediction import CostPredictor
|
||||
from .feature_analysis import FeatureAnalysis
|
||||
import pandas as pd
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/api/predict', methods=['POST'])
|
||||
def predict_cost():
|
||||
"""
|
||||
成本预测接口
|
||||
"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
# 验证必要参数
|
||||
required_params = [
|
||||
'length_m', 'width_m', 'height_m', 'weight_standard_kg',
|
||||
'weight_combat_kg', 'max_range_km', 'max_speed_ms'
|
||||
]
|
||||
|
||||
for param in required_params:
|
||||
if param not in data:
|
||||
return jsonify({'error': f'Missing parameter: {param}'}), 400
|
||||
|
||||
predictor = CostPredictor()
|
||||
result = predictor.predict(data)
|
||||
|
||||
return jsonify({
|
||||
'predicted_cost': float(result['predicted_cost']),
|
||||
'confidence_interval': {
|
||||
'lower': float(result['confidence_intervals'][0]),
|
||||
'upper': float(result['confidence_intervals'][1])
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@app.route('/api/analyze-features', methods=['POST'])
|
||||
def analyze_features():
|
||||
"""
|
||||
特征分析接口
|
||||
"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
analyzer = FeatureAnalysis()
|
||||
|
||||
# 数据预处理
|
||||
processed_data = analyzer.preprocess_features(pd.DataFrame(data))
|
||||
|
||||
# 特征重要性分析
|
||||
important_features = analyzer.select_features(
|
||||
processed_data,
|
||||
data['cost']
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'important_features': important_features,
|
||||
'correlation_analysis': analyzer.correlation_analysis(
|
||||
processed_data,
|
||||
data['cost']
|
||||
).to_dict()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
68
src/app.py
Normal file
68
src/app.py
Normal file
@ -0,0 +1,68 @@
|
||||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
import logging
|
||||
import os
|
||||
from .routes import api_bp
|
||||
|
||||
def create_app():
|
||||
"""
|
||||
创建并配置Flask应用
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
# 配置跨域
|
||||
CORS(app)
|
||||
|
||||
# 配置日志
|
||||
setup_logging()
|
||||
|
||||
# 注册蓝图
|
||||
app.register_blueprint(api_bp, url_prefix='/api')
|
||||
|
||||
# 错误处理
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
logging.error(f"404 error: {error}")
|
||||
return {'error': 'Resource not found'}, 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
logging.error(f"500 error: {error}")
|
||||
return {'error': 'Internal server error'}, 500
|
||||
|
||||
@app.errorhandler(Exception)
|
||||
def handle_exception(error):
|
||||
logging.error(f"Unhandled exception: {error}", exc_info=True)
|
||||
return {'error': str(error)}, 500
|
||||
|
||||
return app
|
||||
|
||||
def setup_logging():
|
||||
"""
|
||||
配置日志系统
|
||||
"""
|
||||
# 确保日志目录存在
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# 配置日志格式
|
||||
logging.basicConfig(
|
||||
filename='logs/api.log',
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 同时输出到控制台
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger('').addHandler(console_handler)
|
||||
|
||||
app = create_app()
|
||||
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
"""
|
||||
健康检查端点
|
||||
"""
|
||||
return {'status': 'ok'}
|
||||
218
src/cost_prediction.py
Normal file
218
src/cost_prediction.py
Normal file
@ -0,0 +1,218 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import tensorflow as tf
|
||||
from scipy import stats
|
||||
import joblib
|
||||
import os
|
||||
import pandas as pd
|
||||
from .feature_analysis import FeatureAnalysis
|
||||
import logging
|
||||
from src.model_trainer import ModelTrainer
|
||||
from src.database import get_db_connection
|
||||
|
||||
class CostPredictor:
|
||||
def __init__(self):
|
||||
self.scaler_X = StandardScaler()
|
||||
self.scaler_y = StandardScaler()
|
||||
self.model = None
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.equipment_type = None
|
||||
|
||||
# 添加 TensorFlow 配置
|
||||
tf.config.run_functions_eagerly(False) # 启用图执行模式
|
||||
|
||||
# 创建预测函数
|
||||
@tf.function(reduce_retracing=True, jit_compile=True)
|
||||
def predict_fn(x):
|
||||
return self.model(x, training=False)
|
||||
|
||||
self._predict_fn = predict_fn
|
||||
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
加载预训练模型和标准化器
|
||||
"""
|
||||
try:
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 创建默认模型
|
||||
self._create_default_model()
|
||||
|
||||
# 创建预测函数
|
||||
@tf.function(reduce_retracing=True, jit_compile=True)
|
||||
def predict_fn(x):
|
||||
return self.model(x, training=False)
|
||||
|
||||
self._predict_fn = predict_fn
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
self._create_default_model()
|
||||
|
||||
def _create_default_model(self):
|
||||
"""
|
||||
创建默认模型并进行初始化训练
|
||||
"""
|
||||
# 创建输入层
|
||||
inputs = tf.keras.Input(shape=(11,))
|
||||
|
||||
# 创建隐藏层
|
||||
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
|
||||
x = tf.keras.layers.Dense(32, activation='relu')(x)
|
||||
|
||||
# 创建输出层
|
||||
outputs = tf.keras.layers.Dense(1)(x)
|
||||
|
||||
# 创建模型
|
||||
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
# 编译模型
|
||||
self.model.compile(
|
||||
optimizer='adam',
|
||||
loss=tf.keras.losses.mean_squared_error,
|
||||
metrics=[tf.keras.metrics.mean_absolute_error]
|
||||
)
|
||||
|
||||
# 创建示例数据
|
||||
example_data = pd.DataFrame({
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
'weight_kg': [13700, 28500],
|
||||
'max_range_km': [20.4, 70],
|
||||
'firing_angle_horizontal': [102, 110],
|
||||
'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1],
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
})
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(example_data)
|
||||
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
|
||||
|
||||
# 设置默认装备类型
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def _create_example_data(self):
|
||||
"""
|
||||
创建示例数据来训练标准化器
|
||||
"""
|
||||
# 火箭炮示例数据
|
||||
rocket_data = pd.DataFrame({
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
'weight_kg': [13700, 28500],
|
||||
'max_range_km': [20.4, 70],
|
||||
'firing_angle_horizontal': [102, 110],
|
||||
'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1],
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
})
|
||||
|
||||
# 巡飞弹示例数据
|
||||
missile_data = pd.DataFrame({
|
||||
'length_m': [1.3, 2.5],
|
||||
'width_m': [0.23, 0.6],
|
||||
'height_m': [0.23, 0.6],
|
||||
'weight_kg': [12.5, 135],
|
||||
'max_range_km': [40, 250],
|
||||
'max_speed_kmh': [180, 185],
|
||||
'cruise_speed_kmh': [100, 110],
|
||||
'flight_time_min': [60, 120],
|
||||
'folded_length_mm': [1300, 2500],
|
||||
'folded_width_mm': [230, 600],
|
||||
'folded_height_mm': [230, 600]
|
||||
})
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
|
||||
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
|
||||
|
||||
# 设置默认装备类型
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def predict(self, data):
|
||||
"""
|
||||
预测成本
|
||||
"""
|
||||
try:
|
||||
equipment_type = data.get('type')
|
||||
|
||||
# 加载模型
|
||||
trainer = ModelTrainer()
|
||||
if not trainer.load_model(equipment_type):
|
||||
raise ValueError(f"No trained model found for {equipment_type}")
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
|
||||
# 预测
|
||||
y_pred = trainer.predict(X)
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = self._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
# 确保预测值和置信区间都是正数且合理的范围
|
||||
predicted_cost = max(1000, float(y_pred[0])) # 最小值设为1000元
|
||||
lower_bound = max(1000, float(confidence_interval[0]))
|
||||
upper_bound = max(predicted_cost * 1.2, float(confidence_interval[1])) # 至少比预测值大20%
|
||||
|
||||
return {
|
||||
'predicted_cost': predicted_cost,
|
||||
'confidence_interval': {
|
||||
'lower': lower_bound,
|
||||
'upper': upper_bound
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Prediction error: {str(e)}")
|
||||
raise
|
||||
|
||||
def _calculate_confidence_interval(self, prediction, confidence=0.95):
|
||||
"""
|
||||
计算预测值的置信区间
|
||||
"""
|
||||
try:
|
||||
# 使用预测值的20%作为标准差(增加不确定性)
|
||||
std = abs(prediction) * 0.2
|
||||
|
||||
# 计算置信区间
|
||||
from scipy import stats
|
||||
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
|
||||
|
||||
# 确保区间值为正数且合理
|
||||
lower = max(1000, interval[0]) # 最小值设为1000元
|
||||
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
|
||||
|
||||
logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
|
||||
|
||||
return [lower, upper]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error calculating confidence interval: {str(e)}")
|
||||
# 如果计算失败,返回基于20%的简单区间
|
||||
lower = max(1000, prediction * 0.8)
|
||||
upper = prediction * 1.2
|
||||
return [lower, upper]
|
||||
|
||||
def evaluate(self, y_true, y_pred):
|
||||
"""
|
||||
模型评估
|
||||
"""
|
||||
return {
|
||||
'mae': float(mean_absolute_error(y_true, y_pred)),
|
||||
'mse': float(mean_squared_error(y_true, y_pred)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||||
'r2': float(r2_score(y_true, y_pred))
|
||||
}
|
||||
152
src/create_template.py
Normal file
152
src/create_template.py
Normal file
@ -0,0 +1,152 @@
|
||||
import pandas as pd
|
||||
import openpyxl
|
||||
from openpyxl.styles import PatternFill, Font, Alignment
|
||||
from openpyxl.worksheet.datavalidation import DataValidation
|
||||
import os
|
||||
|
||||
def create_excel_template():
|
||||
"""
|
||||
创建数据模板
|
||||
"""
|
||||
try:
|
||||
# 确保data目录存在
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# 创建完整的文件路径
|
||||
template_path = os.path.join(data_dir, 'equipment_data_template.xlsx')
|
||||
|
||||
# 创建 Excel 写入器
|
||||
writer = pd.ExcelWriter(template_path, engine='openpyxl')
|
||||
|
||||
# 火箭炮基本参数表
|
||||
rocket_artillery_columns = [
|
||||
'名称', '类型', '制造商', '口径_mm',
|
||||
'反射管数量', '乘员数', '总长_m',
|
||||
'宽度_m', '高度_m', '重量_kg',
|
||||
'战斗重_kg', '速度_km/h', '最大射程_km',
|
||||
'最小射程_km', '方向射界_度', '高低射界_度',
|
||||
'火箭弹长度_m', '火箭弹重量_kg',
|
||||
'火箭弹最大速度_m/s', '射速_发',
|
||||
'战斗部重量_kg', '行走方式',
|
||||
'结构布局', '发动机型号', '发动机参数',
|
||||
'功率_hp', '行程_km', '成本_元'
|
||||
]
|
||||
|
||||
# 巡飞弹基本参数表
|
||||
loitering_munition_columns = [
|
||||
'名称', '类型', '制造商', '目标类型',
|
||||
'弹长_m', '弹径_mm', '翼展_m',
|
||||
'重量_kg', '战斗部重量_kg',
|
||||
'最大射程_km', '最大速度_m/s',
|
||||
'巡航速度_kmh', '巡飞时间_min',
|
||||
'战斗部类型', '发射方式',
|
||||
'折叠长度_mm', '折叠宽度_mm',
|
||||
'折叠高度_mm', '动力装置',
|
||||
'制导体制', '成本_元'
|
||||
]
|
||||
|
||||
# 特殊参数表
|
||||
special_params_columns = [
|
||||
'装备名称', # 关联字段
|
||||
'参数名称',
|
||||
'参数值',
|
||||
'参数单位',
|
||||
'参数说明'
|
||||
]
|
||||
|
||||
# 创建工作表
|
||||
pd.DataFrame(columns=rocket_artillery_columns).to_excel(
|
||||
writer, sheet_name='火箭炮', index=False
|
||||
)
|
||||
pd.DataFrame(columns=loitering_munition_columns).to_excel(
|
||||
writer, sheet_name='巡飞弹', index=False
|
||||
)
|
||||
pd.DataFrame(columns=special_params_columns).to_excel(
|
||||
writer, sheet_name='特殊参数', index=False
|
||||
)
|
||||
|
||||
# 获取工作簿
|
||||
workbook = writer.book
|
||||
|
||||
# 设置火箭炮工作表格式
|
||||
rocket_sheet = workbook['火箭炮']
|
||||
for col in range(1, len(rocket_artillery_columns) + 1):
|
||||
cell = rocket_sheet.cell(row=1, column=col)
|
||||
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
|
||||
cell.font = Font(bold=True)
|
||||
cell.alignment = Alignment(horizontal='center')
|
||||
|
||||
# 设置巡飞弹工作表格式
|
||||
missile_sheet = workbook['巡飞弹']
|
||||
for col in range(1, len(loitering_munition_columns) + 1):
|
||||
cell = missile_sheet.cell(row=1, column=col)
|
||||
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
|
||||
cell.font = Font(bold=True)
|
||||
cell.alignment = Alignment(horizontal='center')
|
||||
|
||||
# 设置特殊参数工作表格式
|
||||
special_sheet = workbook['特殊参数']
|
||||
for col in range(1, len(special_params_columns) + 1):
|
||||
cell = special_sheet.cell(row=1, column=col)
|
||||
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
|
||||
cell.font = Font(bold=True)
|
||||
cell.alignment = Alignment(horizontal='center')
|
||||
|
||||
# 添加数据验证
|
||||
for sheet in [rocket_sheet, missile_sheet]:
|
||||
# 数值验证
|
||||
number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0")
|
||||
number_validation.error = "请输入大于0的数值"
|
||||
number_validation.errorTitle = "输入错误"
|
||||
|
||||
# 应用到相应列
|
||||
for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
|
||||
number_validation.add(f"{col}2:{col}1000")
|
||||
|
||||
sheet.add_data_validation(number_validation)
|
||||
|
||||
# 添加说明
|
||||
rocket_sheet['AD1'] = "填写说明:"
|
||||
rocket_sheet['AD2'] = "1. 所有数值必须大于0"
|
||||
rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写"
|
||||
rocket_sheet['AD4'] = "3. 成本单位为元"
|
||||
|
||||
missile_sheet['V1'] = "填写说明:"
|
||||
missile_sheet['V2'] = "1. 所有数值必须大于0"
|
||||
missile_sheet['V3'] = "2. 单位必须按照表头要求填写"
|
||||
missile_sheet['V4'] = "3. 成本单位为元"
|
||||
|
||||
special_sheet['G1'] = "填写说明:"
|
||||
special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致"
|
||||
special_sheet['G3'] = "2. 参数值必须包含单位"
|
||||
special_sheet['G4'] = "3. 参数说明应简明扼要"
|
||||
|
||||
# 调整列宽
|
||||
for sheet in [rocket_sheet, missile_sheet, special_sheet]:
|
||||
for col in sheet.columns:
|
||||
max_length = 0
|
||||
column = col[0].column_letter
|
||||
for cell in col:
|
||||
try:
|
||||
if len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except:
|
||||
pass
|
||||
adjusted_width = (max_length + 2)
|
||||
sheet.column_dimensions[column].width = adjusted_width
|
||||
|
||||
# 保存文件
|
||||
writer.close()
|
||||
|
||||
return template_path
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"创建模板文件失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
template_path = create_excel_template()
|
||||
print(f"模板文件已创建: {template_path}")
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
199
src/data_preparation.py
Normal file
199
src/data_preparation.py
Normal file
@ -0,0 +1,199 @@
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from datetime import datetime
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
|
||||
from xgboost import XGBRegressor
|
||||
from lightgbm import LGBMRegressor
|
||||
from sklearn.model_selection import cross_val_score, LeaveOneOut
|
||||
import json
|
||||
import logging
|
||||
from src.database.db_connection import get_db_connection
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||
|
||||
class DataPreparation:
|
||||
def __init__(self):
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.feature_scaler = StandardScaler()
|
||||
self.target_scaler = StandardScaler() # 添加目标值标准化器
|
||||
|
||||
def prepare_training_data(self, equipment_data, equipment_type):
|
||||
"""
|
||||
准备训练数据
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Preparing training data for {equipment_type}")
|
||||
logging.info(f"Raw data size: {len(equipment_data)}")
|
||||
|
||||
# 如果输入已经是 numpy 数组,直接返回
|
||||
if isinstance(equipment_data, np.ndarray):
|
||||
X = equipment_data
|
||||
logging.info(f"Input is already numpy array with shape: {X.shape}")
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
return {
|
||||
'X': X,
|
||||
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}
|
||||
|
||||
# 从原始数据中提取特征和目标值
|
||||
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
features = []
|
||||
targets = []
|
||||
|
||||
for item in equipment_data:
|
||||
# 提取特征值
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
value = item.get(name)
|
||||
try:
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
except (ValueError, TypeError):
|
||||
feature_values.append(0.0)
|
||||
features.append(feature_values)
|
||||
|
||||
# 提取目标值(成本)
|
||||
try:
|
||||
cost = float(item['actual_cost'])
|
||||
if cost > 0: # 只使用正数成本值
|
||||
targets.append(cost)
|
||||
else:
|
||||
logging.warning(f"Skipping non-positive cost value: {cost}")
|
||||
except (ValueError, TypeError, KeyError):
|
||||
logging.error(f"Invalid cost value: {item.get('actual_cost')}")
|
||||
continue
|
||||
|
||||
# 转换为numpy数组
|
||||
X = np.array(features, dtype=float)
|
||||
y = np.array(targets, dtype=float)
|
||||
|
||||
# 记录原始数据范围
|
||||
logging.info(f"Original X range: min={X.min()}, max={X.max()}")
|
||||
logging.info(f"Original y range: min={y.min()}, max={y.max()}")
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 标准化特征和目标值
|
||||
X_scaled = self.feature_scaler.fit_transform(X)
|
||||
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
|
||||
|
||||
# 记录标准化后的数据范围
|
||||
logging.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
|
||||
logging.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
|
||||
|
||||
return {
|
||||
'X': X_scaled,
|
||||
'y': y_scaled,
|
||||
'feature_names': feature_names,
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in data preparation: {str(e)}")
|
||||
raise Exception(f"Training error: {str(e)}")
|
||||
|
||||
def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None):
|
||||
"""
|
||||
准备验证数据
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Preparing validation data for {equipment_type}")
|
||||
logging.info(f"Raw validation data size: {len(validation_data)}")
|
||||
|
||||
# 如果输入已经是 numpy 数组,直接使用
|
||||
if isinstance(validation_data, np.ndarray):
|
||||
X = validation_data
|
||||
logging.info(f"Input is already numpy array with shape: {X.shape}")
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 使用训练数据的标准化器
|
||||
if scalers and 'feature_scaler' in scalers:
|
||||
X_scaled = scalers['feature_scaler'].transform(X)
|
||||
else:
|
||||
# 如果没有提供标准化器,直接返回处理后的数组
|
||||
X_scaled = X
|
||||
|
||||
logging.info(f"Preprocessed data shape: {X_scaled.shape}")
|
||||
logging.info(f"Validation features shape: {X_scaled.shape}")
|
||||
logging.info(f"Validation features type: {X_scaled.dtype}")
|
||||
|
||||
return {
|
||||
'X': X_scaled,
|
||||
'y': None # 验证数据可能没有标签
|
||||
}
|
||||
|
||||
# 否则,从原始数据中提取特征
|
||||
if not feature_names:
|
||||
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
|
||||
# 提取特征和目标值
|
||||
features = []
|
||||
targets = []
|
||||
|
||||
for item in validation_data:
|
||||
# 提取特征值
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
value = item.get(name)
|
||||
try:
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
except (ValueError, TypeError):
|
||||
feature_values.append(0.0) # 使用0替代NaN
|
||||
features.append(feature_values)
|
||||
|
||||
# 提取目标值(成本)
|
||||
try:
|
||||
targets.append(float(item['actual_cost']))
|
||||
except (ValueError, TypeError):
|
||||
logging.error(f"Invalid cost value: {item.get('actual_cost')}")
|
||||
continue
|
||||
|
||||
# 转换为numpy数组
|
||||
X = np.array(features, dtype=float)
|
||||
y = np.array(targets, dtype=float)
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 使用训练数据的标准化器
|
||||
if scalers and 'feature_scaler' in scalers:
|
||||
X_scaled = scalers['feature_scaler'].transform(X)
|
||||
else:
|
||||
# 如果没有提供标准化器,直接返回处理后的数组
|
||||
X_scaled = X
|
||||
|
||||
logging.info(f"Preprocessed data shape: {X_scaled.shape}")
|
||||
logging.info(f"Validation features shape: {X_scaled.shape}")
|
||||
logging.info(f"Validation features type: {X_scaled.dtype}")
|
||||
|
||||
return {
|
||||
'X': X_scaled,
|
||||
'y': y # 返回原始成本值
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in validation data preparation: {str(e)}")
|
||||
logging.error(f"Feature names: {feature_names}")
|
||||
logging.error(f"Equipment type: {equipment_type}")
|
||||
raise Exception(f"Validation error: {str(e)}")
|
||||
|
||||
def calculate_derived_features(self, data, equipment_type):
|
||||
"""
|
||||
计算衍生特征
|
||||
"""
|
||||
try:
|
||||
return self.feature_analyzer.calculate_derived_features(data, equipment_type)
|
||||
except Exception as e:
|
||||
logging.error(f"Error calculating derived features: {str(e)}")
|
||||
raise Exception(f"Feature calculation error: {str(e)}")
|
||||
1
src/database/__init__.py
Normal file
1
src/database/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .db_connection import get_db_connection
|
||||
28
src/database/db_connection.py
Normal file
28
src/database/db_connection.py
Normal file
@ -0,0 +1,28 @@
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
# 数据库配置
|
||||
DB_CONFIG = {
|
||||
'host': 'localhost',
|
||||
'user': 'root',
|
||||
'password': '123456',
|
||||
'database': 'equipment_cost_db'
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
"""
|
||||
数据库连接上下文管理器
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = mysql.connector.connect(**DB_CONFIG)
|
||||
yield conn
|
||||
except Error as e:
|
||||
logging.error(f"Error connecting to MySQL: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
if conn and conn.is_connected():
|
||||
conn.close()
|
||||
266
src/feature_analysis.py
Normal file
266
src/feature_analysis.py
Normal file
@ -0,0 +1,266 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy import stats
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.metrics import r2_score
|
||||
import logging
|
||||
|
||||
class FeatureAnalysis:
|
||||
def __init__(self):
|
||||
self.scaler = StandardScaler()
|
||||
self.important_features = []
|
||||
|
||||
# 添加特征名称映射
|
||||
self.feature_names_map = {
|
||||
# 通用参数
|
||||
'length_m': '总长(m)',
|
||||
'width_m': '宽度(m)',
|
||||
'height_m': '高度(m)',
|
||||
'weight_kg': '重量(kg)',
|
||||
'max_range_km': '最大射程(km)',
|
||||
|
||||
# 火箭炮特有参数
|
||||
'firing_angle_horizontal': '方向射界(度)',
|
||||
'firing_angle_vertical': '高低射界(度)',
|
||||
'rocket_length_m': '火箭弹长度(m)',
|
||||
'rocket_diameter_mm': '口径(mm)',
|
||||
'rocket_weight_kg': '火箭弹重量(kg)',
|
||||
'rate_of_fire': '射速(发/分)',
|
||||
'combat_weight_kg': '战斗重量(kg)',
|
||||
'speed_kmh': '速度(km/h)',
|
||||
'min_range_km': '最小射程(km)',
|
||||
'power_hp': '功率(hp)',
|
||||
|
||||
# 火箭炮衍生特征
|
||||
'fire_density': '火力密度',
|
||||
'mobility_index': '机动性指标',
|
||||
'range_ratio': '射程比',
|
||||
'power_weight_ratio': '功重比',
|
||||
'volume_density': '体积密度',
|
||||
|
||||
# 巡飞弹特有参数
|
||||
'wingspan_m': '翼展(m)',
|
||||
'warhead_weight_kg': '战斗部重量(kg)',
|
||||
'max_speed_ms': '最大速度(m/s)',
|
||||
'cruise_speed_kmh': '巡航速度(km/h)',
|
||||
'flight_time_min': '巡飞时间(min)',
|
||||
'folded_length_mm': '折叠长度(mm)',
|
||||
'folded_width_mm': '折叠宽度(mm)',
|
||||
'folded_height_mm': '折叠高度(mm)',
|
||||
|
||||
# 巡飞弹衍生特征
|
||||
'warhead_ratio': '战斗部比重',
|
||||
'speed_ratio': '速度比',
|
||||
'range_time_ratio': '射程时间比',
|
||||
'aspect_ratio': '展弦比',
|
||||
'volume_density': '体积密度'
|
||||
}
|
||||
|
||||
def get_equipment_specific_features(self, equipment_type):
|
||||
"""
|
||||
获取特定装备类型的特征列表
|
||||
"""
|
||||
# 通用参数
|
||||
common_features = [
|
||||
'length_m', # 总长(m)
|
||||
'width_m', # 宽度(m)
|
||||
'height_m', # 高度(m)
|
||||
'weight_kg', # 重量(kg)
|
||||
'max_range_km' # 最大射程(km)
|
||||
]
|
||||
|
||||
if equipment_type == '火箭炮':
|
||||
# 火箭炮特有参数
|
||||
specific_features = [
|
||||
'firing_angle_horizontal', # 方向射界(度)
|
||||
'firing_angle_vertical', # 高低射界(度)
|
||||
'rocket_length_m', # 火箭弹长度(m)
|
||||
'rocket_diameter_mm', # 口径(mm)
|
||||
'rocket_weight_kg', # 火箭弹重量(kg)
|
||||
'rate_of_fire', # 射速(发/分)
|
||||
'combat_weight_kg', # 战斗重量(kg)
|
||||
'speed_kmh', # 速度(km/h)
|
||||
'min_range_km', # 最小射程(km)
|
||||
'power_hp' # 功率(hp)
|
||||
]
|
||||
|
||||
# 火箭炮衍生特征
|
||||
derived_features = [
|
||||
'fire_density', # 火力密度 = 射速 * 火箭弹重量
|
||||
'mobility_index', # 机动性指标 = 速度 / 战斗重量
|
||||
'range_ratio', # 射程比 = 最大射程 / 最小射程
|
||||
'power_weight_ratio', # 功重比 = 功率 / 战斗重量
|
||||
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
|
||||
]
|
||||
|
||||
return common_features + specific_features + derived_features
|
||||
|
||||
else: # 巡飞弹
|
||||
# 巡飞弹特有参数
|
||||
specific_features = [
|
||||
'wingspan_m', # 翼展(m)
|
||||
'warhead_weight_kg', # 战斗部重量(kg)
|
||||
'max_speed_ms', # 最大速度(m/s)
|
||||
'cruise_speed_kmh', # 巡航速度(km/h)
|
||||
'flight_time_min', # 巡飞时间(min)
|
||||
'folded_length_mm', # 折叠长度(mm)
|
||||
'folded_width_mm', # 折叠宽度(mm)
|
||||
'folded_height_mm' # 折叠高度(mm)
|
||||
]
|
||||
|
||||
# 巡飞弹衍生特征
|
||||
derived_features = [
|
||||
'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
|
||||
'speed_ratio', # 速度比 = 巡航速度 / 最大速度
|
||||
'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
|
||||
'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
|
||||
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
|
||||
]
|
||||
|
||||
return common_features + specific_features + derived_features
|
||||
|
||||
def calculate_derived_features(self, data, equipment_type):
|
||||
"""
|
||||
计算衍生特征
|
||||
"""
|
||||
try:
|
||||
if equipment_type == '火箭炮':
|
||||
# 火箭炮衍生特征计算
|
||||
if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
|
||||
data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
|
||||
else:
|
||||
data['fire_density'] = 0 # 或者其他默认值
|
||||
|
||||
if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
|
||||
data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
|
||||
else:
|
||||
data['mobility_index'] = 0
|
||||
|
||||
if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
|
||||
data['range_ratio'] = data['max_range_km'] / data['min_range_km']
|
||||
else:
|
||||
data['range_ratio'] = 0
|
||||
|
||||
if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
|
||||
data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
|
||||
else:
|
||||
data['power_weight_ratio'] = 0
|
||||
|
||||
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
|
||||
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
|
||||
else:
|
||||
data['volume_density'] = 0
|
||||
|
||||
else: # 巡飞弹
|
||||
# 巡飞弹衍生特征计算
|
||||
if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
|
||||
data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
|
||||
else:
|
||||
data['warhead_ratio'] = 0
|
||||
|
||||
if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
|
||||
data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
|
||||
else:
|
||||
data['speed_ratio'] = 0
|
||||
|
||||
if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
|
||||
data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
|
||||
else:
|
||||
data['range_time_ratio'] = 0
|
||||
|
||||
if 'wingspan_m' in data.columns and 'length_m' in data.columns:
|
||||
data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
|
||||
else:
|
||||
data['aspect_ratio'] = 0
|
||||
|
||||
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
|
||||
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
|
||||
else:
|
||||
data['volume_density'] = 0
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error calculating derived features: {str(e)}")
|
||||
raise
|
||||
|
||||
def analyze_features(self, features, target, feature_names):
|
||||
"""
|
||||
分析特征重要性和相关性
|
||||
"""
|
||||
try:
|
||||
# 转换为numpy数组
|
||||
X = np.array(features)
|
||||
y = np.array(target)
|
||||
|
||||
# 数据标准化
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
|
||||
# 特征重要性分析
|
||||
rf = RandomForestRegressor(n_estimators=100, random_state=42)
|
||||
rf.fit(X_scaled, y)
|
||||
importances = rf.feature_importances_
|
||||
|
||||
# 按重要性排序,使用中文特征名
|
||||
importance_indices = np.argsort(importances)[::-1]
|
||||
important_features = [
|
||||
{
|
||||
'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
|
||||
'importance': float(importances[i])
|
||||
}
|
||||
for i in importance_indices
|
||||
]
|
||||
|
||||
# 相关性分析
|
||||
df = pd.DataFrame(X_scaled, columns=feature_names)
|
||||
correlation_matrix = df.corr().values
|
||||
|
||||
# 生成相关性分析数据,保留2位小数
|
||||
correlation_data = []
|
||||
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
|
||||
for i in range(len(feature_names)):
|
||||
for j in range(len(feature_names)):
|
||||
correlation_data.append([
|
||||
i, j,
|
||||
round(correlation_matrix[i][j], 2) # 修改为保留2位小数
|
||||
])
|
||||
|
||||
return {
|
||||
'important_features': important_features,
|
||||
'correlation_analysis': {
|
||||
'features': chinese_feature_names, # 使用中文特征名
|
||||
'matrix': correlation_data
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in feature analysis: {str(e)}")
|
||||
raise
|
||||
|
||||
def preprocess_features(self, equipment_data, equipment_type):
|
||||
"""
|
||||
预处理特征数据
|
||||
"""
|
||||
try:
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(equipment_data)
|
||||
|
||||
# 计算衍生特征
|
||||
df = self.calculate_derived_features(df, equipment_type)
|
||||
|
||||
# 处理缺失值
|
||||
numeric_columns = df.select_dtypes(include=[np.number]).columns
|
||||
for col in numeric_columns:
|
||||
# 转换为数值类型
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
# 使用新的方式填充缺失值
|
||||
mean_value = df[col].mean()
|
||||
df[col] = df[col].fillna(mean_value)
|
||||
|
||||
logging.info(f"Preprocessed data shape: {df.shape}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preprocessing features: {str(e)}")
|
||||
raise Exception(f"Feature preprocessing error: {str(e)}")
|
||||
259
src/import_data.py
Normal file
259
src/import_data.py
Normal file
@ -0,0 +1,259 @@
|
||||
import pandas as pd
|
||||
import logging
|
||||
from src.database.db_connection import get_db_connection
|
||||
|
||||
logging.basicConfig(
|
||||
filename='logs/import.log',
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def import_training_data(excel_file):
|
||||
"""
|
||||
从Excel导入训练数据到数据库
|
||||
"""
|
||||
try:
|
||||
# 读取所有sheet
|
||||
rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮')
|
||||
missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹')
|
||||
special_df = pd.read_excel(excel_file, sheet_name='特殊参数')
|
||||
|
||||
# 记录所有装备名称,用于后续检查
|
||||
equipment_names = set()
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 先导入火箭炮数据
|
||||
logging.info("开始导入火箭炮数据...")
|
||||
for _, row in rocket_df.iterrows():
|
||||
equipment_names.add(row['名称'])
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipment
|
||||
WHERE name = %s AND type = '火箭炮'
|
||||
""", (row['名称'],))
|
||||
|
||||
existing_equipment = cursor.fetchone()
|
||||
if existing_equipment:
|
||||
logging.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入")
|
||||
continue
|
||||
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipment (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
""", (row['名称'], '火箭炮', row['制造商']))
|
||||
|
||||
equipment_id = cursor.lastrowid
|
||||
|
||||
# 插入通用参数
|
||||
cursor.execute("""
|
||||
INSERT INTO common_params
|
||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
equipment_id,
|
||||
row['总长_m'] if pd.notna(row['总长_m']) else None,
|
||||
row['宽度_m'] if pd.notna(row['宽度_m']) else None,
|
||||
row['高度_m'] if pd.notna(row['高度_m']) else None,
|
||||
row['重量_kg'] if pd.notna(row['重量_kg']) else None,
|
||||
row['最大射程_km'] if pd.notna(row['最大射程_km']) else None
|
||||
))
|
||||
|
||||
# 插入火箭炮特有参数
|
||||
cursor.execute("""
|
||||
INSERT INTO rocket_artillery_params
|
||||
(equipment_id, firing_angle_horizontal, firing_angle_vertical,
|
||||
rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire,
|
||||
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
|
||||
structure_layout, engine_model, engine_params, power_hp,
|
||||
travel_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
equipment_id,
|
||||
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
|
||||
row['高低射界_度'] if pd.notna(row['高低射界_度']) else None,
|
||||
row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None,
|
||||
row['口径_mm'] if pd.notna(row['口径_mm']) else None,
|
||||
row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None,
|
||||
row['射速_发'] if pd.notna(row['射速_发']) else None,
|
||||
row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None,
|
||||
row['速度_km/h'] if pd.notna(row['速度_km/h']) else None,
|
||||
row['最小射程_km'] if pd.notna(row['最小射程_km']) else None,
|
||||
row['行走方式'] if pd.notna(row['行走方式']) else None,
|
||||
row['结构布局'] if pd.notna(row['结构布局']) else None,
|
||||
row['发动机型号'] if pd.notna(row['发动机型号']) else None,
|
||||
row['发动机参数'] if pd.notna(row['发动机参数']) else None,
|
||||
row['功率_hp'] if pd.notna(row['功率_hp']) else None,
|
||||
row['行程_km'] if pd.notna(row['行程_km']) else None
|
||||
))
|
||||
|
||||
# 插入成本数据
|
||||
if pd.notna(row['成本_元']):
|
||||
cursor.execute("""
|
||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||
VALUES (%s, %s)
|
||||
""", (equipment_id, row['成本_元']))
|
||||
|
||||
logging.info("火箭炮数据导入完成")
|
||||
|
||||
# 2. 导入巡飞弹数据
|
||||
logging.info("开始导入巡飞弹数据...")
|
||||
for index, row in missile_df.iterrows():
|
||||
# 记录每行数据的空值情况
|
||||
null_values = row[row.isna()].index.tolist()
|
||||
if null_values:
|
||||
logging.info(f"行 {index + 2} 中的空值字段: {null_values}")
|
||||
|
||||
equipment_names.add(row['名称'])
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipment
|
||||
WHERE name = %s AND type = '巡飞弹'
|
||||
""", (row['名称'],))
|
||||
|
||||
existing_equipment = cursor.fetchone()
|
||||
if existing_equipment:
|
||||
logging.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入")
|
||||
continue
|
||||
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipment (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
""", (
|
||||
row['名称'],
|
||||
'巡飞弹',
|
||||
row['制造商'] if pd.notna(row['制造商']) else None
|
||||
))
|
||||
|
||||
equipment_id = cursor.lastrowid
|
||||
|
||||
# 插入通用参数
|
||||
cursor.execute("""
|
||||
INSERT INTO common_params
|
||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
equipment_id,
|
||||
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
|
||||
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
|
||||
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
|
||||
float(row['重量_kg']) if pd.notna(row['重量_kg']) else None,
|
||||
float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None
|
||||
))
|
||||
|
||||
# 插入巡飞弹特有参数
|
||||
cursor.execute("""
|
||||
INSERT INTO loitering_munition_params
|
||||
(equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms,
|
||||
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
|
||||
folded_length_mm, folded_width_mm, folded_height_mm,
|
||||
power_system, guidance_system)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
equipment_id,
|
||||
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
|
||||
float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None,
|
||||
float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None,
|
||||
float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None,
|
||||
float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None,
|
||||
str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None,
|
||||
str(row['发射方式']) if pd.notna(row['发射方式']) else None,
|
||||
float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None,
|
||||
float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None,
|
||||
float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None,
|
||||
str(row['动力装置']) if pd.notna(row['动力装置']) else None,
|
||||
str(row['制导体制']) if pd.notna(row['制导体制']) else None
|
||||
))
|
||||
|
||||
# 插入成本数据
|
||||
if pd.notna(row['成本_元']):
|
||||
cursor.execute("""
|
||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||
VALUES (%s, %s)
|
||||
""", (equipment_id, float(row['成本_元'])))
|
||||
|
||||
logging.info("巡飞弹数据导入完成")
|
||||
|
||||
# 提交之前的更改并关闭原有游标
|
||||
cursor.close()
|
||||
conn.commit()
|
||||
|
||||
# 3. 导入特殊参数
|
||||
logging.info("开始导入特殊参数...")
|
||||
for index, row in special_df.iterrows():
|
||||
equipment_name = row['装备名称']
|
||||
param_name = row['参数名称']
|
||||
logging.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'")
|
||||
|
||||
if equipment_name not in equipment_names:
|
||||
logging.warning(f"未找到装备: {equipment_name},请检查名称是否正确")
|
||||
continue
|
||||
|
||||
# 获取装备ID - 使用新的游标
|
||||
logging.debug(f"查询装备ID: {equipment_name}")
|
||||
with conn.cursor() as id_cursor:
|
||||
id_cursor.execute("""
|
||||
SELECT id FROM equipment WHERE name = %s
|
||||
""", (equipment_name,))
|
||||
result = id_cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
logging.warning(f"未找到装备: {equipment_name}")
|
||||
continue
|
||||
|
||||
equipment_id = result[0]
|
||||
logging.debug(f"找到装备ID: {equipment_id}")
|
||||
|
||||
# 检查参数是否存在 - 使用新的游标
|
||||
logging.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
|
||||
with conn.cursor() as check_cursor:
|
||||
check_cursor.execute("""
|
||||
SELECT id FROM custom_params
|
||||
WHERE equipment_id = %s AND param_name = %s
|
||||
""", (equipment_id, param_name))
|
||||
exists = check_cursor.fetchone()
|
||||
|
||||
if exists:
|
||||
logging.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
|
||||
continue
|
||||
|
||||
# 插入新的参数 - 使用新的游标
|
||||
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
|
||||
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
|
||||
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
|
||||
|
||||
logging.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
|
||||
with conn.cursor() as insert_cursor:
|
||||
insert_cursor.execute("""
|
||||
INSERT INTO custom_params
|
||||
(equipment_id, param_name, param_value, param_unit, description)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""", (
|
||||
equipment_id,
|
||||
param_name,
|
||||
param_value,
|
||||
param_unit,
|
||||
param_desc
|
||||
))
|
||||
logging.debug(f"成功插入参数记录")
|
||||
|
||||
# 最终提交
|
||||
conn.commit()
|
||||
logging.info("特殊参数导入完成")
|
||||
logging.info("所有数据导入成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error importing data: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
excel_file = 'data/equipment_data_20241108.xlsx'
|
||||
import_training_data(excel_file)
|
||||
logging.info("All data imported successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Import failed: {str(e)}")
|
||||
277
src/init_data.sql
Normal file
277
src/init_data.sql
Normal file
@ -0,0 +1,277 @@
|
||||
-- 插入装备基本信息
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
|
||||
|
||||
-- 插入巡飞弹技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES (
|
||||
1, -- 终结者巡飞弹
|
||||
0.56,
|
||||
0.15,
|
||||
0.20,
|
||||
2.72,
|
||||
160.93,
|
||||
96.56,
|
||||
24,
|
||||
15,
|
||||
'破片杀伤战斗部',
|
||||
'凭自身动力起飞',
|
||||
560,
|
||||
150,
|
||||
200
|
||||
);
|
||||
|
||||
-- 插入火箭炮技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES (
|
||||
2, -- 胜利-2火箭炮
|
||||
10,
|
||||
2.5,
|
||||
3.34,
|
||||
15000,
|
||||
23
|
||||
);
|
||||
|
||||
-- 插入成本数据(示例数据)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(1, 1000000), -- 终结者巡飞弹成本
|
||||
(2, 5000000); -- 胜利-2火箭炮成本
|
||||
|
||||
-- 插入更多巡飞弹变体数据用于训练
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
|
||||
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
|
||||
|
||||
-- 插入变体技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- 终结者-A(稍大型号)
|
||||
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
|
||||
-- 终结者-B(稍小型号)
|
||||
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
|
||||
-- 终结者-C(标准型号的改进版)
|
||||
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
|
||||
|
||||
-- 插入变体成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(3, 1100000), -- 终结者-A成本(较高)
|
||||
(4, 900000), -- 终结者-B成本(较低)
|
||||
(5, 1050000); -- 终结者-C成本(中等)
|
||||
|
||||
-- 添加更多巡飞弹数据
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
|
||||
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
|
||||
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
|
||||
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
|
||||
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
|
||||
|
||||
-- 添加它们的技术参数
|
||||
INSERT INTO technical_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
max_range_km,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- 哈比
|
||||
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
|
||||
-- 游荡者
|
||||
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
|
||||
-- 凤凰
|
||||
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
|
||||
-- 弹簧刀
|
||||
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
|
||||
-- 彩虹-4
|
||||
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
|
||||
|
||||
-- 添加成本数据
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(6, 800000), -- 哈比
|
||||
(7, 500000), -- 游荡者
|
||||
(8, 450000), -- 凤凰
|
||||
(9, 480000), -- 弹簧刀
|
||||
(10, 1500000); -- 彩虹-4
|
||||
|
||||
-- 火箭炮数据
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('BM-21', '火箭炮', '俄罗斯', '面目标'),
|
||||
('SR5', '火箭炮', '中国', '面目标和点目标'),
|
||||
('HIMARS', '火箭炮', '美国', '战术目标'),
|
||||
('LAR-160', '火箭炮', '以色列', '面目标'),
|
||||
('T-122', '火箭炮', '土耳其', '面目标'),
|
||||
('RM-70', '火箭炮', '捷克', '面目标'),
|
||||
('ASTROS II', '火箭炮', '巴西', '面目标和点目标');
|
||||
|
||||
-- 插入火箭炮参数
|
||||
INSERT INTO common_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES
|
||||
-- BM-21
|
||||
(1, 7.35, 2.4, 3.1, 13700, 20.4),
|
||||
-- SR5
|
||||
(2, 10.2, 2.8, 3.2, 28500, 70),
|
||||
-- HIMARS
|
||||
(3, 7.0, 2.4, 3.2, 16250, 70),
|
||||
-- LAR-160
|
||||
(4, 6.7, 2.5, 2.8, 15000, 45),
|
||||
-- T-122
|
||||
(5, 7.2, 2.5, 2.9, 18000, 40),
|
||||
-- RM-70
|
||||
(6, 7.5, 2.5, 3.0, 17200, 20.3),
|
||||
-- ASTROS II
|
||||
(7, 8.0, 2.7, 3.1, 24500, 90);
|
||||
|
||||
INSERT INTO rocket_artillery_params (
|
||||
equipment_id,
|
||||
firing_angle_horizontal,
|
||||
firing_angle_vertical,
|
||||
rocket_length_m,
|
||||
rocket_diameter_mm,
|
||||
rocket_weight_kg,
|
||||
rate_of_fire
|
||||
) VALUES
|
||||
-- BM-21
|
||||
(1, 102, 55, 2.87, 122, 66.6, 40),
|
||||
-- SR5
|
||||
(2, 110, 60, 4.1, 220, 150, 60),
|
||||
-- HIMARS
|
||||
(3, 90, 65, 3.94, 227, 301, 6),
|
||||
-- LAR-160
|
||||
(4, 100, 58, 3.3, 160, 110, 18),
|
||||
-- T-122
|
||||
(5, 110, 65, 2.95, 122, 65.5, 40),
|
||||
-- RM-70
|
||||
(6, 100, 50, 2.87, 122, 66.6, 40),
|
||||
-- ASTROS II
|
||||
(7, 90, 65, 4.3, 300, 550, 30);
|
||||
|
||||
-- 插入成本数据(示例成本)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(1, 800000), -- BM-21
|
||||
(2, 4500000), -- SR5
|
||||
(3, 5500000), -- HIMARS
|
||||
(4, 3500000), -- LAR-160
|
||||
(5, 2800000), -- T-122
|
||||
(6, 1500000), -- RM-70
|
||||
(7, 4800000); -- ASTROS II
|
||||
|
||||
-- 巡飞弹数据
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('Hero-120', '巡飞弹', '以色列', '装甲目标'),
|
||||
('Switchblade 600', '巡飞弹', '美国', '装甲车辆'),
|
||||
('Warmate', '巡飞弹', '波兰', '轻型装甲目标'),
|
||||
('CH-901', '巡飞弹', '中国', '人员和轻型装甲车辆'),
|
||||
('HAROP', '巡飞弹', '以色列', '防空系统'),
|
||||
('Coyote', '巡飞弹', '美国', '无人机和巡飞弹'),
|
||||
('WS-43', '巡飞弹', '中国', '固定目标和装甲目标');
|
||||
|
||||
-- 插入巡飞弹参数
|
||||
INSERT INTO common_params (
|
||||
equipment_id,
|
||||
length_m,
|
||||
width_m,
|
||||
height_m,
|
||||
weight_kg,
|
||||
max_range_km
|
||||
) VALUES
|
||||
-- Hero-120
|
||||
(8, 1.3, 0.23, 0.23, 12.5, 40),
|
||||
-- Switchblade 600
|
||||
(9, 1.3, 0.22, 0.22, 15.0, 40),
|
||||
-- Warmate
|
||||
(10, 1.1, 0.15, 0.15, 5.7, 15),
|
||||
-- CH-901
|
||||
(11, 1.2, 0.18, 0.18, 9.0, 20),
|
||||
-- HAROP
|
||||
(12, 2.5, 0.43, 0.43, 135, 1000),
|
||||
-- Coyote
|
||||
(13, 0.9, 0.12, 0.12, 5.9, 20),
|
||||
-- WS-43
|
||||
(14, 1.8, 0.35, 0.35, 20, 60);
|
||||
|
||||
INSERT INTO loitering_munition_params (
|
||||
equipment_id,
|
||||
max_speed_kmh,
|
||||
cruise_speed_kmh,
|
||||
flight_time_min,
|
||||
warhead_type,
|
||||
launch_mode,
|
||||
folded_length_mm,
|
||||
folded_width_mm,
|
||||
folded_height_mm
|
||||
) VALUES
|
||||
-- Hero-120
|
||||
(8, 180, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230),
|
||||
-- Switchblade 600
|
||||
(9, 185, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220),
|
||||
-- Warmate
|
||||
(10, 150, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150),
|
||||
-- CH-901
|
||||
(11, 160, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180),
|
||||
-- HAROP
|
||||
(12, 185, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430),
|
||||
-- Coyote
|
||||
(13, 150, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120),
|
||||
-- WS-43
|
||||
(14, 170, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350);
|
||||
|
||||
-- 插入成本数据(示例成本)
|
||||
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
|
||||
(8, 150000), -- Hero-120
|
||||
(9, 180000), -- Switchblade 600
|
||||
(10, 80000), -- Warmate
|
||||
(11, 100000), -- CH-901
|
||||
(12, 850000), -- HAROP
|
||||
(13, 75000), -- Coyote
|
||||
(14, 120000); -- WS-43
|
||||
398
src/model_trainer.py
Normal file
398
src/model_trainer.py
Normal file
@ -0,0 +1,398 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.model_selection import cross_val_score
|
||||
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
||||
from sklearn.impute import SimpleImputer
|
||||
import xgboost as xgb
|
||||
import lightgbm as lgb
|
||||
import logging
|
||||
import joblib
|
||||
import os
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from datetime import datetime
|
||||
import json
|
||||
from src.database import get_db_connection
|
||||
from src.data_preparation import DataPreparation
|
||||
|
||||
class ModelTrainer:
|
||||
def __init__(self):
|
||||
self.models = {
|
||||
'xgboost': self._create_xgboost_model(),
|
||||
'lightgbm': self._create_lightgbm_model(),
|
||||
'gbdt': self._create_gbdt_model(),
|
||||
'rf': self._create_rf_model()
|
||||
}
|
||||
self.best_model = None
|
||||
self.imputer = SimpleImputer(strategy='mean')
|
||||
self.feature_scaler = None
|
||||
self.target_scaler = None
|
||||
|
||||
def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
|
||||
"""
|
||||
训练模型并返回评估结果
|
||||
"""
|
||||
try:
|
||||
# 记录数据范围
|
||||
logging.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
|
||||
logging.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
|
||||
|
||||
results = {}
|
||||
best_score = -float('inf')
|
||||
best_model_info = None
|
||||
|
||||
for model_name in model_names:
|
||||
if model_name not in self.models:
|
||||
logging.warning(f"Unknown model: {model_name}")
|
||||
continue
|
||||
|
||||
logging.info(f"Training {model_name}...")
|
||||
model = self.models[model_name]
|
||||
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# 计算评估指标
|
||||
metrics = self._calculate_metrics(
|
||||
model,
|
||||
X_train, y_train,
|
||||
X_val, y_val
|
||||
)
|
||||
|
||||
# 更新最佳模型
|
||||
if metrics['validation']['r2'] > best_score:
|
||||
best_score = metrics['validation']['r2']
|
||||
self.best_model = model
|
||||
best_model_info = {
|
||||
'type': model_name,
|
||||
'r2': float(metrics['validation']['r2']),
|
||||
'mae': float(metrics['validation']['mae']) if metrics['validation']['mae'] is not None else None,
|
||||
'rmse': float(metrics['validation']['rmse']) if metrics['validation']['rmse'] is not None else None
|
||||
}
|
||||
|
||||
# 转换 numpy 数据类型为 Python 原生类型
|
||||
results[model_name] = {
|
||||
'train': {
|
||||
'r2': float(metrics['train']['r2']),
|
||||
'mae': float(metrics['train']['mae']),
|
||||
'rmse': float(metrics['train']['rmse'])
|
||||
},
|
||||
'validation': {
|
||||
'r2': float(metrics['validation']['r2']),
|
||||
'mae': float(metrics['validation']['mae']) if metrics['validation']['mae'] is not None else None,
|
||||
'rmse': float(metrics['validation']['rmse']) if metrics['validation']['rmse'] is not None else None
|
||||
}
|
||||
}
|
||||
|
||||
# 保存最佳模型
|
||||
if equipment_type and best_model_info:
|
||||
self._save_best_model(equipment_type, best_model_info, X_train)
|
||||
|
||||
# 转换特征重要性为列表
|
||||
feature_importance = None
|
||||
if self.best_model and hasattr(self.best_model, 'feature_importances_'):
|
||||
feature_importance = [float(x) for x in self.best_model.feature_importances_]
|
||||
|
||||
return {
|
||||
'metrics': results,
|
||||
'best_model': best_model_info,
|
||||
'feature_importance': feature_importance
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in model training: {str(e)}")
|
||||
raise
|
||||
|
||||
def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
|
||||
"""
|
||||
计算模型评估指标
|
||||
"""
|
||||
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
||||
|
||||
# 训练集评估
|
||||
train_pred = model.predict(X_train)
|
||||
train_metrics = {
|
||||
'r2': r2_score(y_train, train_pred),
|
||||
'mae': mean_absolute_error(y_train, train_pred),
|
||||
'rmse': np.sqrt(mean_squared_error(y_train, train_pred))
|
||||
}
|
||||
|
||||
# 验证集评估
|
||||
if X_val is not None and y_val is not None:
|
||||
val_pred = model.predict(X_val)
|
||||
val_metrics = {
|
||||
'r2': r2_score(y_val, val_pred),
|
||||
'mae': mean_absolute_error(y_val, val_pred),
|
||||
'rmse': np.sqrt(mean_squared_error(y_val, val_pred))
|
||||
}
|
||||
else:
|
||||
# 使用交叉验证
|
||||
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
|
||||
val_metrics = {
|
||||
'r2': cv_scores.mean(),
|
||||
'mae': None,
|
||||
'rmse': None
|
||||
}
|
||||
|
||||
return {
|
||||
'train': train_metrics,
|
||||
'validation': val_metrics
|
||||
}
|
||||
|
||||
def _create_xgboost_model(self):
|
||||
"""
|
||||
创建 XGBoost 模型,增强正则化
|
||||
"""
|
||||
return xgb.XGBRegressor(
|
||||
n_estimators=50, # 减少树的数量
|
||||
learning_rate=0.05, # 减小学习率
|
||||
max_depth=3, # 减小树的深度
|
||||
min_child_weight=3, # 增加最小子节点权重
|
||||
subsample=0.7, # 减小样本采样比例
|
||||
colsample_bytree=0.7, # 减小特征采样比例
|
||||
reg_alpha=0.1, # L1 正则化
|
||||
reg_lambda=1, # L2 正则化
|
||||
random_state=42
|
||||
)
|
||||
|
||||
def _create_lightgbm_model(self):
|
||||
"""
|
||||
创建 LightGBM 模型,增强正则化
|
||||
"""
|
||||
return lgb.LGBMRegressor(
|
||||
n_estimators=50,
|
||||
learning_rate=0.05,
|
||||
max_depth=3,
|
||||
num_leaves=7,
|
||||
min_data_in_leaf=3,
|
||||
min_sum_hessian_in_leaf=1e-3,
|
||||
subsample=0.7,
|
||||
colsample_bytree=0.7,
|
||||
reg_alpha=0.1,
|
||||
reg_lambda=1,
|
||||
random_state=42,
|
||||
verbose=-1
|
||||
)
|
||||
|
||||
def _create_gbdt_model(self):
|
||||
"""
|
||||
创建 GBDT 模型,增强正则化以减轻过拟合
|
||||
"""
|
||||
return GradientBoostingRegressor(
|
||||
n_estimators=20, # 减少树的数量
|
||||
learning_rate=0.01, # 减小学习率
|
||||
max_depth=2, # 减小树的深度
|
||||
min_samples_split=4, # 增加分裂所需的最小样本数
|
||||
min_samples_leaf=3, # 增加叶子节点最小样本数
|
||||
subsample=0.5, # 减小样本采样比例
|
||||
random_state=42,
|
||||
validation_fraction=0.2 # 使用部分训练数据作为验证集
|
||||
)
|
||||
|
||||
def _create_rf_model(self):
|
||||
"""
|
||||
创建随机森林模型,针对小样本数据调整参数
|
||||
"""
|
||||
return RandomForestRegressor(
|
||||
n_estimators=100, # 增加树的数量
|
||||
max_depth=4, # 限制树的深度
|
||||
min_samples_split=2, # 减小分需的最小样本数
|
||||
min_samples_leaf=1, # 减小叶子节点最小样本数
|
||||
max_features='sqrt', # 特征采样
|
||||
bootstrap=True, # 使用 bootstrap 采样
|
||||
oob_score=True, # 计算袋外分数
|
||||
random_state=42
|
||||
)
|
||||
|
||||
def _save_best_model(self, equipment_type, best_model_info, X_train):
|
||||
"""
|
||||
保存最佳模型
|
||||
"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 保存模型文件
|
||||
model_path = f'{model_dir}/{equipment_type}_{timestamp}'
|
||||
if isinstance(self.best_model, xgb.XGBRegressor):
|
||||
self.best_model.save_model(f'{model_path}.json')
|
||||
model_format = 'json'
|
||||
else:
|
||||
joblib.dump(self.best_model, f'{model_path}.joblib')
|
||||
model_format = 'joblib'
|
||||
|
||||
# 验证标准化器
|
||||
if not isinstance(self.feature_scaler, StandardScaler):
|
||||
raise ValueError("Invalid feature scaler")
|
||||
if not isinstance(self.target_scaler, StandardScaler):
|
||||
raise ValueError("Invalid target scaler")
|
||||
|
||||
# 保存标准化器
|
||||
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
|
||||
joblib.dump({
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}, scaler_path)
|
||||
|
||||
logging.info(f"Saved model to {model_path}.{model_format}")
|
||||
logging.info(f"Saved scalers to {scaler_path}")
|
||||
|
||||
# 更新数据库中的模型记录
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 将之前的激活模型设置为非激活
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s
|
||||
""", (equipment_type,))
|
||||
|
||||
# 插入新的模型记录
|
||||
cursor.execute("""
|
||||
INSERT INTO trained_models (
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
scaler_path, r2_score, mae, rmse, feature_importance,
|
||||
training_data_size, training_date, is_active, created_by
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, 'system')
|
||||
""", (
|
||||
f'{best_model_info["type"]}_{timestamp}',
|
||||
best_model_info["type"],
|
||||
equipment_type,
|
||||
f'{model_path}.{model_format}',
|
||||
scaler_path,
|
||||
best_model_info["r2"],
|
||||
best_model_info["mae"],
|
||||
best_model_info["rmse"],
|
||||
json.dumps(self.feature_importance) if hasattr(self, 'feature_importance') else None,
|
||||
len(X_train)
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
|
||||
logging.info(f"Best model saved: {model_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving best model: {str(e)}")
|
||||
return False
|
||||
|
||||
def load_model(self, equipment_type):
|
||||
"""
|
||||
加载已训练的模型
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Loading model for {equipment_type}")
|
||||
|
||||
# 从数据库获<E5BA93><E88EB7>最新的激活模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
|
||||
model_record = cursor.fetchone()
|
||||
if not model_record:
|
||||
raise ValueError(f"No active model found for {equipment_type}")
|
||||
|
||||
logging.info(f"Found model: {model_record['model_name']}")
|
||||
logging.info(f"Model path: {model_record['model_path']}")
|
||||
logging.info(f"Scaler path: {model_record['scaler_path']}")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(model_record['model_path']):
|
||||
raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
|
||||
if not os.path.exists(model_record['scaler_path']):
|
||||
raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
|
||||
|
||||
# 加载模型文件
|
||||
if model_record['model_type'] == 'xgboost':
|
||||
self.best_model = xgb.XGBRegressor()
|
||||
self.best_model.load_model(model_record['model_path'])
|
||||
else:
|
||||
self.best_model = joblib.load(model_record['model_path'])
|
||||
|
||||
# 加载标准化器
|
||||
try:
|
||||
scalers = joblib.load(model_record['scaler_path'])
|
||||
logging.info(f"Loaded scalers: {scalers.keys()}")
|
||||
|
||||
if 'feature_scaler' not in scalers or 'target_scaler' not in scalers:
|
||||
raise ValueError("Missing scalers in saved file")
|
||||
|
||||
self.feature_scaler = scalers['feature_scaler']
|
||||
self.target_scaler = scalers['target_scaler']
|
||||
|
||||
# 验证标准化器
|
||||
if not hasattr(self.feature_scaler, 'transform') or not hasattr(self.target_scaler, 'transform'):
|
||||
raise ValueError("Invalid scaler objects")
|
||||
|
||||
logging.info("Model and scalers loaded successfully")
|
||||
logging.info(f"Feature scaler type: {type(self.feature_scaler)}")
|
||||
logging.info(f"Target scaler type: {type(self.target_scaler)}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading scalers: {str(e)}")
|
||||
logging.error(f"Scaler file content: {scalers if 'scalers' in locals() else 'Not loaded'}")
|
||||
raise ValueError(f"Failed to load scalers: {str(e)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
logging.error("Detailed traceback:", exc_info=True)
|
||||
return False
|
||||
|
||||
def predict(self, features):
|
||||
"""
|
||||
使用加载的模型进行预测
|
||||
"""
|
||||
try:
|
||||
if not self.best_model:
|
||||
raise ValueError("No model loaded")
|
||||
|
||||
if not self.feature_scaler:
|
||||
raise ValueError("Feature scaler not loaded")
|
||||
|
||||
if not self.target_scaler:
|
||||
raise ValueError("Target scaler not loaded")
|
||||
|
||||
logging.info("Starting prediction")
|
||||
logging.info(f"Input features shape: {features.shape}")
|
||||
logging.info(f"Input features: \n{features}")
|
||||
|
||||
# 处理缺失值
|
||||
features_filled = np.array(features, dtype=float)
|
||||
features_filled[np.isnan(features_filled)] = 0
|
||||
features_filled = np.nan_to_num(features_filled, 0)
|
||||
|
||||
logging.info(f"Filled features: \n{features_filled}")
|
||||
|
||||
# 标准化特征
|
||||
X = self.feature_scaler.transform(features_filled)
|
||||
logging.info(f"Transformed features shape: {X.shape}")
|
||||
logging.info(f"Transformed features: \n{X}")
|
||||
|
||||
# 预测
|
||||
y_pred_scaled = self.best_model.predict(X)
|
||||
logging.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
|
||||
logging.info(f"Scaled prediction: {y_pred_scaled}")
|
||||
|
||||
# 反标准化
|
||||
y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
|
||||
logging.info(f"Final prediction shape: {y_pred.shape}")
|
||||
logging.info(f"Final prediction: {y_pred}")
|
||||
|
||||
# 记录标准化器的参数
|
||||
logging.info("Target scaler params:")
|
||||
logging.info(f"Mean: {self.target_scaler.mean_}")
|
||||
logging.info(f"Scale: {self.target_scaler.scale_}")
|
||||
|
||||
return y_pred.ravel()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in prediction: {str(e)}")
|
||||
raise
|
||||
313
src/pls_regression.py
Normal file
313
src/pls_regression.py
Normal file
@ -0,0 +1,313 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from sklearn.metrics import r2_score, mean_absolute_error
|
||||
from sklearn.model_selection import LeaveOneOut
|
||||
import os
|
||||
from datetime import datetime
|
||||
import joblib
|
||||
from src.database.db_connection import get_db_connection
|
||||
|
||||
class PLSPredictor:
|
||||
def __init__(self, n_components=2):
|
||||
"""
|
||||
初始化PLS回归模型
|
||||
"""
|
||||
self.model = PLSRegression(
|
||||
n_components=n_components,
|
||||
scale=True,
|
||||
max_iter=500,
|
||||
tol=1e-6
|
||||
)
|
||||
self.scaler_X = StandardScaler()
|
||||
self.scaler_y = StandardScaler()
|
||||
self.feature_names = None
|
||||
self.model_path = None
|
||||
|
||||
# 尝试加载已训练的模型
|
||||
self.load_model()
|
||||
|
||||
# 初始化示例数据
|
||||
self._initialize_scalers()
|
||||
|
||||
def _initialize_scalers(self):
|
||||
"""
|
||||
使用示例数据初始化标准化器
|
||||
"""
|
||||
# 创建示例数据
|
||||
example_data = pd.DataFrame({
|
||||
'length_m': [0.56, 0.58, 0.54],
|
||||
'width_m': [0.15, 0.16, 0.14],
|
||||
'height_m': [0.20, 0.21, 0.19],
|
||||
'weight_kg': [2.72, 2.85, 2.60],
|
||||
'max_range_km': [24, 26, 22],
|
||||
'max_speed_kmh': [160.93, 170, 155],
|
||||
'cruise_speed_kmh': [96.56, 100, 93],
|
||||
'flight_time_min': [15, 16, 14],
|
||||
'folded_length_mm': [560, 580, 540],
|
||||
'folded_width_mm': [150, 160, 140],
|
||||
'folded_height_mm': [200, 210, 190]
|
||||
})
|
||||
|
||||
# 初始化特征标准化器
|
||||
self.scaler_X.fit(example_data)
|
||||
|
||||
# 初始化目标变量标准化器
|
||||
example_costs = np.array([[1000000], [1100000], [900000]])
|
||||
self.scaler_y.fit(example_costs)
|
||||
|
||||
def predict(self, features):
|
||||
"""
|
||||
使用PLS模型进行预测
|
||||
"""
|
||||
try:
|
||||
# 转换输入数据为DataFrame
|
||||
if not isinstance(features, pd.DataFrame):
|
||||
features = pd.DataFrame([features])
|
||||
|
||||
# 选择数值特征
|
||||
numeric_features = features.select_dtypes(include=[np.number]).columns
|
||||
X = features[numeric_features]
|
||||
|
||||
# 标准化特征
|
||||
X_scaled = self.scaler_X.transform(X)
|
||||
|
||||
# 预测
|
||||
y_pred_scaled = self.model.predict(X_scaled)
|
||||
y_pred = self.scaler_y.inverse_transform(y_pred_scaled)
|
||||
|
||||
# 计算置信区间
|
||||
ci = self._calculate_confidence_intervals(y_pred)
|
||||
|
||||
return {
|
||||
'predicted_cost': float(abs(y_pred[0][0])),
|
||||
'confidence_interval': {
|
||||
'lower': float(abs(ci['lower'])),
|
||||
'upper': float(abs(ci['upper']))
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in PLS prediction: {str(e)}")
|
||||
raise Exception(f"PLS prediction error: {str(e)}")
|
||||
|
||||
def fit(self, X, y):
|
||||
"""
|
||||
训练PLS模型
|
||||
"""
|
||||
try:
|
||||
logging.info("=== PLS Training Start ===")
|
||||
|
||||
# 1. 检查输入数据
|
||||
logging.info(f"Input X type: {type(X)}, shape: {X.shape if hasattr(X, 'shape') else 'no shape'}")
|
||||
logging.info(f"Input y type: {type(y)}, shape: {y.shape if hasattr(y, 'shape') else 'no shape'}")
|
||||
logging.info(f"X data:\n{X}")
|
||||
logging.info(f"y data:\n{y}")
|
||||
|
||||
# 2. 转换为numpy数组
|
||||
if isinstance(X, pd.DataFrame):
|
||||
# 保存特征名称
|
||||
self.feature_names = X.columns.tolist()
|
||||
X = X.values
|
||||
X = np.array(X, dtype=float)
|
||||
y = np.array(y, dtype=float)
|
||||
|
||||
# 3. 标准化数据
|
||||
logging.info("Standardizing data...")
|
||||
X_scaled = self.scaler_X.fit_transform(X)
|
||||
y_scaled = self.scaler_y.fit_transform(y.reshape(-1, 1))
|
||||
logging.info(f"X_scaled shape: {X_scaled.shape}")
|
||||
logging.info(f"y_scaled shape: {y_scaled.shape}")
|
||||
|
||||
# 4. 训练模型
|
||||
logging.info("Training PLS model...")
|
||||
self.model.fit(X_scaled, y_scaled.ravel())
|
||||
logging.info("PLS model training completed")
|
||||
|
||||
# 5. 计算R²分数
|
||||
logging.info("Calculating R² score...")
|
||||
y_pred = self.model.predict(X_scaled)
|
||||
y_pred = self.scaler_y.inverse_transform(y_pred.reshape(-1, 1))
|
||||
r2 = r2_score(y.reshape(-1, 1), y_pred)
|
||||
logging.info(f"R² score: {r2}")
|
||||
|
||||
result = {
|
||||
'r2_score': float(r2),
|
||||
'n_components': int(self.model.n_components),
|
||||
'feature_importance': None
|
||||
}
|
||||
logging.info(f"Final result: {result}")
|
||||
logging.info("=== PLS Training End ===")
|
||||
|
||||
# 保存训练好的模型
|
||||
equipment_type = 'missile' # 或者从参数中获取
|
||||
self.save_model(equipment_type)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in PLS training: {str(e)}")
|
||||
logging.error(f"Error traceback:", exc_info=True)
|
||||
raise Exception(f"PLS training error: {str(e)}")
|
||||
|
||||
def _calculate_confidence_intervals(self, predictions, confidence=0.95):
|
||||
"""
|
||||
计算预测值的置信区间
|
||||
"""
|
||||
try:
|
||||
# 使用 bootstrap 方法计算置信区间
|
||||
n_predictions = 1000
|
||||
bootstrap_predictions = []
|
||||
|
||||
for _ in range(n_predictions):
|
||||
# 添加随机噪声
|
||||
noise = np.random.normal(0, predictions.mean() * 0.05, predictions.shape)
|
||||
noisy_pred = predictions + noise
|
||||
bootstrap_predictions.append(noisy_pred)
|
||||
|
||||
bootstrap_predictions = np.array(bootstrap_predictions).flatten()
|
||||
|
||||
# 计算置信区间
|
||||
lower = np.percentile(bootstrap_predictions, ((1 - confidence) / 2) * 100)
|
||||
upper = np.percentile(bootstrap_predictions, (1 - (1 - confidence) / 2) * 100)
|
||||
|
||||
return {
|
||||
'lower': float(lower),
|
||||
'upper': float(upper)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error calculating confidence intervals: {str(e)}")
|
||||
# 如果计算失败,返回基于10%标准差的区间
|
||||
mean_pred = np.mean(predictions)
|
||||
return {
|
||||
'lower': float(mean_pred * 0.9),
|
||||
'upper': float(mean_pred * 1.1)
|
||||
}
|
||||
|
||||
def _get_feature_importance(self):
|
||||
"""
|
||||
计算特征重要性
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self.model, 'x_weights_'):
|
||||
return {}
|
||||
|
||||
# 获取 VIP 分数
|
||||
t = self.model.x_scores_
|
||||
w = self.model.x_weights_
|
||||
q = self.model.y_loadings_
|
||||
|
||||
# 计算每个特征的 VIP 分数
|
||||
m, p = w.shape
|
||||
vips = np.zeros((p,))
|
||||
|
||||
s = np.diag(t.T @ t @ q.T @ q).reshape(m, -1)
|
||||
total_s = np.sum(s)
|
||||
|
||||
for i in range(p):
|
||||
weight = np.array([(w[j,i] / np.linalg.norm(w[:,i]))**2 for j in range(m)])
|
||||
vips[i] = np.sqrt(p*(s.T @ weight)/total_s)
|
||||
|
||||
# 创建特征重要性字典
|
||||
feature_importance = {}
|
||||
for i, score in enumerate(vips):
|
||||
feature_name = f"feature_{i}" if self.feature_names is None else self.feature_names[i]
|
||||
feature_importance[feature_name] = float(score)
|
||||
|
||||
# 按重要性排序
|
||||
return dict(sorted(feature_importance.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error calculating feature importance: {str(e)}")
|
||||
return {}
|
||||
|
||||
def save_model(self, equipment_type):
|
||||
"""
|
||||
保存模型和标准化器
|
||||
"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 保存模型文件
|
||||
model_path = f'{model_dir}/pls_{equipment_type}_{timestamp}'
|
||||
joblib.dump({
|
||||
'model': self.model,
|
||||
'scaler_X': self.scaler_X,
|
||||
'scaler_y': self.scaler_y,
|
||||
'feature_names': self.feature_names
|
||||
}, f'{model_path}.joblib')
|
||||
|
||||
# 更新数据库中的模型记录
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 将之前的激活模型设置为非激活
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type = 'pls'
|
||||
""", (equipment_type,))
|
||||
|
||||
# 插入新的模型记录
|
||||
cursor.execute("""
|
||||
INSERT INTO trained_models (
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
r2_score, training_date, is_active, created_by
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, 'system')
|
||||
""", (
|
||||
f'PLS_{timestamp}',
|
||||
'pls',
|
||||
equipment_type,
|
||||
f'{model_path}.joblib',
|
||||
float(self.r2_score_)
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self.model_path = f'{model_path}.joblib'
|
||||
logging.info(f"Model saved to {self.model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving model: {str(e)}")
|
||||
raise Exception(f"Failed to save model: {str(e)}")
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
加载最新的激活模型
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 获取最新的激活模型
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE model_type = 'pls' AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""")
|
||||
|
||||
model_record = cursor.fetchone()
|
||||
|
||||
if model_record and os.path.exists(model_record['model_path']):
|
||||
# 加载模型文件
|
||||
saved_data = joblib.load(model_record['model_path'])
|
||||
self.model = saved_data['model']
|
||||
self.scaler_X = saved_data['scaler_X']
|
||||
self.scaler_y = saved_data['scaler_y']
|
||||
self.feature_names = saved_data['feature_names']
|
||||
self.model_path = model_record['model_path']
|
||||
|
||||
logging.info(f"Loaded model from {self.model_path}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
return False
|
||||
9
src/real_data.sql
Normal file
9
src/real_data.sql
Normal file
@ -0,0 +1,9 @@
|
||||
-- 火箭炮数据(13种)
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('BM-21', '火箭炮', '俄罗斯', '面目标'),
|
||||
-- ... 其他12种火箭炮数据
|
||||
|
||||
-- 巡飞弹数据(18种)
|
||||
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
|
||||
('Hero-120', '巡飞弹', '以色列', '装甲目标'),
|
||||
-- ... 其他17种巡飞弹数据
|
||||
1298
src/routes.py
Normal file
1298
src/routes.py
Normal file
File diff suppressed because it is too large
Load Diff
28
src/run.py
Normal file
28
src/run.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
import logging
|
||||
from src.app import app
|
||||
|
||||
# 确保必要的目录存在
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
os.makedirs('models', exist_ok=True)
|
||||
os.makedirs('data', exist_ok=True)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
filename='logs/server.log',
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
logging.info("Starting server...")
|
||||
app.run(
|
||||
host='localhost',
|
||||
port=5001,
|
||||
debug=True, # 启用调试模式
|
||||
use_reloader=True # 启用自动重载
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Server failed to start: {str(e)}")
|
||||
raise
|
||||
129
src/schema.sql
Normal file
129
src/schema.sql
Normal file
@ -0,0 +1,129 @@
|
||||
-- 装备基本信息表
|
||||
CREATE TABLE equipment (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100), -- 名称
|
||||
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
|
||||
manufacturer VARCHAR(100), -- 制造商
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 通用参数表
|
||||
CREATE TABLE common_params (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
equipment_id INT,
|
||||
length_m FLOAT, -- 总长(m)
|
||||
width_m FLOAT, -- 宽度(m)
|
||||
height_m FLOAT, -- 高度(m)
|
||||
weight_kg FLOAT, -- 重量(kg)
|
||||
max_range_km FLOAT, -- 最大射程(km)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 火箭炮特有参数表
|
||||
CREATE TABLE rocket_artillery_params (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
equipment_id INT,
|
||||
firing_angle_horizontal FLOAT, -- 方向射界(度)
|
||||
firing_angle_vertical FLOAT, -- 高低射界(度)
|
||||
rocket_length_m FLOAT, -- 火箭弹长度(m)
|
||||
rocket_diameter_mm FLOAT, -- 弹体直径(mm)
|
||||
rocket_weight_kg FLOAT, -- 火箭弹重量(kg)
|
||||
rate_of_fire FLOAT, -- 射速(发/分钟)
|
||||
combat_weight_kg FLOAT, -- 战斗重量(kg)
|
||||
speed_kmh FLOAT, -- 速度(km/h)
|
||||
min_range_km FLOAT, -- 最小射程(km)
|
||||
mobility_type VARCHAR(50), -- 行走方式
|
||||
structure_layout VARCHAR(100), -- 结构布局
|
||||
engine_model VARCHAR(100), -- 发动机型号
|
||||
engine_params TEXT, -- 发动机参数
|
||||
power_hp FLOAT, -- 功率(hp)
|
||||
travel_range_km FLOAT, -- 行程(km)
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 巡飞弹特有参数表
|
||||
CREATE TABLE loitering_munition_params (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
equipment_id INT,
|
||||
wingspan_m FLOAT, -- 翼展(m)
|
||||
warhead_weight_kg FLOAT, -- 战斗部重量(kg)
|
||||
max_speed_ms FLOAT, -- 最大速度(m/s)
|
||||
cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
|
||||
flight_time_min FLOAT, -- 巡飞时间(min)
|
||||
warhead_type VARCHAR(50), -- 战斗部类型
|
||||
launch_mode VARCHAR(50), -- 发射方式
|
||||
folded_length_mm FLOAT, -- 折叠长度(mm)
|
||||
folded_width_mm FLOAT, -- 折叠宽度(mm)
|
||||
folded_height_mm FLOAT, -- 折叠高度(mm)
|
||||
power_system VARCHAR(100), -- 动力装置
|
||||
guidance_system VARCHAR(100), -- 制导体制
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 成本数据表
|
||||
CREATE TABLE cost_data (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
equipment_id INT,
|
||||
actual_cost DECIMAL(15,2), -- 实际成本(元)
|
||||
predicted_cost DECIMAL(15,2), -- 预测成本(元)
|
||||
prediction_date TIMESTAMP, -- 预测日期
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 特殊参数表
|
||||
CREATE TABLE custom_params (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
equipment_id INT,
|
||||
param_name VARCHAR(100), -- 参数名称
|
||||
param_value VARCHAR(255), -- 参数值
|
||||
param_unit VARCHAR(50), -- 参数单位
|
||||
description TEXT, -- 参数说明
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_equipment_type ON equipment(type);
|
||||
CREATE INDEX idx_equipment_name ON equipment(name);
|
||||
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
|
||||
|
||||
-- 数据集表
|
||||
CREATE TABLE datasets (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL, -- 数据集名称
|
||||
description TEXT, -- 数据集描述
|
||||
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
|
||||
purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证)
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 数据集-装备关联表
|
||||
CREATE TABLE dataset_equipment (
|
||||
dataset_id INT NOT NULL,
|
||||
equipment_id INT NOT NULL,
|
||||
PRIMARY KEY (dataset_id, equipment_id),
|
||||
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
|
||||
);
|
||||
|
||||
-- 训练模型表
|
||||
CREATE TABLE trained_models (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
model_name VARCHAR(100) NOT NULL, -- 模型名称
|
||||
model_type VARCHAR(50) NOT NULL, -- 模型类型
|
||||
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
|
||||
model_path VARCHAR(255) NOT NULL, -- 模型文件路径
|
||||
scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径
|
||||
r2_score FLOAT, -- R²分数
|
||||
mae FLOAT, -- 平均绝对误差
|
||||
rmse FLOAT, -- 均方根误差
|
||||
feature_importance JSON, -- 特征重要性
|
||||
training_data_size INT, -- 训练数据量
|
||||
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
|
||||
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
|
||||
created_by VARCHAR(50) -- 创建者
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- 添加索引
|
||||
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
|
||||
CREATE INDEX idx_model_active ON trained_models(is_active);
|
||||
192
src/test_api.py
Normal file
192
src/test_api.py
Normal file
@ -0,0 +1,192 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
def test_api_endpoints():
|
||||
"""
|
||||
测试API各个端点
|
||||
"""
|
||||
base_url = 'http://localhost:5001/api'
|
||||
|
||||
# 1. 测试根路由
|
||||
print("\n1. 测试 API 根路由")
|
||||
response = requests.get(f'{base_url}/')
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
|
||||
# 2. 测试机器学习预测接口
|
||||
print("\n2. 测试机器学习预测接口")
|
||||
predict_data = {
|
||||
"type": "巡飞弹",
|
||||
"length_m": 0.56,
|
||||
"width_m": 0.15,
|
||||
"height_m": 0.20,
|
||||
"weight_kg": 2.72,
|
||||
"max_range_km": 24,
|
||||
"max_speed_kmh": 160.93,
|
||||
"cruise_speed_kmh": 96.56,
|
||||
"flight_time_min": 15,
|
||||
"folded_length_mm": 560,
|
||||
"folded_width_mm": 150,
|
||||
"folded_height_mm": 200,
|
||||
"warhead_type": "破片杀伤战斗部",
|
||||
"launch_mode": "凭自身动力起飞"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f'{base_url}/predict',
|
||||
json=predict_data
|
||||
)
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
|
||||
# 3. 测试 PLS 预测接口
|
||||
print("\n3. 测试 PLS 预测接口")
|
||||
response = requests.post(
|
||||
f'{base_url}/pls/predict',
|
||||
json=predict_data
|
||||
)
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
|
||||
# 4. 测试特征分析接口
|
||||
print("\n4. 测试特征分析接口")
|
||||
analysis_data = {
|
||||
"data": [{
|
||||
"type": "巡飞弹",
|
||||
"length_m": 0.56,
|
||||
"width_m": 0.15,
|
||||
"height_m": 0.20,
|
||||
"weight_kg": 2.72,
|
||||
"max_range_km": 24,
|
||||
"max_speed_kmh": 160.93,
|
||||
"cruise_speed_kmh": 96.56,
|
||||
"flight_time_min": 15,
|
||||
"folded_length_mm": 560,
|
||||
"folded_width_mm": 150,
|
||||
"folded_height_mm": 200
|
||||
}],
|
||||
"cost": [1000000]
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f'{base_url}/analyze-features',
|
||||
json=analysis_data
|
||||
)
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
|
||||
# 5. 测试机器学习模型训练接口
|
||||
print("\n5. 测试机器学习模型训练接口")
|
||||
training_data = {
|
||||
"training_data": [
|
||||
{
|
||||
"type": "巡飞弹",
|
||||
"length_m": 0.56,
|
||||
"width_m": 0.15,
|
||||
"height_m": 0.20,
|
||||
"weight_kg": 2.72,
|
||||
"max_range_km": 24,
|
||||
"max_speed_kmh": 160.93,
|
||||
"cruise_speed_kmh": 96.56,
|
||||
"flight_time_min": 15,
|
||||
"folded_length_mm": 560,
|
||||
"folded_width_mm": 150,
|
||||
"folded_height_mm": 200,
|
||||
"cost": 1000000
|
||||
},
|
||||
{
|
||||
"type": "巡飞弹",
|
||||
"length_m": 0.58,
|
||||
"width_m": 0.16,
|
||||
"height_m": 0.21,
|
||||
"weight_kg": 2.85,
|
||||
"max_range_km": 26,
|
||||
"max_speed_kmh": 170,
|
||||
"cruise_speed_kmh": 100,
|
||||
"flight_time_min": 16,
|
||||
"folded_length_mm": 580,
|
||||
"folded_width_mm": 160,
|
||||
"folded_height_mm": 210,
|
||||
"cost": 1100000
|
||||
},
|
||||
{
|
||||
"type": "巡飞弹",
|
||||
"length_m": 0.54,
|
||||
"width_m": 0.14,
|
||||
"height_m": 0.19,
|
||||
"weight_kg": 2.60,
|
||||
"max_range_km": 22,
|
||||
"max_speed_kmh": 155,
|
||||
"cruise_speed_kmh": 93,
|
||||
"flight_time_min": 14,
|
||||
"folded_length_mm": 540,
|
||||
"folded_width_mm": 140,
|
||||
"folded_height_mm": 190,
|
||||
"cost": 900000
|
||||
}
|
||||
],
|
||||
"equipment_type": "巡飞弹"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f'{base_url}/train',
|
||||
json=training_data
|
||||
)
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
|
||||
# 6. 测试 PLS 模型训练接口
|
||||
print("\n6. 测试 PLS 模型训练接口")
|
||||
|
||||
# 使用真实的训练数据
|
||||
training_data = {
|
||||
"training_data": [
|
||||
{
|
||||
"length_m": 1.3, # 哈比
|
||||
"width_m": 0.23,
|
||||
"height_m": 0.23,
|
||||
"weight_kg": 12.5,
|
||||
"max_range_km": 40,
|
||||
"max_speed_kmh": 180,
|
||||
"cruise_speed_kmh": 100,
|
||||
"flight_time_min": 60,
|
||||
"folded_length_mm": 1300,
|
||||
"folded_width_mm": 230,
|
||||
"folded_height_mm": 230
|
||||
},
|
||||
{
|
||||
"length_m": 2.5, # HAROP
|
||||
"width_m": 0.43,
|
||||
"height_m": 0.43,
|
||||
"weight_kg": 135,
|
||||
"max_range_km": 1000,
|
||||
"max_speed_kmh": 185,
|
||||
"cruise_speed_kmh": 110,
|
||||
"flight_time_min": 360,
|
||||
"folded_length_mm": 2500,
|
||||
"folded_width_mm": 430,
|
||||
"folded_height_mm": 430
|
||||
},
|
||||
{
|
||||
"length_m": 1.1, # Warmate
|
||||
"width_m": 0.15,
|
||||
"height_m": 0.15,
|
||||
"weight_kg": 5.7,
|
||||
"max_range_km": 15,
|
||||
"max_speed_kmh": 150,
|
||||
"cruise_speed_kmh": 90,
|
||||
"flight_time_min": 30,
|
||||
"folded_length_mm": 1100,
|
||||
"folded_width_mm": 150,
|
||||
"folded_height_mm": 150
|
||||
}
|
||||
],
|
||||
"actual_costs": [150000, 850000, 80000] # 对应的实际成本
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f'{base_url}/pls/train',
|
||||
json=training_data
|
||||
)
|
||||
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
test_api_endpoints()
|
||||
except Exception as e:
|
||||
print(f"测试过程中出现错误: {str(e)}")
|
||||
18
vite.config.js
Normal file
18
vite.config.js
Normal file
@ -0,0 +1,18 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import path from 'path'
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, 'src'),
|
||||
}
|
||||
},
|
||||
define: {
|
||||
__VUE_OPTIONS_API__: true,
|
||||
__VUE_PROD_DEVTOOLS__: false,
|
||||
__VUE_PROD_HYDRATION_MISMATCH_DETAILS__: false
|
||||
}
|
||||
})
|
||||
Loading…
Reference in New Issue
Block a user