增加了部署文档,调整了前端代码

This commit is contained in:
Tian jianyong 2024-11-11 18:04:22 +08:00
parent fccd4c4366
commit 30d4b58cdf
68 changed files with 9424 additions and 638 deletions

25
.env.example Normal file
View File

@ -0,0 +1,25 @@
# 数据库配置
MYSQL_HOST=localhost
MYSQL_USER=root
MYSQL_PASSWORD=your_password_here
MYSQL_DATABASE=equipment_cost_db
# 服务配置
PORT=5001
DEBUG=False
# 日志配置
LOG_LEVEL=INFO
LOG_DIR=logs
# 模型配置
MODEL_DIR=models
DATA_DIR=data
# 安全配置
SECRET_KEY=your_secret_key_here
ALLOWED_HOSTS=localhost,127.0.0.1
# 其他配置
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
ALLOWED_FILE_TYPES=.xlsx,.xls

View File

@ -8,8 +8,8 @@ DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
SECRET_KEY = secrets.token_hex(16)
# 环境配置
DEBUG = True
ENV = 'development'
DEBUG = False
ENV = 'production'
# 文件上传配置
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')

Binary file not shown.

View File

@ -0,0 +1,25 @@
# 数据库配置
MYSQL_HOST=localhost
MYSQL_USER=root
MYSQL_PASSWORD=your_password_here
MYSQL_DATABASE=equipment_cost_db
# 服务配置
PORT=5001
DEBUG=False
# 日志配置
LOG_LEVEL=INFO
LOG_DIR=logs
# 模型配置
MODEL_DIR=models
DATA_DIR=data
# 安全配置
SECRET_KEY=your_secret_key_here
ALLOWED_HOSTS=localhost,127.0.0.1
# 其他配置
UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
ALLOWED_FILE_TYPES=.xlsx,.xls

View File

@ -0,0 +1,663 @@
# 装备成本估算系统 API 文档
这个 API 文档提供了完整的接口说明,包括:
- 每个端点的详细描述
- 请求和响应的具体示例
- 清晰的参数格式要求
- 统一的错误处理说明
- 重要的注意事项
文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。
## 基本信息
- 基础URL: `http://localhost:5001/api`
- 版本: 1.0.0
- 响应格式: JSON
## API 端点列表
### 1. 获取 API 信息
获取 API 版本信息和可用端点列表。
- **URL**: `/`
- **方法**: `GET`
- **响应示例**:
json
{
"name": "装备成本估算系统 API",
"version": "1.0.0",
"endpoints": {
"predict": {
"url": "/api/predict",
"method": "POST",
"description": "成本预测"
},
"analyze-features": {
"url": "/api/analyze-features",
"method": "POST",
"description": "特征分析"
},
"train": {
"url": "/api/train",
"method": "POST",
"description": "模型训练"
},
"evaluate": {
"url": "/api/evaluate",
"method": "POST",
"description": "模型评估"
}
}
}
### 2. 单模型预测
使用当前激活的最优模型进行成本预测。
- **URL**: `/predict`
- **方法**: `POST`
- **请求体示例** (巡飞弹):
```json
{
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
```
- **响应示例**:
```json
{
"predicted_cost": 150000.0,
"model_info": {
"type": "xgboost",
"name": "巡飞弹_20241111_model",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"confidence_interval": {
"lower": 135000.0,
"upper": 165000.0
}
}
```
### 3. PLS 模型预测
使用 PLS 回归模型进行预测。
- **URL**: `/pls/predict`
- **方法**: `POST`
- **请求体**: 与单模型预测相同
- **响应示例**:
```json
{
"predicted_cost": 148000.0,
"confidence_interval": {
"lower": 133000.0,
"upper": 163000.0
}
}
```
### 4. 多模型预测
使用所有激活的模型进行预测并返回综合结果。
- **URL**: `/predict/all`
- **方法**: `POST`
- **请求体**: 与单模型预测相同
- **响应示例**:
```json
{
"individual_predictions": {
"xgboost": {
"predicted_cost": 150000.0,
"model_info": {
"name": "巡飞弹_xgboost_model",
"type": "xgboost",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"confidence_interval": {
"lower": 135000.0,
"upper": 165000.0
}
},
"pls": {
"predicted_cost": 148000.0,
"model_info": {
"name": "巡飞弹_pls_model",
"type": "pls",
"r2_score": 0.92,
"mae": 5500.0,
"rmse": 8000.0
},
"confidence_interval": {
"lower": 133000.0,
"upper": 163000.0
}
}
},
"ensemble_prediction": {
"predicted_cost": 149000.0,
"standard_deviation": 1414.21,
"confidence_interval": {
"lower": 146228.15,
"upper": 151771.85
}
}
}
```
### 5. 特征分析
分析数据集中特征的重要性和相关性。
- **URL**: `/analyze-features`
- **方法**: `POST`
- **请求体示例**:
```json
{
"dataset_id": 1,
"equipment_type": "巡飞弹"
}
```
- **响应示例**:
```json
{
"important_features": [
{
"name": "最大射程(km)",
"importance": 0.35
},
{
"name": "重量(kg)",
"importance": 0.25
}
],
"correlation_analysis": {
"features": ["最大射程(km)", "重量(kg)"],
"matrix": [[1.0, 0.8], [0.8, 1.0]]
}
}
```
### 6. 模型训练
训练新的模型。
- **URL**: `/train`
- **方法**: `POST`
- **请求体示例**:
```json
{
"type": "巡飞弹",
"train_dataset_id": 1,
"validation_dataset_id": 2,
"models": ["xgboost", "lightgbm", "rf"]
}
```
- **响应示例**:
```json
{
"metrics": {
"xgboost": {
"train": {
"r2": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"validation": {
"r2": 0.92,
"mae": 5500.0,
"rmse": 8000.0
}
}
},
"best_model": {
"type": "xgboost",
"r2": 0.92,
"mae": 5500.0,
"rmse": 8000.0
}
}
```
### 7. 数据集管理
#### 7.1 获取数据集列表
- **URL**: `/datasets`
- **方法**: `GET`
- **响应示例**:
```json
[
{
"id": 1,
"name": "训练数据集",
"description": "用于训练的数据集",
"equipment_type": "巡飞弹",
"equipment_count": 10,
"equipment_names": ["设备1", "设备2"],
"purpose": "训练",
"created_at": "2024-11-11T10:00:00"
}
]
```
#### 7.2 获取数据集详情
- **URL**: `/datasets/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 1,
"name": "训练数据集",
"description": "用于训练的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"created_at": "2024-11-11T10:00:00",
"equipment": [
{
"id": 1,
"name": "设备1",
"type": "巡飞弹",
"manufacturer": "制造商1",
"actual_cost": 150000
}
],
"statistics": {
"equipment_count": 10,
"total_cost": 1500000,
"average_cost": 150000
}
}
```
#### 7.3 创建数据集
- **URL**: `/datasets`
- **方法**: `POST`
- **请求体示例**:
```json
{
"name": "测试数据集",
"description": "用于测试的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"equipment_ids": [1, 2, 3]
}
```
- **响应示例**:
```json
{
"id": 2,
"message": "数据集创建成功"
}
```
#### 7.4 更新数据集
- **URL**: `/datasets/{id}`
- **方法**: `PUT`
- **请求体示例**:
```json
{
"name": "更新后的数据集名称",
"description": "更新后的描述",
"equipment_type": "巡飞弹",
"purpose": "验证",
"equipment_ids": [1, 2, 3, 4]
}
```
- **响应示例**:
```json
{
"success": true,
"message": "数据集更新成功"
}
```
#### 7.5 删除数据集
- **URL**: `/datasets/{id}`
- **方法**: `DELETE`
- **描述**: 删除指定的数据集及其关联关系
- **响应示例**:
```json
{
"success": true,
"message": "数据集删除成功"
}
```
注意事项:
1. 数据集删除后不会删除关联的装备数据
2. 不能删除正在被模型使用的数据集
3. 更新数据集时会重新计算统计信息
4. 数据集的装备类型一旦创建后不能更改
### 8. 模型管理
#### 8.1 获取模型列表
- **URL**: `/models`
- **方法**: `GET`
- **响应示例**:
```json
[
{
"id": 1,
"model_name": "巡飞弹_xgboost_model",
"model_type": "xgboost",
"equipment_type": "巡飞弹",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0,
"is_active": true,
"training_date": "2024-11-11T10:00:00"
}
]
```
#### 8.2 获取最新模型
- **URL**: `/models/{equipment_type}/latest`
- **方法**: `GET`
- **响应示例**: 与模型列表的单个模型格式相同
#### 8.3 获取模型详情
- **URL**: `/models/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 1,
"model_name": "巡飞弹_xgboost_model",
"model_type": "xgboost",
"equipment_type": "巡飞弹",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0,
"is_active": true,
"training_date": "2024-11-11T10:00:00",
"feature_importance": {
"max_range_km": 0.35,
"weight_kg": 0.25,
"length_m": 0.20
},
"training_data_size": 100,
"created_by": "system"
}
```
#### 8.4 激活模型
- **URL**: `/models/{id}/activate`
- **方法**: `POST`
- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态
- **响应示例**:
```json
{
"success": true,
"message": "模型已激活"
}
```
#### 8.5 删除模型
- **URL**: `/models/{id}`
- **方法**: `DELETE`
- **描述**: 删除指定模型,包括模型文件和数据库记录
- **响应示例**:
```json
{
"success": true,
"message": "模型已删除"
}
```
注意事项:
1. 删除模型时会同时删除相关的文件和数据库记录
2. 不能删除当前正在使用(已激活)的模型
3. 激活模型时会自动取消同类型其他模型的激活状态
4. 模型详情包含了更多的训练相关信息,如特征重要性等
### 9. 数据管理
#### 9.1 获取装备数据列表
- **URL**: `/data`
- **方法**: `GET`
- **响应示例**:
```json
{
"rocket_artillery": [
{
"id": 1,
"name": "BM-21",
"type": "火箭炮",
"manufacturer": "俄罗斯",
"length_m": 7.35,
"width_m": 2.4,
"height_m": 3.1,
"weight_kg": 13700,
"max_range_km": 20.4,
"actual_cost": 800000
}
],
"loitering_munition": [
{
"id": 8,
"name": "Hero-120",
"type": "巡飞弹",
"manufacturer": "以色列",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"actual_cost": 150000
}
]
}
```
#### 9.2 获取装备详情
- **URL**: `/data/details/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 8,
"name": "Hero-120",
"type": "巡飞弹",
"manufacturer": "以色列",
"common_params": {
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40
},
"specific_params": {
"wingspan_m": 2.1,
"warhead_weight_kg": 3.5,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "箱式发射",
"power_system": "电动机",
"guidance_system": "GPS/INS"
},
"cost_data": {
"actual_cost": 150000,
"prediction_date": "2024-11-11T10:00:00",
"predicted_cost": 148000
},
"custom_params": [
{
"id": 1,
"param_name": "续航时间",
"param_value": "2小时",
"param_unit": "小时",
"description": "最大续航时间"
}
]
}
```
#### 9.3 更新装备数据
- **URL**: `/data/{id}`
- **方法**: `PUT`
- **请求体示例**:
```json
{
"name": "Hero-120",
"manufacturer": "以色列",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"wingspan_m": 2.1,
"warhead_weight_kg": 3.5,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"actual_cost": 150000,
"custom_params": [
{
"id": 1,
"param_value": "2.5小时"
}
]
}
```
- **响应示例**:
```json
{
"success": true,
"message": "装备数据更新成功"
}
```
#### 9.4 删除装备数据
- **URL**: `/data/{id}`
- **方法**: `DELETE`
- **响应示例**:
```json
{
"success": true,
"message": "装备数据删除成功"
}
```
#### 9.5 下载数据模板
- **URL**: `/data/template`
- **方法**: `GET`
- **描述**: 下载Excel格式的数据导入模板
- **响应**: Excel文件下载
#### 9.6 导入数据
- **URL**: `/data/import`
- **方法**: `POST`
- **请求体**:
- Content-Type: multipart/form-data
- 参数名: file
- 文件类型: .xlsx 或 .xls
- **响应示例**:
```json
{
"success": true,
"message": "数据导入成功",
"imported_count": {
"rocket_artillery": 3,
"loitering_munition": 5
}
}
```
注意事项:
1. 导入数据时必须使用系统提供的模板
2. 更新装备数据时会同时更新关联的参数表
3. 删除装备数据会同时删除相关的参数和成本数据
4. 导入的Excel文件大小不应超过10MB
5. 所有数值字段必须符合指定的单位和范围要求
6. 特殊参数的值必须包含单位信息
## 错误响应
所有接口在发生错误时都会返回以下格式的响应:
```json
{
"error": "错误描述信息"
}
```
## 注意事项
1. 所有数值参数必须大于0
2. 所有单位必须按照参数名称中指定的单位提供
3. 预测结果中的成本单位为元
4. 置信区间表示预测结果的95%置信水平范围
5. 所有请求和响应的编码均为 UTF-8

View File

@ -0,0 +1,120 @@
# 装备成本估算系统部署指南
## 一、系统要求
### 1. 基础软件
- Linux 操作系统 (推荐 Ubuntu 20.04+)
- Python 3.8+ 及相关组件
```bash
sudo apt update
sudo apt install python3 python3-pip python3-venv
sudo apt install python3-dev build-essential
```
- Node.js 14+ 及 npm
```bash
# 使用 nvm 安装 Node.js
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
source ~/.bashrc
nvm install 14
nvm use 14
```
### 2. 数据库
- MySQL 8.0+
```bash
sudo apt install mysql-server mysql-client
sudo apt install libmysqlclient-dev
```
### 3. Python包依赖
```bash
# 科学计算相关
sudo apt install libatlas-base-dev # numpy依赖
sudo apt install libopenblas-dev # 线性代数库
sudo apt install liblapack-dev # 线性代数包
sudo apt install gfortran # Fortran编译器(scipy依赖)
# XML处理相关(用于Excel文件处理)
sudo apt install libxml2-dev
sudo apt install libxslt1-dev
```
## 二、部署运行
### 1. 安装服务
```bash
sh scripts/install.sh
```
### 2. 启动服务
```bash
sh scripts/start.sh
```
### 3. 停止服务
```bash
sh scripts/stop.sh
```
## 三、维护说明
### 1. 日志管理
```bash
# 后端日志
tail -f logs/api.log
# 数据库日志
tail -f /var/log/mysql/error.log
```
## 四、安全建议
1. 系统安全
- 使用防火墙限制端口访问
- 定期更新系统和依赖包
2. 数据安全
- 定期备份数据库
- 加密敏感信息
- 限制数据库远程访问
3. 访问控制
- 使用强密码
- 配置适当的文件权限
- 使用非root用户运行服务
## 五、监控方案
### 1. 系统监控
```bash
# 资源使用
top -b -n 1
df -h
free -m
# 服务状态
ps aux | grep gunicorn
ps aux | grep node
```
### 2. 应用监控
```bash
# API 响应时间
curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
# 错误日志
grep "ERROR" logs/api.log
```

View File

@ -0,0 +1,5 @@
module.exports = {
presets: [
'@vue/cli-plugin-babel/preset'
]
}

View File

@ -0,0 +1,61 @@
{
"name": "frontend",
"version": "0.1.0",
"private": true,
"engines": {
"node": ">=16",
"npm": ">=8"
},
"scripts": {
"serve": "vue-cli-service serve",
"build": "vue-cli-service build",
"lint": "vue-cli-service lint"
},
"dependencies": {
"axios": "^1.6.0",
"core-js": "^3.8.3",
"echarts": "^5.4.3",
"element-plus": "^2.4.2",
"vue": "^3.2.13",
"vue-router": "^4.0.3",
"vuex": "^4.0.0"
},
"devDependencies": {
"@babel/core": "^7.12.16",
"@babel/eslint-parser": "^7.12.16",
"@element-plus/icons-vue": "^2.3.1",
"@vue/cli-plugin-babel": "~5.0.0",
"@vue/cli-plugin-eslint": "~5.0.0",
"@vue/cli-plugin-router": "~5.0.0",
"@vue/cli-plugin-vuex": "~5.0.0",
"@vue/cli-service": "~5.0.0",
"@vue/compiler-sfc": "^3.2.13",
"eslint": "^7.32.0",
"eslint-plugin-vue": "^8.0.3",
"sass": "^1.32.7",
"sass-loader": "^12.0.0"
},
"eslintConfig": {
"root": true,
"env": {
"node": true
},
"extends": [
"plugin:vue/vue3-essential",
"eslint:recommended"
],
"parserOptions": {
"parser": "@babel/eslint-parser"
},
"rules": {
"vue/multi-word-component-names": "off",
"no-unused-vars": "warn"
}
},
"browserslist": [
"> 1%",
"last 2 versions",
"not dead",
"not ie 11"
]
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

View File

@ -0,0 +1,17 @@
<!DOCTYPE html>
<html lang="">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
<title><%= htmlWebpackPlugin.options.title %></title>
</head>
<body>
<noscript>
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
</noscript>
<div id="app"></div>
<!-- built files will be auto injected -->
</body>
</html>

View File

@ -0,0 +1,42 @@
<template>
<el-container>
<el-header>
<el-menu
:router="true"
mode="horizontal"
:default-active="$route.path"
>
<el-menu-item index="/">首页</el-menu-item>
<el-menu-item index="/predict">成本预测</el-menu-item>
<el-menu-item index="/analysis">特征分析</el-menu-item>
<el-menu-item index="/training">模型训练</el-menu-item>
<el-menu-item index="/models">模型管理</el-menu-item>
<el-menu-item index="/datasets">数据集管理</el-menu-item>
<el-menu-item index="/data">数据管理</el-menu-item>
</el-menu>
</el-header>
<el-main>
<router-view v-slot="{ Component }">
<keep-alive>
<component :is="Component" :key="$route.fullPath" />
</keep-alive>
</router-view>
</el-main>
</el-container>
</template>
<style lang="scss" scoped>
.el-header {
padding: 0;
.el-menu {
border-bottom: none;
}
}
.el-main {
background-color: #f5f7fa;
min-height: calc(100vh - 60px);
}
</style>

View File

@ -0,0 +1,43 @@
import axios from 'axios'
import { API_BASE_URL } from '@/config'
const api = axios.create({
baseURL: API_BASE_URL,
timeout: 10000
})
export const predict = (data) => {
return api.post('/predict', data)
}
export const analyzeFeatures = (data) => {
return api.post('/analyze-features', data)
}
export const trainModel = (data) => {
return api.post('/train', data)
}
export const evaluateModel = (data) => {
return api.post('/evaluate', data)
}
export const importData = (formData) => {
return api.post('/data/import', formData, {
headers: {
'Content-Type': 'multipart/form-data'
}
})
}
export const getEquipmentData = () => {
return api.get('/data')
}
export const updateEquipment = (id, data) => {
return api.put(`/data/${id}`, data)
}
export const deleteEquipment = (id) => {
return api.delete(`/data/${id}`)
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.7 KiB

View File

@ -0,0 +1,39 @@
/* 禁用 ResizeObserver 警告 */
iframe[style*="position: fixed; top: 0px; left: 0px; width: 100%; height: 100%; border: none; z-index: 2147483647;"] {
display: none !important;
}
.el-overlay {
overflow: hidden !important;
}
/* 添加全局样式 */
body {
margin: 0;
padding: 0;
overflow-x: hidden;
}
/* 修复 Element Plus 的一些已知问题 */
.el-dialog__wrapper {
overflow: hidden !important;
}
.el-select-dropdown {
overflow: hidden !important;
}
/* 禁用 ResizeObserver 相关的警告样式 */
.resize-observer-warning {
display: none !important;
}
/* 优化滚动行为 */
* {
scroll-behavior: smooth;
}
/* 防止页面抖动 */
.el-main {
overflow-x: hidden;
}

View File

@ -0,0 +1,8 @@
export const API_BASE_URL = 'http://localhost:5001/api';
export const DB_CONFIG = {
host: 'localhost',
user: 'root',
password: '123456',
database: 'equipment_cost_db'
};

View File

@ -0,0 +1,55 @@
import { createApp } from 'vue'
import App from './App.vue'
import router from './router'
import store from './store'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import './assets/styles/global.css'
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
// 创建应用实例
const app = createApp(App)
// 注册插件
app.use(ElementPlus, {
size: 'default'
})
app.use(router)
app.use(store)
// 注册图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
app.component(key, component)
}
// 全局错误处理
app.config.errorHandler = (err) => {
if (err.message && err.message.includes('ResizeObserver')) {
return
}
console.error(err)
}
// 全局警告处理
app.config.warnHandler = (msg, trace) => {
if (msg.includes('ResizeObserver')) {
return
}
console.warn(msg, trace)
}
// 挂载应用
app.mount('#app')
// 处理 ResizeObserver 错误
const _ResizeObserver = window.ResizeObserver
window.ResizeObserver = class ResizeObserver extends _ResizeObserver {
constructor(callback) {
super((entries, observer) => {
requestAnimationFrame(() => {
if (!Array.isArray(entries)) return
callback(entries, observer)
})
})
}
}

View File

@ -0,0 +1,52 @@
import { createRouter, createWebHistory } from 'vue-router'
import HomePage from '@/views/HomePage.vue'
import DataPage from '@/views/DataPage.vue'
import DatasetPage from '@/views/DatasetPage.vue'
import PredictPage from '@/views/PredictPage.vue'
import AnalysisPage from '@/views/AnalysisPage.vue'
import TrainingPage from '@/views/TrainingPage.vue'
const routes = [
{
path: '/',
name: 'Home',
component: HomePage
},
{
path: '/data',
name: 'Data',
component: DataPage
},
{
path: '/datasets',
name: 'Datasets',
component: DatasetPage
},
{
path: '/predict',
name: 'Predict',
component: PredictPage
},
{
path: '/analysis',
name: 'Analysis',
component: AnalysisPage
},
{
path: '/training',
name: 'Training',
component: TrainingPage
},
{
path: '/models',
name: 'Models',
component: () => import('../views/ModelPage.vue')
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
export default router

View File

@ -0,0 +1,14 @@
import { createStore } from 'vuex'
export default createStore({
state: {
},
getters: {
},
mutations: {
},
actions: {
},
modules: {
}
})

View File

@ -0,0 +1,73 @@
// 处理 ResizeObserver 错误
const resizeHandler = () => {
const resizeObserverErrors = [
"ResizeObserver loop completed with undelivered notifications.",
"ResizeObserver loop limit exceeded",
"ResizeObserver loop completed"
]
// 添加全局错误处理
const handler = (event) => {
if (event && event.message && resizeObserverErrors.includes(event.message)) {
event.stopPropagation()
event.preventDefault()
event.stopImmediatePropagation()
return false
}
}
// 添加多个事件监听器
window.addEventListener('error', handler, true)
window.addEventListener('unhandledrejection', handler, true)
// 添加 ResizeObserver 错误处理
if (window.ResizeObserver) {
const resizeObserverPrototype = ResizeObserver.prototype
const originalObserve = resizeObserverPrototype.observe
resizeObserverPrototype.observe = function (...args) {
try {
return originalObserve.apply(this, args)
} catch (e) {
if (resizeObserverErrors.includes(e.message)) {
return null
}
throw e
}
}
}
}
export default {
install: (app) => {
resizeHandler()
// 添加全局错误处理器
app.config.errorHandler = (err, vm, info) => {
const resizeObserverErrors = [
"ResizeObserver loop completed with undelivered notifications.",
"ResizeObserver loop limit exceeded",
"ResizeObserver loop completed"
]
if (err && err.message && resizeObserverErrors.includes(err.message)) {
return
}
console.error('Vue Error:', err, info)
}
// 添加全局警告处理器
app.config.warnHandler = (msg, vm, trace) => {
const resizeObserverErrors = [
"ResizeObserver loop completed with undelivered notifications.",
"ResizeObserver loop limit exceeded",
"ResizeObserver loop completed"
]
if (resizeObserverErrors.includes(msg)) {
return
}
console.warn('Vue Warning:', msg, trace)
}
}
}

View File

@ -0,0 +1,304 @@
<template>
<div class="analysis-page">
<el-card class="analysis-card">
<template #header>
<div class="header-content">
<h2>特征分析</h2>
</div>
</template>
<!-- 数据集选择 -->
<div class="dataset-section">
<el-form :model="analysisForm" label-width="120px">
<el-form-item label="装备类型" required>
<el-select v-model="analysisForm.equipment_type" @change="handleEquipmentTypeChange">
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<el-form-item label="选择数据集" required>
<el-select v-model="analysisForm.dataset_id" @change="handleDatasetChange">
<el-option
v-for="dataset in availableDatasets"
:key="dataset.id"
:label="dataset.name"
:value="dataset.id"
></el-option>
</el-select>
</el-form-item>
</el-form>
<!-- 数据集信息 -->
<el-descriptions v-if="selectedDataset" :column="2" border>
<el-descriptions-item label="数据集名称">{{ selectedDataset.name }}</el-descriptions-item>
<el-descriptions-item label="装备数量">{{ selectedDataset.equipment_count }}</el-descriptions-item>
<el-descriptions-item label="描述" :span="2">{{ selectedDataset.description }}</el-descriptions-item>
</el-descriptions>
</div>
<!-- 分析按钮 -->
<div class="action-section">
<el-button type="primary" @click="startAnalysis" :loading="analyzing" :disabled="!analysisForm.dataset_id">
{{ analyzing ? '分析中...' : '开始分析' }}
</el-button>
</div>
<!-- 分析结果 -->
<div v-if="analysisResult" class="result-section">
<el-divider content-position="left">分析结果</el-divider>
<!-- 特征重要性 -->
<h3>特征重要性</h3>
<div class="chart-container">
<div ref="importanceChartRef" style="width: 100%; height: 400px"></div>
</div>
<!-- 相关性分析 -->
<h3>相关性分析</h3>
<div class="chart-container">
<div ref="correlationChartRef" style="width: 100%; height: 500px"></div>
</div>
</div>
</el-card>
</div>
</template>
<script setup>
import { ref, onMounted, watch, nextTick, onUnmounted } from 'vue'
import { ElMessage } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
import * as echarts from 'echarts'
//
const analysisForm = ref({
equipment_type: '',
dataset_id: null
})
const availableDatasets = ref([])
const selectedDataset = ref(null)
const analyzing = ref(false)
const analysisResult = ref(null)
const importanceChartRef = ref(null)
const correlationChartRef = ref(null)
//
const importanceChart = ref(null)
const correlationChart = ref(null)
//
watch(() => analysisResult.value, async (newResult) => {
if (newResult) {
console.log('Analysis result updated:', newResult)
// tick DOM
await nextTick()
//
setTimeout(() => {
renderCharts()
}, 100)
}
}, { deep: true })
//
const loadDatasets = async (type) => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets`, {
params: { equipment_type: type, purpose: '训练' }
})
availableDatasets.value = response.data
} catch (error) {
ElMessage.error('获取数据集列表失败')
}
}
//
const handleEquipmentTypeChange = () => {
analysisForm.value.dataset_id = null
selectedDataset.value = null
analysisResult.value = null
loadDatasets(analysisForm.value.equipment_type)
}
//
const handleDatasetChange = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets/${analysisForm.value.dataset_id}`)
selectedDataset.value = response.data
analysisResult.value = null
} catch (error) {
ElMessage.error('获取数据集详情失败')
}
}
//
const startAnalysis = async () => {
if (!analysisForm.value.dataset_id) {
ElMessage.warning('请先选择数据集')
return
}
analyzing.value = true
try {
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
analysisResult.value = response.data
console.log('Analysis completed, result:', analysisResult.value)
} catch (error) {
ElMessage.error('特征分析失败')
console.error('Analysis error:', error)
} finally {
analyzing.value = false
}
}
//
const renderCharts = () => {
console.log('Starting to render charts')
//
if (importanceChart.value) {
importanceChart.value.dispose()
}
if (correlationChart.value) {
correlationChart.value.dispose()
}
// DOM
if (!importanceChartRef.value || !correlationChartRef.value) {
console.log('Chart DOM elements not ready')
return
}
try {
//
importanceChart.value = echarts.init(importanceChartRef.value)
correlationChart.value = echarts.init(correlationChartRef.value)
//
importanceChart.value.setOption({
title: { text: '特征重要性排序' },
tooltip: {},
xAxis: {
type: 'value',
name: '重要性得分'
},
yAxis: {
type: 'category',
data: analysisResult.value.important_features.map(f => f.name)
},
series: [{
type: 'bar',
data: analysisResult.value.important_features.map(f => f.importance)
}]
})
correlationChart.value.setOption({
title: { text: '特征相关性热力图' },
tooltip: {
position: 'top',
formatter: function (params) {
const value = params.data[2].toFixed(2)
const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]]
const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]]
return `${feature1}${feature2} 的相关性: ${value}`
}
},
grid: {
height: '50%',
top: '10%'
},
xAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true },
axisLabel: {
interval: 0,
rotate: 45
}
},
yAxis: {
type: 'category',
data: analysisResult.value.correlation_analysis.features,
splitArea: { show: true }
},
visualMap: {
min: -1,
max: 1,
calculable: true,
orient: 'horizontal',
left: 'center',
bottom: '15%',
color: ['#cc3333', '#eeeeee', '#00007f']
},
series: [{
type: 'heatmap',
data: analysisResult.value.correlation_analysis.matrix,
label: {
show: true,
formatter: function(params) {
return params.data[2].toFixed(2)
}
}
}]
})
//
window.addEventListener('resize', () => {
importanceChart.value?.resize()
correlationChart.value?.resize()
})
console.log('Charts rendered successfully')
} catch (error) {
console.error('Error rendering charts:', error)
}
}
//
onMounted(() => {
//
})
//
onUnmounted(() => {
importanceChart.value?.dispose()
correlationChart.value?.dispose()
})
</script>
<style lang="scss" scoped>
.analysis-page {
padding: 20px;
.analysis-card {
.header-content {
h2 {
margin: 0;
}
}
.dataset-section {
margin-bottom: 20px;
}
.action-section {
margin: 20px 0;
text-align: center;
}
.result-section {
h3 {
margin: 20px 0;
}
.chart-container {
margin: 20px 0;
border: 1px solid #ebeef5;
border-radius: 4px;
}
}
}
}
</style>

View File

@ -0,0 +1,725 @@
<template>
<div class="data-page">
<el-card class="data-card">
<template #header>
<div class="header-content">
<h2>数据管理</h2>
<div class="header-buttons">
<el-upload
:action="null"
:auto-upload="false"
:show-file-list="false"
accept=".xls,.xlsx"
@change="handleFileChange"
>
<el-button type="primary">导入数据</el-button>
</el-upload>
<el-button type="primary" @click="downloadTemplate">下载模板</el-button>
</div>
</div>
</template>
<!-- 数据列表 -->
<el-tabs v-model="activeTab" @tab-click="handleTabClick">
<el-tab-pane label="火箭炮数据" name="rocket">
<!-- 搜索和过滤 -->
<div class="filter-section">
<el-input
v-model="searchQuery"
placeholder="搜索装备名称或制造商"
style="width: 200px; margin-right: 10px;"
/>
<el-select v-model="filterManufacturer" placeholder="制造商" clearable>
<el-option
v-for="item in manufacturers"
:key="item"
:label="item"
:value="item"
/>
</el-select>
</div>
<!-- 数据表格 -->
<el-table :data="filteredRocketData" border style="width: 100%">
<el-table-column prop="name" label="名称" sortable></el-table-column>
<el-table-column prop="manufacturer" label="制造商" sortable></el-table-column>
<el-table-column prop="length_m" label="总长(m)" sortable></el-table-column>
<el-table-column prop="weight_kg" label="重量(kg)" sortable></el-table-column>
<el-table-column prop="max_range_km" label="最大射程(km)" sortable></el-table-column>
<el-table-column prop="rocket_diameter_mm" label="口径(mm)" sortable></el-table-column>
<el-table-column prop="rate_of_fire" label="射速(发/分)" sortable></el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDetails(scope.row)">
详情
</el-button>
<el-button size="small" type="primary" @click="editData(scope.row)">
编辑
</el-button>
<el-button size="small" type="danger" @click="deleteData(scope.row)">
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="巡飞弹数据" name="missile">
<!-- 搜索和过滤 -->
<div class="filter-section">
<el-input
v-model="searchQuery"
placeholder="搜索装备名称或制造商"
style="width: 200px; margin-right: 10px;"
/>
<el-select v-model="filterManufacturer" placeholder="制造商" clearable>
<el-option
v-for="item in manufacturers"
:key="item"
:label="item"
:value="item"
/>
</el-select>
</div>
<!-- 数据表格 -->
<el-table :data="filteredMissileData" border style="width: 100%">
<el-table-column prop="name" label="名称" sortable></el-table-column>
<el-table-column prop="manufacturer" label="制造" sortable></el-table-column>
<el-table-column prop="length_m" label="弹长(m)" sortable></el-table-column>
<el-table-column prop="weight_kg" label="重量(kg)" sortable></el-table-column>
<el-table-column prop="max_range_km" label="最大射程(km)" sortable></el-table-column>
<el-table-column prop="max_speed_ms" label="最大速度(m/s)" sortable></el-table-column>
<el-table-column prop="flight_time_min" label="巡飞时间(min)" sortable></el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDetails(scope.row)">
详情
</el-button>
<el-button size="small" type="primary" @click="editData(scope.row)">
编辑
</el-button>
<el-button size="small" type="danger" @click="deleteData(scope.row)">
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
</el-tabs>
<!-- 详情对话框 -->
<el-dialog v-model="detailsVisible" title="装备详情" width="70%">
<el-descriptions :column="2" border>
<!-- 基本信息 -->
<template>
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">基本信息</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="名称">{{ selectedData?.name }}</el-descriptions-item>
<el-descriptions-item label="制造商">{{ selectedData?.manufacturer }}</el-descriptions-item>
<!-- 通用参数 -->
<el-descriptions-item label="总长(m)">{{ formatNumber(selectedData?.length_m) }}</el-descriptions-item>
<el-descriptions-item label="宽度(m)">{{ formatNumber(selectedData?.width_m) }}</el-descriptions-item>
<el-descriptions-item label="高度(m)">{{ formatNumber(selectedData?.height_m) }}</el-descriptions-item>
<el-descriptions-item label="重量(kg)">{{ formatNumber(selectedData?.weight_kg) }}</el-descriptions-item>
<el-descriptions-item label="最大射程(km)">{{ formatNumber(selectedData?.max_range_km) }}</el-descriptions-item>
</template>
<!-- 火箭炮特有参数 -->
<template v-if="selectedData?.type === '火箭炮'">
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">火箭炮参数</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="口径(mm)">{{ formatNumber(selectedData?.rocket_diameter_mm) }}</el-descriptions-item>
<el-descriptions-item label="射速(发/分)">{{ formatNumber(selectedData?.rate_of_fire) }}</el-descriptions-item>
<el-descriptions-item label="火箭弹长度(m)">{{ formatNumber(selectedData?.rocket_length_m) }}</el-descriptions-item>
<el-descriptions-item label="火箭弹重量(kg)">{{ formatNumber(selectedData?.rocket_weight_kg) }}</el-descriptions-item>
<el-descriptions-item label="方向射界(度)">{{ formatNumber(selectedData?.firing_angle_horizontal) }}</el-descriptions-item>
<el-descriptions-item label="高低射界(度)">{{ formatNumber(selectedData?.firing_angle_vertical) }}</el-descriptions-item>
<el-descriptions-item label="机动方式">{{ selectedData?.mobility_type }}</el-descriptions-item>
<el-descriptions-item label="结构布局">{{ selectedData?.structure_layout }}</el-descriptions-item>
<el-descriptions-item label="发动机型号">{{ selectedData?.engine_model }}</el-descriptions-item>
<el-descriptions-item label="发动机参数">{{ selectedData?.engine_params }}</el-descriptions-item>
<el-descriptions-item label="功率(hp)">{{ formatNumber(selectedData?.power_hp) }}</el-descriptions-item>
<el-descriptions-item label="行程(km)">{{ formatNumber(selectedData?.travel_range_km) }}</el-descriptions-item>
</template>
<!-- 巡飞弹特有参数 -->
<template v-if="selectedData?.type === '巡飞弹'">
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">巡飞弹参数</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="最大速度(m/s)">{{ formatNumber(selectedData?.max_speed_ms) }}</el-descriptions-item>
<el-descriptions-item label="巡航速度(km/h)">{{ formatNumber(selectedData?.cruise_speed_kmh) }}</el-descriptions-item>
<el-descriptions-item label="巡飞时间(min)">{{ formatNumber(selectedData?.flight_time_min) }}</el-descriptions-item>
<el-descriptions-item label="战斗部类型">{{ selectedData?.warhead_type }}</el-descriptions-item>
<el-descriptions-item label="发射方式">{{ selectedData?.launch_mode }}</el-descriptions-item>
<el-descriptions-item label="动力装置">{{ selectedData?.power_system }}</el-descriptions-item>
<el-descriptions-item label="制导体制">{{ selectedData?.guidance_system }}</el-descriptions-item>
</template>
<!-- 特殊参数 -->
<template v-if="selectedData?.custom_params?.length > 0">
{{ console.log('Rendering custom params:', selectedData.custom_params) }}
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">特殊参数</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item
v-for="param in selectedData.custom_params"
:key="param.id"
:label="param.param_name"
>
{{ formatCustomParamValue(param) }}
</el-descriptions-item>
</template>
<!-- 成本信息 -->
<template>
<el-descriptions-item :span="2">
<template #label>
<el-divider content-position="left">成本信息</el-divider>
</template>
</el-descriptions-item>
<el-descriptions-item label="实际成本(元)">
{{ formatMoney(selectedData?.actual_cost) }}
</el-descriptions-item>
<el-descriptions-item label="预测成本(元)">
{{ formatMoney(selectedData?.predicted_cost) }}
</el-descriptions-item>
<el-descriptions-item label="成本估算时间">
{{ selectedData?.cost_estimate_date || '-' }}
</el-descriptions-item>
</template>
</el-descriptions>
</el-dialog>
<!-- 编辑对话框 -->
<el-dialog v-model="editVisible" title="编辑装备数据" width="70%">
<el-form :model="editForm" label-width="120px">
<!-- 基本信息 -->
<template>
<el-divider content-position="left">基本信息</el-divider>
<el-form-item label="名称">
<el-input v-model="editForm.name"></el-input>
</el-form-item>
<el-form-item label="制造商">
<el-input v-model="editForm.manufacturer"></el-input>
</el-form-item>
<!-- 通用参数 -->
<el-form-item label="总长(m)">
<el-input-number v-model="editForm.length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="宽度(m)">
<el-input-number v-model="editForm.width_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="高度(m)">
<el-input-number v-model="editForm.height_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="重量(kg)">
<el-input-number v-model="editForm.weight_kg" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="最大射程(km)">
<el-input-number v-model="editForm.max_range_km" :precision="2"></el-input-number>
</el-form-item>
</template>
<!-- 火箭炮特有参数 -->
<template v-if="editForm.type === '火箭炮'">
<el-divider content-position="left">火箭炮参数</el-divider>
<el-form-item label="口径(mm)">
<el-input-number v-model="editForm.rocket_diameter_mm" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="射速(发/分)">
<el-input-number v-model="editForm.rate_of_fire" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹长度(m)">
<el-input-number v-model="editForm.rocket_length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹重量(kg)">
<el-input-number v-model="editForm.rocket_weight_kg" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="方向射界(度)">
<el-input-number v-model="editForm.firing_angle_horizontal" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="低射界(度)">
<el-input-number v-model="editForm.firing_angle_vertical" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="机动方式">
<el-select v-model="editForm.mobility_type">
<el-option v-for="option in getSelectOptions('mobility_type')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="结构布局">
<el-input v-model="editForm.structure_layout"></el-input>
</el-form-item>
<el-form-item label="发动型号">
<el-input v-model="editForm.engine_model"></el-input>
</el-form-item>
<el-form-item label="发动机参数">
<el-input v-model="editForm.engine_params"></el-input>
</el-form-item>
<el-form-item label="功率(hp)">
<el-input-number v-model="editForm.power_hp" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="行程(km)">
<el-input-number v-model="editForm.travel_range_km" :precision="2"></el-input-number>
</el-form-item>
</template>
<!-- 巡飞弹特有参数 -->
<template v-if="editForm.type === '巡飞弹'">
<el-divider content-position="left">巡飞弹参数</el-divider>
<el-form-item label="最大速度(m/s)">
<el-input-number v-model="editForm.max_speed_ms" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡航速度(km/h)">
<el-input-number v-model="editForm.cruise_speed_kmh" :precision="1"></el-input-number>
</el-form-item>
<el-form-item label="巡飞时间(min)">
<el-input-number v-model="editForm.flight_time_min" :precision="0"></el-input-number>
</el-form-item>
<el-form-item label="战斗部类型">
<el-select v-model="editForm.warhead_type">
<el-option v-for="option in getSelectOptions('warhead_type')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="发射方式">
<el-select v-model="editForm.launch_mode">
<el-option v-for="option in getSelectOptions('launch_mode')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="动力装置">
<el-select v-model="editForm.power_system">
<el-option v-for="option in getSelectOptions('power_system')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
<el-form-item label="制导体制">
<el-select v-model="editForm.guidance_system">
<el-option v-for="option in getSelectOptions('guidance_system')" :key="option" :label="option" :value="option"></el-option>
</el-select>
</el-form-item>
</template>
<!-- 特殊参数 -->
<template v-if="editForm.custom_params?.length">
<el-divider content-position="left">特殊参数</el-divider>
<el-form-item
v-for="(param, index) in editForm.custom_params"
:key="param.id"
:label="param.param_name"
>
<el-input-number
v-if="isNumericParam(param)"
v-model="editForm.custom_params[index].param_value"
:precision="2"
:step="0.1"
>
<template #append v-if="param.param_unit">
{{ param.param_unit }}
</template>
</el-input-number>
<el-input
v-else
v-model="editForm.custom_params[index].param_value"
>
<template #append v-if="param.param_unit">
{{ param.param_unit }}
</template>
</el-input>
</el-form-item>
</template>
<!-- 成本信息 -->
<el-divider content-position="left">成本信息</el-divider>
<el-form-item label="实际成本(元)">
<el-input-number
v-model="editForm.actual_cost"
:precision="0"
:min="0"
:step="1000"
:controls="true"
style="width: 200px"
></el-input-number>
</el-form-item>
<el-form-item label="预测成本(元)">
<el-input-number
v-model="editForm.predicted_cost"
:precision="0"
disabled
style="width: 200px"
></el-input-number>
</el-form-item>
<el-form-item label="成本估算时间">
<el-date-picker
v-model="editForm.cost_estimate_date"
type="datetime"
placeholder="选择日期时间"
format="YYYY-MM-DD HH:mm:ss"
value-format="YYYY-MM-DD HH:mm:ss"
disabled
></el-date-picker>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" @click="saveEdit">保存</el-button>
</span>
</template>
</el-dialog>
</el-card>
</div>
</template>
<script setup>
import { ref, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
const activeTab = ref('rocket')
const rocketData = ref([])
const missileData = ref([])
const detailsVisible = ref(false)
const editVisible = ref(false)
const selectedData = ref(null)
const editForm = ref(null)
const searchQuery = ref('')
const filterManufacturer = ref('')
//
const manufacturers = computed(() => {
const data = activeTab.value === 'rocket' ? rocketData.value : missileData.value
return [...new Set(data.map(item => item.manufacturer))]
})
//
const filteredRocketData = computed(() => {
return filterData(rocketData.value)
})
const filteredMissileData = computed(() => {
return filterData(missileData.value)
})
const filterData = (data) => {
return data.filter(item => {
const matchesSearch = searchQuery.value ?
(item.name.toLowerCase().includes(searchQuery.value.toLowerCase()) ||
item.manufacturer.toLowerCase().includes(searchQuery.value.toLowerCase())) : true
const matchesManufacturer = filterManufacturer.value ?
item.manufacturer === filterManufacturer.value : true
return matchesSearch && matchesManufacturer
})
}
//
const loadData = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/data`)
if (response.data.error) {
throw new Error(response.data.error)
}
//
rocketData.value = response.data.rocket_artillery.map(item => ({
...item,
custom_params: item.custom_params || []
}))
//
missileData.value = response.data.loitering_munition.map(item => ({
...item,
custom_params: item.custom_params || []
}))
} catch (error) {
ElMessage.error('加载数据失败')
console.error('Error loading data:', error)
}
}
//
const handleFileChange = async (file) => {
try {
const formData = new FormData()
formData.append('file', file.raw)
const response = await axios.post(
`${API_BASE_URL}/data/import`,
formData,
{
headers: {
'Content-Type': 'multipart/form-data'
}
}
)
if (response.data.success) {
ElMessage.success('数据导入成功')
loadData() //
} else {
throw new Error(response.data.error)
}
} catch (error) {
ElMessage.error(error.message || '数据导入失败')
}
}
//
const downloadTemplate = () => {
window.open(`${API_BASE_URL}/data/template`, '_blank')
}
//
const viewDetails = async (row) => {
try {
console.log('Requesting details for row:', row)
const response = await axios.get(`${API_BASE_URL}/data/details/${row.id}`)
if (response.data.error) {
throw new Error(response.data.error)
}
console.log('Details response:', response.data)
// custom_params
if (typeof response.data.custom_params === 'string') {
response.data.custom_params = JSON.parse(response.data.custom_params)
}
console.log('Parsed custom params:', response.data.custom_params)
selectedData.value = response.data
console.log('Selected data:', selectedData.value)
detailsVisible.value = true
} catch (error) {
ElMessage.error('获取详情失败')
console.error('Error getting details:', error)
}
}
//
const editData = async (row) => {
try {
console.log('Editing row:', row)
const response = await axios.get(`${API_BASE_URL}/data/details/${row.id}`)
if (response.data.error) {
throw new Error(response.data.error)
}
console.log('Edit data response:', response.data)
// custom_params
if (typeof response.data.custom_params === 'string') {
response.data.custom_params = JSON.parse(response.data.custom_params)
}
//
const data = { ...response.data }
Object.keys(data).forEach(key => {
if (isNumberInput(key) && data[key] !== null && data[key] !== undefined) {
data[key] = Number(data[key])
}
})
//
if (data.custom_params) {
data.custom_params = data.custom_params.map(param => ({
...param,
param_value: !isNaN(param.param_value) ? Number(param.param_value) : param.param_value
}))
}
console.log('Parsed custom params:', data.custom_params)
editForm.value = data
console.log('Edit form data:', editForm.value)
editVisible.value = true
} catch (error) {
ElMessage.error('获取编辑数据失败')
console.error('Error getting edit data:', error)
}
}
//
const saveEdit = async () => {
try {
//
const saveData = {
...editForm.value,
custom_params: editForm.value.custom_params.map(param => ({
id: param.id,
param_name: param.param_name,
param_value: param.param_value,
param_unit: param.param_unit
}))
}
const response = await axios.put(
`${API_BASE_URL}/data/${editForm.value.id}`,
saveData
)
if (response.data.success) {
ElMessage.success('保存成功')
editVisible.value = false
loadData() //
} else {
throw new Error(response.data.error)
}
} catch (error) {
ElMessage.error(error.message || '保存失败')
}
}
//
const deleteData = async (row) => {
try {
await ElMessageBox.confirm('确定要删除这条数据吗?', '警告', {
type: 'warning'
})
const response = await axios.delete(`${API_BASE_URL}/data/${row.id}`)
if (response.data.success) {
ElMessage.success('删除成功')
loadData() //
} else {
throw new Error(response.data.error)
}
} catch (error) {
if (error !== 'cancel') {
ElMessage.error(error.message || '删除失败')
}
}
}
//
const formatNumber = (value) => {
if (value === null || value === undefined) return '-'
return Number(value).toFixed(2)
}
//
const formatMoney = (value) => {
if (value === null || value === undefined) return '-'
return new Intl.NumberFormat('zh-CN', {
style: 'currency',
currency: 'CNY'
}).format(value)
}
//
const formatCustomParamValue = (param) => {
if (!param || !param.param_value) return '-'
let value = param.param_value
//
if (!isNaN(value)) {
value = Number(value).toFixed(2)
}
//
if (param.param_unit) {
value = `${value} ${param.param_unit}`
}
return value
}
//
const isNumericParam = (param) => {
return !isNaN(param.param_value) && param.param_unit !== undefined
}
// handleTabClick
const handleTabClick = () => {
//
searchQuery.value = ''
filterManufacturer.value = ''
}
//
onMounted(() => {
loadData()
})
//
const isNumberInput = (key) => {
const numberFields = [
'length_m',
'width_m',
'height_m',
'weight_kg',
'max_range_km',
'firing_angle_horizontal',
'firing_angle_vertical',
'rocket_length_m',
'rocket_diameter_mm',
'rocket_weight_kg',
'rate_of_fire',
'combat_weight_kg',
'speed_kmh',
'min_range_km',
'power_hp',
'travel_range_km',
'max_speed_ms',
'cruise_speed_kmh',
'flight_time_min',
'folded_length_mm',
'folded_width_mm',
'folded_height_mm',
'actual_cost',
'predicted_cost'
]
return numberFields.includes(key)
}
</script>
<style lang="scss" scoped>
.data-page {
padding: 20px;
.data-card {
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
h2 {
margin: 0;
}
.header-buttons {
display: flex;
gap: 10px;
}
}
}
.filter-section {
margin-bottom: 20px;
display: flex;
gap: 10px;
}
.el-table {
margin-top: 20px;
}
}
</style>

View File

@ -0,0 +1,322 @@
<template>
<div class="dataset-page">
<el-card class="dataset-card">
<template #header>
<div class="header-content">
<h2>数据集管理</h2>
<div class="header-buttons">
<el-button type="primary" @click="createDataset">创建数据集</el-button>
</div>
</div>
</template>
<!-- 数据集列表 -->
<el-table :data="datasets" border style="width: 100%">
<el-table-column prop="name" label="数据集名称"></el-table-column>
<el-table-column prop="equipment_type" label="装备类型"></el-table-column>
<el-table-column prop="purpose" label="用途"></el-table-column>
<el-table-column prop="description" label="描述"></el-table-column>
<el-table-column prop="equipment_count" label="装备数量"></el-table-column>
<el-table-column prop="created_at" label="创建时间">
<template #default="scope">
{{ formatDateTime(scope.row.created_at) }}
</template>
</el-table-column>
<el-table-column prop="updated_at" label="修改时间">
<template #default="scope">
{{ formatDateTime(scope.row.updated_at) }}
</template>
</el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button size="small" @click="viewDataset(scope.row)">查看</el-button>
<el-button size="small" type="primary" @click="editDataset(scope.row)">编辑</el-button>
<el-button size="small" type="danger" @click="deleteDataset(scope.row)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-card>
<!-- 数据集详情对话框 -->
<el-dialog v-model="detailsVisible" :title="selectedDataset?.name" width="70%">
<el-descriptions :column="2" border>
<el-descriptions-item label="装备类型">{{ selectedDataset?.equipment_type }}</el-descriptions-item>
<el-descriptions-item label="用途">{{ selectedDataset?.purpose }}</el-descriptions-item>
<el-descriptions-item label="描述" :span="2">{{ selectedDataset?.description }}</el-descriptions-item>
</el-descriptions>
<!-- 数据集统计信息 -->
<div style="margin-top: 20px">
<el-divider content-position="left">统计信息</el-divider>
<el-descriptions :column="2" border>
<el-descriptions-item label="装备数量">{{ selectedDataset?.statistics?.equipment_count || 0 }}</el-descriptions-item>
<el-descriptions-item label="平均成本">{{ formatMoney(selectedDataset?.statistics?.average_cost) }}</el-descriptions-item>
<el-descriptions-item label="总成本">{{ formatMoney(selectedDataset?.statistics?.total_cost) }}</el-descriptions-item>
</el-descriptions>
</div>
<!-- 包含的装备列表 -->
<div style="margin-top: 20px">
<el-divider content-position="left">包含装备</el-divider>
<el-table :data="selectedDataset?.equipment" border>
<el-table-column prop="name" label="名称"></el-table-column>
<el-table-column prop="manufacturer" label="制造商"></el-table-column>
<el-table-column prop="actual_cost" label="成本(元)">
<template #default="scope">
{{ formatMoney(scope.row.actual_cost) }}
</template>
</el-table-column>
<el-table-column label="操作" width="100">
<template #default="scope">
<el-button size="small" @click="viewEquipment(scope.row)">详情</el-button>
</template>
</el-table-column>
</el-table>
</div>
</el-dialog>
<!-- 创建/编辑数据集对话框 -->
<el-dialog v-model="editVisible" :title="datasetForm.id ? '编辑数据集' : '创建数据集'" width="70%">
<el-form :model="datasetForm" label-width="120px">
<el-form-item label="数据集名称" required>
<el-input v-model="datasetForm.name"></el-input>
</el-form-item>
<el-form-item label="装备类型" required>
<el-select v-model="datasetForm.equipment_type" @change="handleEquipmentTypeChange">
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<el-form-item label="用途" required>
<el-select v-model="datasetForm.purpose">
<el-option label="训练" value="训练"></el-option>
<el-option label="验证" value="验证"></el-option>
</el-select>
</el-form-item>
<el-form-item label="描述">
<el-input type="textarea" v-model="datasetForm.description"></el-input>
</el-form-item>
<!-- 选择装备数据 -->
<el-form-item label="选择装备" required>
<el-table
:data="availableEquipment"
border
@selection-change="handleSelectionChange"
:max-height="400"
>
<el-table-column type="selection" width="55"></el-table-column>
<el-table-column prop="name" label="名称"></el-table-column>
<el-table-column prop="manufacturer" label="制造商"></el-table-column>
<el-table-column prop="actual_cost" label="成本(元)">
<template #default="scope">
{{ formatMoney(scope.row.actual_cost) }}
</template>
</el-table-column>
</el-table>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click="editVisible = false">取消</el-button>
<el-button type="primary" @click="saveDataset">保存</el-button>
</span>
</template>
</el-dialog>
</div>
</template>
<script setup>
import { ref, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
//
const datasets = ref([]) //
const selectedDataset = ref(null) //
const detailsVisible = ref(false) //
const editVisible = ref(false) //
const availableEquipment = ref([]) //
const selectedEquipment = ref([]) //
//
const datasetForm = ref({
id: null,
name: '',
equipment_type: '',
purpose: '',
description: '',
selected_equipment: []
})
//
const loadDatasets = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets`)
datasets.value = response.data
} catch (error) {
ElMessage.error('获取数据集列表失败')
}
}
//
const createDataset = () => {
datasetForm.value = {
id: null,
name: '',
equipment_type: '',
purpose: '',
description: '',
selected_equipment: []
}
editVisible.value = true
}
//
const editDataset = async (dataset) => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets/${dataset.id}`)
datasetForm.value = response.data
await loadAvailableEquipment()
editVisible.value = true
} catch (error) {
ElMessage.error('获取数据集详情失败')
}
}
//
const viewDataset = async (dataset) => {
try {
const response = await axios.get(`${API_BASE_URL}/datasets/${dataset.id}`)
selectedDataset.value = response.data
detailsVisible.value = true
} catch (error) {
ElMessage.error('获取数据集详情失败')
}
}
//
const deleteDataset = async (dataset) => {
try {
await ElMessageBox.confirm('确定要删除这个数据集吗?', '警告', {
type: 'warning'
})
await axios.delete(`${API_BASE_URL}/datasets/${dataset.id}`)
ElMessage.success('删除成功')
loadDatasets()
} catch (error) {
if (error !== 'cancel') {
ElMessage.error('删除失败')
}
}
}
//
const loadAvailableEquipment = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/data`)
availableEquipment.value = datasetForm.value.equipment_type === '火箭炮'
? response.data.rocket_artillery
: response.data.loitering_munition
} catch (error) {
ElMessage.error('获取装备列表失败')
}
}
//
const handleEquipmentTypeChange = () => {
selectedEquipment.value = [] //
loadAvailableEquipment() //
}
//
const handleSelectionChange = (selection) => {
selectedEquipment.value = selection
}
//
const saveDataset = async () => {
try {
const data = {
...datasetForm.value,
equipment_ids: selectedEquipment.value.map(item => item.id)
}
if (data.id) {
await axios.put(`${API_BASE_URL}/datasets/${data.id}`, data)
} else {
await axios.post(`${API_BASE_URL}/datasets`, data)
}
ElMessage.success('保存成功')
editVisible.value = false
loadDatasets()
} catch (error) {
ElMessage.error('保存失败')
}
}
//
const viewEquipment = (equipment) => {
//
console.log('View equipment:', equipment)
}
//
const formatMoney = (value) => {
if (value === null || value === undefined) return '-'
return new Intl.NumberFormat('zh-CN', {
style: 'currency',
currency: 'CNY'
}).format(value)
}
//
const formatDateTime = (value) => {
if (!value) return '-'
const date = new Date(value)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
})
}
//
onMounted(() => {
loadDatasets()
})
</script>
<style lang="scss" scoped>
.dataset-page {
padding: 20px;
.dataset-card {
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
h2 {
margin: 0;
}
.header-buttons {
display: flex;
gap: 10px;
}
}
}
.el-table {
margin-top: 20px;
}
}
</style>

View File

@ -0,0 +1,101 @@
<template>
<div class="home-page">
<el-card class="welcome-card">
<template #header>
<h2>装备成本估算系统</h2>
</template>
<el-row :gutter="20">
<el-col :span="8">
<el-card @click="$router.push('/predict')">
<el-icon><Money /></el-icon>
<h3>成本预测</h3>
<p>基于机器学习和 PLS 回归模型的成本预测</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/analysis')">
<el-icon><DataAnalysis /></el-icon>
<h3>特征分析</h3>
<p>分析参数对成本的影响</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/training')">
<el-icon><Monitor /></el-icon>
<h3>模型训练</h3>
<p>训练和优化预测模型</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/models')">
<el-icon><Management /></el-icon>
<h3>模型管理</h3>
<p>管理训练好的模型</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/datasets')">
<el-icon><Collection /></el-icon>
<h3>数据集管理</h3>
<p>管理训练和验证数据集</p>
</el-card>
</el-col>
<el-col :span="8">
<el-card @click="$router.push('/data')">
<el-icon><Management /></el-icon>
<h3>数据管理</h3>
<p>管理装备数据和成本数据</p>
</el-card>
</el-col>
</el-row>
</el-card>
</div>
</template>
<script setup>
import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-plus/icons-vue'
</script>
<style lang="scss" scoped>
.home-page {
padding: 20px;
.welcome-card {
max-width: 1200px;
margin: 0 auto;
h2 {
text-align: center;
margin: 0;
}
}
.el-card {
text-align: center;
cursor: pointer;
transition: all 0.3s;
margin-bottom: 20px;
&:hover {
transform: translateY(-5px);
box-shadow: 0 2px 12px 0 rgba(0,0,0,.1);
}
.el-icon {
font-size: 48px;
color: #409EFF;
margin: 20px 0;
}
h3 {
margin: 10px 0;
font-size: 18px;
}
p {
color: #909399;
font-size: 14px;
}
}
}
</style>

View File

@ -0,0 +1,279 @@
<template>
<div class="model-page">
<el-card class="model-card">
<template #header>
<div class="header-content">
<h2>模型管理</h2>
</div>
</template>
<!-- 模型列表 -->
<el-table :data="modelList" border style="width: 100%">
<el-table-column prop="model_type" label="模型类型">
<template #default="scope">
{{ getModelName(scope.row.model_type) }}
</template>
</el-table-column>
<el-table-column prop="model_name" label="模型名称"></el-table-column>
<el-table-column prop="equipment_type" label="装备类型"></el-table-column>
<el-table-column prop="r2_score" label="R²分数">
<template #default="scope">
{{ scope.row.r2_score.toFixed(4) }}
</template>
</el-table-column>
<el-table-column prop="mae" label="MAE (元)">
<template #default="scope">
{{ scope.row.mae.toFixed(2) }}
</template>
</el-table-column>
<el-table-column prop="rmse" label="RMSE (元)">
<template #default="scope">
{{ scope.row.rmse.toFixed(2) }}
</template>
</el-table-column>
<el-table-column prop="training_date" label="训练时间">
<template #default="scope">
{{ formatDateTime(scope.row.training_date) }}
</template>
</el-table-column>
<el-table-column prop="is_active" label="状态">
<template #default="scope">
<el-tag :type="scope.row.is_active ? 'success' : 'info'">
{{ scope.row.is_active ? '使用中' : '未使用' }}
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作" width="200">
<template #default="scope">
<el-button
size="small"
type="primary"
:disabled="scope.row.is_active"
@click="activateModel(scope.row)"
>
激活
</el-button>
<el-button
size="small"
@click="viewDetails(scope.row)"
>
详情
</el-button>
<el-button
size="small"
type="danger"
:disabled="scope.row.is_active"
@click="deleteModel(scope.row)"
>
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-card>
<!-- 模型详情对话框 -->
<el-dialog v-model="detailsVisible" title="模型详情" width="70%">
<el-descriptions :column="2" border>
<el-descriptions-item label="模型名称">{{ selectedModel?.model_name }}</el-descriptions-item>
<el-descriptions-item label="模型类型">{{ formatModelType(selectedModel?.model_type) }}</el-descriptions-item>
<el-descriptions-item label="装备类型">{{ selectedModel?.equipment_type }}</el-descriptions-item>
<el-descriptions-item label="训练数据量">{{ selectedModel?.training_data_size }}</el-descriptions-item>
<el-descriptions-item label="训练时间">{{ formatDateTime(selectedModel?.training_date) }}</el-descriptions-item>
<el-descriptions-item label="状态">
<el-tag :type="selectedModel?.is_active ? 'success' : 'info'">
{{ selectedModel?.is_active ? '使用中' : '未使用' }}
</el-tag>
</el-descriptions-item>
</el-descriptions>
<!-- 评估指标 -->
<div style="margin-top: 20px">
<el-divider content-position="left">评估指标</el-divider>
<el-descriptions :column="3" border>
<el-descriptions-item label="R²分数">{{ selectedModel?.r2_score.toFixed(4) }}</el-descriptions-item>
<el-descriptions-item label="MAE">{{ selectedModel?.mae.toFixed(2) }} </el-descriptions-item>
<el-descriptions-item label="RMSE">{{ selectedModel?.rmse.toFixed(2) }} </el-descriptions-item>
</el-descriptions>
</div>
<!-- 特征重要性 -->
<div v-if="selectedModel?.feature_importance" style="margin-top: 20px">
<el-divider content-position="left">特征重要性</el-divider>
<div ref="importanceChartRef" style="width: 100%; height: 400px"></div>
</div>
</el-dialog>
</div>
</template>
<script setup>
import { ref, onMounted, watch, nextTick, onUnmounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
import * as echarts from 'echarts'
//
const modelList = ref([])
const selectedModel = ref(null)
const detailsVisible = ref(false)
const importanceChartRef = ref(null)
const importanceChart = ref(null)
//
const loadModels = async () => {
try {
const response = await axios.get(`${API_BASE_URL}/models`)
modelList.value = response.data
} catch (error) {
ElMessage.error('获取模型列表失败')
}
}
//
const activateModel = async (model) => {
try {
await ElMessageBox.confirm('确定要激活这个模型吗?', '提示', {
type: 'warning'
})
await axios.post(`${API_BASE_URL}/models/${model.id}/activate`)
ElMessage.success('模型激活成功')
loadModels()
} catch (error) {
if (error !== 'cancel') {
ElMessage.error('模型激活失败')
}
}
}
//
const deleteModel = async (model) => {
try {
await ElMessageBox.confirm('确定要删除这个模型吗?', '警告', {
type: 'warning'
})
await axios.delete(`${API_BASE_URL}/models/${model.id}`)
ElMessage.success('删除成功')
loadModels()
} catch (error) {
if (error !== 'cancel') {
ElMessage.error('删除失败')
}
}
}
//
const viewDetails = async (model) => {
selectedModel.value = model
detailsVisible.value = true
}
//
watch(() => detailsVisible.value, async (visible) => {
if (visible && selectedModel.value?.feature_importance) {
await nextTick()
renderImportanceChart()
}
})
//
const renderImportanceChart = () => {
if (importanceChart.value) {
importanceChart.value.dispose()
}
importanceChart.value = echarts.init(importanceChartRef.value)
const featureImportance = JSON.parse(selectedModel.value.feature_importance)
const features = Object.keys(featureImportance)
const values = Object.values(featureImportance)
importanceChart.value.setOption({
title: { text: '特征重要性' },
tooltip: {},
xAxis: {
type: 'value',
name: '重要性得分'
},
yAxis: {
type: 'category',
data: features
},
series: [{
type: 'bar',
data: values
}]
})
}
//
const formatModelType = (type) => {
const typeMap = {
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbdt': 'GBDT',
'rf': 'Random Forest'
}
return typeMap[type] || type
}
//
const formatDateTime = (value) => {
if (!value) return '-'
const date = new Date(value)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
})
}
//
onUnmounted(() => {
importanceChart.value?.dispose()
})
//
onMounted(() => {
loadModels()
})
const getModelName = (modelType) => {
const modelNames = {
'pls': 'PLS回归',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
}
return modelNames[modelType] || modelType
}
</script>
<style lang="scss" scoped>
.model-page {
padding: 20px;
.model-card {
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
h2 {
margin: 0;
}
}
}
.el-table {
margin-top: 20px;
}
}
</style>

View File

@ -0,0 +1,321 @@
<template>
<div class="predict-page">
<el-card class="predict-card">
<template #header>
<h2>成本预测</h2>
</template>
<el-form :model="formData" label-width="120px">
<!-- 装备类型选择 -->
<el-form-item label="装备类型">
<el-select v-model="formData.type" @change="handleTypeChange">
<el-option label="火箭炮" value="火箭炮"></el-option>
<el-option label="巡飞弹" value="巡飞弹"></el-option>
</el-select>
</el-form-item>
<!-- 通用参数 -->
<el-form-item label="总长(m)">
<el-input-number v-model="formData.length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="宽度(m)">
<el-input-number v-model="formData.width_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="高度(m)">
<el-input-number v-model="formData.height_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="重量(kg)">
<el-input-number v-model="formData.weight_kg"></el-input-number>
</el-form-item>
<el-form-item label="最大射程(km)">
<el-input-number v-model="formData.max_range_km"></el-input-number>
</el-form-item>
<!-- 火箭炮特有参数 -->
<template v-if="formData.type === '火箭炮'">
<el-form-item label="方向射界(度)">
<el-input-number v-model="formData.firing_angle_horizontal"></el-input-number>
</el-form-item>
<el-form-item label="高低射界(度)">
<el-input-number v-model="formData.firing_angle_vertical"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹长度(m)">
<el-input-number v-model="formData.rocket_length_m" :precision="2"></el-input-number>
</el-form-item>
<el-form-item label="弹体直径(mm)">
<el-input-number v-model="formData.rocket_diameter_mm"></el-input-number>
</el-form-item>
<el-form-item label="火箭弹重量(kg)">
<el-input-number v-model="formData.rocket_weight_kg"></el-input-number>
</el-form-item>
<el-form-item label="射速(发/分钟)">
<el-input-number v-model="formData.rate_of_fire"></el-input-number>
</el-form-item>
</template>
<!-- 巡飞弹特有参数 -->
<template v-if="formData.type === '巡飞弹'">
<el-form-item label="最大速度(km/h)">
<el-input-number v-model="formData.max_speed_kmh"></el-input-number>
</el-form-item>
<el-form-item label="巡航速度(km/h)">
<el-input-number v-model="formData.cruise_speed_kmh"></el-input-number>
</el-form-item>
<el-form-item label="巡飞时间(min)">
<el-input-number v-model="formData.flight_time_min"></el-input-number>
</el-form-item>
<el-form-item label="战斗部类型">
<el-select v-model="formData.warhead_type">
<el-option label="破片杀伤战斗部" value="破片杀伤战斗部"></el-option>
<el-option label="破甲战斗部" value="破甲战斗部"></el-option>
<el-option label="高爆战斗部" value="高爆战斗部"></el-option>
</el-select>
</el-form-item>
<el-form-item label="发射方式">
<el-select v-model="formData.launch_mode">
<el-option label="箱式发射" value="箱式发射"></el-option>
<el-option label="凭自身动力起飞" value="凭自身动力起飞"></el-option>
</el-select>
</el-form-item>
<el-form-item label="折叠长度(mm)">
<el-input-number v-model="formData.folded_length_mm"></el-input-number>
</el-form-item>
<el-form-item label="折叠宽度(mm)">
<el-input-number v-model="formData.folded_width_mm"></el-input-number>
</el-form-item>
<el-form-item label="折叠高度(mm)">
<el-input-number v-model="formData.folded_height_mm"></el-input-number>
</el-form-item>
</template>
<el-form-item>
<el-button type="primary" @click="submitForm">预测成本</el-button>
<el-button @click="resetForm">重置</el-button>
</el-form-item>
</el-form>
<!-- 预测结果 -->
<div v-if="predictionResults" class="prediction-results">
<h3>预测结果</h3>
<!-- 机器学习模型预测结果 -->
<div class="ml-prediction">
<h4>机器学习模型预测</h4>
<el-descriptions border>
<el-descriptions-item label="模型类型">
{{ getModelName(mlPrediction.model_info.type) }}
</el-descriptions-item>
<el-descriptions-item label="模型名称">
{{ mlPrediction.model_info.name }}
</el-descriptions-item>
<el-descriptions-item label="预测成本">
{{ formatMoney(mlPrediction.predicted_cost) }}
</el-descriptions-item>
<el-descriptions-item label="置信区间">
{{ formatMoney(mlPrediction.confidence_interval.lower) }} ~
{{ formatMoney(mlPrediction.confidence_interval.upper) }}
</el-descriptions-item>
</el-descriptions>
</div>
<!-- PLS回归预测结果 -->
<div class="pls-prediction">
<h4>PLS回归预测</h4>
<el-descriptions border>
<el-descriptions-item label="模型类型">
{{ getModelName(plsPrediction.model_info.type) }}
</el-descriptions-item>
<el-descriptions-item label="模型名称">
{{ plsPrediction.model_info.name }}
</el-descriptions-item>
<el-descriptions-item label="预测成本">
{{ formatMoney(plsPrediction.predicted_cost) }}
</el-descriptions-item>
<el-descriptions-item label="置信区间">
{{ formatMoney(plsPrediction.confidence_interval.lower) }} ~
{{ formatMoney(plsPrediction.confidence_interval.upper) }}
</el-descriptions-item>
</el-descriptions>
</div>
</div>
</el-card>
</div>
</template>
<script setup>
import { ref, reactive } from 'vue'
import { ElMessage } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
const formData = reactive({
type: '',
length_m: null,
width_m: null,
height_m: null,
weight_kg: null,
max_range_km: null
})
const predictionResults = ref(null)
const mlPrediction = ref(null)
const plsPrediction = ref(null)
const handleTypeChange = () => {
//
if (formData.type === '火箭炮') {
formData.firing_angle_horizontal = null
formData.firing_angle_vertical = null
formData.rocket_length_m = null
formData.rocket_diameter_mm = null
formData.rocket_weight_kg = null
formData.rate_of_fire = null
} else if (formData.type === '巡飞弹') {
formData.max_speed_kmh = null
formData.cruise_speed_kmh = null
formData.flight_time_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.folded_length_mm = null
formData.folded_width_mm = null
formData.folded_height_mm = null
}
}
const submitForm = async () => {
try {
//
if (!formData.type) {
throw new Error('请选择装备类型')
}
//
const commonFields = ['length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km']
for (const field of commonFields) {
if (!formData[field]) {
throw new Error(`请输入${formatFieldName(field)}`)
}
}
//
if (formData.type === '火箭炮') {
const rocketFields = [
'firing_angle_horizontal', 'firing_angle_vertical',
'rocket_length_m', 'rocket_diameter_mm', 'rocket_weight_kg', 'rate_of_fire'
]
for (const field of rocketFields) {
if (!formData[field]) {
throw new Error(`请输入${formatFieldName(field)}`)
}
}
} else if (formData.type === '巡飞弹') {
const missileFields = [
'max_speed_kmh', 'cruise_speed_kmh', 'flight_time_min',
'folded_length_mm', 'folded_width_mm', 'folded_height_mm'
]
for (const field of missileFields) {
if (!formData[field]) {
throw new Error(`请输入${formatFieldName(field)}`)
}
}
}
//
const [mlResponse, plsResponse] = await Promise.all([
axios.post(`${API_BASE_URL}/predict`, formData),
axios.post(`${API_BASE_URL}/pls/predict`, formData)
])
mlPrediction.value = mlResponse.data
plsPrediction.value = plsResponse.data
predictionResults.value = true
} catch (error) {
ElMessage.error(error.message || '预测失败')
}
}
const resetForm = () => {
formData.type = ''
formData.length_m = null
formData.width_m = null
formData.height_m = null
formData.weight_kg = null
formData.max_range_km = null
predictionResults.value = null
mlPrediction.value = null
plsPrediction.value = null
}
const formatFieldName = (field) => {
const nameMap = {
'length_m': '总长',
'width_m': '宽度',
'height_m': '高度',
'weight_kg': '重量',
'max_range_km': '最大射程',
'firing_angle_horizontal': '方向射界',
'firing_angle_vertical': '高低射界',
'rocket_length_m': '火箭弹长度',
'rocket_diameter_mm': '弹体直径',
'rocket_weight_kg': '火箭弹重量',
'rate_of_fire': '射速',
'max_speed_kmh': '最大速度',
'cruise_speed_kmh': '巡航速度',
'flight_time_min': '巡飞时间',
'folded_length_mm': '折叠长度',
'folded_width_mm': '折叠宽度',
'folded_height_mm': '折叠高度'
}
return nameMap[field] || field
}
const formatMoney = (value) => {
return new Intl.NumberFormat('zh-CN', {
style: 'currency',
currency: 'CNY'
}).format(value)
}
const getModelName = (modelType) => {
const modelNames = {
'pls': 'PLS回归',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
}
return modelNames[modelType] || modelType
}
</script>
<style scoped>
.predict-page {
padding: 20px;
}
.predict-card {
max-width: 800px;
margin: 0 auto;
}
.prediction-results {
margin-top: 20px;
.ml-prediction, .pls-prediction {
margin-top: 20px;
padding: 20px;
background-color: #f5f7fa;
border-radius: 4px;
}
h4 {
margin-top: 0;
margin-bottom: 15px;
}
}
.el-descriptions {
margin-top: 10px;
}
</style>

View File

@ -0,0 +1,370 @@
<template>
<div class="training-page">
<el-card class="training-card">
<template #header>
<h2>模型训练</h2>
</template>
<!-- 训练配置 -->
<el-form :model="trainingConfig" label-width="120px">
<el-form-item label="装备类型">
<el-select v-model="trainingConfig.type" placeholder="选择装备类型">
<el-option label="火箭炮" value="火箭炮" />
<el-option label="巡飞弹" value="巡飞弹" />
</el-select>
</el-form-item>
<el-form-item label="训练数据集">
<el-select v-model="trainingConfig.train_dataset_id" placeholder="选择训练数据集">
<el-option
v-for="dataset in trainingDatasets"
:key="dataset.id"
:label="dataset.name"
:value="dataset.id"
/>
</el-select>
</el-form-item>
<el-form-item label="验证数据集">
<el-select v-model="trainingConfig.validation_dataset_id" placeholder="选择验证数据集">
<el-option
v-for="dataset in validationDatasets"
:key="dataset.id"
:label="dataset.name"
:value="dataset.id"
/>
</el-select>
</el-form-item>
<el-form-item label="选择模型">
<el-checkbox-group v-model="trainingConfig.models">
<el-checkbox label="pls" disabled>PLS回归</el-checkbox>
<el-checkbox label="xgboost" checked>XGBoost</el-checkbox>
<el-checkbox label="lightgbm" checked>LightGBM</el-checkbox>
<el-checkbox label="gbm" checked>GBM</el-checkbox>
<el-checkbox label="rf" checked>Random Forest</el-checkbox>
</el-checkbox-group>
</el-form-item>
<el-form-item>
<el-button type="primary" @click="startTraining" :loading="isTraining">
开始训练
</el-button>
</el-form-item>
</el-form>
<!-- 训练结果 -->
<div v-if="trainingResult" class="training-result">
<h3>训练结果</h3>
<!-- 最佳模型信息 -->
<div class="best-model-info" v-if="trainingResult.best_model">
<h4>最佳模型: {{ getModelName(trainingResult.best_model.type) }}</h4>
<p>R²分数: {{ formatNumber(trainingResult.best_model.r2) }}</p>
<p>MAE: {{ formatNumber(trainingResult.best_model.mae) }} </p>
<p>RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} </p>
</div>
<!-- 所有模型评估结果 -->
<el-table :data="modelResults" border style="width: 100%; margin-top: 20px;">
<el-table-column prop="model" label="模型" width="120">
<template #default="scope">
{{ getModelName(scope.row.model) }}
</template>
</el-table-column>
<!-- 训练集评估 -->
<el-table-column label="训练集评估">
<el-table-column prop="train.r2" label="R²分数" width="120">
<template #default="scope">
{{ formatNumber(scope.row.train.r2) }}
</template>
</el-table-column>
<el-table-column prop="train.mae" label="MAE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.train.mae) }}
</template>
</el-table-column>
<el-table-column prop="train.rmse" label="RMSE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.train.rmse) }}
</template>
</el-table-column>
</el-table-column>
<!-- 验证集评估 -->
<el-table-column label="验证集评估">
<el-table-column prop="validation.r2" label="R²分数" width="120">
<template #default="scope">
{{ formatNumber(scope.row.validation.r2) }}
</template>
</el-table-column>
<el-table-column prop="validation.mae" label="MAE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.validation.mae) }}
</template>
</el-table-column>
<el-table-column prop="validation.rmse" label="RMSE (元)" width="150">
<template #default="scope">
{{ formatNumber(scope.row.validation.rmse) }}
</template>
</el-table-column>
</el-table-column>
</el-table>
<!-- 特征重要性 -->
<div v-if="trainingResult.feature_importance" class="feature-importance">
<h4>特征重要性</h4>
<el-table
:data="featureImportanceData"
border
style="width: 100%; margin-top: 10px;"
>
<el-table-column prop="feature" label="特征" width="180" />
<el-table-column prop="importance" label="重要性" width="120">
<template #default="scope">
{{ formatNumber(scope.row.importance) }}
</template>
</el-table-column>
</el-table>
</div>
</div>
</el-card>
</div>
</template>
<script setup>
import { ref, computed, onMounted, watch } from 'vue'
import { ElMessage } from 'element-plus'
import axios from 'axios'
import { API_BASE_URL } from '@/config'
//
const trainingConfig = ref({
type: '',
train_dataset_id: null,
validation_dataset_id: null,
models: ['pls']
})
//
const trainingDatasets = ref([])
const validationDatasets = ref([])
//
const isTraining = ref(false)
const trainingResult = ref(null)
//
const loadDatasets = async () => {
try {
//
const trainResponse = await axios.get(
`${API_BASE_URL}/datasets`,
{ params: { equipment_type: trainingConfig.value.type, purpose: '训练' } }
)
trainingDatasets.value = trainResponse.data
//
const valResponse = await axios.get(
`${API_BASE_URL}/datasets`,
{ params: { equipment_type: trainingConfig.value.type, purpose: '验证' } }
)
validationDatasets.value = valResponse.data
} catch (error) {
ElMessage.error('加载数据集失败')
console.error('Error loading datasets:', error)
}
}
//
watch(() => trainingConfig.value.type, (newType) => {
if (newType) {
loadDatasets()
}
})
//
const startTraining = async () => {
try {
//
if (!trainingConfig.value.type) {
ElMessage.warning('请选择装备类型')
return
}
if (!trainingConfig.value.train_dataset_id) {
ElMessage.warning('请选择训练数据集')
return
}
if (!trainingConfig.value.validation_dataset_id) {
ElMessage.warning('请选择验证数据集')
return
}
if (trainingConfig.value.models.length === 0) {
ElMessage.warning('请至少选择一个模型')
return
}
isTraining.value = true
//
const response = await axios.post(`${API_BASE_URL}/train`, trainingConfig.value)
if (response.data.error) {
throw new Error(response.data.error)
}
trainingResult.value = response.data
ElMessage.success('训练完成')
} catch (error) {
ElMessage.error(error.message || '训练失败')
console.error('Training error:', error)
} finally {
isTraining.value = false
}
}
//
const formatNumber = (value) => {
if (value === null || value === undefined) return '-'
if (typeof value === 'number') {
if (Math.abs(value) >= 1000) {
return value.toLocaleString('zh-CN', { maximumFractionDigits: 2 })
}
return value.toFixed(4)
}
return value
}
//
const getModelName = (modelType) => {
const modelNames = {
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
}
return modelNames[modelType] || modelType
}
//
const modelResults = computed(() => {
if (!trainingResult.value?.metrics) return []
return Object.entries(trainingResult.value.metrics).map(([model, metrics]) => ({
model,
train: metrics.train,
validation: metrics.validation
}))
})
//
const featureNameMap = {
//
'length_m': '总长(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
//
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
'rocket_diameter_mm': '口径(mm)',
'rocket_weight_kg': '火箭弹重量(kg)',
'rate_of_fire': '射速(发/分)',
'combat_weight_kg': '战斗重量(kg)',
'speed_kmh': '速度(km/h)',
'min_range_km': '最小射程(km)',
'power_hp': '功率(hp)',
//
'fire_density': '火力密度',
'mobility_index': '机动性指标',
'range_ratio': '射程比',
'power_weight_ratio': '功重比',
'volume_density': '体积密度',
//
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
'flight_time_min': '巡飞时间(min)',
'folded_length_mm': '折叠长度(mm)',
'folded_width_mm': '折叠宽度(mm)',
'folded_height_mm': '折叠高度(mm)',
//
'warhead_ratio': '战斗部比重',
'speed_ratio': '速度比',
'range_time_ratio': '射程时间比',
'aspect_ratio': '展弦比'
}
//
const featureImportanceData = computed(() => {
if (!trainingResult.value?.feature_importance || !trainingResult.value?.feature_names) return []
//
const data = trainingResult.value.feature_importance
.map((importance, index) => ({
feature: featureNameMap[trainingResult.value.feature_names[index]] || trainingResult.value.feature_names[index],
importance
}))
// 0
.filter(item => item.importance > 0)
//
.sort((a, b) => b.importance - a.importance)
return data
})
//
onMounted(() => {
if (trainingConfig.value.type) {
loadDatasets()
}
})
</script>
<style lang="scss" scoped>
.training-page {
padding: 20px;
.training-card {
.training-result {
margin-top: 20px;
.best-model-info {
background-color: #f5f7fa;
padding: 15px;
border-radius: 4px;
margin-bottom: 20px;
}
.feature-importance {
margin-top: 20px;
.importance-bar {
width: 100%;
background-color: #f5f7fa;
border-radius: 4px;
.importance-value {
background-color: #409eff;
color: white;
padding: 4px 8px;
border-radius: 4px;
text-align: right;
transition: width 0.3s ease;
}
}
}
}
}
}
</style>

View File

@ -0,0 +1,5 @@
module.exports = {
devServer: {
port: 8080
}
}

View File

@ -0,0 +1,12 @@
flask==2.0.1
flask-cors==3.0.10
sqlalchemy==1.4.23
pymysql==1.0.2
cryptography==3.4.7 # MySQL 8.0+ 认证需要
numpy==1.21.2
pandas==1.3.3
scikit-learn==0.24.2
tensorflow==2.6.0
urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
openpyxl==3.1.2 # 用于读取 .xlsx 文件
xlrd==2.0.1 # 用于读取 .xls 文件

View File

@ -0,0 +1,60 @@
#!/bin/bash
echo "开始安装装备成本估算系统..."
# 检查Python版本
python3 -V || {
echo "错误: 需要 Python 3.8+"
exit 1
}
# 检查Node.js版本
node -v || {
echo "错误: 需要 Node.js 14+"
exit 1
}
# 创建必要的目录
echo "创建系统目录..."
mkdir -p {logs,data,models}
# 安装后端依赖
echo "安装后端依赖..."
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
# 安装前端依赖
echo "安装前端依赖..."
cd frontend
npm install
npm run build
cd ..
# 配置文件
if [ ! -f config/.env ]; then
echo "创建配置文件..."
cp config/.env.template config/.env
echo "请修改 config/.env 中的配置"
fi
# 初始化数据库
echo "初始化数据库..."
read -p "请输入MySQL root密码: " mysqlpass
mysql -u root -p$mysqlpass < src/schema.sql
# 导入测试数据(可选)
read -p "是否导入测试数据?(y/n) " import_test_data
if [ "$import_test_data" = "y" ]; then
mysql -u root -p$mysqlpass equipment_cost_db < src/init_data.sql
fi
# 设置权限
echo "设置文件权限..."
chmod +x scripts/*.sh
chmod 755 logs models data
chmod 600 config/.env
echo "安装完成!"
echo "请检查并修改 config/.env 中的配置"
echo "使用 ./scripts/start.sh 启动服务"

View File

@ -0,0 +1,51 @@
#!/bin/bash
echo "启动装备成本估算系统..."
# 检查配置文件
if [ ! -f config/.env ]; then
echo "错误: 配置文件不存在"
echo "请先运行 install.sh"
exit 1
fi
# 检查日志目录
if [ ! -d logs ]; then
mkdir -p logs
fi
# 激活虚拟环境
source venv/bin/activate
# 导出环境变量
export $(cat config/.env | xargs)
# 启动后端服务
echo "启动后端服务..."
gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" \
--daemon \
--pid gunicorn.pid \
--access-logfile logs/access.log \
--error-logfile logs/error.log
# 等待后端服务启动
sleep 2
# 检查后端服务是否成功启动
if ! curl -s http://localhost:5001/api/ > /dev/null; then
echo "错误: 后端服务启动失败"
exit 1
fi
# 启动前端服务
echo "启动前端服务..."
cd frontend
npm run serve -- --port 8080 --host 0.0.0.0 &
echo $! > ../frontend.pid
cd ..
echo "服务已启动!"
echo "后端API: http://localhost:5001"
echo "前端界面: http://localhost:8080"
echo "查看后端日志: tail -f logs/access.log"
echo "查看前端日志: tail -f logs/frontend.log"

View File

@ -0,0 +1,44 @@
#!/bin/bash
echo "停止装备成本估算系统..."
# 停止后端服务
if [ -f gunicorn.pid ]; then
echo "停止后端服务..."
pid=$(cat gunicorn.pid)
kill -TERM $pid
rm gunicorn.pid
echo "后端服务已停止 (PID: $pid)"
else
echo "未找到后端服务PID文件"
# 尝试查找并停止所有gunicorn进程
pkill -f gunicorn
echo "已尝试停止所有gunicorn进程"
fi
# 停止前端服务
if [ -f frontend.pid ]; then
echo "停止前端服务..."
pid=$(cat frontend.pid)
kill -TERM $pid
rm frontend.pid
echo "前端服务已停止 (PID: $pid)"
else
echo "未找到前端服务PID文件"
# 尝试查找并停止前端服务进程
pkill -f "vite preview"
echo "已尝试停止所有前端服务进程"
fi
# 检查是否还有相关进程在运行
if pgrep -f gunicorn > /dev/null; then
echo "警告: 仍有gunicorn进程在运行"
ps aux | grep gunicorn | grep -v grep
fi
if pgrep -f "vite preview" > /dev/null; then
echo "警告: 仍有前端服务进程在运行"
ps aux | grep "vite preview" | grep -v grep
fi
echo "所有服务已停止"

View File

@ -0,0 +1 @@
# 这个文件可以为空,但必须存在

View File

@ -0,0 +1,50 @@
from flask import Flask
from flask_cors import CORS
from .routes import api_bp
from .logger import setup_logger
import os
# 获取logger
logger = setup_logger(__name__)
def create_app():
"""
创建并配置Flask应用
"""
try:
# 创建必要的目录
os.makedirs('logs', exist_ok=True)
os.makedirs('data', exist_ok=True)
os.makedirs('models', exist_ok=True)
logger.info("=== Server Starting ===")
logger.info("Initializing directories...")
# 创建Flask应用
app = Flask(__name__)
# 配置CORS
CORS(app)
logger.info("CORS enabled")
# 注册API蓝图
app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered")
# 配置数据库连接
app.config['MYSQL_HOST'] = 'localhost'
app.config['MYSQL_USER'] = 'root'
app.config['MYSQL_PASSWORD'] = '123456'
app.config['MYSQL_DB'] = 'equipment_cost_db'
logger.info("Starting server...")
return app
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
raise
if __name__ == '__main__':
app = create_app()
app.run(host='localhost', port=5001)

View File

@ -0,0 +1,342 @@
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from scipy import stats
import joblib
import os
import pandas as pd
from .feature_analysis import FeatureAnalysis
import logging
from src.model_trainer import ModelTrainer
from src.database import get_db_connection
from .logger import setup_logger
logger = setup_logger(__name__)
class CostPredictor:
def __init__(self):
self.scaler_X = StandardScaler()
self.scaler_y = StandardScaler()
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
# 添加 TensorFlow 配置
tf.config.run_functions_eagerly(False) # 启用图执行模式
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
self.load_model()
def load_model(self):
"""
加载预训练型和标准化器
"""
try:
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 创建默认模型
self._create_default_model()
# 创建预测函数
@tf.function(reduce_retracing=True, jit_compile=True)
def predict_fn(x):
return self.model(x, training=False)
self._predict_fn = predict_fn
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
def _create_default_model(self):
"""
创建默认模型并进行初始化训练
"""
# 创建输入层
inputs = tf.keras.Input(shape=(11,))
# 创建隐藏层
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
# 创建输出层
outputs = tf.keras.layers.Dense(1)(x)
# 创建模型
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 编译模型
self.model.compile(
optimizer='adam',
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_absolute_error]
)
# 创建示例数据
example_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 训练标准化器
self.scaler_X.fit(example_data)
self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
# 设置默认装备类型
self.equipment_type = '火箭炮'
def _create_example_data(self):
"""
创建示例数据来训练标准化器
"""
# 火箭炮示例数据
rocket_data = pd.DataFrame({
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
})
# 巡飞弹示例数据
missile_data = pd.DataFrame({
'length_m': [1.3, 2.5],
'width_m': [0.23, 0.6],
'height_m': [0.23, 0.6],
'weight_kg': [12.5, 135],
'max_range_km': [40, 250],
'max_speed_kmh': [180, 185],
'cruise_speed_kmh': [100, 110],
'flight_time_min': [60, 120],
'folded_length_mm': [1300, 2500],
'folded_width_mm': [230, 600],
'folded_height_mm': [230, 600]
})
# 训练标准化器
self.scaler_X.fit(rocket_data) # 使用火箭炮数据
self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
# 设置默认装备类型
self.equipment_type = '火箭炮'
def predict(self, data):
"""
使用训练好的最优模型进行预测
"""
try:
logger.info(f"Starting prediction for {data.get('type')}")
equipment_type = data.get('type')
# 加载已训练的最优模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type):
raise ValueError(f"No trained model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 获取模型类型
model_type = trainer.get_model_type()
return {
'predicted_cost': float(y_pred[0]),
'model_type': model_type, # 返回使用的模型类型
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise
def _calculate_confidence_interval(self, prediction, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
lower = max(1000, interval[0]) # 最小值设为1000元
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
return [lower, upper]
except Exception as e:
logging.error(f"Error calculating confidence interval: {str(e)}")
# 如果计算失败返回基于20%的简单区间
lower = max(1000, prediction * 0.8)
upper = prediction * 1.2
return [lower, upper]
def evaluate(self, y_true, y_pred):
"""
模型评估
"""
return {
'mae': float(mean_absolute_error(y_true, y_pred)),
'mse': float(mean_squared_error(y_true, y_pred)),
'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
'r2': float(r2_score(y_true, y_pred))
}
def predict_pls(self, data):
"""
使用 PLS <EFBFBD><EFBFBD><EFBFBD>型预测成本
"""
try:
logger.info(f"Starting PLS prediction for {data.get('type')}")
equipment_type = data.get('type')
# 加载 PLS 模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
raise ValueError(f"No trained PLS model found for {equipment_type}")
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
return {
'predicted_cost': float(y_pred[0]),
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"PLS prediction error: {str(e)}")
raise
def predict_all(self, data):
"""
使用所有可用模型进行预测
"""
try:
logger.info(f"Starting multi-model prediction for {data.get('type')}")
equipment_type = data.get('type')
results = {}
# 1. 获取所有激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT id, model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
""", (equipment_type,))
active_models = cursor.fetchall()
if not active_models:
raise ValueError(f"No active models found for {equipment_type}")
# 2. 使用每个模型进行预测
trainer = ModelTrainer()
for model_info in active_models:
try:
# 加载特定模型
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
logger.warning(f"Failed to load model: {model_info['model_name']}")
continue
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 保存结果
results[model_info['model_type']] = {
'predicted_cost': float(y_pred[0]),
'model_info': {
'name': model_info['model_name'],
'type': model_info['model_type'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
continue
if not results:
raise ValueError("No successful predictions from any model")
# 3. 计算综合预测结果
all_predictions = [result['predicted_cost'] for result in results.values()]
ensemble_prediction = float(np.mean(all_predictions))
prediction_std = float(np.std(all_predictions))
# 4. 返回所有结果
return {
'individual_predictions': results,
'ensemble_prediction': {
'predicted_cost': ensemble_prediction,
'standard_deviation': prediction_std,
'confidence_interval': {
'lower': float(ensemble_prediction - 1.96 * prediction_std),
'upper': float(ensemble_prediction + 1.96 * prediction_std)
}
}
}
except Exception as e:
logger.error(f"Error in multi-model prediction: {str(e)}")
raise

View File

@ -0,0 +1,155 @@
import pandas as pd
import openpyxl
from openpyxl.styles import PatternFill, Font, Alignment
from openpyxl.worksheet.datavalidation import DataValidation
import os
from .logger import setup_logger
logger = setup_logger(__name__)
def create_excel_template():
"""
创建数据模板
"""
try:
# 确保data目录存在
data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
os.makedirs(data_dir, exist_ok=True)
# 创建完整的文件路径
template_path = os.path.join(data_dir, 'equipment_data_template.xlsx')
# 创建 Excel 写入器
writer = pd.ExcelWriter(template_path, engine='openpyxl')
# 火箭炮基本参数表
rocket_artillery_columns = [
'名称', '类型', '制造商', '口径_mm',
'反射管数量', '乘员数', '总长_m',
'宽度_m', '高度_m', '重量_kg',
'战斗重_kg', '速度_km/h', '最大射程_km',
'最小射程_km', '方向射界_度', '高低射界_度',
'火箭弹长度_m', '火箭弹重量_kg',
'火箭弹最大速度_m/s', '射速_发',
'战斗部重量_kg', '行走方式',
'结构布局', '发动机型号', '发动机参数',
'功率_hp', '行程_km', '成本_元'
]
# 巡飞弹基本参数表
loitering_munition_columns = [
'名称', '类型', '制造商', '目标类型',
'弹长_m', '弹径_mm', '翼展_m',
'重量_kg', '战斗部重量_kg',
'最大射程_km', '最大速度_m/s',
'巡航速度_kmh', '巡飞时间_min',
'战斗部类型', '发射方式',
'折叠长度_mm', '折叠宽度_mm',
'折叠高度_mm', '动力装置',
'制导体制', '成本_元'
]
# 特殊参数表
special_params_columns = [
'装备名称', # 关联字段
'参数名称',
'参数值',
'参数单位',
'参数说明'
]
# 创建工作表
pd.DataFrame(columns=rocket_artillery_columns).to_excel(
writer, sheet_name='火箭炮', index=False
)
pd.DataFrame(columns=loitering_munition_columns).to_excel(
writer, sheet_name='巡飞弹', index=False
)
pd.DataFrame(columns=special_params_columns).to_excel(
writer, sheet_name='特殊参数', index=False
)
# 获取工作簿
workbook = writer.book
# 设置火箭炮工作表格式
rocket_sheet = workbook['火箭炮']
for col in range(1, len(rocket_artillery_columns) + 1):
cell = rocket_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 设置巡飞弹工作表格式
missile_sheet = workbook['巡飞弹']
for col in range(1, len(loitering_munition_columns) + 1):
cell = missile_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 设置特殊参数工作表格式
special_sheet = workbook['特殊参数']
for col in range(1, len(special_params_columns) + 1):
cell = special_sheet.cell(row=1, column=col)
cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
cell.font = Font(bold=True)
cell.alignment = Alignment(horizontal='center')
# 添加数据验证
for sheet in [rocket_sheet, missile_sheet]:
# 数值验证
number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0")
number_validation.error = "请输入大于0的数值"
number_validation.errorTitle = "输入错误"
# 应用到相应列
for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
number_validation.add(f"{col}2:{col}1000")
sheet.add_data_validation(number_validation)
# 添加说明
rocket_sheet['AD1'] = "填写说明:"
rocket_sheet['AD2'] = "1. 所有数值必须大于0"
rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写"
rocket_sheet['AD4'] = "3. 成本单位为元"
missile_sheet['V1'] = "填写说明:"
missile_sheet['V2'] = "1. 所有数值必须大于0"
missile_sheet['V3'] = "2. 单位必须按照表头要求填写"
missile_sheet['V4'] = "3. 成本单位为元"
special_sheet['G1'] = "填写说明:"
special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致"
special_sheet['G3'] = "2. 参数值必须包含单位"
special_sheet['G4'] = "3. 参数说明应简明扼要"
# 调整列宽
for sheet in [rocket_sheet, missile_sheet, special_sheet]:
for col in sheet.columns:
max_length = 0
column = col[0].column_letter
for cell in col:
try:
if len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except:
pass
adjusted_width = (max_length + 2)
sheet.column_dimensions[column].width = adjusted_width
# 保存文件
writer.close()
return template_path
except Exception as e:
raise Exception(f"创建模板文件失败: {str(e)}")
if __name__ == "__main__":
try:
template_path = create_excel_template()
print(f"模板文件已创建: {template_path}")
except Exception as e:
print(f"错误: {str(e)}")

View File

@ -0,0 +1,233 @@
from sklearn.preprocessing import StandardScaler
from datetime import datetime
import os
import joblib
import pandas as pd
import numpy as np
from src.feature_analysis import FeatureAnalysis
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.model_selection import cross_val_score, LeaveOneOut
import json
import logging
from src.database.db_connection import get_db_connection
from sklearn.metrics import mean_absolute_error, mean_squared_error
from .logger import setup_logger
logger = setup_logger(__name__)
class DataPreparation:
def __init__(self):
self.feature_analyzer = FeatureAnalysis()
self.feature_scaler = StandardScaler()
self.target_scaler = StandardScaler() # 添加目标值标准化器
def prepare_training_data(self, equipment_data, equipment_type):
"""
准备训练数据
"""
try:
logger.info(f"Preparing training data for {equipment_type}")
logger.info(f"Raw data size: {len(equipment_data)}")
# 如果输入已经是 numpy 数组,直接返回
if isinstance(equipment_data, np.ndarray):
X = equipment_data
logger.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
return {
'X': X,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
# 从原始数据中提取特征和目标值
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
for item in equipment_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 提取目标值(成本)
try:
cost = float(item['actual_cost'])
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logger.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError, KeyError):
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录原始数据范围
logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
# 标准化特征和目标值
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
# 记录标准化后的数据范围
logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
# 记录标准化器参数
logger.info("Feature scaler params:")
logger.info(f"Mean: {self.feature_scaler.mean_}")
logger.info(f"Scale: {self.feature_scaler.scale_}")
logger.info("Target scaler params:")
logger.info(f"Mean: {self.target_scaler.mean_}")
logger.info(f"Scale: {self.target_scaler.scale_}")
return {
'X': X_scaled,
'y': y_scaled,
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}
except Exception as e:
logger.error(f"Error in data preparation: {str(e)}")
raise Exception(f"Training error: {str(e)}")
def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None):
"""
准备验证数据
"""
try:
logger.info(f"Preparing validation data for {equipment_type}")
logger.info(f"Raw validation data size: {len(validation_data)}")
# 如果输入已经是 numpy 数组,直接使用
if isinstance(validation_data, np.ndarray):
X = validation_data
logger.info(f"Input is already numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 使用训练数据的标准化器
if scalers and 'feature_scaler' in scalers:
X_scaled = scalers['feature_scaler'].transform(X)
else:
X_scaled = X
logger.info(f"Preprocessed data shape: {X_scaled.shape}")
logger.info(f"Validation features shape: {X_scaled.shape}")
logger.info(f"Validation features type: {X_scaled.dtype}")
return {
'X': X_scaled,
'y': None # 验证数据可能没有标签
}
# 从原始数据中提取特征和目标值
if not feature_names:
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
for item in validation_data:
# 提取特征值
feature_values = []
for name in feature_names:
value = item.get(name)
try:
feature_values.append(float(value) if value is not None else 0.0)
except (ValueError, TypeError):
feature_values.append(0.0)
features.append(feature_values)
# 提取目标值(成本)并验证范围
try:
cost = float(item['actual_cost'])
logger.info(f"Raw cost value: {cost}")
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:
logger.warning(f"Skipping non-positive cost value: {cost}")
except (ValueError, TypeError):
logger.error(f"Invalid cost value: {item.get('actual_cost')}")
continue
# 转换为numpy数组
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录数据范围
logger.info(f"Features range: min={X.min()}, max={X.max()}")
logger.info(f"Targets range: min={y.min()}, max={y.max()}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 使用训练数据的标准化器
if scalers and 'feature_scaler' in scalers:
X_scaled = scalers['feature_scaler'].transform(X)
if 'target_scaler' in scalers:
y_scaled = scalers['target_scaler'].transform(y.reshape(-1, 1)).ravel()
else:
y_scaled = y
else:
X_scaled = X
y_scaled = y
logger.info(f"Preprocessed data shape: {X_scaled.shape}")
logger.info(f"Validation features shape: {X_scaled.shape}")
logger.info(f"Validation features type: {X_scaled.dtype}")
# 记录标准化后的数据范围
logger.info(f"Scaled validation X range: min={X_scaled.min()}, max={X_scaled.max()}")
logger.info(f"Scaled validation y range: min={y_scaled.min()}, max={y_scaled.max()}")
# 确保特征维度一致
if not feature_names:
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
logger.info(f"Expected features: {len(feature_names)}")
logger.info(f"Actual features: {X_scaled.shape[1]}")
if X_scaled.shape[1] != len(feature_names):
raise ValueError(f"Feature dimension mismatch: expected {len(feature_names)}, got {X_scaled.shape[1]}")
return {
'X': X_scaled,
'y': y_scaled # 返回标准化后的成本值
}
except Exception as e:
logger.error(f"Error in validation data preparation: {str(e)}")
logger.error(f"Feature names: {feature_names}")
logger.error(f"Equipment type: {equipment_type}")
raise Exception(f"Validation error: {str(e)}")
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
return self.feature_analyzer.calculate_derived_features(data, equipment_type)
except Exception as e:
logger.error(f"Error calculating derived features: {str(e)}")
raise Exception(f"Feature calculation error: {str(e)}")

View File

@ -0,0 +1 @@
from .db_connection import get_db_connection

View File

@ -0,0 +1,37 @@
import mysql.connector
from mysql.connector import Error
from contextlib import contextmanager
import os
from dotenv import load_dotenv
from ..logger import setup_logger
# 获取logger
logger = setup_logger(__name__)
# 加载环境变量
load_dotenv()
@contextmanager
def get_db_connection():
"""
数据库连接上下文管理器
"""
connection = None
try:
connection = mysql.connector.connect(
host=os.getenv('MYSQL_HOST', 'localhost'),
user=os.getenv('MYSQL_USER', 'root'),
password=os.getenv('MYSQL_PASSWORD', '123456'),
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
)
logger.info("Database connection established")
yield connection
except Error as e:
logger.error(f"Error connecting to MySQL: {str(e)}")
raise
finally:
if connection and connection.is_connected():
connection.close()
logger.info("Database connection closed")

View File

@ -0,0 +1,269 @@
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
import logging
from .logger import setup_logger
logger = setup_logger(__name__)
class FeatureAnalysis:
def __init__(self):
self.scaler = StandardScaler()
self.important_features = []
# 添加特征名称映射
self.feature_names_map = {
# 通用参数
'length_m': '总长(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
# 火箭炮特有参数
'firing_angle_horizontal': '方向射界(度)',
'firing_angle_vertical': '高低射界(度)',
'rocket_length_m': '火箭弹长度(m)',
'rocket_diameter_mm': '口径(mm)',
'rocket_weight_kg': '火箭弹重量(kg)',
'rate_of_fire': '射速(发/分)',
'combat_weight_kg': '战斗重量(kg)',
'speed_kmh': '速度(km/h)',
'min_range_km': '最小射程(km)',
'power_hp': '功率(hp)',
# 火箭炮衍生特征
'fire_density': '火力密度',
'mobility_index': '机动性指标',
'range_ratio': '射程比',
'power_weight_ratio': '功重比',
'volume_density': '体积密度',
# 巡飞弹特有参数
'wingspan_m': '翼展(m)',
'warhead_weight_kg': '战斗部重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
'flight_time_min': '巡飞时间(min)',
'folded_length_mm': '折叠长度(mm)',
'folded_width_mm': '折叠宽度(mm)',
'folded_height_mm': '折叠高度(mm)',
# 巡飞弹衍生特征
'warhead_ratio': '战斗部比重',
'speed_ratio': '速度比',
'range_time_ratio': '射程时间比',
'aspect_ratio': '展弦比',
'volume_density': '体积密度'
}
def get_equipment_specific_features(self, equipment_type):
"""
获取特定装备类型的特征列表
"""
# 通用参数
common_features = [
'length_m', # 总长(m)
'width_m', # 宽度(m)
'height_m', # 高度(m)
'weight_kg', # 重量(kg)
'max_range_km' # 最大射程(km)
]
if equipment_type == '火箭炮':
# 火箭炮特有参数
specific_features = [
'firing_angle_horizontal', # 方向射界(度)
'firing_angle_vertical', # 高低射界(度)
'rocket_length_m', # 火箭弹长度(m)
'rocket_diameter_mm', # 口径(mm)
'rocket_weight_kg', # 火箭弹重量(kg)
'rate_of_fire', # 射速(发/分)
'combat_weight_kg', # 战斗重量(kg)
'speed_kmh', # 速度(km/h)
'min_range_km', # 最小射程(km)
'power_hp' # 功率(hp)
]
# 火箭炮衍生特征
derived_features = [
'fire_density', # 火力密度 = 射速 * 火箭弹重量
'mobility_index', # 机动性指标 = 速度 / 战斗重量
'range_ratio', # 射程比 = 最大射程 / 最小射程
'power_weight_ratio', # 功重比 = 功率 / 战斗重量
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
else: # 巡飞弹
# 巡飞弹特有参数
specific_features = [
'wingspan_m', # 翼展(m)
'warhead_weight_kg', # 战斗部重量(kg)
'max_speed_ms', # 最大速度(m/s)
'cruise_speed_kmh', # 巡航速度(km/h)
'flight_time_min', # 巡飞时间(min)
'folded_length_mm', # 折叠长度(mm)
'folded_width_mm', # 折叠宽度(mm)
'folded_height_mm' # 折叠高度(mm)
]
# 巡飞弹衍生特征
derived_features = [
'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
'speed_ratio', # 速度比 = 巡航速度 / 最大速度
'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
]
return common_features + specific_features + derived_features
def calculate_derived_features(self, data, equipment_type):
"""
计算衍生特征
"""
try:
if equipment_type == '火箭炮':
# 火箭炮衍生特征计算
if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
else:
data['fire_density'] = 0 # 或者其他默认值
if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
else:
data['mobility_index'] = 0
if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
data['range_ratio'] = data['max_range_km'] / data['min_range_km']
else:
data['range_ratio'] = 0
if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
else:
data['power_weight_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
else: # 巡飞弹
# 巡飞弹衍生特征计算
if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
else:
data['warhead_ratio'] = 0
if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
else:
data['speed_ratio'] = 0
if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
else:
data['range_time_ratio'] = 0
if 'wingspan_m' in data.columns and 'length_m' in data.columns:
data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
else:
data['aspect_ratio'] = 0
if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
else:
data['volume_density'] = 0
return data
except Exception as e:
logger.error(f"Error calculating derived features: {str(e)}")
raise
def analyze_features(self, features, target, feature_names):
"""
分析特征重要性和相关性
"""
try:
# 转换为numpy数组
X = np.array(features)
y = np.array(target)
# 数据标准化
X_scaled = self.scaler.fit_transform(X)
# 特征重要性分析
rf = RandomForestRegressor(n_estimators=100, random_state=42)
rf.fit(X_scaled, y)
importances = rf.feature_importances_
# 按重要性排序,使用中文特征名
importance_indices = np.argsort(importances)[::-1]
important_features = [
{
'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
'importance': float(importances[i])
}
for i in importance_indices
]
# 相关性分析
df = pd.DataFrame(X_scaled, columns=feature_names)
correlation_matrix = df.corr().values
# 生成相关性分析数据保留2位小数
correlation_data = []
chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
for i in range(len(feature_names)):
for j in range(len(feature_names)):
correlation_data.append([
i, j,
round(correlation_matrix[i][j], 2) # 修改为保留2位小数
])
return {
'important_features': important_features,
'correlation_analysis': {
'features': chinese_feature_names, # 使用中文特征名
'matrix': correlation_data
}
}
except Exception as e:
logger.error(f"Error in feature analysis: {str(e)}")
raise
def preprocess_features(self, equipment_data, equipment_type):
"""
预处理特征数据
"""
try:
# 转换为 DataFrame
df = pd.DataFrame(equipment_data)
# 计算衍生特征
df = self.calculate_derived_features(df, equipment_type)
# 处理缺失值
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
# 转换为数值类型
df[col] = pd.to_numeric(df[col], errors='coerce')
# 使用新的方式填充缺失值
mean_value = df[col].mean()
df[col] = df[col].fillna(mean_value)
logger.info(f"Preprocessed data shape: {df.shape}")
return df
except Exception as e:
logger.error(f"Error preprocessing features: {str(e)}")
raise Exception(f"Feature preprocessing error: {str(e)}")

View File

@ -0,0 +1,255 @@
import pandas as pd
from .logger import setup_logger
from src.database.db_connection import get_db_connection
logger = setup_logger(__name__)
def import_training_data(excel_file):
"""
从Excel导入训练数据到数据库
"""
try:
# 读取所有sheet
rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮')
missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹')
special_df = pd.read_excel(excel_file, sheet_name='特殊参数')
# 记录所有装备名称,用于后续检查
equipment_names = set()
with get_db_connection() as conn:
cursor = conn.cursor()
# 1. 先导入火箭炮数据
logger.info("开始导入火箭炮数据...")
for _, row in rocket_df.iterrows():
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '火箭炮'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
if existing_equipment:
logger.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入")
continue
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (row['名称'], '火箭炮', row['制造商']))
equipment_id = cursor.lastrowid
# 插入通用参数
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
equipment_id,
row['总长_m'] if pd.notna(row['总长_m']) else None,
row['宽度_m'] if pd.notna(row['宽度_m']) else None,
row['高度_m'] if pd.notna(row['高度_m']) else None,
row['重量_kg'] if pd.notna(row['重量_kg']) else None,
row['最大射程_km'] if pd.notna(row['最大射程_km']) else None
))
# 插入火箭炮特有参数
cursor.execute("""
INSERT INTO rocket_artillery_params
(equipment_id, firing_angle_horizontal, firing_angle_vertical,
rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire,
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
structure_layout, engine_model, engine_params, power_hp,
travel_range_km)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
equipment_id,
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
row['高低射界_度'] if pd.notna(row['高低射界_度']) else None,
row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None,
row['口径_mm'] if pd.notna(row['口径_mm']) else None,
row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None,
row['射速_发'] if pd.notna(row['射速_发']) else None,
row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None,
row['速度_km/h'] if pd.notna(row['速度_km/h']) else None,
row['最小射程_km'] if pd.notna(row['最小射程_km']) else None,
row['行走方式'] if pd.notna(row['行走方式']) else None,
row['结构布局'] if pd.notna(row['结构布局']) else None,
row['发动机型号'] if pd.notna(row['发动机型号']) else None,
row['发动机参数'] if pd.notna(row['发动机参数']) else None,
row['功率_hp'] if pd.notna(row['功率_hp']) else None,
row['行程_km'] if pd.notna(row['行程_km']) else None
))
# 插入成本数据
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
""", (equipment_id, row['成本_元']))
logger.info("火箭炮数据导入完成")
# 2. 导入巡飞弹数据
logger.info("开始导入巡飞弹数据...")
for index, row in missile_df.iterrows():
# 记录每行数据的空值情况
null_values = row[row.isna()].index.tolist()
if null_values:
logger.info(f"{index + 2} 中的空值字段: {null_values}")
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '巡飞弹'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
if existing_equipment:
logger.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入")
continue
# 插入基本信息
cursor.execute("""
INSERT INTO equipment (name, type, manufacturer)
VALUES (%s, %s, %s)
""", (
row['名称'],
'巡飞弹',
row['制造商'] if pd.notna(row['制造商']) else None
))
equipment_id = cursor.lastrowid
# 插入通用参数
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
equipment_id,
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
float(row['重量_kg']) if pd.notna(row['重量_kg']) else None,
float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None
))
# 插入巡飞弹特有参数
cursor.execute("""
INSERT INTO loitering_munition_params
(equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms,
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
folded_length_mm, folded_width_mm, folded_height_mm,
power_system, guidance_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
equipment_id,
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None,
float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None,
float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None,
float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None,
str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None,
str(row['发射方式']) if pd.notna(row['发射方式']) else None,
float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None,
float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None,
float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None,
str(row['动力装置']) if pd.notna(row['动力装置']) else None,
str(row['制导体制']) if pd.notna(row['制导体制']) else None
))
# 插入成本数据
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
""", (equipment_id, float(row['成本_元'])))
logger.info("巡飞弹数据导入完成")
# 提交之前的更改并关闭原有游标
cursor.close()
conn.commit()
# 3. 导入特殊参数
logger.info("开始导入特殊参数...")
for index, row in special_df.iterrows():
equipment_name = row['装备名称']
param_name = row['参数名称']
logger.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'")
if equipment_name not in equipment_names:
logger.warning(f"未找到装备: {equipment_name},请检查名称是否正确")
continue
# 获取装备ID - 使用新的游标
logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
SELECT id FROM equipment WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()
if not result:
logger.warning(f"未找到装备: {equipment_name}")
continue
equipment_id = result[0]
logger.debug(f"找到装备ID: {equipment_id}")
# 检查参数是否存在 - 使用新的游标
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
with conn.cursor() as check_cursor:
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = %s AND param_name = %s
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
if exists:
logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
continue
# 插入新的参数 - 使用新的游标
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
with conn.cursor() as insert_cursor:
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (%s, %s, %s, %s, %s)
""", (
equipment_id,
param_name,
param_value,
param_unit,
param_desc
))
logger.debug(f"成功插入参数记录")
# 最终提交
conn.commit()
logger.info("特殊参数导入完成")
logger.info("所有数据导入成功")
return True
except Exception as e:
logger.error(f"Error importing data: {str(e)}")
raise
if __name__ == "__main__":
try:
excel_file = 'data/equipment_data_20241108.xlsx'
import_training_data(excel_file)
logger.info("All data imported successfully")
except Exception as e:
logger.error(f"Import failed: {str(e)}")

View File

@ -0,0 +1,319 @@
/*
使
1.
2.
3.
*/
-- 插入装备基本信息
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-- 插入巡飞弹技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES (
1, -- 终结者巡飞弹
0.56,
0.15,
0.20,
2.72,
160.93,
96.56,
24,
15,
'破片杀伤战斗部',
'凭自身动力起飞',
560,
150,
200
);
-- 插入火箭炮技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES (
2, -- 胜利-2火箭炮
10,
2.5,
3.34,
15000,
23
);
-- 插入成本数据(示例数据)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 1000000), -- 终结者巡飞弹成本
(2, 5000000); -- 胜利-2火箭炮成本
-- 插入更多巡飞弹变体数据用于训练
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-- 插入变体技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 终结者-A稍大型号
(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
-- 终结者-B稍小型号
(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
-- 终结者-C标准型号的改进版
(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-- 插入变体成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(3, 1100000), -- 终结者-A成本较高
(4, 900000), -- 终结者-B成本较低
(5, 1050000); -- 终结者-C成本中等
-- 添加更多巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-- 添加它们的技术参数
INSERT INTO technical_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_speed_kmh,
cruise_speed_kmh,
max_range_km,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
) VALUES
-- 哈比
(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
-- 游荡者
(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
-- 凤凰
(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
-- 弹簧刀
(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
-- 彩虹-4
(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-- 添加成本数据
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(6, 800000), -- 哈比
(7, 500000), -- 游荡者
(8, 450000), -- 凤凰
(9, 480000), -- 弹簧刀
(10, 1500000); -- 彩虹-4
-- 火箭炮数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('BM-21', '火箭炮', '俄罗斯'),
('SR5', '火箭炮', '中国'),
('HIMARS', '火箭炮', '美国'),
('LAR-160', '火箭炮', '以色列'),
('T-122', '火箭炮', '土耳其'),
('RM-70', '火箭炮', '捷克'),
('ASTROS II', '火箭炮', '巴西');
-- 火箭炮通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- BM-21
(1, 7.35, 2.4, 3.1, 13700, 20.4),
-- SR5
(2, 10.2, 2.8, 3.2, 28500, 70),
-- HIMARS
(3, 7.0, 2.4, 3.2, 16250, 70),
-- LAR-160
(4, 6.7, 2.5, 2.8, 15000, 45),
-- T-122
(5, 7.2, 2.5, 2.9, 18000, 40),
-- RM-70
(6, 7.5, 2.5, 3.0, 17200, 20.3),
-- ASTROS II
(7, 8.0, 2.7, 3.1, 24500, 90);
-- 火箭炮特有参数
INSERT INTO rocket_artillery_params (
equipment_id,
firing_angle_horizontal,
firing_angle_vertical,
rocket_length_m,
rocket_diameter_mm,
rocket_weight_kg,
rate_of_fire,
combat_weight_kg,
speed_kmh,
min_range_km,
mobility_type,
structure_layout,
engine_model,
engine_params,
power_hp,
travel_range_km
) VALUES
-- BM-21
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
-- SR5
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
-- HIMARS
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
-- LAR-160
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
-- T-122
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
-- RM-70
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
-- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-- 巡飞弹数据
INSERT INTO equipment (name, type, manufacturer) VALUES
('Hero-120', '巡飞弹', '以色列'),
('Switchblade 600', '巡飞弹', '美国'),
('Warmate', '巡飞弹', '波兰'),
('CH-901', '巡飞弹', '中国'),
('HAROP', '巡飞弹', '以色列'),
('Coyote', '巡飞弹', '美国'),
('WS-43', '巡飞弹', '中国');
-- 巡飞弹通用参数
INSERT INTO common_params (
equipment_id,
length_m,
width_m,
height_m,
weight_kg,
max_range_km
) VALUES
-- Hero-120
(8, 1.3, 0.23, 0.23, 12.5, 40),
-- Switchblade 600
(9, 1.3, 0.22, 0.22, 15.0, 40),
-- Warmate
(10, 1.1, 0.15, 0.15, 5.7, 15),
-- CH-901
(11, 1.2, 0.18, 0.18, 9.0, 20),
-- HAROP
(12, 2.5, 0.43, 0.43, 135, 1000),
-- Coyote
(13, 0.9, 0.12, 0.12, 5.9, 20),
-- WS-43
(14, 1.8, 0.35, 0.35, 20, 60);
-- 巡飞弹特有参数
INSERT INTO loitering_munition_params (
equipment_id,
wingspan_m,
warhead_weight_kg,
max_speed_ms,
cruise_speed_kmh,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm,
power_system,
guidance_system
) VALUES
-- Hero-120
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
-- Switchblade 600
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
-- Warmate
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
-- CH-901
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
-- HAROP
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
-- Coyote
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
-- WS-43
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-- 火箭炮
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000), -- ASTROS II
-- 巡飞弹
(8, 150000), -- Hero-120
(9, 180000), -- Switchblade 600
(10, 80000), -- Warmate
(11, 100000), -- CH-901
(12, 850000), -- HAROP
(13, 75000), -- Coyote
(14, 120000); -- WS-43
-- 创建初始数据集
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-- 关联装备到数据集
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 火箭炮训练集
(1, 1), (1, 2), (1, 3), (1, 4),
-- 巡飞弹训练集
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
-- 火箭炮验证集
(3, 5), (3, 6), (3, 7),
-- 巡飞弹验证集
(4, 13), (4, 14);

View File

@ -0,0 +1,33 @@
import logging
import os
from datetime import datetime
def setup_logger(name):
"""
创建并配置logger
"""
# 创建logger
logger = logging.getLogger(name)
# 如果logger已经有处理器直接返回
if logger.handlers:
return logger
# 设置日志级别
logger.setLevel(logging.INFO)
# 确保日志目录存在
os.makedirs('logs', exist_ok=True)
# 创建文件处理器
file_handler = logging.FileHandler('logs/api.log')
file_handler.setLevel(logging.INFO)
# 创建格式化器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
return logger

View File

@ -0,0 +1,612 @@
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.impute import SimpleImputer
import xgboost as xgb
import lightgbm as lgb
import logging
import joblib
import os
from src.feature_analysis import FeatureAnalysis
from datetime import datetime
import json
from src.database import get_db_connection
from src.data_preparation import DataPreparation
from sklearn.cross_decomposition import PLSRegression
from .logger import setup_logger
logger = setup_logger(__name__)
class ModelTrainer:
def __init__(self):
"""
初始化 ModelTrainer
"""
self.models = {
'xgboost': self._create_xgboost_model(),
'lightgbm': self._create_lightgbm_model(),
'gbm': self._create_gbm_model(),
'rf': self._create_rf_model(),
'pls': self._create_pls_model()
}
self.best_model = None
self.imputer = SimpleImputer(strategy='mean')
self.feature_scaler = None
self.target_scaler = None
self.equipment_type = None
self.feature_analyzer = FeatureAnalysis()
def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
"""
训练模型并返回评估结果
"""
try:
self.equipment_type = equipment_type
logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
results = {}
best_score = -float('inf')
best_model_info = None
# 首先训练 PLS 模型
logger.info("Training pls...")
pls_model = self.models['pls']
pls_model.fit(X_train, y_train)
pls_metrics = self._calculate_metrics(
pls_model,
X_train, y_train,
X_val, y_val
)
results['pls'] = pls_metrics
# 训练其他机器学习模型
for model_name in model_names:
if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
continue
if model_name not in self.models:
logger.warning(f"Unknown model: {model_name}")
continue
logger.info(f"Training {model_name}...")
model = self.models[model_name]
# 训练模型
model.fit(X_train, y_train)
# 计算评估指标
metrics = self._calculate_metrics(
model,
X_train, y_train,
X_val, y_val
)
results[model_name] = metrics
# 更新最佳模型(只在机器学习模型中比较)
if metrics['validation']['r2'] > best_score:
best_score = metrics['validation']['r2']
best_model_info = {
'type': model_name,
'r2': metrics['validation']['r2'],
'mae': metrics['validation']['mae'],
'rmse': metrics['validation']['rmse']
}
self.best_model = model
# 保存最佳模型和 PLS 模型
if equipment_type and best_model_info:
self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
return {
'metrics': results,
'best_model': best_model_info
}
except Exception as e:
logger.error(f"Error in model training: {str(e)}")
raise
def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
"""
计算模型评估指标
"""
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
# 训练集评估
train_pred = model.predict(X_train)
# 如果使用了标准化,需要转换回原始范围
if hasattr(self, 'target_scaler'):
train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
else:
y_train_orig = y_train
# 记录预测范围
logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
train_metrics = {
'r2': r2_score(y_train_orig, train_pred),
'mae': mean_absolute_error(y_train_orig, train_pred),
'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
}
# 验证集评估
if X_val is not None and y_val is not None:
val_pred = model.predict(X_val)
# 如果使用了标准化,需要转换回原始范围
if hasattr(self, 'target_scaler'):
val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
else:
y_val_orig = y_val
# 记录预测范围
logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
val_metrics = {
'r2': r2_score(y_val_orig, val_pred),
'mae': mean_absolute_error(y_val_orig, val_pred),
'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
}
else:
# 使用交叉验证
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
val_metrics = {
'r2': cv_scores.mean(),
'mae': None,
'rmse': None
}
return {
'train': train_metrics,
'validation': val_metrics
}
def _create_xgboost_model(self):
"""
创建 XGBoost 模型增强正则化
"""
return xgb.XGBRegressor(
n_estimators=50, # 减少树的数量
learning_rate=0.05, # 学习率
max_depth=3, # 减小树的深
min_child_weight=3, # 增加节点权重
subsample=0.7, # 减小样本采样比例
colsample_bytree=0.7, # 减小特征采样比例
reg_alpha=0.1, # L1 正则化
reg_lambda=1, # L2 正则化
random_state=42
)
def _create_lightgbm_model(self):
"""
创建 LightGBM 模型增强正则化
"""
return lgb.LGBMRegressor(
n_estimators=50,
learning_rate=0.05,
max_depth=3,
num_leaves=7,
min_data_in_leaf=3,
min_sum_hessian_in_leaf=1e-3,
subsample=0.7,
colsample_bytree=0.7,
reg_alpha=0.1,
reg_lambda=1,
random_state=42,
verbose=-1
)
def _create_gbm_model(self):
"""
创建 GBM 模型增强正则化以减轻过拟合
"""
return GradientBoostingRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42,
subsample=0.8,
min_samples_split=3,
min_samples_leaf=2
)
def _create_rf_model(self):
"""
创建随机森林模型针对小样本数据调整参数
"""
return RandomForestRegressor(
n_estimators=100,
max_depth=3,
random_state=42,
min_samples_split=3,
min_samples_leaf=2
)
def _create_pls_model(self):
"""
创建 PLS 模型优化参数配置
"""
return PLSRegression(
n_components=2, # 减少主成分数量从5减到2
scale=True, # 保持数据标准化
max_iter=500, # 减少最大迭代次数,避免过拟合
tol=1e-6 # 降低收敛精度,避免过拟合
)
def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
"""
保存最佳模型和 PLS 模型
"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 1. 保存最佳机器学习模型
model_path = f'{model_dir}/{equipment_type}_{timestamp}'
if isinstance(self.best_model, xgb.XGBRegressor):
self.best_model.save_model(f'{model_path}.json')
model_format = 'json'
else:
joblib.dump(self.best_model, f'{model_path}.joblib')
model_format = 'joblib'
# 2. 保存 PLS 模型
pls_model = self.models['pls']
pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
joblib.dump(pls_model, pls_path)
# 3. 保存标准化器
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
joblib.dump({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
logger.info(f"Saved best model to {model_path}.{model_format}")
logger.info(f"Saved PLS model to {pls_path}")
logger.info(f"Saved scalers to {scaler_path}")
# 4. 更新数据库中的模型记录
with get_db_connection() as conn:
cursor = conn.cursor()
# 将所有模型设置为非激活
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s
""", (equipment_type,))
# 获取 PLS 模型的评估指标
pls_metrics = self._calculate_metrics(
self.models['pls'],
X_train,
y_train,
X_val,
y_val
)
# 保存最佳机器学习模型记录
self.best_model.equipment_type = equipment_type # 设置装备类型
ml_feature_importance = self._get_feature_importance(self.best_model)
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path, scaler_path,
r2_score, mae, rmse, feature_importance, training_data_size,
training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
f"{equipment_type}_{timestamp}", # model_name
best_model_info['type'], # model_type
equipment_type, # equipment_type
f"{model_path}.{model_format}", # model_path
scaler_path, # scaler_path
best_model_info['r2'], # r2_score
best_model_info['mae'], # mae
best_model_info['rmse'], # rmse
json.dumps(ml_feature_importance), # feature_importance
len(X_train), # training_data_size
'system' # created_by
))
# 保存 PLS 模型记录
pls_model.equipment_type = equipment_type # 设置装备类型
pls_feature_importance = self._get_feature_importance(pls_model)
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path, scaler_path,
r2_score, mae, rmse, feature_importance, training_data_size,
training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
""", (
f"{equipment_type}_{timestamp}_pls", # model_name
'pls', # model_type
equipment_type, # equipment_type
pls_path, # model_path
scaler_path, # scaler_path
float(pls_metrics['validation']['r2']), # r2_score
float(pls_metrics['validation']['mae']), # mae
float(pls_metrics['validation']['rmse']), # rmse
json.dumps(pls_feature_importance), # feature_importance
len(X_train), # training_data_size
'system' # created_by
))
conn.commit()
except Exception as e:
logger.error(f"Error saving models: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
def load_model(self, equipment_type, model_type='ml'):
"""
加载已训练的模型
"""
try:
logger.info(f"Loading {model_type} model for {equipment_type}")
# 从数据库获取激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 构建查询语句
if model_type == 'pls':
query = """
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type = 'pls'
AND is_active = TRUE
LIMIT 1
"""
params = (equipment_type,)
else:
query = """
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type != 'pls'
AND is_active = TRUE
LIMIT 1
"""
params = (equipment_type,)
# 记录查询信息
logger.info(f"Executing query: {query}")
logger.info(f"Query params: {params}")
cursor.execute(query, params)
model_record = cursor.fetchone()
# 记录查询结果
if model_record:
logger.info(f"Found model record: {model_record}")
else:
logger.warning(f"No active model found for type {model_type}")
return False
# 检查文件是否存在
logger.info(f"Checking model file: {model_record['model_path']}")
logger.info(f"Checking scaler file: {model_record['scaler_path']}")
if not os.path.exists(model_record['model_path']):
logger.error(f"Model file not found: {model_record['model_path']}")
raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
if not os.path.exists(model_record['scaler_path']):
logger.error(f"Scaler file not found: {model_record['scaler_path']}")
raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
# 加载模型文件
logger.info(f"Loading model from {model_record['model_path']}")
if model_type == 'pls':
self.best_model = joblib.load(model_record['model_path'])
logger.info("Loaded PLS model")
else:
if model_record['model_type'] == 'xgboost':
self.best_model = xgb.XGBRegressor()
self.best_model.load_model(model_record['model_path'])
logger.info("Loaded XGBoost model")
else:
self.best_model = joblib.load(model_record['model_path'])
logger.info(f"Loaded {model_record['model_type']} model")
# 加载标准化器
logger.info(f"Loading scalers from {model_record['scaler_path']}")
scalers = joblib.load(model_record['scaler_path'])
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
logger.info("Loaded scalers successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.error(f"Detailed traceback:", exc_info=True)
return False
def predict(self, features):
"""
使用加载的模型进行预测
"""
try:
if not self.best_model:
raise ValueError("No model loaded")
if not self.feature_scaler:
raise ValueError("Feature scaler not loaded")
if not self.target_scaler:
raise ValueError("Target scaler not loaded")
logger.info("Starting prediction")
logger.info(f"Input features shape: {features.shape}")
logger.info(f"Input features: \n{features}")
# 处理缺失值
features_filled = np.array(features, dtype=float)
features_filled[np.isnan(features_filled)] = 0
features_filled = np.nan_to_num(features_filled, 0)
logger.info(f"Filled features: \n{features_filled}")
# 标准化特征
X = self.feature_scaler.transform(features_filled)
logger.info(f"Transformed features shape: {X.shape}")
logger.info(f"Transformed features: \n{X}")
# 预测
y_pred_scaled = self.best_model.predict(X)
logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
logger.info(f"Scaled prediction: {y_pred_scaled}")
# <20><>标准化
y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
logger.info(f"Final prediction shape: {y_pred.shape}")
logger.info(f"Final prediction: {y_pred}")
# 记录标准化器的参数
logger.info("Target scaler params:")
logger.info(f"Mean: {self.target_scaler.mean_}")
logger.info(f"Scale: {self.target_scaler.scale_}")
return y_pred.ravel()
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise
def _get_feature_importance(self, model):
"""
获取特征重要性
"""
try:
if not model:
return {}
# 获取特征名称
feature_analyzer = FeatureAnalysis()
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
# 获取特<E58F96><E789B9><EFBFBD>重要性
if hasattr(model, 'feature_importances_'):
importances = model.feature_importances_
elif hasattr(model, 'coef_'):
if len(model.coef_.shape) > 1: # 如果是二维数组
importances = np.abs(model.coef_[0]) # 取第一行
else:
importances = np.abs(model.coef_)
else:
return {}
# 创建特征重要性字典
importance_dict = {}
for name, importance in zip(feature_names, importances):
importance_dict[name] = float(importance) # 确保转换为 Python 标量
# 按重要性降序排序
sorted_dict = dict(sorted(
importance_dict.items(),
key=lambda x: x[1],
reverse=True
))
# 过滤掉重要性为0的特征
return {k: v for k, v in sorted_dict.items() if v > 0}
except Exception as e:
logger.error(f"Error getting feature importance: {str(e)}")
return {}
def _calculate_confidence_interval(self, prediction, confidence=0.95):
"""
计算预测值的置信区间
"""
try:
# 使用预测值的20%作为标准差(增加不确定性)
std = abs(prediction) * 0.2
# 计算置信区间
from scipy import stats
interval = stats.norm.interval(confidence, loc=prediction, scale=std)
# 确保区间值为正数且合理
lower = max(1000, interval[0]) # 最小值设为1000元
upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
return [lower, upper]
except Exception as e:
logger.error(f"Error calculating confidence interval: {str(e)}")
# 如果计算失败返回基于20%的简单区间
lower = max(1000, prediction * 0.8)
upper = prediction * 1.2
return [lower, upper]
def get_model_type(self):
"""
获取当前模型的类型
"""
if isinstance(self.best_model, xgb.XGBRegressor):
return 'xgboost'
elif isinstance(self.best_model, lgb.LGBMRegressor):
return 'lightgbm'
elif isinstance(self.best_model, GradientBoostingRegressor):
return 'gbm'
elif isinstance(self.best_model, RandomForestRegressor):
return 'rf'
else:
return 'unknown'
def _get_pls_feature_importance(self):
"""
获取 PLS 模型的特征重要性
"""
try:
if not self.models['pls']:
return {}
# 获取特征名称
feature_analyzer = FeatureAnalysis()
feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
# 获取 PLS 模型的系数作为特征重要性
pls_model = self.models['pls']
if hasattr(pls_model, 'coef_'):
# 使用绝对值作为重要性指标
importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
else:
return {}
# 创建特征重要性字典
importance_dict = {}
for name, importance in zip(feature_names, importances):
importance_dict[name] = float(importance) # 确保转换为 Python 标量
# 按重要性降序排序
sorted_dict = dict(sorted(
importance_dict.items(),
key=lambda x: x[1],
reverse=True
))
# 过滤掉重要性为0的特征
return {k: v for k, v in sorted_dict.items() if v > 0}
except Exception as e:
logger.error(f"Error getting PLS feature importance: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return {}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,140 @@
-- 如果数据库已存在则删除
DROP DATABASE IF EXISTS equipment_cost_db;
-- 创建数据库
CREATE DATABASE equipment_cost_db
DEFAULT CHARACTER SET utf8mb4
COLLATE utf8mb4_unicode_ci;
-- 使用数据库
USE equipment_cost_db;
-- 装备基本信息表
CREATE TABLE equipment (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100), -- 名称
type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
manufacturer VARCHAR(100), -- 制造商
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 通用参数表
CREATE TABLE common_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
length_m FLOAT, -- 总长(m)
width_m FLOAT, -- 宽度(m)
height_m FLOAT, -- 高度(m)
weight_kg FLOAT, -- 重量(kg)
max_range_km FLOAT, -- 最大射程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 火箭炮特有参数表
CREATE TABLE rocket_artillery_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
firing_angle_horizontal FLOAT, -- 方向射界(度)
firing_angle_vertical FLOAT, -- 高低射界(度)
rocket_length_m FLOAT, -- 火箭弹长度(m)
rocket_diameter_mm FLOAT, -- 弹体直径(mm)
rocket_weight_kg FLOAT, -- 火箭弹重量(kg)
rate_of_fire FLOAT, -- 射速(发/分钟)
combat_weight_kg FLOAT, -- 战斗重量(kg)
speed_kmh FLOAT, -- 速度(km/h)
min_range_km FLOAT, -- 最小射程(km)
mobility_type VARCHAR(50), -- 行走方式
structure_layout VARCHAR(100), -- 结构布局
engine_model VARCHAR(100), -- 发动机型号
engine_params TEXT, -- 发动机参数
power_hp FLOAT, -- 功率(hp)
travel_range_km FLOAT, -- 行程(km)
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 巡飞弹特有参数表
CREATE TABLE loitering_munition_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
wingspan_m FLOAT, -- 翼展(m)
warhead_weight_kg FLOAT, -- 战斗部重量(kg)
max_speed_ms FLOAT, -- 最大速度(m/s)
cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
flight_time_min FLOAT, -- 巡飞时间(min)
warhead_type VARCHAR(50), -- 战斗部类型
launch_mode VARCHAR(50), -- 发射方式
folded_length_mm FLOAT, -- 折叠长度(mm)
folded_width_mm FLOAT, -- 折叠宽度(mm)
folded_height_mm FLOAT, -- 折叠高度(mm)
power_system VARCHAR(100), -- 动力装置
guidance_system VARCHAR(100), -- 制导体制
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 成本数据表
CREATE TABLE cost_data (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
actual_cost DECIMAL(15,2), -- 实际成本(元)
predicted_cost DECIMAL(15,2), -- 预测成本(元)
prediction_date TIMESTAMP, -- 预测日期
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 特殊参数表
CREATE TABLE custom_params (
id INT AUTO_INCREMENT PRIMARY KEY,
equipment_id INT,
param_name VARCHAR(100), -- 参数名称
param_value VARCHAR(255), -- 参数值
param_unit VARCHAR(50), -- 参数单位
description TEXT, -- 参数说明
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_equipment_type ON equipment(type);
CREATE INDEX idx_equipment_name ON equipment(name);
CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-- 数据集表
CREATE TABLE datasets (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(100) NOT NULL, -- 数据集名称
description TEXT, -- 数据集描述
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
-- 数据集-装备关联表
CREATE TABLE dataset_equipment (
dataset_id INT NOT NULL,
equipment_id INT NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
);
-- 训练模型表
CREATE TABLE trained_models (
id INT AUTO_INCREMENT PRIMARY KEY,
model_name VARCHAR(100) NOT NULL, -- 模型名称
model_type VARCHAR(50) NOT NULL, -- 模型类型
equipment_type VARCHAR(50) NOT NULL, -- 装备类型
model_path VARCHAR(255) NOT NULL, -- 模型文件路径
scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径
r2_score FLOAT, -- R²分数
mae FLOAT, -- 平均绝对误差
rmse FLOAT, -- 均方根误差
feature_importance JSON, -- 特征重要性
training_data_size INT, -- 训练数据量
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
created_by VARCHAR(50) -- 创建者
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 添加索引
CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX idx_model_active ON trained_models(is_active);

View File

@ -0,0 +1,191 @@
import requests
import json
import logging
from datetime import datetime
import os
import sys
# 添加项目根目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 配置基本日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/test_api.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def test_api_endpoints():
"""
测试 API 各个端点
"""
base_url = 'http://localhost:5001/api'
try:
# 1. 测试根路由
logger.info("\n1. 测试 API 根路由")
response = requests.get(f'{base_url}/')
print_response(response, "API 根路由")
# 2. 测试机器学习预测接口
logger.info("\n2. 测试机器学习预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post(
f'{base_url}/predict',
json=predict_data
)
print_response(response, "机器学习预测")
# 3. 测试 PLS 预测接口
logger.info("\n3. 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print_response(response, "PLS 预测")
# 4. 测试特征分析接口
logger.info("\n4. 测试特征分析接口")
analysis_data = {
"dataset_id": 1, # 假设数据集 ID 为 1
"equipment_type": "巡飞弹"
}
response = requests.post(
f'{base_url}/analyze-features',
json=analysis_data
)
print_response(response, "特征分析")
# 5. 测试机器学习模型训练接口
logger.info("\n5. 测试机器学习模型训练接口")
training_data = {
"type": "巡飞弹",
"train_dataset_id": 1, # 训练数据集 ID
"validation_dataset_id": 2, # 验证数据集 ID
"models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型
}
response = requests.post(
f'{base_url}/train',
json=training_data
)
print_response(response, "模型训练")
# 6. 测试数据集相关接口
logger.info("\n6. 测试数据集相关接口")
# 6.1 获取数据集列表
response = requests.get(f'{base_url}/datasets')
print_response(response, "获取数据集列表")
# 6.2 获取可用的装备列表
response = requests.get(f'{base_url}/data')
equipment_data = response.json()
# 获取巡飞弹类型的装备ID
available_equipment_ids = []
if 'loitering_munition' in equipment_data:
available_equipment_ids = [
item['id']
for item in equipment_data['loitering_munition']
if item['id'] is not None
][:3] # 取前3个可用的ID
if not available_equipment_ids:
logger.warning("没有找到可用的装备ID跳过创建数据集测试")
else:
# 6.3 创建新数据集
new_dataset = {
"name": "测试数据集",
"description": "用于测试的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"equipment_ids": available_equipment_ids
}
logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}")
response = requests.post(
f'{base_url}/datasets',
json=new_dataset
)
print_response(response, "创建数据集")
# 7. 测试模型相关接口
logger.info("\n7. 测试模型相关接口")
# 7.1 获取模型列表
response = requests.get(f'{base_url}/models')
print_response(response, "获取模型列表")
# 7.2 获取最新模型
response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型")
# 8. 测试多模型预测接口
logger.info("\n8. 测试多模型预测接口")
response = requests.post(
f'{base_url}/predict/all',
json=predict_data
)
print_response(response, "多模型预测")
logger.info("所有测试完成")
except requests.exceptions.RequestException as e:
logger.error(f"API 请求错误: {str(e)}")
except Exception as e:
logger.error(f"测试过程中出现错误: {str(e)}")
def print_response(response, endpoint_name):
"""
打印响应结果
"""
try:
logger.info(f"\n=== {endpoint_name} 测试结果 ===")
logger.info(f"状态码: {response.status_code}")
if response.status_code == 200:
result = response.json()
logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}")
else:
logger.error(f"错误响应:\n{response.text}")
except Exception as e:
logger.error(f"处理响应时出错: {str(e)}")
if __name__ == "__main__":
try:
# 确保日志目录存在
os.makedirs('logs', exist_ok=True)
logger.info(f"=== API 测试开始 - {datetime.now()} ===")
test_api_endpoints()
logger.info(f"=== API 测试结束 - {datetime.now()} ===")
except Exception as e:
logger.error(f"测试执行失败: {str(e)}")

View File

34
deploy/setup.md Normal file
View File

@ -0,0 +1,34 @@
# 安装说明
## 1. 系统要求
- Linux 服务器 (推荐 Ubuntu 20.04+)
- Python 3.8+
- MySQL 8.0+
## 2. 安装部署
```bash
# 解压部署包
tar -xzf equipment_cost_prediction.tar.gz
cd equipment_cost_prediction
# 运行安装脚本
bash scripts/install.sh
# 修改配置文件
vim config/.env
# 启动服务
bash scripts/start.sh
```
## 3. 验证部署
```bash
# 检查服务状态
curl http://localhost:5001/api/
# 检查日志
tail -f logs/api.log
```

663
docs/deploy/api.md Normal file
View File

@ -0,0 +1,663 @@
# 装备成本估算系统 API 文档
这个 API 文档提供了完整的接口说明,包括:
- 每个端点的详细描述
- 请求和响应的具体示例
- 清晰的参数格式要求
- 统一的错误处理说明
- 重要的注意事项
文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。
## 基本信息
- 基础URL: `http://localhost:5001/api`
- 版本: 1.0.0
- 响应格式: JSON
## API 端点列表
### 1. 获取 API 信息
获取 API 版本信息和可用端点列表。
- **URL**: `/`
- **方法**: `GET`
- **响应示例**:
json
{
"name": "装备成本估算系统 API",
"version": "1.0.0",
"endpoints": {
"predict": {
"url": "/api/predict",
"method": "POST",
"description": "成本预测"
},
"analyze-features": {
"url": "/api/analyze-features",
"method": "POST",
"description": "特征分析"
},
"train": {
"url": "/api/train",
"method": "POST",
"description": "模型训练"
},
"evaluate": {
"url": "/api/evaluate",
"method": "POST",
"description": "模型评估"
}
}
}
### 2. 单模型预测
使用当前激活的最优模型进行成本预测。
- **URL**: `/predict`
- **方法**: `POST`
- **请求体示例** (巡飞弹):
```json
{
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
```
- **响应示例**:
```json
{
"predicted_cost": 150000.0,
"model_info": {
"type": "xgboost",
"name": "巡飞弹_20241111_model",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"confidence_interval": {
"lower": 135000.0,
"upper": 165000.0
}
}
```
### 3. PLS 模型预测
使用 PLS 回归模型进行预测。
- **URL**: `/pls/predict`
- **方法**: `POST`
- **请求体**: 与单模型预测相同
- **响应示例**:
```json
{
"predicted_cost": 148000.0,
"confidence_interval": {
"lower": 133000.0,
"upper": 163000.0
}
}
```
### 4. 多模型预测
使用所有激活的模型进行预测并返回综合结果。
- **URL**: `/predict/all`
- **方法**: `POST`
- **请求体**: 与单模型预测相同
- **响应示例**:
```json
{
"individual_predictions": {
"xgboost": {
"predicted_cost": 150000.0,
"model_info": {
"name": "巡飞弹_xgboost_model",
"type": "xgboost",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"confidence_interval": {
"lower": 135000.0,
"upper": 165000.0
}
},
"pls": {
"predicted_cost": 148000.0,
"model_info": {
"name": "巡飞弹_pls_model",
"type": "pls",
"r2_score": 0.92,
"mae": 5500.0,
"rmse": 8000.0
},
"confidence_interval": {
"lower": 133000.0,
"upper": 163000.0
}
}
},
"ensemble_prediction": {
"predicted_cost": 149000.0,
"standard_deviation": 1414.21,
"confidence_interval": {
"lower": 146228.15,
"upper": 151771.85
}
}
}
```
### 5. 特征分析
分析数据集中特征的重要性和相关性。
- **URL**: `/analyze-features`
- **方法**: `POST`
- **请求体示例**:
```json
{
"dataset_id": 1,
"equipment_type": "巡飞弹"
}
```
- **响应示例**:
```json
{
"important_features": [
{
"name": "最大射程(km)",
"importance": 0.35
},
{
"name": "重量(kg)",
"importance": 0.25
}
],
"correlation_analysis": {
"features": ["最大射程(km)", "重量(kg)"],
"matrix": [[1.0, 0.8], [0.8, 1.0]]
}
}
```
### 6. 模型训练
训练新的模型。
- **URL**: `/train`
- **方法**: `POST`
- **请求体示例**:
```json
{
"type": "巡飞弹",
"train_dataset_id": 1,
"validation_dataset_id": 2,
"models": ["xgboost", "lightgbm", "rf"]
}
```
- **响应示例**:
```json
{
"metrics": {
"xgboost": {
"train": {
"r2": 0.95,
"mae": 5000.0,
"rmse": 7500.0
},
"validation": {
"r2": 0.92,
"mae": 5500.0,
"rmse": 8000.0
}
}
},
"best_model": {
"type": "xgboost",
"r2": 0.92,
"mae": 5500.0,
"rmse": 8000.0
}
}
```
### 7. 数据集管理
#### 7.1 获取数据集列表
- **URL**: `/datasets`
- **方法**: `GET`
- **响应示例**:
```json
[
{
"id": 1,
"name": "训练数据集",
"description": "用于训练的数据集",
"equipment_type": "巡飞弹",
"equipment_count": 10,
"equipment_names": ["设备1", "设备2"],
"purpose": "训练",
"created_at": "2024-11-11T10:00:00"
}
]
```
#### 7.2 获取数据集详情
- **URL**: `/datasets/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 1,
"name": "训练数据集",
"description": "用于训练的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"created_at": "2024-11-11T10:00:00",
"equipment": [
{
"id": 1,
"name": "设备1",
"type": "巡飞弹",
"manufacturer": "制造商1",
"actual_cost": 150000
}
],
"statistics": {
"equipment_count": 10,
"total_cost": 1500000,
"average_cost": 150000
}
}
```
#### 7.3 创建数据集
- **URL**: `/datasets`
- **方法**: `POST`
- **请求体示例**:
```json
{
"name": "测试数据集",
"description": "用于测试的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"equipment_ids": [1, 2, 3]
}
```
- **响应示例**:
```json
{
"id": 2,
"message": "数据集创建成功"
}
```
#### 7.4 更新数据集
- **URL**: `/datasets/{id}`
- **方法**: `PUT`
- **请求体示例**:
```json
{
"name": "更新后的数据集名称",
"description": "更新后的描述",
"equipment_type": "巡飞弹",
"purpose": "验证",
"equipment_ids": [1, 2, 3, 4]
}
```
- **响应示例**:
```json
{
"success": true,
"message": "数据集更新成功"
}
```
#### 7.5 删除数据集
- **URL**: `/datasets/{id}`
- **方法**: `DELETE`
- **描述**: 删除指定的数据集及其关联关系
- **响应示例**:
```json
{
"success": true,
"message": "数据集删除成功"
}
```
注意事项:
1. 数据集删除后不会删除关联的装备数据
2. 不能删除正在被模型使用的数据集
3. 更新数据集时会重新计算统计信息
4. 数据集的装备类型一旦创建后不能更改
### 8. 模型管理
#### 8.1 获取模型列表
- **URL**: `/models`
- **方法**: `GET`
- **响应示例**:
```json
[
{
"id": 1,
"model_name": "巡飞弹_xgboost_model",
"model_type": "xgboost",
"equipment_type": "巡飞弹",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0,
"is_active": true,
"training_date": "2024-11-11T10:00:00"
}
]
```
#### 8.2 获取最新模型
- **URL**: `/models/{equipment_type}/latest`
- **方法**: `GET`
- **响应示例**: 与模型列表的单个模型格式相同
#### 8.3 获取模型详情
- **URL**: `/models/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 1,
"model_name": "巡飞弹_xgboost_model",
"model_type": "xgboost",
"equipment_type": "巡飞弹",
"r2_score": 0.95,
"mae": 5000.0,
"rmse": 7500.0,
"is_active": true,
"training_date": "2024-11-11T10:00:00",
"feature_importance": {
"max_range_km": 0.35,
"weight_kg": 0.25,
"length_m": 0.20
},
"training_data_size": 100,
"created_by": "system"
}
```
#### 8.4 激活模型
- **URL**: `/models/{id}/activate`
- **方法**: `POST`
- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态
- **响应示例**:
```json
{
"success": true,
"message": "模型已激活"
}
```
#### 8.5 删除模型
- **URL**: `/models/{id}`
- **方法**: `DELETE`
- **描述**: 删除指定模型,包括模型文件和数据库记录
- **响应示例**:
```json
{
"success": true,
"message": "模型已删除"
}
```
注意事项:
1. 删除模型时会同时删除相关的文件和数据库记录
2. 不能删除当前正在使用(已激活)的模型
3. 激活模型时会自动取消同类型其他模型的激活状态
4. 模型详情包含了更多的训练相关信息,如特征重要性等
### 9. 数据管理
#### 9.1 获取装备数据列表
- **URL**: `/data`
- **方法**: `GET`
- **响应示例**:
```json
{
"rocket_artillery": [
{
"id": 1,
"name": "BM-21",
"type": "火箭炮",
"manufacturer": "俄罗斯",
"length_m": 7.35,
"width_m": 2.4,
"height_m": 3.1,
"weight_kg": 13700,
"max_range_km": 20.4,
"actual_cost": 800000
}
],
"loitering_munition": [
{
"id": 8,
"name": "Hero-120",
"type": "巡飞弹",
"manufacturer": "以色列",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"actual_cost": 150000
}
]
}
```
#### 9.2 获取装备详情
- **URL**: `/data/details/{id}`
- **方法**: `GET`
- **响应示例**:
```json
{
"id": 8,
"name": "Hero-120",
"type": "巡飞弹",
"manufacturer": "以色列",
"common_params": {
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40
},
"specific_params": {
"wingspan_m": 2.1,
"warhead_weight_kg": 3.5,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "箱式发射",
"power_system": "电动机",
"guidance_system": "GPS/INS"
},
"cost_data": {
"actual_cost": 150000,
"prediction_date": "2024-11-11T10:00:00",
"predicted_cost": 148000
},
"custom_params": [
{
"id": 1,
"param_name": "续航时间",
"param_value": "2小时",
"param_unit": "小时",
"description": "最大续航时间"
}
]
}
```
#### 9.3 更新装备数据
- **URL**: `/data/{id}`
- **方法**: `PUT`
- **请求体示例**:
```json
{
"name": "Hero-120",
"manufacturer": "以色列",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"wingspan_m": 2.1,
"warhead_weight_kg": 3.5,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"actual_cost": 150000,
"custom_params": [
{
"id": 1,
"param_value": "2.5小时"
}
]
}
```
- **响应示例**:
```json
{
"success": true,
"message": "装备数据更新成功"
}
```
#### 9.4 删除装备数据
- **URL**: `/data/{id}`
- **方法**: `DELETE`
- **响应示例**:
```json
{
"success": true,
"message": "装备数据删除成功"
}
```
#### 9.5 下载数据模板
- **URL**: `/data/template`
- **方法**: `GET`
- **描述**: 下载Excel格式的数据导入模板
- **响应**: Excel文件下载
#### 9.6 导入数据
- **URL**: `/data/import`
- **方法**: `POST`
- **请求体**:
- Content-Type: multipart/form-data
- 参数名: file
- 文件类型: .xlsx 或 .xls
- **响应示例**:
```json
{
"success": true,
"message": "数据导入成功",
"imported_count": {
"rocket_artillery": 3,
"loitering_munition": 5
}
}
```
注意事项:
1. 导入数据时必须使用系统提供的模板
2. 更新装备数据时会同时更新关联的参数表
3. 删除装备数据会同时删除相关的参数和成本数据
4. 导入的Excel文件大小不应超过10MB
5. 所有数值字段必须符合指定的单位和范围要求
6. 特殊参数的值必须包含单位信息
## 错误响应
所有接口在发生错误时都会返回以下格式的响应:
```json
{
"error": "错误描述信息"
}
```
## 注意事项
1. 所有数值参数必须大于0
2. 所有单位必须按照参数名称中指定的单位提供
3. 预测结果中的成本单位为元
4. 置信区间表示预测结果的95%置信水平范围
5. 所有请求和响应的编码均为 UTF-8

120
docs/deploy/deploy.md Normal file
View File

@ -0,0 +1,120 @@
# 装备成本估算系统部署指南
## 一、系统要求
### 1. 基础软件
- Linux 操作系统 (推荐 Ubuntu 20.04+)
- Python 3.8+ 及相关组件
```bash
sudo apt update
sudo apt install python3 python3-pip python3-venv
sudo apt install python3-dev build-essential
```
- Node.js 14+ 及 npm
```bash
# 使用 nvm 安装 Node.js
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
source ~/.bashrc
nvm install 14
nvm use 14
```
### 2. 数据库
- MySQL 8.0+
```bash
sudo apt install mysql-server mysql-client
sudo apt install libmysqlclient-dev
```
### 3. Python包依赖
```bash
# 科学计算相关
sudo apt install libatlas-base-dev # numpy依赖
sudo apt install libopenblas-dev # 线性代数库
sudo apt install liblapack-dev # 线性代数包
sudo apt install gfortran # Fortran编译器(scipy依赖)
# XML处理相关(用于Excel文件处理)
sudo apt install libxml2-dev
sudo apt install libxslt1-dev
```
## 二、部署运行
### 1. 安装服务
```bash
sh scripts/install.sh
```
### 2. 启动服务
```bash
sh scripts/start.sh
```
### 3. 停止服务
```bash
sh scripts/stop.sh
```
## 三、维护说明
### 1. 日志管理
```bash
# 后端日志
tail -f logs/api.log
# 数据库日志
tail -f /var/log/mysql/error.log
```
## 四、安全建议
1. 系统安全
- 使用防火墙限制端口访问
- 定期更新系统和依赖包
2. 数据安全
- 定期备份数据库
- 加密敏感信息
- 限制数据库远程访问
3. 访问控制
- 使用强密码
- 配置适当的文件权限
- 使用非root用户运行服务
## 五、监控方案
### 1. 系统监控
```bash
# 资源使用
top -b -n 1
df -h
free -m
# 服务状态
ps aux | grep gunicorn
ps aux | grep node
```
### 2. 应用监控
```bash
# API 响应时间
curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
# 错误日志
grep "ERROR" logs/api.log
```

250
docs/dev/deployment.md Normal file
View File

@ -0,0 +1,250 @@
# 装备成本估算系统部署指南
## 一、系统打包
### 1. 创建部署包结构
```bash
mkdir -p deploy/equipment_cost_system/{backend,frontend,docs}
```
### 2. 准备部署文件
#### 2.1 后端文件
```bash
cd deploy/equipment_cost_system/backend
mkdir -p {src,scripts,config,data,logs,models}
cp -r ../../../src/* src/
cp ../../../requirements.txt ./
cp ../../../.env.example config/.env.template
```
#### 2.2 前端文件
```bash
cd ../frontend
mkdir -p {src,public,dist}
cp -r ../../frontend/src/* src/
cp -r ../../frontend/public/* public/
cp ../../../frontend/package.json ./
cp ../../../frontend/vite.config.js ./
cp ../../../frontend/.env.production ./
```
#### 2.3 复制文档
```bash
cd ../docs
cp -r ../../docs/deploy/* ./
```
#### 2.4 创建部署脚本
```bash
touch scripts/{install.sh,start.sh,stop.sh}
```
### 3. 部署脚本内容
#### 3.1 安装脚本 (install.sh)
```bash
#!/bin/bash
# 安装 Python 依赖
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
# 创建必要的目录
mkdir -p logs
mkdir -p data
mkdir -p models
# 配置文件
if [ ! -f config/.env ]; then
cp config/.env.template config/.env
echo "请修改 config/.env 中的配置"
fi
# 初始化数据库
read -p "请输入MySQL root密码: " mysqlpass
mysql -u root -p$mysqlpass < src/schema.sql
# 设置权限
chmod +x scripts/*.sh
```
#### 3.2 启动脚本 (start.sh)
```bash
#!/bin/bash
# 激活虚拟环境
source venv/bin/activate
# 检查配置文件
if [ ! -f config/.env ]; then
echo "错误: 配置文件不存在"
exit 1
fi
# 启动服务
export $(cat config/.env | xargs)
gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" --daemon
echo "服务已启动,访问 http://localhost:5001"
```
#### 3.3 停止脚本 (stop.sh)
```bash
#!/bin/bash
# 查找并停止 gunicorn 进程
pkill -f gunicorn
echo "服务已停止"
```
### 4. 创建部署包
```bash
cd ..
tar -czf equipment_cost_system.tar.gz equipment_cost_system/
```
## 二、部署步骤
### 1. 系统要求
- Linux 服务器 (推荐 Ubuntu 20.04+)
- Python 3.8+
- MySQL 8.0+
### 2. 安装部署
```bash
# 解压部署包
tar -xzf equipment_cost_system.tar.gz
cd equipment_cost_system
# 运行安装脚本
bash scripts/install.sh
# 修改配置文件
vim config/.env
# 启动服务
bash scripts/start.sh
```
### 3. 验证部署
```bash
# 检查服务状态
curl http://localhost:5001/api/
# 检查日志
tail -f logs/api.log
```
## 三、维护说明
### 1. 日常维护
```bash
# 查看日志
tail -f logs/api.log
# 备份数据库
mysqldump -u root -p equipment_cost_db > backup/$(date +%Y%m%d).sql
# 清理旧日志
find logs/ -name "*.log" -mtime +30 -delete
```
### 2. 更新部署
```bash
# 停止服务
bash scripts/stop.sh
# 备份数据
cp -r data data_backup_$(date +%Y%m%d)
# 更新代码
# ... 更新相关文件 ...
# 重启服务
bash scripts/start.sh
```
### 3. 故障处理
```bash
# 检查服务状态
ps aux | grep gunicorn
# 检查数据库连接
mysql -u root -p -e "show databases;"
# 重启服务
bash scripts/stop.sh
bash scripts/start.sh
```
## 四、安全建议
1. 文件权限设置
```bash
# 设置适当的文件权限
chmod 755 scripts/*.sh
chmod 600 config/.env
chmod 700 logs models data
```
2. 数据库安全
- 使用强密码
- 限制数据库访问IP
- 定期备份数据
3. 服务器配置
- 配置防火墙规则
- 启用 SSL/TLS
- 定期更新系统
## 五、监控方案
### 1. 系统监控
```bash
# 检查CPU和内存使用
top -b -n 1
# 检查磁盘使用
df -h
# 检查网络连接
netstat -tunlp
```
### 2. 应用监控
```bash
# 检查API响应时间
curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
# 检查错误日志
grep "ERROR" logs/api.log
```
### 3. 告警设置
- 配置日志告警
- 设置资源使用阈值告警
- 配置服务可用性监控

View File

@ -16,11 +16,11 @@
2线性相关分析对于特征和标签皆为连续值的回归问题要检测二者的相关性最直接的做法就是求相关系数rxy本质是建立协方差矩阵分析数据和成本之间相关关系的类型和程度筛选出影响特征
3互信息 (mutual information) 用于特征选择,可以从两个角度进行解释:(1)、基于 KL 散度和 (2)、基于信息增益。
2. 数据一致性分析:对特征数据分层分组,计算组内一致性,目标是选择比较合适的一组数据,以此产生一个进行成本估算和分析的虚拟量.大部分的研究中报告的三个数据rwg、ICC(1)、ICC(2)要符合3个条件rwg>0.7、ICC(1)>0.05、ICC(2)>0.5
RWG值打分一致性
ICC1组内一致性
ICC2组间一致性。
RWG值打分一致性
ICC1组内一致性
ICC2组间一致性。
3. 回归模型:偏最小二乘回归(partial Least SquaresPLS)
4. 神经网络模型:采用 BP 网络
4. 神经网络模型:采用适用的神经网络模型
### 数据准备

164
docs/dev/run.md Normal file
View File

@ -0,0 +1,164 @@
# 装备成本估算系统运行说明
## 一、开发环境配置
### 1. 系统要求
- Linux/macOS/Windows
- Python 3.8+
- MySQL 8.0+
### 2. 安装依赖
```bash
# 创建并激活虚拟环境
python3 -m venv venv
source venv/bin/activate # Linux/macOS
# 或
.\venv\Scripts\activate # Windows
# 安装依赖包
pip install -r requirements.txt
```
## 二、初始化系统
### 1. 创建必要目录
```bash
mkdir -p {logs,data,models}
```
### 2. 配置数据库
```bash
# 执行数据库初始化脚本
mysql -u root -p < src/schema.sql
# [可选] 导入测试数据(仅用于开发环境)
mysql -u root -p equipment_cost_db < src/init_data.sql
```
### 3. 环境配置
创建 `.env` 文件:
```ini
MYSQL_HOST=localhost
MYSQL_USER=root
MYSQL_PASSWORD=123456
MYSQL_DATABASE=equipment_cost_db
```
## 三、启动服务
### 1. 开发模式
```bash
# 启动开发服务器
python run.py
```
### 2. 测试 API
```bash
# 运行 API 测试
python src/test_api.py
```
## 四、开发调试
### 1. 日志查看
```bash
# API 日志
tail -f logs/api.log
# 测试日志
tail -f logs/test_api.log
# 训练日志
tail -f logs/training.log
```
### 2. 数据库调试
```sql
-- 检查数据表
SHOW TABLES;
-- 查看示例数据
SELECT * FROM equipment LIMIT 5;
```
### 3. API 测试
```bash
# 测试 API 根路由
curl http://localhost:5001/api/
# 测试预测接口
curl -X POST http://localhost:5001/api/predict \
-H "Content-Type: application/json" \
-d '{
"type": "巡飞弹",
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40
}'
```
## 五、注意事项
1. 开发环境配置
- 使用虚拟环境隔离依赖
- 保持日志目录可写权限
- 定期清理日志文件
2. 数据库使用
- 使用 UTF-8 字符集
- 定期备份数据
- 避免直接修改生产数据
3. 代码调试
- 查看详细日志输出
- 使用测试数据验证功能
- 遵循代码规范
## 六、常见问题
1. 数据库连接错误
- 检查 MySQL 服务状态
- 验证数据库用户名密码
- 确认数据库字符集设置
2. API 访问问题
- 检查服务是否正常运行
- 验证请求格式是否正确
- 查看错误日志信息
3. 模型相关问题
- 确保训练数据完整性
- 检查模型文件权限
- 验证预测结果合理性
## 七、开发建议
1. 代码管理
- 使用版本控制
- 遵循项目结构
- 及时更新文档
2. 测试规范
- 运行完整测试套件
- 验证各个功能模块
- 记录测试结果
3. 安全注意
- 使用安全的数据库密码
- 避免敏感信息提交
- 保护测试数据安全
注:生产环境部署请参考 `deploy.md`

View File

@ -1,136 +0,0 @@
# Node.js 安装指南
## Windows 安装方法
### 1. 使用安装包
1. 访问 Node.js 官网 <https://nodejs.org/>
2. 下载 14.x LTS 版本安装包
3. 运行安装包,按提示完成安装
4. 验证安装:
```bash
node --version
npm --version
```
### 2. 使用 nvm-windows推荐
1. 下载 nvm-windows<https://github.com/coreybutler/nvm-windows/releases>
2. 安装 nvm-windows
3. 安装 Node.js
```bash
nvm install 14.21.3
nvm use 14.21.3
```
## Linux 安装方法
### 1. 使用 nvm推荐
```bash
# 安装 nvm
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
# 重新加载配置
source ~/.bashrc
# 安装 Node.js 14
nvm install 14
nvm use 14
```
### 2. 使用包管理器
#### Ubuntu/Debian
```bash
# 添加 NodeSource 仓库
curl -fsSL https://deb.nodesource.com/setup_14.x | sudo -E bash -
# 安装 Node.js
sudo apt-get install -y nodejs
```
#### CentOS/RHEL
```bash
# 添加 NodeSource 仓库
curl -fsSL https://rpm.nodesource.com/setup_14.x | sudo bash -
# 安装 Node.js
sudo yum install -y nodejs
```
## macOS 安装方法
### 1. 使用 Homebrew推荐
```bash
# 安装 Homebrew如果未安装
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
# 安装 Node.js 14
brew install node@14
# 添加环境变量
echo 'export PATH="/usr/local/opt/node@14/bin:$PATH"' >> ~/.zshrc
source ~/.zshrc
```
### 2. 使用 nvm
```bash
# 安装 nvm
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
# 重新加载配置
source ~/.zshrc
# 安装 Node.js 14
nvm install 14
nvm use 14
```
## 验证安装
安装完成后,运行以下命令验证:
```bash
# 检查 Node.js 版本
node --version # 应显示 v14.x.x
# 检查 npm 版本
npm --version # 应显示 6.x.x 或更高
```
## 常见问题
### 1. 权限问题
如果遇到权限错误,可以:
```bash
# Linux/macOS
sudo chown -R $USER /usr/local/lib/node_modules
```
### 2. 版本切换
如果需要在不同版本间切换:
```bash
# 使用 nvm
nvm list # 查看已安装版本
nvm use 14 # 切换到 14.x 版本
```
### 3. npm 配置
建议配置国内镜像源:
```bash
# 使用淘宝镜像
npm config set registry https://registry.npmmirror.com
```

View File

@ -1,163 +0,0 @@
# 系统运行说明
## 一、环境准备
### 1. 安装必要软件
```bash
# 安装 Python 3.8+
# 安装 MySQL 8.0+
# 安装 Node.js 14+
```
### 2. 安装 Python 依赖
```bash
pip install -r requirements.txt
```
### 3. 安装前端依赖
```bash
cd frontend
npm install
```
## 二、数据库配置
### 1. 创建数据库
```sql
CREATE DATABASE equipment_cost_db DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
```
### 2. 初始化数据库结构
```bash
# 执行数据库结构初始化脚本
mysql -u username -p equipment_cost_db < src/schema.sql
# 导入示例数据
mysql -u username -p equipment_cost_db < src/init_data.sql
# 导入真实数据
mysql -u username -p equipment_cost_db < src/real_data.sql
```
## 三、配置文件
### 1. 后端配置
创建 `config.py` 文件:
```python
# config.py
DATABASE_URI = "mysql+pymysql://username:password@localhost:3306/equipment_cost_db"
SECRET_KEY = "your-secret-key"
DEBUG = True
```
### 2. 前端配置
修改 `frontend/src/config.js`
```javascript
export const API_BASE_URL = 'http://localhost:5001/api';
```
## 四、启动系统
### 1. 启动后端服务
```bash
# 开发环境
python run.py # 服务将在 http://localhost:5001 启动
# 生产环境
gunicorn -w 4 -b 0.0.0.0:5001 run:app
```
### 2. 启动前端服务
```bash
# 开发环境
cd frontend
npm run serve # 前端将在 http://localhost:8080 启动
# 生产环境
npm run build
```
## 五、访问系统
- 后端API<http://localhost:5001/api>
- 前端界面:<http://localhost:8080>
## 六、常见问题
### 1. 数据库连接问题
- 检查 MySQL 服务是否启动
- 验证数据库用户名和密码
- 确认数据库端口是否正确
### 2. 模型训练
```bash
# 训练模型
python src/train_model.py
# 查看训练日志
tail -f logs/training.log
```
### 3. 系统监控
```bash
# 查看系统日志
tail -f logs/app.log
# 监控API请求
tail -f logs/access.log
```
## 七、开发调试
### 1. 后端调试
```bash
# 启动调试模式
python run.py --debug
# 运行测试
python -m pytest tests/
```
### 2. 前端调试
```bash
# 启动开发服务器
npm run serve
# 运行测试
npm run test
```
## 八、部署建议
### 1. 使用 Docker 部署
```bash
# 构建镜像
docker-compose build
# 启动服务
docker-compose up -d
```
### 2. 生产环境配置
- 使用 Nginx 作为反向代理
- 配置 SSL 证书
- 设置适当的防火墙规则
- 启用数据库备份

View File

@ -1,24 +1,29 @@
# frontend
## Project setup
```
```bash
npm install
```
### Compiles and hot-reloads for development
```
```bash
npm run serve
```
### Compiles and minifies for production
```
```bash
npm run build
```
### Lints and fixes files
```
```bash
npm run lint
```
### Customize configuration
See [Configuration Reference](https://cli.vuejs.org/config/).

View File

@ -153,31 +153,93 @@ const startAnalysis = async () => {
}
}
//
const resizeHandler = ref(null)
// resize
const createResizeHandler = () => {
const handler = () => {
try {
if (importanceChart.value && !importanceChart.value.isDisposed()) {
importanceChart.value.resize()
}
if (correlationChart.value && !correlationChart.value.isDisposed()) {
correlationChart.value.resize()
}
} catch (error) {
console.error('Error in resize handler:', error)
}
}
// 使
return debounce(handler, 200)
}
//
onMounted(() => {
// resize
resizeHandler.value = createResizeHandler()
window.addEventListener('resize', resizeHandler.value)
})
//
onUnmounted(() => {
//
if (resizeHandler.value) {
window.removeEventListener('resize', resizeHandler.value)
resizeHandler.value = null
}
//
try {
if (importanceChart.value) {
importanceChart.value.dispose()
importanceChart.value = null
}
if (correlationChart.value) {
correlationChart.value.dispose()
correlationChart.value = null
}
} catch (error) {
console.error('Error disposing charts:', error)
}
})
//
const renderCharts = () => {
console.log('Starting to render charts')
//
if (importanceChart.value) {
importanceChart.value.dispose()
}
if (correlationChart.value) {
correlationChart.value.dispose()
//
if (!analysisResult.value ||
!analysisResult.value.important_features ||
!analysisResult.value.correlation_analysis) {
console.log('Analysis result not ready')
return
}
// DOM
// DOM
if (!importanceChartRef.value || !correlationChartRef.value) {
console.log('Chart DOM elements not ready')
return
}
try {
//
if (importanceChart.value) {
importanceChart.value.dispose()
importanceChart.value = null
}
if (correlationChart.value) {
correlationChart.value.dispose()
correlationChart.value = null
}
//
importanceChart.value = echarts.init(importanceChartRef.value)
correlationChart.value = echarts.init(correlationChartRef.value)
//
importanceChart.value.setOption({
const importanceOption = {
title: { text: '特征重要性排序' },
tooltip: {},
xAxis: {
@ -189,12 +251,13 @@ const renderCharts = () => {
data: analysisResult.value.important_features.map(f => f.name)
},
series: [{
name: '重要性',
type: 'bar',
data: analysisResult.value.important_features.map(f => f.importance)
}]
})
}
correlationChart.value.setOption({
const correlationOption = {
title: { text: '特征相关性热力图' },
tooltip: {
position: 'top',
@ -233,6 +296,7 @@ const renderCharts = () => {
color: ['#cc3333', '#eeeeee', '#00007f']
},
series: [{
name: '相关性',
type: 'heatmap',
data: analysisResult.value.correlation_analysis.matrix,
label: {
@ -242,13 +306,11 @@ const renderCharts = () => {
}
}
}]
})
}
//
window.addEventListener('resize', () => {
importanceChart.value?.resize()
correlationChart.value?.resize()
})
//
importanceChart.value.setOption(importanceOption)
correlationChart.value.setOption(correlationOption)
console.log('Charts rendered successfully')
} catch (error) {
@ -256,16 +318,16 @@ const renderCharts = () => {
}
}
//
onMounted(() => {
//
})
//
onUnmounted(() => {
importanceChart.value?.dispose()
correlationChart.value?.dispose()
})
//
function debounce(fn, delay) {
let timer = null
return function (...args) {
if (timer) clearTimeout(timer)
timer = setTimeout(() => {
fn.apply(this, args)
}, delay)
}
}
</script>
<style lang="scss" scoped>

View File

@ -1,68 +0,0 @@
from flask import Flask, request, jsonify
from .model_trainer import ModelTrainer
from .cost_prediction import CostPredictor
from .feature_analysis import FeatureAnalysis
import pandas as pd
app = Flask(__name__)
@app.route('/api/predict', methods=['POST'])
def predict_cost():
"""
成本预测接口
"""
try:
data = request.get_json()
# 验证必要参数
required_params = [
'length_m', 'width_m', 'height_m', 'weight_standard_kg',
'weight_combat_kg', 'max_range_km', 'max_speed_ms'
]
for param in required_params:
if param not in data:
return jsonify({'error': f'Missing parameter: {param}'}), 400
predictor = CostPredictor()
result = predictor.predict(data)
return jsonify({
'predicted_cost': float(result['predicted_cost']),
'confidence_interval': {
'lower': float(result['confidence_intervals'][0]),
'upper': float(result['confidence_intervals'][1])
}
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/analyze-features', methods=['POST'])
def analyze_features():
"""
特征分析接口
"""
try:
data = request.get_json()
analyzer = FeatureAnalysis()
# 数据预处理
processed_data = analyzer.preprocess_features(pd.DataFrame(data))
# 特征重要性分析
important_features = analyzer.select_features(
processed_data,
data['cost']
)
return jsonify({
'important_features': important_features,
'correlation_analysis': analyzer.correlation_analysis(
processed_data,
data['cost']
).to_dict()
})
except Exception as e:
return jsonify({'error': str(e)}), 500

View File

@ -222,7 +222,7 @@ class CostPredictor:
def predict_pls(self, data):
"""
使用 PLS 模型预测成本
使用 PLS 型预<EFBFBD><EFBFBD><EFBFBD>成本
"""
try:
logger.info(f"Starting PLS prediction for {data.get('type')}")
@ -253,4 +253,90 @@ class CostPredictor:
except Exception as e:
logger.error(f"PLS prediction error: {str(e)}")
raise
def predict_all(self, data):
"""
使用所有可用模型进行预测
"""
try:
logger.info(f"Starting multi-model prediction for {data.get('type')}")
equipment_type = data.get('type')
results = {}
# 1. 获取所有激活的模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT id, model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
""", (equipment_type,))
active_models = cursor.fetchall()
if not active_models:
raise ValueError(f"No active models found for {equipment_type}")
# 2. 使用每个模型进行预测
trainer = ModelTrainer()
for model_info in active_models:
try:
# 加载特定模型
if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
logger.warning(f"Failed to load model: {model_info['model_name']}")
continue
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
X = np.array([[data.get(feature) for feature in features]])
# 预测
y_pred = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
# 保存结果
results[model_info['model_type']] = {
'predicted_cost': float(y_pred[0]),
'model_info': {
'name': model_info['model_name'],
'type': model_info['model_type'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
except Exception as e:
logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
continue
if not results:
raise ValueError("No successful predictions from any model")
# 3. 计算综合预测结果
all_predictions = [result['predicted_cost'] for result in results.values()]
ensemble_prediction = float(np.mean(all_predictions))
prediction_std = float(np.std(all_predictions))
# 4. 返回所有结果
return {
'individual_predictions': results,
'ensemble_prediction': {
'predicted_cost': ensemble_prediction,
'standard_deviation': prediction_std,
'confidence_interval': {
'lower': float(ensemble_prediction - 1.96 * prediction_std),
'upper': float(ensemble_prediction + 1.96 * prediction_std)
}
}
}
except Exception as e:
logger.error(f"Error in multi-model prediction: {str(e)}")
raise

View File

@ -1,3 +1,13 @@
/*
使
1.
2.
3.
*/
-- 插入装备基本信息
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
@ -139,16 +149,16 @@ INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(10, 1500000); -- 彩虹-4
-- 火箭炮数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('BM-21', '火箭炮', '俄罗斯', '面目标'),
('SR5', '火箭炮', '中国', '面目标和点目标'),
('HIMARS', '火箭炮', '美国', '战术目标'),
('LAR-160', '火箭炮', '以色列', '面目标'),
('T-122', '火箭炮', '土耳其', '面目标'),
('RM-70', '火箭炮', '捷克', '面目标'),
('ASTROS II', '火箭炮', '巴西', '面目标和点目标');
INSERT INTO equipment (name, type, manufacturer) VALUES
('BM-21', '火箭炮', '俄罗斯'),
('SR5', '火箭炮', '中国'),
('HIMARS', '火箭炮', '美国'),
('LAR-160', '火箭炮', '以色列'),
('T-122', '火箭炮', '土耳其'),
('RM-70', '火箭炮', '捷克'),
('ASTROS II', '火箭炮', '巴西');
-- 插入火箭炮参数
-- 火箭炮通用参数
INSERT INTO common_params (
equipment_id,
length_m,
@ -172,6 +182,7 @@ INSERT INTO common_params (
-- ASTROS II
(7, 8.0, 2.7, 3.1, 24500, 90);
-- 火箭炮特有参数
INSERT INTO rocket_artillery_params (
equipment_id,
firing_angle_horizontal,
@ -179,44 +190,43 @@ INSERT INTO rocket_artillery_params (
rocket_length_m,
rocket_diameter_mm,
rocket_weight_kg,
rate_of_fire
rate_of_fire,
combat_weight_kg,
speed_kmh,
min_range_km,
mobility_type,
structure_layout,
engine_model,
engine_params,
power_hp,
travel_range_km
) VALUES
-- BM-21
(1, 102, 55, 2.87, 122, 66.6, 40),
(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
-- SR5
(2, 110, 60, 4.1, 220, 150, 60),
(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
-- HIMARS
(3, 90, 65, 3.94, 227, 301, 6),
(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
-- LAR-160
(4, 100, 58, 3.3, 160, 110, 18),
(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
-- T-122
(5, 110, 65, 2.95, 122, 65.5, 40),
(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
-- RM-70
(6, 100, 50, 2.87, 122, 66.6, 40),
(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
-- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30);
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000); -- ASTROS II
(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-- 巡飞弹数据
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('Hero-120', '巡飞弹', '以色列', '装甲目标'),
('Switchblade 600', '巡飞弹', '美国', '装甲车辆'),
('Warmate', '巡飞弹', '波兰', '轻型装甲目标'),
('CH-901', '巡飞弹', '中国', '人员和轻型装甲车辆'),
('HAROP', '巡飞弹', '以色列', '防空系统'),
('Coyote', '巡飞弹', '美国', '无人机和巡飞弹'),
('WS-43', '巡飞弹', '中国', '固定目标和装甲目标');
INSERT INTO equipment (name, type, manufacturer) VALUES
('Hero-120', '巡飞弹', '以色列'),
('Switchblade 600', '巡飞弹', '美国'),
('Warmate', '巡飞弹', '波兰'),
('CH-901', '巡飞弹', '中国'),
('HAROP', '巡飞弹', '以色列'),
('Coyote', '巡飞弹', '美国'),
('WS-43', '巡飞弹', '中国');
-- 插入巡飞弹参数
-- 巡飞弹通用参数
INSERT INTO common_params (
equipment_id,
length_m,
@ -240,38 +250,70 @@ INSERT INTO common_params (
-- WS-43
(14, 1.8, 0.35, 0.35, 20, 60);
-- 巡飞弹特有参数
INSERT INTO loitering_munition_params (
equipment_id,
max_speed_kmh,
wingspan_m,
warhead_weight_kg,
max_speed_ms,
cruise_speed_kmh,
flight_time_min,
warhead_type,
launch_mode,
folded_length_mm,
folded_width_mm,
folded_height_mm
folded_height_mm,
power_system,
guidance_system
) VALUES
-- Hero-120
(8, 180, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230),
(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
-- Switchblade 600
(9, 185, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220),
(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
-- Warmate
(10, 150, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150),
(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
-- CH-901
(11, 160, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180),
(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
-- HAROP
(12, 185, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430),
(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
-- Coyote
(13, 150, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120),
(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
-- WS-43
(14, 170, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350);
(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-- 插入成本数据(示例成本)
INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-- 火箭炮
(1, 800000), -- BM-21
(2, 4500000), -- SR5
(3, 5500000), -- HIMARS
(4, 3500000), -- LAR-160
(5, 2800000), -- T-122
(6, 1500000), -- RM-70
(7, 4800000), -- ASTROS II
-- 巡飞弹
(8, 150000), -- Hero-120
(9, 180000), -- Switchblade 600
(10, 80000), -- Warmate
(11, 100000), -- CH-901
(12, 850000), -- HAROP
(13, 75000), -- Coyote
(14, 120000); -- WS-43
(14, 120000); -- WS-43
-- 创建初始数据集
INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-- 关联装备到数据集
INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
-- 火箭炮训练集
(1, 1), (1, 2), (1, 3), (1, 4),
-- 巡飞弹训练集
(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
-- 火箭炮验证集
(3, 5), (3, 6), (3, 7),
-- 巡飞弹验证集
(4, 13), (4, 14);

View File

@ -1,9 +0,0 @@
-- 火箭炮数据13种
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('BM-21', '火箭炮', '俄罗斯', '面目标'),
-- ... 其他12种火箭炮数据
-- 巡飞弹数据18种
INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
('Hero-120', '巡飞弹', '以色列', '装甲目标'),
-- ... 其他17种巡飞弹数据

View File

@ -2,7 +2,6 @@ from flask import Blueprint, request, jsonify, send_file
from .cost_prediction import CostPredictor
from .feature_analysis import FeatureAnalysis
import pandas as pd
import logging
from datetime import datetime
import numpy as np
import mysql.connector
@ -10,7 +9,6 @@ from sklearn.metrics import mean_absolute_error
from .create_template import create_excel_template
import json
import os
import time
from .data_preparation import DataPreparation
from .model_trainer import ModelTrainer
from .logger import setup_logger

View File

@ -1,3 +1,14 @@
-- 如果数据库已存在则删除
DROP DATABASE IF EXISTS equipment_cost_db;
-- 创建数据库
CREATE DATABASE equipment_cost_db
DEFAULT CHARACTER SET utf8mb4
COLLATE utf8mb4_unicode_ci;
-- 使用数据库
USE equipment_cost_db;
-- 装备基本信息表
CREATE TABLE equipment (
id INT AUTO_INCREMENT PRIMARY KEY,

View File

@ -1,192 +1,191 @@
import requests
import json
import logging
from datetime import datetime
import os
import sys
# 添加项目根目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 配置基本日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/test_api.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def test_api_endpoints():
"""
测试API各个端点
测试 API 各个端点
"""
base_url = 'http://localhost:5001/api'
# 1. 测试根路由
print("\n1. 测试 API 根路由")
response = requests.get(f'{base_url}/')
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 2. 测试机器学习预测接口
print("\n2. 测试机器学习预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 0.56,
"width_m": 0.15,
"height_m": 0.20,
"weight_kg": 2.72,
"max_range_km": 24,
"max_speed_kmh": 160.93,
"cruise_speed_kmh": 96.56,
"flight_time_min": 15,
"folded_length_mm": 560,
"folded_width_mm": 150,
"folded_height_mm": 200,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post(
f'{base_url}/predict',
json=predict_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 3. 测试 PLS 预测接口
print("\n3. 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 4. 测试特征分析接口
print("\n4. 测试特征分析接口")
analysis_data = {
"data": [{
try:
# 1. 测试根路由
logger.info("\n1. 测试 API 根路由")
response = requests.get(f'{base_url}/')
print_response(response, "API 根路由")
# 2. 测试机器学习预测接口
logger.info("\n2. 测试机器学习预测接口")
predict_data = {
"type": "巡飞弹",
"length_m": 0.56,
"width_m": 0.15,
"height_m": 0.20,
"weight_kg": 2.72,
"max_range_km": 24,
"max_speed_kmh": 160.93,
"cruise_speed_kmh": 96.56,
"flight_time_min": 15,
"folded_length_mm": 560,
"folded_width_mm": 150,
"folded_height_mm": 200
}],
"cost": [1000000]
}
response = requests.post(
f'{base_url}/analyze-features',
json=analysis_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 5. 测试机器学习模型训练接口
print("\n5. 测试机器学习模型训练接口")
training_data = {
"training_data": [
{
"type": "巡飞弹",
"length_m": 0.56,
"width_m": 0.15,
"height_m": 0.20,
"weight_kg": 2.72,
"max_range_km": 24,
"max_speed_kmh": 160.93,
"cruise_speed_kmh": 96.56,
"flight_time_min": 15,
"folded_length_mm": 560,
"folded_width_mm": 150,
"folded_height_mm": 200,
"cost": 1000000
},
{
"type": "巡飞弹",
"length_m": 0.58,
"width_m": 0.16,
"height_m": 0.21,
"weight_kg": 2.85,
"max_range_km": 26,
"max_speed_kmh": 170,
"cruise_speed_kmh": 100,
"flight_time_min": 16,
"folded_length_mm": 580,
"folded_width_mm": 160,
"folded_height_mm": 210,
"cost": 1100000
},
{
"type": "巡飞弹",
"length_m": 0.54,
"width_m": 0.14,
"height_m": 0.19,
"weight_kg": 2.60,
"max_range_km": 22,
"max_speed_kmh": 155,
"cruise_speed_kmh": 93,
"flight_time_min": 14,
"folded_length_mm": 540,
"folded_width_mm": 140,
"folded_height_mm": 190,
"cost": 900000
"length_m": 1.3,
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_ms": 50,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230,
"warhead_type": "破片杀伤战斗部",
"launch_mode": "凭自身动力起飞"
}
response = requests.post(
f'{base_url}/predict',
json=predict_data
)
print_response(response, "机器学习预测")
# 3. 测试 PLS 预测接口
logger.info("\n3. 测试 PLS 预测接口")
response = requests.post(
f'{base_url}/pls/predict',
json=predict_data
)
print_response(response, "PLS 预测")
# 4. 测试特征分析接口
logger.info("\n4. 测试特征分析接口")
analysis_data = {
"dataset_id": 1, # 假设数据集 ID 为 1
"equipment_type": "巡飞弹"
}
response = requests.post(
f'{base_url}/analyze-features',
json=analysis_data
)
print_response(response, "特征分析")
# 5. 测试机器学习模型训练接口
logger.info("\n5. 测试机器学习模型训练接口")
training_data = {
"type": "巡飞弹",
"train_dataset_id": 1, # 训练数据集 ID
"validation_dataset_id": 2, # 验证数据集 ID
"models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型
}
response = requests.post(
f'{base_url}/train',
json=training_data
)
print_response(response, "模型训练")
# 6. 测试数据集相关接口
logger.info("\n6. 测试数据集相关接口")
# 6.1 获取数据集列表
response = requests.get(f'{base_url}/datasets')
print_response(response, "获取数据集列表")
# 6.2 获取可用的装备列表
response = requests.get(f'{base_url}/data')
equipment_data = response.json()
# 获取巡飞弹类型的装备ID
available_equipment_ids = []
if 'loitering_munition' in equipment_data:
available_equipment_ids = [
item['id']
for item in equipment_data['loitering_munition']
if item['id'] is not None
][:3] # 取前3个可用的ID
if not available_equipment_ids:
logger.warning("没有找到可用的装备ID跳过创建数据集测试")
else:
# 6.3 创建新数据集
new_dataset = {
"name": "测试数据集",
"description": "用于测试的数据集",
"equipment_type": "巡飞弹",
"purpose": "训练",
"equipment_ids": available_equipment_ids
}
],
"equipment_type": "巡飞弹"
}
response = requests.post(
f'{base_url}/train',
json=training_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
# 6. 测试 PLS 模型训练接口
print("\n6. 测试 PLS 模型训练接口")
# 使用真实的训练数据
training_data = {
"training_data": [
{
"length_m": 1.3, # 哈比
"width_m": 0.23,
"height_m": 0.23,
"weight_kg": 12.5,
"max_range_km": 40,
"max_speed_kmh": 180,
"cruise_speed_kmh": 100,
"flight_time_min": 60,
"folded_length_mm": 1300,
"folded_width_mm": 230,
"folded_height_mm": 230
},
{
"length_m": 2.5, # HAROP
"width_m": 0.43,
"height_m": 0.43,
"weight_kg": 135,
"max_range_km": 1000,
"max_speed_kmh": 185,
"cruise_speed_kmh": 110,
"flight_time_min": 360,
"folded_length_mm": 2500,
"folded_width_mm": 430,
"folded_height_mm": 430
},
{
"length_m": 1.1, # Warmate
"width_m": 0.15,
"height_m": 0.15,
"weight_kg": 5.7,
"max_range_km": 15,
"max_speed_kmh": 150,
"cruise_speed_kmh": 90,
"flight_time_min": 30,
"folded_length_mm": 1100,
"folded_width_mm": 150,
"folded_height_mm": 150
}
],
"actual_costs": [150000, 850000, 80000] # 对应的实际成本
}
response = requests.post(
f'{base_url}/pls/train',
json=training_data
)
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}")
response = requests.post(
f'{base_url}/datasets',
json=new_dataset
)
print_response(response, "创建数据集")
# 7. 测试模型相关接口
logger.info("\n7. 测试模型相关接口")
# 7.1 获取模型列表
response = requests.get(f'{base_url}/models')
print_response(response, "获取模型列表")
# 7.2 获取最新模型
response = requests.get(f'{base_url}/models/巡飞弹/latest')
print_response(response, "获取最新模型")
# 8. 测试多模型预测接口
logger.info("\n8. 测试多模型预测接口")
response = requests.post(
f'{base_url}/predict/all',
json=predict_data
)
print_response(response, "多模型预测")
logger.info("所有测试完成")
except requests.exceptions.RequestException as e:
logger.error(f"API 请求错误: {str(e)}")
except Exception as e:
logger.error(f"测试过程中出现错误: {str(e)}")
def print_response(response, endpoint_name):
"""
打印响应结果
"""
try:
logger.info(f"\n=== {endpoint_name} 测试结果 ===")
logger.info(f"状态码: {response.status_code}")
if response.status_code == 200:
result = response.json()
logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}")
else:
logger.error(f"错误响应:\n{response.text}")
except Exception as e:
logger.error(f"处理响应时出错: {str(e)}")
if __name__ == "__main__":
try:
# 确保日志目录存在
os.makedirs('logs', exist_ok=True)
logger.info(f"=== API 测试开始 - {datetime.now()} ===")
test_api_endpoints()
logger.info(f"=== API 测试结束 - {datetime.now()} ===")
except Exception as e:
print(f"测试过程中出现错误: {str(e)}")
logger.error(f"测试执行失败: {str(e)}")