diff --git a/deploy/equipment_cost_system.tar.gz b/deploy/equipment_cost_system.tar.gz
deleted file mode 100644
index c906e30..0000000
Binary files a/deploy/equipment_cost_system.tar.gz and /dev/null differ
diff --git a/deploy/equipment_cost_system/config/.env.template b/deploy/equipment_cost_system/config/.env.template
deleted file mode 100644
index 195987f..0000000
--- a/deploy/equipment_cost_system/config/.env.template
+++ /dev/null
@@ -1,25 +0,0 @@
-# 数据库配置
-MYSQL_HOST=localhost
-MYSQL_USER=root
-MYSQL_PASSWORD=your_password_here
-MYSQL_DATABASE=equipment_cost_db
-
-# 服务配置
-PORT=5001
-DEBUG=False
-
-# 日志配置
-LOG_LEVEL=INFO
-LOG_DIR=logs
-
-# 模型配置
-MODEL_DIR=models
-DATA_DIR=data
-
-# 安全配置
-SECRET_KEY=your_secret_key_here
-ALLOWED_HOSTS=localhost,127.0.0.1
-
-# 其他配置
-UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
-ALLOWED_FILE_TYPES=.xlsx,.xls
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/docs/api.md b/deploy/equipment_cost_system/docs/api.md
deleted file mode 100644
index 000bd35..0000000
--- a/deploy/equipment_cost_system/docs/api.md
+++ /dev/null
@@ -1,663 +0,0 @@
-# 装备成本估算系统 API 文档
-
-这个 API 文档提供了完整的接口说明,包括:
-
-- 每个端点的详细描述
-- 请求和响应的具体示例
-- 清晰的参数格式要求
-- 统一的错误处理说明
-- 重要的注意事项
-
-文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。
-
-## 基本信息
-
-- 基础URL: `http://localhost:5001/api`
-- 版本: 1.0.0
-- 响应格式: JSON
-
-## API 端点列表
-
-### 1. 获取 API 信息
-
-获取 API 版本信息和可用端点列表。
-
-- **URL**: `/`
-- **方法**: `GET`
-- **响应示例**:
-json
-{
-"name": "装备成本估算系统 API",
-"version": "1.0.0",
-"endpoints": {
-"predict": {
-"url": "/api/predict",
-"method": "POST",
-"description": "成本预测"
-},
-"analyze-features": {
-"url": "/api/analyze-features",
-"method": "POST",
-"description": "特征分析"
-},
-"train": {
-"url": "/api/train",
-"method": "POST",
-"description": "模型训练"
-},
-"evaluate": {
-"url": "/api/evaluate",
-"method": "POST",
-"description": "模型评估"
-}
-}
-}
-
-### 2. 单模型预测
-
-使用当前激活的最优模型进行成本预测。
-
-- **URL**: `/predict`
-- **方法**: `POST`
-- **请求体示例** (巡飞弹):
-
-```json
-{
- "type": "巡飞弹",
- "length_m": 1.3,
- "width_m": 0.23,
- "height_m": 0.23,
- "weight_kg": 12.5,
- "max_range_km": 40,
- "max_speed_ms": 50,
- "cruise_speed_kmh": 100,
- "flight_time_min": 60,
- "folded_length_mm": 1300,
- "folded_width_mm": 230,
- "folded_height_mm": 230,
- "warhead_type": "破片杀伤战斗部",
- "launch_mode": "凭自身动力起飞"
-}
-```
-
-- **响应示例**:
-
-```json
-{
- "predicted_cost": 150000.0,
- "model_info": {
- "type": "xgboost",
- "name": "巡飞弹_20241111_model",
- "r2_score": 0.95,
- "mae": 5000.0,
- "rmse": 7500.0
- },
- "confidence_interval": {
- "lower": 135000.0,
- "upper": 165000.0
- }
-}
-```
-
-### 3. PLS 模型预测
-
-使用 PLS 回归模型进行预测。
-
-- **URL**: `/pls/predict`
-- **方法**: `POST`
-- **请求体**: 与单模型预测相同
-- **响应示例**:
-
-```json
-{
- "predicted_cost": 148000.0,
- "confidence_interval": {
- "lower": 133000.0,
- "upper": 163000.0
- }
-}
-```
-
-### 4. 多模型预测
-
-使用所有激活的模型进行预测并返回综合结果。
-
-- **URL**: `/predict/all`
-- **方法**: `POST`
-- **请求体**: 与单模型预测相同
-- **响应示例**:
-
-```json
-{
- "individual_predictions": {
- "xgboost": {
- "predicted_cost": 150000.0,
- "model_info": {
- "name": "巡飞弹_xgboost_model",
- "type": "xgboost",
- "r2_score": 0.95,
- "mae": 5000.0,
- "rmse": 7500.0
- },
- "confidence_interval": {
- "lower": 135000.0,
- "upper": 165000.0
- }
- },
- "pls": {
- "predicted_cost": 148000.0,
- "model_info": {
- "name": "巡飞弹_pls_model",
- "type": "pls",
- "r2_score": 0.92,
- "mae": 5500.0,
- "rmse": 8000.0
- },
- "confidence_interval": {
- "lower": 133000.0,
- "upper": 163000.0
- }
- }
- },
- "ensemble_prediction": {
- "predicted_cost": 149000.0,
- "standard_deviation": 1414.21,
- "confidence_interval": {
- "lower": 146228.15,
- "upper": 151771.85
- }
- }
-}
-```
-
-### 5. 特征分析
-
-分析数据集中特征的重要性和相关性。
-
-- **URL**: `/analyze-features`
-- **方法**: `POST`
-- **请求体示例**:
-
-```json
-{
- "dataset_id": 1,
- "equipment_type": "巡飞弹"
-}
-```
-
-- **响应示例**:
-
-```json
-{
- "important_features": [
- {
- "name": "最大射程(km)",
- "importance": 0.35
- },
- {
- "name": "重量(kg)",
- "importance": 0.25
- }
- ],
- "correlation_analysis": {
- "features": ["最大射程(km)", "重量(kg)"],
- "matrix": [[1.0, 0.8], [0.8, 1.0]]
- }
-}
-```
-
-### 6. 模型训练
-
-训练新的模型。
-
-- **URL**: `/train`
-- **方法**: `POST`
-- **请求体示例**:
-
-```json
-{
- "type": "巡飞弹",
- "train_dataset_id": 1,
- "validation_dataset_id": 2,
- "models": ["xgboost", "lightgbm", "rf"]
-}
-```
-
-- **响应示例**:
-
-```json
-{
- "metrics": {
- "xgboost": {
- "train": {
- "r2": 0.95,
- "mae": 5000.0,
- "rmse": 7500.0
- },
- "validation": {
- "r2": 0.92,
- "mae": 5500.0,
- "rmse": 8000.0
- }
- }
- },
- "best_model": {
- "type": "xgboost",
- "r2": 0.92,
- "mae": 5500.0,
- "rmse": 8000.0
- }
-}
-```
-
-### 7. 数据集管理
-
-#### 7.1 获取数据集列表
-
-- **URL**: `/datasets`
-- **方法**: `GET`
-- **响应示例**:
-
-```json
-[
- {
- "id": 1,
- "name": "训练数据集",
- "description": "用于训练的数据集",
- "equipment_type": "巡飞弹",
- "equipment_count": 10,
- "equipment_names": ["设备1", "设备2"],
- "purpose": "训练",
- "created_at": "2024-11-11T10:00:00"
- }
-]
-```
-
-#### 7.2 获取数据集详情
-
-- **URL**: `/datasets/{id}`
-- **方法**: `GET`
-- **响应示例**:
-
-```json
-{
- "id": 1,
- "name": "训练数据集",
- "description": "用于训练的数据集",
- "equipment_type": "巡飞弹",
- "purpose": "训练",
- "created_at": "2024-11-11T10:00:00",
- "equipment": [
- {
- "id": 1,
- "name": "设备1",
- "type": "巡飞弹",
- "manufacturer": "制造商1",
- "actual_cost": 150000
- }
- ],
- "statistics": {
- "equipment_count": 10,
- "total_cost": 1500000,
- "average_cost": 150000
- }
-}
-```
-
-#### 7.3 创建数据集
-
-- **URL**: `/datasets`
-- **方法**: `POST`
-- **请求体示例**:
-
-```json
-{
- "name": "测试数据集",
- "description": "用于测试的数据集",
- "equipment_type": "巡飞弹",
- "purpose": "训练",
- "equipment_ids": [1, 2, 3]
-}
-```
-
-- **响应示例**:
-
-```json
-{
- "id": 2,
- "message": "数据集创建成功"
-}
-```
-
-#### 7.4 更新数据集
-
-- **URL**: `/datasets/{id}`
-- **方法**: `PUT`
-- **请求体示例**:
-
-```json
-{
- "name": "更新后的数据集名称",
- "description": "更新后的描述",
- "equipment_type": "巡飞弹",
- "purpose": "验证",
- "equipment_ids": [1, 2, 3, 4]
-}
-```
-
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "数据集更新成功"
-}
-```
-
-#### 7.5 删除数据集
-
-- **URL**: `/datasets/{id}`
-- **方法**: `DELETE`
-- **描述**: 删除指定的数据集及其关联关系
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "数据集删除成功"
-}
-```
-
-注意事项:
-
-1. 数据集删除后不会删除关联的装备数据
-2. 不能删除正在被模型使用的数据集
-3. 更新数据集时会重新计算统计信息
-4. 数据集的装备类型一旦创建后不能更改
-
-### 8. 模型管理
-
-#### 8.1 获取模型列表
-
-- **URL**: `/models`
-- **方法**: `GET`
-- **响应示例**:
-
-```json
-[
- {
- "id": 1,
- "model_name": "巡飞弹_xgboost_model",
- "model_type": "xgboost",
- "equipment_type": "巡飞弹",
- "r2_score": 0.95,
- "mae": 5000.0,
- "rmse": 7500.0,
- "is_active": true,
- "training_date": "2024-11-11T10:00:00"
- }
-]
-```
-
-#### 8.2 获取最新模型
-
-- **URL**: `/models/{equipment_type}/latest`
-- **方法**: `GET`
-- **响应示例**: 与模型列表的单个模型格式相同
-
-#### 8.3 获取模型详情
-
-- **URL**: `/models/{id}`
-- **方法**: `GET`
-- **响应示例**:
-
-```json
-{
- "id": 1,
- "model_name": "巡飞弹_xgboost_model",
- "model_type": "xgboost",
- "equipment_type": "巡飞弹",
- "r2_score": 0.95,
- "mae": 5000.0,
- "rmse": 7500.0,
- "is_active": true,
- "training_date": "2024-11-11T10:00:00",
- "feature_importance": {
- "max_range_km": 0.35,
- "weight_kg": 0.25,
- "length_m": 0.20
- },
- "training_data_size": 100,
- "created_by": "system"
-}
-```
-
-#### 8.4 激活模型
-
-- **URL**: `/models/{id}/activate`
-- **方法**: `POST`
-- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "模型已激活"
-}
-```
-
-#### 8.5 删除模型
-
-- **URL**: `/models/{id}`
-- **方法**: `DELETE`
-- **描述**: 删除指定模型,包括模型文件和数据库记录
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "模型已删除"
-}
-```
-
-注意事项:
-
-1. 删除模型时会同时删除相关的文件和数据库记录
-2. 不能删除当前正在使用(已激活)的模型
-3. 激活模型时会自动取消同类型其他模型的激活状态
-4. 模型详情包含了更多的训练相关信息,如特征重要性等
-
-### 9. 数据管理
-
-#### 9.1 获取装备数据列表
-
-- **URL**: `/data`
-- **方法**: `GET`
-- **响应示例**:
-
-```json
-{
- "rocket_artillery": [
- {
- "id": 1,
- "name": "BM-21",
- "type": "火箭炮",
- "manufacturer": "俄罗斯",
- "length_m": 7.35,
- "width_m": 2.4,
- "height_m": 3.1,
- "weight_kg": 13700,
- "max_range_km": 20.4,
- "actual_cost": 800000
- }
- ],
- "loitering_munition": [
- {
- "id": 8,
- "name": "Hero-120",
- "type": "巡飞弹",
- "manufacturer": "以色列",
- "length_m": 1.3,
- "width_m": 0.23,
- "height_m": 0.23,
- "weight_kg": 12.5,
- "max_range_km": 40,
- "actual_cost": 150000
- }
- ]
-}
-```
-
-#### 9.2 获取装备详情
-
-- **URL**: `/data/details/{id}`
-- **方法**: `GET`
-- **响应示例**:
-
-```json
-{
- "id": 8,
- "name": "Hero-120",
- "type": "巡飞弹",
- "manufacturer": "以色列",
- "common_params": {
- "length_m": 1.3,
- "width_m": 0.23,
- "height_m": 0.23,
- "weight_kg": 12.5,
- "max_range_km": 40
- },
- "specific_params": {
- "wingspan_m": 2.1,
- "warhead_weight_kg": 3.5,
- "max_speed_ms": 50,
- "cruise_speed_kmh": 100,
- "flight_time_min": 60,
- "warhead_type": "破片杀伤战斗部",
- "launch_mode": "箱式发射",
- "power_system": "电动机",
- "guidance_system": "GPS/INS"
- },
- "cost_data": {
- "actual_cost": 150000,
- "prediction_date": "2024-11-11T10:00:00",
- "predicted_cost": 148000
- },
- "custom_params": [
- {
- "id": 1,
- "param_name": "续航时间",
- "param_value": "2小时",
- "param_unit": "小时",
- "description": "最大续航时间"
- }
- ]
-}
-```
-
-#### 9.3 更新装备数据
-
-- **URL**: `/data/{id}`
-- **方法**: `PUT`
-- **请求体示例**:
-
-```json
-{
- "name": "Hero-120",
- "manufacturer": "以色列",
- "length_m": 1.3,
- "width_m": 0.23,
- "height_m": 0.23,
- "weight_kg": 12.5,
- "max_range_km": 40,
- "wingspan_m": 2.1,
- "warhead_weight_kg": 3.5,
- "max_speed_ms": 50,
- "cruise_speed_kmh": 100,
- "flight_time_min": 60,
- "actual_cost": 150000,
- "custom_params": [
- {
- "id": 1,
- "param_value": "2.5小时"
- }
- ]
-}
-```
-
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "装备数据更新成功"
-}
-```
-
-#### 9.4 删除装备数据
-
-- **URL**: `/data/{id}`
-- **方法**: `DELETE`
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "装备数据删除成功"
-}
-```
-
-#### 9.5 下载数据模板
-
-- **URL**: `/data/template`
-- **方法**: `GET`
-- **描述**: 下载Excel格式的数据导入模板
-- **响应**: Excel文件下载
-
-#### 9.6 导入数据
-
-- **URL**: `/data/import`
-- **方法**: `POST`
-- **请求体**:
- - Content-Type: multipart/form-data
- - 参数名: file
- - 文件类型: .xlsx 或 .xls
-- **响应示例**:
-
-```json
-{
- "success": true,
- "message": "数据导入成功",
- "imported_count": {
- "rocket_artillery": 3,
- "loitering_munition": 5
- }
-}
-```
-
-注意事项:
-
-1. 导入数据时必须使用系统提供的模板
-2. 更新装备数据时会同时更新关联的参数表
-3. 删除装备数据会同时删除相关的参数和成本数据
-4. 导入的Excel文件大小不应超过10MB
-5. 所有数值字段必须符合指定的单位和范围要求
-6. 特殊参数的值必须包含单位信息
-
-## 错误响应
-
-所有接口在发生错误时都会返回以下格式的响应:
-
-```json
-{
- "error": "错误描述信息"
-}
-```
-
-## 注意事项
-
-1. 所有数值参数必须大于0
-2. 所有单位必须按照参数名称中指定的单位提供
-3. 预测结果中的成本单位为元
-4. 置信区间表示预测结果的95%置信水平范围
-5. 所有请求和响应的编码均为 UTF-8
diff --git a/deploy/equipment_cost_system/docs/deploy.md b/deploy/equipment_cost_system/docs/deploy.md
deleted file mode 100644
index 1f52359..0000000
--- a/deploy/equipment_cost_system/docs/deploy.md
+++ /dev/null
@@ -1,120 +0,0 @@
-# 装备成本估算系统部署指南
-
-## 一、系统要求
-
-### 1. 基础软件
-
-- Linux 操作系统 (推荐 Ubuntu 20.04+)
-- Python 3.8+ 及相关组件
-
- ```bash
- sudo apt update
- sudo apt install python3 python3-pip python3-venv
- sudo apt install python3-dev build-essential
- ```
-
-- Node.js 14+ 及 npm
-
- ```bash
- # 使用 nvm 安装 Node.js
- curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
- source ~/.bashrc
- nvm install 14
- nvm use 14
- ```
-
-### 2. 数据库
-
-- MySQL 8.0+
-
- ```bash
- sudo apt install mysql-server mysql-client
- sudo apt install libmysqlclient-dev
- ```
-
-### 3. Python包依赖
-
-```bash
-# 科学计算相关
-sudo apt install libatlas-base-dev # numpy依赖
-sudo apt install libopenblas-dev # 线性代数库
-sudo apt install liblapack-dev # 线性代数包
-sudo apt install gfortran # Fortran编译器(scipy依赖)
-
-# XML处理相关(用于Excel文件处理)
-sudo apt install libxml2-dev
-sudo apt install libxslt1-dev
-```
-
-## 二、部署运行
-
-### 1. 安装服务
-
-```bash
-sh scripts/install.sh
-```
-
-### 2. 启动服务
-
-```bash
-sh scripts/start.sh
-```
-
-### 3. 停止服务
-
-```bash
-sh scripts/stop.sh
-```
-
-## 三、维护说明
-
-### 1. 日志管理
-
-```bash
-# 后端日志
-tail -f logs/api.log
-
-# 数据库日志
-tail -f /var/log/mysql/error.log
-```
-
-## 四、安全建议
-
-1. 系统安全
- - 使用防火墙限制端口访问
- - 定期更新系统和依赖包
-
-2. 数据安全
- - 定期备份数据库
- - 加密敏感信息
- - 限制数据库远程访问
-
-3. 访问控制
- - 使用强密码
- - 配置适当的文件权限
- - 使用非root用户运行服务
-
-## 五、监控方案
-
-### 1. 系统监控
-
-```bash
-# 资源使用
-top -b -n 1
-df -h
-free -m
-
-# 服务状态
-ps aux | grep gunicorn
-ps aux | grep node
-```
-
-### 2. 应用监控
-
-```bash
-# API 响应时间
-curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
-
-# 错误日志
-grep "ERROR" logs/api.log
-```
diff --git a/deploy/equipment_cost_system/frontend/babel.config.js b/deploy/equipment_cost_system/frontend/babel.config.js
deleted file mode 100644
index e955840..0000000
--- a/deploy/equipment_cost_system/frontend/babel.config.js
+++ /dev/null
@@ -1,5 +0,0 @@
-module.exports = {
- presets: [
- '@vue/cli-plugin-babel/preset'
- ]
-}
diff --git a/deploy/equipment_cost_system/frontend/package.json b/deploy/equipment_cost_system/frontend/package.json
deleted file mode 100644
index 5e6409e..0000000
--- a/deploy/equipment_cost_system/frontend/package.json
+++ /dev/null
@@ -1,61 +0,0 @@
-{
- "name": "frontend",
- "version": "0.1.0",
- "private": true,
- "engines": {
- "node": ">=16",
- "npm": ">=8"
- },
- "scripts": {
- "serve": "vue-cli-service serve",
- "build": "vue-cli-service build",
- "lint": "vue-cli-service lint"
- },
- "dependencies": {
- "axios": "^1.6.0",
- "core-js": "^3.8.3",
- "echarts": "^5.4.3",
- "element-plus": "^2.4.2",
- "vue": "^3.2.13",
- "vue-router": "^4.0.3",
- "vuex": "^4.0.0"
- },
- "devDependencies": {
- "@babel/core": "^7.12.16",
- "@babel/eslint-parser": "^7.12.16",
- "@element-plus/icons-vue": "^2.3.1",
- "@vue/cli-plugin-babel": "~5.0.0",
- "@vue/cli-plugin-eslint": "~5.0.0",
- "@vue/cli-plugin-router": "~5.0.0",
- "@vue/cli-plugin-vuex": "~5.0.0",
- "@vue/cli-service": "~5.0.0",
- "@vue/compiler-sfc": "^3.2.13",
- "eslint": "^7.32.0",
- "eslint-plugin-vue": "^8.0.3",
- "sass": "^1.32.7",
- "sass-loader": "^12.0.0"
- },
- "eslintConfig": {
- "root": true,
- "env": {
- "node": true
- },
- "extends": [
- "plugin:vue/vue3-essential",
- "eslint:recommended"
- ],
- "parserOptions": {
- "parser": "@babel/eslint-parser"
- },
- "rules": {
- "vue/multi-word-component-names": "off",
- "no-unused-vars": "warn"
- }
- },
- "browserslist": [
- "> 1%",
- "last 2 versions",
- "not dead",
- "not ie 11"
- ]
-}
diff --git a/deploy/equipment_cost_system/frontend/public/favicon.ico b/deploy/equipment_cost_system/frontend/public/favicon.ico
deleted file mode 100644
index df36fcf..0000000
Binary files a/deploy/equipment_cost_system/frontend/public/favicon.ico and /dev/null differ
diff --git a/deploy/equipment_cost_system/frontend/public/index.html b/deploy/equipment_cost_system/frontend/public/index.html
deleted file mode 100644
index 3e5a139..0000000
--- a/deploy/equipment_cost_system/frontend/public/index.html
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
- <%= htmlWebpackPlugin.options.title %>
-
-
-
-
-
-
-
diff --git a/deploy/equipment_cost_system/frontend/src/App.vue b/deploy/equipment_cost_system/frontend/src/App.vue
deleted file mode 100644
index c4278c5..0000000
--- a/deploy/equipment_cost_system/frontend/src/App.vue
+++ /dev/null
@@ -1,42 +0,0 @@
-
-
-
-
- 首页
- 成本预测
- 特征分析
- 模型训练
- 模型管理
- 数据集管理
- 数据管理
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/deploy/equipment_cost_system/frontend/src/api/index.js b/deploy/equipment_cost_system/frontend/src/api/index.js
deleted file mode 100644
index 1090d1a..0000000
--- a/deploy/equipment_cost_system/frontend/src/api/index.js
+++ /dev/null
@@ -1,43 +0,0 @@
-import axios from 'axios'
-import { API_BASE_URL } from '@/config'
-
-const api = axios.create({
- baseURL: API_BASE_URL,
- timeout: 10000
-})
-
-export const predict = (data) => {
- return api.post('/predict', data)
-}
-
-export const analyzeFeatures = (data) => {
- return api.post('/analyze-features', data)
-}
-
-export const trainModel = (data) => {
- return api.post('/train', data)
-}
-
-export const evaluateModel = (data) => {
- return api.post('/evaluate', data)
-}
-
-export const importData = (formData) => {
- return api.post('/data/import', formData, {
- headers: {
- 'Content-Type': 'multipart/form-data'
- }
- })
-}
-
-export const getEquipmentData = () => {
- return api.get('/data')
-}
-
-export const updateEquipment = (id, data) => {
- return api.put(`/data/${id}`, data)
-}
-
-export const deleteEquipment = (id) => {
- return api.delete(`/data/${id}`)
-}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/assets/logo.png b/deploy/equipment_cost_system/frontend/src/assets/logo.png
deleted file mode 100644
index f3d2503..0000000
Binary files a/deploy/equipment_cost_system/frontend/src/assets/logo.png and /dev/null differ
diff --git a/deploy/equipment_cost_system/frontend/src/assets/styles/global.css b/deploy/equipment_cost_system/frontend/src/assets/styles/global.css
deleted file mode 100644
index a9e81e1..0000000
--- a/deploy/equipment_cost_system/frontend/src/assets/styles/global.css
+++ /dev/null
@@ -1,39 +0,0 @@
-/* 禁用 ResizeObserver 警告 */
-iframe[style*="position: fixed; top: 0px; left: 0px; width: 100%; height: 100%; border: none; z-index: 2147483647;"] {
- display: none !important;
-}
-
-.el-overlay {
- overflow: hidden !important;
-}
-
-/* 添加全局样式 */
-body {
- margin: 0;
- padding: 0;
- overflow-x: hidden;
-}
-
-/* 修复 Element Plus 的一些已知问题 */
-.el-dialog__wrapper {
- overflow: hidden !important;
-}
-
-.el-select-dropdown {
- overflow: hidden !important;
-}
-
-/* 禁用 ResizeObserver 相关的警告样式 */
-.resize-observer-warning {
- display: none !important;
-}
-
-/* 优化滚动行为 */
-* {
- scroll-behavior: smooth;
-}
-
-/* 防止页面抖动 */
-.el-main {
- overflow-x: hidden;
-}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/config.js b/deploy/equipment_cost_system/frontend/src/config.js
deleted file mode 100644
index 0a697ee..0000000
--- a/deploy/equipment_cost_system/frontend/src/config.js
+++ /dev/null
@@ -1,8 +0,0 @@
-export const API_BASE_URL = 'http://localhost:5001/api';
-
-export const DB_CONFIG = {
- host: 'localhost',
- user: 'root',
- password: '123456',
- database: 'equipment_cost_db'
-};
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/main.js b/deploy/equipment_cost_system/frontend/src/main.js
deleted file mode 100644
index 360a4a4..0000000
--- a/deploy/equipment_cost_system/frontend/src/main.js
+++ /dev/null
@@ -1,55 +0,0 @@
-import { createApp } from 'vue'
-import App from './App.vue'
-import router from './router'
-import store from './store'
-import ElementPlus from 'element-plus'
-import 'element-plus/dist/index.css'
-import './assets/styles/global.css'
-import * as ElementPlusIconsVue from '@element-plus/icons-vue'
-
-// 创建应用实例
-const app = createApp(App)
-
-// 注册插件
-app.use(ElementPlus, {
- size: 'default'
-})
-app.use(router)
-app.use(store)
-
-// 注册图标
-for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
- app.component(key, component)
-}
-
-// 全局错误处理
-app.config.errorHandler = (err) => {
- if (err.message && err.message.includes('ResizeObserver')) {
- return
- }
- console.error(err)
-}
-
-// 全局警告处理
-app.config.warnHandler = (msg, trace) => {
- if (msg.includes('ResizeObserver')) {
- return
- }
- console.warn(msg, trace)
-}
-
-// 挂载应用
-app.mount('#app')
-
-// 处理 ResizeObserver 错误
-const _ResizeObserver = window.ResizeObserver
-window.ResizeObserver = class ResizeObserver extends _ResizeObserver {
- constructor(callback) {
- super((entries, observer) => {
- requestAnimationFrame(() => {
- if (!Array.isArray(entries)) return
- callback(entries, observer)
- })
- })
- }
-}
diff --git a/deploy/equipment_cost_system/frontend/src/router/index.js b/deploy/equipment_cost_system/frontend/src/router/index.js
deleted file mode 100644
index 3dfcdee..0000000
--- a/deploy/equipment_cost_system/frontend/src/router/index.js
+++ /dev/null
@@ -1,52 +0,0 @@
-import { createRouter, createWebHistory } from 'vue-router'
-import HomePage from '@/views/HomePage.vue'
-import DataPage from '@/views/DataPage.vue'
-import DatasetPage from '@/views/DatasetPage.vue'
-import PredictPage from '@/views/PredictPage.vue'
-import AnalysisPage from '@/views/AnalysisPage.vue'
-import TrainingPage from '@/views/TrainingPage.vue'
-
-const routes = [
- {
- path: '/',
- name: 'Home',
- component: HomePage
- },
- {
- path: '/data',
- name: 'Data',
- component: DataPage
- },
- {
- path: '/datasets',
- name: 'Datasets',
- component: DatasetPage
- },
- {
- path: '/predict',
- name: 'Predict',
- component: PredictPage
- },
- {
- path: '/analysis',
- name: 'Analysis',
- component: AnalysisPage
- },
- {
- path: '/training',
- name: 'Training',
- component: TrainingPage
- },
- {
- path: '/models',
- name: 'Models',
- component: () => import('../views/ModelPage.vue')
- }
-]
-
-const router = createRouter({
- history: createWebHistory(),
- routes
-})
-
-export default router
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/store/index.js b/deploy/equipment_cost_system/frontend/src/store/index.js
deleted file mode 100644
index 0faaec7..0000000
--- a/deploy/equipment_cost_system/frontend/src/store/index.js
+++ /dev/null
@@ -1,14 +0,0 @@
-import { createStore } from 'vuex'
-
-export default createStore({
- state: {
- },
- getters: {
- },
- mutations: {
- },
- actions: {
- },
- modules: {
- }
-})
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js b/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js
deleted file mode 100644
index de54055..0000000
--- a/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js
+++ /dev/null
@@ -1,73 +0,0 @@
-// 处理 ResizeObserver 错误
-const resizeHandler = () => {
- const resizeObserverErrors = [
- "ResizeObserver loop completed with undelivered notifications.",
- "ResizeObserver loop limit exceeded",
- "ResizeObserver loop completed"
- ]
-
- // 添加全局错误处理
- const handler = (event) => {
- if (event && event.message && resizeObserverErrors.includes(event.message)) {
- event.stopPropagation()
- event.preventDefault()
- event.stopImmediatePropagation()
- return false
- }
- }
-
- // 添加多个事件监听器
- window.addEventListener('error', handler, true)
- window.addEventListener('unhandledrejection', handler, true)
-
- // 添加 ResizeObserver 错误处理
- if (window.ResizeObserver) {
- const resizeObserverPrototype = ResizeObserver.prototype
- const originalObserve = resizeObserverPrototype.observe
-
- resizeObserverPrototype.observe = function (...args) {
- try {
- return originalObserve.apply(this, args)
- } catch (e) {
- if (resizeObserverErrors.includes(e.message)) {
- return null
- }
- throw e
- }
- }
- }
-}
-
-export default {
- install: (app) => {
- resizeHandler()
-
- // 添加全局错误处理器
- app.config.errorHandler = (err, vm, info) => {
- const resizeObserverErrors = [
- "ResizeObserver loop completed with undelivered notifications.",
- "ResizeObserver loop limit exceeded",
- "ResizeObserver loop completed"
- ]
-
- if (err && err.message && resizeObserverErrors.includes(err.message)) {
- return
- }
- console.error('Vue Error:', err, info)
- }
-
- // 添加全局警告处理器
- app.config.warnHandler = (msg, vm, trace) => {
- const resizeObserverErrors = [
- "ResizeObserver loop completed with undelivered notifications.",
- "ResizeObserver loop limit exceeded",
- "ResizeObserver loop completed"
- ]
-
- if (resizeObserverErrors.includes(msg)) {
- return
- }
- console.warn('Vue Warning:', msg, trace)
- }
- }
-}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue b/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue
deleted file mode 100644
index f8f8ba1..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue
+++ /dev/null
@@ -1,304 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- {{ selectedDataset.name }}
- {{ selectedDataset.equipment_count }}
- {{ selectedDataset.description }}
-
-
-
-
-
-
- {{ analyzing ? '分析中...' : '开始分析' }}
-
-
-
-
-
-
分析结果
-
-
-
特征重要性
-
-
-
-
相关性分析
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/DataPage.vue b/deploy/equipment_cost_system/frontend/src/views/DataPage.vue
deleted file mode 100644
index 4689f70..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/DataPage.vue
+++ /dev/null
@@ -1,725 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 详情
-
-
- 编辑
-
-
- 删除
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 详情
-
-
- 编辑
-
-
- 删除
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 基本信息
-
-
-
- {{ selectedData?.name }}
- {{ selectedData?.manufacturer }}
-
-
- {{ formatNumber(selectedData?.length_m) }}
- {{ formatNumber(selectedData?.width_m) }}
- {{ formatNumber(selectedData?.height_m) }}
- {{ formatNumber(selectedData?.weight_kg) }}
- {{ formatNumber(selectedData?.max_range_km) }}
-
-
-
-
-
-
- 火箭炮参数
-
-
-
- {{ formatNumber(selectedData?.rocket_diameter_mm) }}
- {{ formatNumber(selectedData?.rate_of_fire) }}
- {{ formatNumber(selectedData?.rocket_length_m) }}
- {{ formatNumber(selectedData?.rocket_weight_kg) }}
- {{ formatNumber(selectedData?.firing_angle_horizontal) }}
- {{ formatNumber(selectedData?.firing_angle_vertical) }}
- {{ selectedData?.mobility_type }}
- {{ selectedData?.structure_layout }}
- {{ selectedData?.engine_model }}
- {{ selectedData?.engine_params }}
- {{ formatNumber(selectedData?.power_hp) }}
- {{ formatNumber(selectedData?.travel_range_km) }}
-
-
-
-
-
-
- 巡飞弹参数
-
-
-
- {{ formatNumber(selectedData?.max_speed_ms) }}
- {{ formatNumber(selectedData?.cruise_speed_kmh) }}
- {{ formatNumber(selectedData?.flight_time_min) }}
- {{ selectedData?.warhead_type }}
- {{ selectedData?.launch_mode }}
- {{ selectedData?.power_system }}
- {{ selectedData?.guidance_system }}
-
-
-
-
- {{ console.log('Rendering custom params:', selectedData.custom_params) }}
-
-
- 特殊参数
-
-
-
-
- {{ formatCustomParamValue(param) }}
-
-
-
-
-
-
-
- 成本信息
-
-
-
-
- {{ formatMoney(selectedData?.actual_cost) }}
-
-
- {{ formatMoney(selectedData?.predicted_cost) }}
-
-
- {{ selectedData?.cost_estimate_date || '-' }}
-
-
-
-
-
-
-
-
-
-
- 基本信息
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 火箭炮参数
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 巡飞弹参数
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 特殊参数
-
-
-
- {{ param.param_unit }}
-
-
-
-
- {{ param.param_unit }}
-
-
-
-
-
-
- 成本信息
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue b/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue
deleted file mode 100644
index 1079611..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue
+++ /dev/null
@@ -1,322 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- {{ formatDateTime(scope.row.created_at) }}
-
-
-
-
- {{ formatDateTime(scope.row.updated_at) }}
-
-
-
-
- 查看
- 编辑
- 删除
-
-
-
-
-
-
-
-
- {{ selectedDataset?.equipment_type }}
- {{ selectedDataset?.purpose }}
- {{ selectedDataset?.description }}
-
-
-
-
- 统计信息
-
- {{ selectedDataset?.statistics?.equipment_count || 0 }}
- {{ formatMoney(selectedDataset?.statistics?.average_cost) }}
- {{ formatMoney(selectedDataset?.statistics?.total_cost) }}
-
-
-
-
-
- 包含装备
-
-
-
-
-
- {{ formatMoney(scope.row.actual_cost) }}
-
-
-
-
- 详情
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- {{ formatMoney(scope.row.actual_cost) }}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/HomePage.vue b/deploy/equipment_cost_system/frontend/src/views/HomePage.vue
deleted file mode 100644
index 60c0676..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/HomePage.vue
+++ /dev/null
@@ -1,101 +0,0 @@
-
-
-
-
- 装备成本估算系统
-
-
-
-
-
- 成本预测
- 基于机器学习和 PLS 回归模型的成本预测
-
-
-
-
-
- 特征分析
- 分析参数对成本的影响
-
-
-
-
-
- 模型训练
- 训练和优化预测模型
-
-
-
-
-
- 模型管理
- 管理训练好的模型
-
-
-
-
-
- 数据集管理
- 管理训练和验证数据集
-
-
-
-
-
- 数据管理
- 管理装备数据和成本数据
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue b/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue
deleted file mode 100644
index e9632f9..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue
+++ /dev/null
@@ -1,279 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
- {{ getModelName(scope.row.model_type) }}
-
-
-
-
-
-
- {{ scope.row.r2_score.toFixed(4) }}
-
-
-
-
- {{ scope.row.mae.toFixed(2) }}
-
-
-
-
- {{ scope.row.rmse.toFixed(2) }}
-
-
-
-
- {{ formatDateTime(scope.row.training_date) }}
-
-
-
-
-
- {{ scope.row.is_active ? '使用中' : '未使用' }}
-
-
-
-
-
-
- 激活
-
-
- 详情
-
-
- 删除
-
-
-
-
-
-
-
-
-
- {{ selectedModel?.model_name }}
- {{ formatModelType(selectedModel?.model_type) }}
- {{ selectedModel?.equipment_type }}
- {{ selectedModel?.training_data_size }}
- {{ formatDateTime(selectedModel?.training_date) }}
-
-
- {{ selectedModel?.is_active ? '使用中' : '未使用' }}
-
-
-
-
-
-
- 评估指标
-
- {{ selectedModel?.r2_score.toFixed(4) }}
- {{ selectedModel?.mae.toFixed(2) }} 元
- {{ selectedModel?.rmse.toFixed(2) }} 元
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue b/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue
deleted file mode 100644
index a42fa9f..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue
+++ /dev/null
@@ -1,321 +0,0 @@
-
-
-
-
- 成本预测
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 预测成本
- 重置
-
-
-
-
-
-
预测结果
-
-
-
-
机器学习模型预测
-
-
- {{ getModelName(mlPrediction.model_info.type) }}
-
-
- {{ mlPrediction.model_info.name }}
-
-
- {{ formatMoney(mlPrediction.predicted_cost) }}
-
-
- {{ formatMoney(mlPrediction.confidence_interval.lower) }} ~
- {{ formatMoney(mlPrediction.confidence_interval.upper) }}
-
-
-
-
-
-
-
PLS回归预测
-
-
- {{ getModelName(plsPrediction.model_info.type) }}
-
-
- {{ plsPrediction.model_info.name }}
-
-
- {{ formatMoney(plsPrediction.predicted_cost) }}
-
-
- {{ formatMoney(plsPrediction.confidence_interval.lower) }} ~
- {{ formatMoney(plsPrediction.confidence_interval.upper) }}
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue b/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue
deleted file mode 100644
index edd20c1..0000000
--- a/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue
+++ /dev/null
@@ -1,370 +0,0 @@
-
-
-
-
- 模型训练
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- PLS回归
- XGBoost
- LightGBM
- GBM
- Random Forest
-
-
-
-
-
- 开始训练
-
-
-
-
-
-
-
训练结果
-
-
-
-
最佳模型: {{ getModelName(trainingResult.best_model.type) }}
-
R²分数: {{ formatNumber(trainingResult.best_model.r2) }}
-
MAE: {{ formatNumber(trainingResult.best_model.mae) }} 元
-
RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} 元
-
-
-
-
-
-
- {{ getModelName(scope.row.model) }}
-
-
-
-
-
-
-
- {{ formatNumber(scope.row.train.r2) }}
-
-
-
-
- {{ formatNumber(scope.row.train.mae) }}
-
-
-
-
- {{ formatNumber(scope.row.train.rmse) }}
-
-
-
-
-
-
-
-
- {{ formatNumber(scope.row.validation.r2) }}
-
-
-
-
- {{ formatNumber(scope.row.validation.mae) }}
-
-
-
-
- {{ formatNumber(scope.row.validation.rmse) }}
-
-
-
-
-
-
-
-
特征重要性
-
-
-
-
- {{ formatNumber(scope.row.importance) }}
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/vue.config.js b/deploy/equipment_cost_system/frontend/vue.config.js
deleted file mode 100644
index b2a2970..0000000
--- a/deploy/equipment_cost_system/frontend/vue.config.js
+++ /dev/null
@@ -1,5 +0,0 @@
-module.exports = {
- devServer: {
- port: 8080
- }
-}
diff --git a/deploy/equipment_cost_system/requirements.txt b/deploy/equipment_cost_system/requirements.txt
deleted file mode 100644
index 6eec783..0000000
--- a/deploy/equipment_cost_system/requirements.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-flask==2.0.1
-flask-cors==3.0.10
-sqlalchemy==1.4.23
-pymysql==1.0.2
-cryptography==3.4.7 # MySQL 8.0+ 认证需要
-numpy==1.21.2
-pandas==1.3.3
-scikit-learn==0.24.2
-tensorflow==2.6.0
-urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
-openpyxl==3.1.2 # 用于读取 .xlsx 文件
-xlrd==2.0.1 # 用于读取 .xls 文件
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/scripts/install.sh b/deploy/equipment_cost_system/scripts/install.sh
deleted file mode 100644
index a7a80d6..0000000
--- a/deploy/equipment_cost_system/scripts/install.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/bin/bash
-
-echo "开始安装装备成本估算系统..."
-
-# 检查Python版本
-python3 -V || {
- echo "错误: 需要 Python 3.8+"
- exit 1
-}
-
-# 检查Node.js版本
-node -v || {
- echo "错误: 需要 Node.js 14+"
- exit 1
-}
-
-# 创建必要的目录
-echo "创建系统目录..."
-mkdir -p {logs,data,models}
-
-# 安装后端依赖
-echo "安装后端依赖..."
-python3 -m venv venv
-source venv/bin/activate
-pip install -r requirements.txt
-
-# 安装前端依赖
-echo "安装前端依赖..."
-cd frontend
-npm install
-npm run build
-cd ..
-
-# 配置文件
-if [ ! -f config/.env ]; then
- echo "创建配置文件..."
- cp config/.env.template config/.env
- echo "请修改 config/.env 中的配置"
-fi
-
-# 初始化数据库
-echo "初始化数据库..."
-read -p "请输入MySQL root密码: " mysqlpass
-mysql -u root -p$mysqlpass < src/schema.sql
-
-# 导入测试数据(可选)
-read -p "是否导入测试数据?(y/n) " import_test_data
-if [ "$import_test_data" = "y" ]; then
- mysql -u root -p$mysqlpass equipment_cost_db < src/init_data.sql
-fi
-
-# 设置权限
-echo "设置文件权限..."
-chmod +x scripts/*.sh
-chmod 755 logs models data
-chmod 600 config/.env
-
-echo "安装完成!"
-echo "请检查并修改 config/.env 中的配置"
-echo "使用 ./scripts/start.sh 启动服务"
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/scripts/start.sh b/deploy/equipment_cost_system/scripts/start.sh
deleted file mode 100644
index 5b28ac4..0000000
--- a/deploy/equipment_cost_system/scripts/start.sh
+++ /dev/null
@@ -1,51 +0,0 @@
-#!/bin/bash
-
-echo "启动装备成本估算系统..."
-
-# 检查配置文件
-if [ ! -f config/.env ]; then
- echo "错误: 配置文件不存在"
- echo "请先运行 install.sh"
- exit 1
-fi
-
-# 检查日志目录
-if [ ! -d logs ]; then
- mkdir -p logs
-fi
-
-# 激活虚拟环境
-source venv/bin/activate
-
-# 导出环境变量
-export $(cat config/.env | xargs)
-
-# 启动后端服务
-echo "启动后端服务..."
-gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" \
- --daemon \
- --pid gunicorn.pid \
- --access-logfile logs/access.log \
- --error-logfile logs/error.log
-
-# 等待后端服务启动
-sleep 2
-
-# 检查后端服务是否成功启动
-if ! curl -s http://localhost:5001/api/ > /dev/null; then
- echo "错误: 后端服务启动失败"
- exit 1
-fi
-
-# 启动前端服务
-echo "启动前端服务..."
-cd frontend
-npm run serve -- --port 8080 --host 0.0.0.0 &
-echo $! > ../frontend.pid
-cd ..
-
-echo "服务已启动!"
-echo "后端API: http://localhost:5001"
-echo "前端界面: http://localhost:8080"
-echo "查看后端日志: tail -f logs/access.log"
-echo "查看前端日志: tail -f logs/frontend.log"
diff --git a/deploy/equipment_cost_system/scripts/stop.sh b/deploy/equipment_cost_system/scripts/stop.sh
deleted file mode 100644
index 085d432..0000000
--- a/deploy/equipment_cost_system/scripts/stop.sh
+++ /dev/null
@@ -1,44 +0,0 @@
-#!/bin/bash
-
-echo "停止装备成本估算系统..."
-
-# 停止后端服务
-if [ -f gunicorn.pid ]; then
- echo "停止后端服务..."
- pid=$(cat gunicorn.pid)
- kill -TERM $pid
- rm gunicorn.pid
- echo "后端服务已停止 (PID: $pid)"
-else
- echo "未找到后端服务PID文件"
- # 尝试查找并停止所有gunicorn进程
- pkill -f gunicorn
- echo "已尝试停止所有gunicorn进程"
-fi
-
-# 停止前端服务
-if [ -f frontend.pid ]; then
- echo "停止前端服务..."
- pid=$(cat frontend.pid)
- kill -TERM $pid
- rm frontend.pid
- echo "前端服务已停止 (PID: $pid)"
-else
- echo "未找到前端服务PID文件"
- # 尝试查找并停止前端服务进程
- pkill -f "vite preview"
- echo "已尝试停止所有前端服务进程"
-fi
-
-# 检查是否还有相关进程在运行
-if pgrep -f gunicorn > /dev/null; then
- echo "警告: 仍有gunicorn进程在运行"
- ps aux | grep gunicorn | grep -v grep
-fi
-
-if pgrep -f "vite preview" > /dev/null; then
- echo "警告: 仍有前端服务进程在运行"
- ps aux | grep "vite preview" | grep -v grep
-fi
-
-echo "所有服务已停止"
diff --git a/deploy/equipment_cost_system/src/__init__.py b/deploy/equipment_cost_system/src/__init__.py
deleted file mode 100644
index 497b4a4..0000000
--- a/deploy/equipment_cost_system/src/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# 这个文件可以为空,但必须存在
diff --git a/deploy/equipment_cost_system/src/app.py b/deploy/equipment_cost_system/src/app.py
deleted file mode 100644
index 037fb79..0000000
--- a/deploy/equipment_cost_system/src/app.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from flask import Flask
-from flask_cors import CORS
-from .routes import api_bp
-from .logger import setup_logger
-import os
-
-# 获取logger
-logger = setup_logger(__name__)
-
-def create_app():
- """
- 创建并配置Flask应用
- """
- try:
- # 创建必要的目录
- os.makedirs('logs', exist_ok=True)
- os.makedirs('data', exist_ok=True)
- os.makedirs('models', exist_ok=True)
-
- logger.info("=== Server Starting ===")
- logger.info("Initializing directories...")
-
- # 创建Flask应用
- app = Flask(__name__)
-
- # 配置CORS
- CORS(app)
- logger.info("CORS enabled")
-
- # 注册API蓝图
- app.register_blueprint(api_bp, url_prefix='/api')
- logger.info("API blueprint registered")
-
- # 配置数据库连接
- app.config['MYSQL_HOST'] = 'localhost'
- app.config['MYSQL_USER'] = 'root'
- app.config['MYSQL_PASSWORD'] = '123456'
- app.config['MYSQL_DB'] = 'equipment_cost_db'
-
- logger.info("Starting server...")
-
- return app
-
- except Exception as e:
- logger.error(f"Error creating app: {str(e)}")
- raise
-
-if __name__ == '__main__':
- app = create_app()
- app.run(host='localhost', port=5001)
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/cost_prediction.py b/deploy/equipment_cost_system/src/cost_prediction.py
deleted file mode 100644
index 23ccce5..0000000
--- a/deploy/equipment_cost_system/src/cost_prediction.py
+++ /dev/null
@@ -1,342 +0,0 @@
-import numpy as np
-from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
-from sklearn.preprocessing import StandardScaler
-import tensorflow as tf
-from scipy import stats
-import joblib
-import os
-import pandas as pd
-from .feature_analysis import FeatureAnalysis
-import logging
-from src.model_trainer import ModelTrainer
-from src.database import get_db_connection
-from .logger import setup_logger
-
-logger = setup_logger(__name__)
-
-class CostPredictor:
- def __init__(self):
- self.scaler_X = StandardScaler()
- self.scaler_y = StandardScaler()
- self.model = None
- self.feature_analyzer = FeatureAnalysis()
- self.equipment_type = None
-
- # 添加 TensorFlow 配置
- tf.config.run_functions_eagerly(False) # 启用图执行模式
-
- # 创建预测函数
- @tf.function(reduce_retracing=True, jit_compile=True)
- def predict_fn(x):
- return self.model(x, training=False)
-
- self._predict_fn = predict_fn
-
- self.load_model()
-
- def load_model(self):
- """
- 加载预训练型和标准化器
- """
- try:
- model_dir = 'models'
- os.makedirs(model_dir, exist_ok=True)
-
- # 创建默认模型
- self._create_default_model()
-
- # 创建预测函数
- @tf.function(reduce_retracing=True, jit_compile=True)
- def predict_fn(x):
- return self.model(x, training=False)
-
- self._predict_fn = predict_fn
-
- except Exception as e:
- logging.error(f"Error loading model: {str(e)}")
- self._create_default_model()
-
- def _create_default_model(self):
- """
- 创建默认模型并进行初始化训练
- """
- # 创建输入层
- inputs = tf.keras.Input(shape=(11,))
-
- # 创建隐藏层
- x = tf.keras.layers.Dense(64, activation='relu')(inputs)
- x = tf.keras.layers.Dense(32, activation='relu')(x)
-
- # 创建输出层
- outputs = tf.keras.layers.Dense(1)(x)
-
- # 创建模型
- self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
-
- # 编译模型
- self.model.compile(
- optimizer='adam',
- loss=tf.keras.losses.mean_squared_error,
- metrics=[tf.keras.metrics.mean_absolute_error]
- )
-
- # 创建示例数据
- example_data = pd.DataFrame({
- 'length_m': [7.35, 10.2],
- 'width_m': [2.4, 2.8],
- 'height_m': [3.1, 3.2],
- 'weight_kg': [13700, 28500],
- 'max_range_km': [20.4, 70],
- 'firing_angle_horizontal': [102, 110],
- 'firing_angle_vertical': [55, 60],
- 'rocket_length_m': [2.87, 4.1],
- 'rocket_diameter_mm': [122, 220],
- 'rocket_weight_kg': [66.6, 150],
- 'rate_of_fire': [40, 60]
- })
-
- # 训练标准化器
- self.scaler_X.fit(example_data)
- self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
-
- # 设置默认装备类型
- self.equipment_type = '火箭炮'
-
- def _create_example_data(self):
- """
- 创建示例数据来训练标准化器
- """
- # 火箭炮示例数据
- rocket_data = pd.DataFrame({
- 'length_m': [7.35, 10.2],
- 'width_m': [2.4, 2.8],
- 'height_m': [3.1, 3.2],
- 'weight_kg': [13700, 28500],
- 'max_range_km': [20.4, 70],
- 'firing_angle_horizontal': [102, 110],
- 'firing_angle_vertical': [55, 60],
- 'rocket_length_m': [2.87, 4.1],
- 'rocket_diameter_mm': [122, 220],
- 'rocket_weight_kg': [66.6, 150],
- 'rate_of_fire': [40, 60]
- })
-
- # 巡飞弹示例数据
- missile_data = pd.DataFrame({
- 'length_m': [1.3, 2.5],
- 'width_m': [0.23, 0.6],
- 'height_m': [0.23, 0.6],
- 'weight_kg': [12.5, 135],
- 'max_range_km': [40, 250],
- 'max_speed_kmh': [180, 185],
- 'cruise_speed_kmh': [100, 110],
- 'flight_time_min': [60, 120],
- 'folded_length_mm': [1300, 2500],
- 'folded_width_mm': [230, 600],
- 'folded_height_mm': [230, 600]
- })
-
- # 训练标准化器
- self.scaler_X.fit(rocket_data) # 使用火箭炮数据
- self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
-
- # 设置默认装备类型
- self.equipment_type = '火箭炮'
-
- def predict(self, data):
- """
- 使用训练好的最优模型进行预测
- """
- try:
- logger.info(f"Starting prediction for {data.get('type')}")
- equipment_type = data.get('type')
-
- # 加载已训练的最优模型
- trainer = ModelTrainer()
- if not trainer.load_model(equipment_type):
- raise ValueError(f"No trained model found for {equipment_type}")
-
- # 准备特征数据
- features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- X = np.array([[data.get(feature) for feature in features]])
-
- # 预测
- y_pred = trainer.predict(X)
-
- # 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
-
- # 获取模型类型
- model_type = trainer.get_model_type()
-
- return {
- 'predicted_cost': float(y_pred[0]),
- 'model_type': model_type, # 返回使用的模型类型
- 'confidence_interval': {
- 'lower': float(confidence_interval[0]),
- 'upper': float(confidence_interval[1])
- }
- }
-
- except Exception as e:
- logger.error(f"Prediction error: {str(e)}")
- raise
-
- def _calculate_confidence_interval(self, prediction, confidence=0.95):
- """
- 计算预测值的置信区间
- """
- try:
- # 使用预测值的20%作为标准差(增加不确定性)
- std = abs(prediction) * 0.2
-
- # 计算置信区间
- from scipy import stats
- interval = stats.norm.interval(confidence, loc=prediction, scale=std)
-
- # 确保区间值为正数且合理
- lower = max(1000, interval[0]) # 最小值设为1000元
- upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
-
- logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
-
- return [lower, upper]
-
- except Exception as e:
- logging.error(f"Error calculating confidence interval: {str(e)}")
- # 如果计算失败,返回基于20%的简单区间
- lower = max(1000, prediction * 0.8)
- upper = prediction * 1.2
- return [lower, upper]
-
- def evaluate(self, y_true, y_pred):
- """
- 模型评估
- """
- return {
- 'mae': float(mean_absolute_error(y_true, y_pred)),
- 'mse': float(mean_squared_error(y_true, y_pred)),
- 'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
- 'r2': float(r2_score(y_true, y_pred))
- }
-
- def predict_pls(self, data):
- """
- 使用 PLS ���型预测成本
- """
- try:
- logger.info(f"Starting PLS prediction for {data.get('type')}")
- equipment_type = data.get('type')
-
- # 加载 PLS 模型
- trainer = ModelTrainer()
- if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
- raise ValueError(f"No trained PLS model found for {equipment_type}")
-
- # 准备特征数据
- features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- X = np.array([[data.get(feature) for feature in features]])
-
- # 预测
- y_pred = trainer.predict(X)
-
- # 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
-
- return {
- 'predicted_cost': float(y_pred[0]),
- 'confidence_interval': {
- 'lower': float(confidence_interval[0]),
- 'upper': float(confidence_interval[1])
- }
- }
-
- except Exception as e:
- logger.error(f"PLS prediction error: {str(e)}")
- raise
-
- def predict_all(self, data):
- """
- 使用所有可用模型进行预测
- """
- try:
- logger.info(f"Starting multi-model prediction for {data.get('type')}")
- equipment_type = data.get('type')
- results = {}
-
- # 1. 获取所有激活的模型
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT id, model_type, model_name, r2_score, mae, rmse
- FROM trained_models
- WHERE equipment_type = %s AND is_active = TRUE
- """, (equipment_type,))
- active_models = cursor.fetchall()
-
- if not active_models:
- raise ValueError(f"No active models found for {equipment_type}")
-
- # 2. 使用每个模型进行预测
- trainer = ModelTrainer()
- for model_info in active_models:
- try:
- # 加载特定模型
- if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
- logger.warning(f"Failed to load model: {model_info['model_name']}")
- continue
-
- # 准备特征数据
- features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- X = np.array([[data.get(feature) for feature in features]])
-
- # 预测
- y_pred = trainer.predict(X)
-
- # 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
-
- # 保存结果
- results[model_info['model_type']] = {
- 'predicted_cost': float(y_pred[0]),
- 'model_info': {
- 'name': model_info['model_name'],
- 'type': model_info['model_type'],
- 'r2_score': float(model_info['r2_score']),
- 'mae': float(model_info['mae']),
- 'rmse': float(model_info['rmse'])
- },
- 'confidence_interval': {
- 'lower': float(confidence_interval[0]),
- 'upper': float(confidence_interval[1])
- }
- }
-
- except Exception as e:
- logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
- continue
-
- if not results:
- raise ValueError("No successful predictions from any model")
-
- # 3. 计算综合预测结果
- all_predictions = [result['predicted_cost'] for result in results.values()]
- ensemble_prediction = float(np.mean(all_predictions))
- prediction_std = float(np.std(all_predictions))
-
- # 4. 返回所有结果
- return {
- 'individual_predictions': results,
- 'ensemble_prediction': {
- 'predicted_cost': ensemble_prediction,
- 'standard_deviation': prediction_std,
- 'confidence_interval': {
- 'lower': float(ensemble_prediction - 1.96 * prediction_std),
- 'upper': float(ensemble_prediction + 1.96 * prediction_std)
- }
- }
- }
-
- except Exception as e:
- logger.error(f"Error in multi-model prediction: {str(e)}")
- raise
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/create_template.py b/deploy/equipment_cost_system/src/create_template.py
deleted file mode 100644
index 0f7d545..0000000
--- a/deploy/equipment_cost_system/src/create_template.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import pandas as pd
-import openpyxl
-from openpyxl.styles import PatternFill, Font, Alignment
-from openpyxl.worksheet.datavalidation import DataValidation
-import os
-from .logger import setup_logger
-
-logger = setup_logger(__name__)
-
-def create_excel_template():
- """
- 创建数据模板
- """
- try:
- # 确保data目录存在
- data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
- os.makedirs(data_dir, exist_ok=True)
-
- # 创建完整的文件路径
- template_path = os.path.join(data_dir, 'equipment_data_template.xlsx')
-
- # 创建 Excel 写入器
- writer = pd.ExcelWriter(template_path, engine='openpyxl')
-
- # 火箭炮基本参数表
- rocket_artillery_columns = [
- '名称', '类型', '制造商', '口径_mm',
- '反射管数量', '乘员数', '总长_m',
- '宽度_m', '高度_m', '重量_kg',
- '战斗重_kg', '速度_km/h', '最大射程_km',
- '最小射程_km', '方向射界_度', '高低射界_度',
- '火箭弹长度_m', '火箭弹重量_kg',
- '火箭弹最大速度_m/s', '射速_发',
- '战斗部重量_kg', '行走方式',
- '结构布局', '发动机型号', '发动机参数',
- '功率_hp', '行程_km', '成本_元'
- ]
-
- # 巡飞弹基本参数表
- loitering_munition_columns = [
- '名称', '类型', '制造商', '目标类型',
- '弹长_m', '弹径_mm', '翼展_m',
- '重量_kg', '战斗部重量_kg',
- '最大射程_km', '最大速度_m/s',
- '巡航速度_kmh', '巡飞时间_min',
- '战斗部类型', '发射方式',
- '折叠长度_mm', '折叠宽度_mm',
- '折叠高度_mm', '动力装置',
- '制导体制', '成本_元'
- ]
-
- # 特殊参数表
- special_params_columns = [
- '装备名称', # 关联字段
- '参数名称',
- '参数值',
- '参数单位',
- '参数说明'
- ]
-
- # 创建工作表
- pd.DataFrame(columns=rocket_artillery_columns).to_excel(
- writer, sheet_name='火箭炮', index=False
- )
- pd.DataFrame(columns=loitering_munition_columns).to_excel(
- writer, sheet_name='巡飞弹', index=False
- )
- pd.DataFrame(columns=special_params_columns).to_excel(
- writer, sheet_name='特殊参数', index=False
- )
-
- # 获取工作簿
- workbook = writer.book
-
- # 设置火箭炮工作表格式
- rocket_sheet = workbook['火箭炮']
- for col in range(1, len(rocket_artillery_columns) + 1):
- cell = rocket_sheet.cell(row=1, column=col)
- cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
- cell.font = Font(bold=True)
- cell.alignment = Alignment(horizontal='center')
-
- # 设置巡飞弹工作表格式
- missile_sheet = workbook['巡飞弹']
- for col in range(1, len(loitering_munition_columns) + 1):
- cell = missile_sheet.cell(row=1, column=col)
- cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
- cell.font = Font(bold=True)
- cell.alignment = Alignment(horizontal='center')
-
- # 设置特殊参数工作表格式
- special_sheet = workbook['特殊参数']
- for col in range(1, len(special_params_columns) + 1):
- cell = special_sheet.cell(row=1, column=col)
- cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
- cell.font = Font(bold=True)
- cell.alignment = Alignment(horizontal='center')
-
- # 添加数据验证
- for sheet in [rocket_sheet, missile_sheet]:
- # 数值验证
- number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0")
- number_validation.error = "请输入大于0的数值"
- number_validation.errorTitle = "输入错误"
-
- # 应用到相应列
- for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
- number_validation.add(f"{col}2:{col}1000")
-
- sheet.add_data_validation(number_validation)
-
- # 添加说明
- rocket_sheet['AD1'] = "填写说明:"
- rocket_sheet['AD2'] = "1. 所有数值必须大于0"
- rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写"
- rocket_sheet['AD4'] = "3. 成本单位为元"
-
- missile_sheet['V1'] = "填写说明:"
- missile_sheet['V2'] = "1. 所有数值必须大于0"
- missile_sheet['V3'] = "2. 单位必须按照表头要求填写"
- missile_sheet['V4'] = "3. 成本单位为元"
-
- special_sheet['G1'] = "填写说明:"
- special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致"
- special_sheet['G3'] = "2. 参数值必须包含单位"
- special_sheet['G4'] = "3. 参数说明应简明扼要"
-
- # 调整列宽
- for sheet in [rocket_sheet, missile_sheet, special_sheet]:
- for col in sheet.columns:
- max_length = 0
- column = col[0].column_letter
- for cell in col:
- try:
- if len(str(cell.value)) > max_length:
- max_length = len(str(cell.value))
- except:
- pass
- adjusted_width = (max_length + 2)
- sheet.column_dimensions[column].width = adjusted_width
-
- # 保存文件
- writer.close()
-
- return template_path
-
- except Exception as e:
- raise Exception(f"创建模板文件失败: {str(e)}")
-
-if __name__ == "__main__":
- try:
- template_path = create_excel_template()
- print(f"模板文件已创建: {template_path}")
- except Exception as e:
- print(f"错误: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/data_preparation.py b/deploy/equipment_cost_system/src/data_preparation.py
deleted file mode 100644
index 6380647..0000000
--- a/deploy/equipment_cost_system/src/data_preparation.py
+++ /dev/null
@@ -1,233 +0,0 @@
-from sklearn.preprocessing import StandardScaler
-from datetime import datetime
-import os
-import joblib
-import pandas as pd
-import numpy as np
-from src.feature_analysis import FeatureAnalysis
-from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
-from xgboost import XGBRegressor
-from lightgbm import LGBMRegressor
-from sklearn.model_selection import cross_val_score, LeaveOneOut
-import json
-import logging
-from src.database.db_connection import get_db_connection
-from sklearn.metrics import mean_absolute_error, mean_squared_error
-from .logger import setup_logger
-
-logger = setup_logger(__name__)
-
-class DataPreparation:
- def __init__(self):
- self.feature_analyzer = FeatureAnalysis()
- self.feature_scaler = StandardScaler()
- self.target_scaler = StandardScaler() # 添加目标值标准化器
-
- def prepare_training_data(self, equipment_data, equipment_type):
- """
- 准备训练数据
- """
- try:
- logger.info(f"Preparing training data for {equipment_type}")
- logger.info(f"Raw data size: {len(equipment_data)}")
-
- # 如果输入已经是 numpy 数组,直接返回
- if isinstance(equipment_data, np.ndarray):
- X = equipment_data
- logger.info(f"Input is already numpy array with shape: {X.shape}")
-
- # 处理无效值
- X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
-
- return {
- 'X': X,
- 'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
- 'feature_scaler': self.feature_scaler,
- 'target_scaler': self.target_scaler
- }
-
- # 从原始数据中提取特征和目标值
- feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
- features = []
- targets = []
-
- for item in equipment_data:
- # 提取特征值
- feature_values = []
- for name in feature_names:
- value = item.get(name)
- try:
- feature_values.append(float(value) if value is not None else 0.0)
- except (ValueError, TypeError):
- feature_values.append(0.0)
- features.append(feature_values)
-
- # 提取目标值(成本)
- try:
- cost = float(item['actual_cost'])
- if cost > 0: # 只使用正数成本值
- targets.append(cost)
- else:
- logger.warning(f"Skipping non-positive cost value: {cost}")
- except (ValueError, TypeError, KeyError):
- logger.error(f"Invalid cost value: {item.get('actual_cost')}")
- continue
-
- # 转换为numpy数组
- X = np.array(features, dtype=float)
- y = np.array(targets, dtype=float)
-
- # 记录原始数据范围
- logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
- logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
-
- # 标准化特征和目标值
- X_scaled = self.feature_scaler.fit_transform(X)
- y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
-
- # 记录标准化后的数据范围
- logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
- logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
-
- # 记录标准化器参数
- logger.info("Feature scaler params:")
- logger.info(f"Mean: {self.feature_scaler.mean_}")
- logger.info(f"Scale: {self.feature_scaler.scale_}")
-
- logger.info("Target scaler params:")
- logger.info(f"Mean: {self.target_scaler.mean_}")
- logger.info(f"Scale: {self.target_scaler.scale_}")
-
- return {
- 'X': X_scaled,
- 'y': y_scaled,
- 'feature_names': feature_names,
- 'feature_scaler': self.feature_scaler,
- 'target_scaler': self.target_scaler
- }
-
- except Exception as e:
- logger.error(f"Error in data preparation: {str(e)}")
- raise Exception(f"Training error: {str(e)}")
-
- def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None):
- """
- 准备验证数据
- """
- try:
- logger.info(f"Preparing validation data for {equipment_type}")
- logger.info(f"Raw validation data size: {len(validation_data)}")
-
- # 如果输入已经是 numpy 数组,直接使用
- if isinstance(validation_data, np.ndarray):
- X = validation_data
- logger.info(f"Input is already numpy array with shape: {X.shape}")
-
- # 处理无效值
- X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
-
- # 使用训练数据的标准化器
- if scalers and 'feature_scaler' in scalers:
- X_scaled = scalers['feature_scaler'].transform(X)
- else:
- X_scaled = X
-
- logger.info(f"Preprocessed data shape: {X_scaled.shape}")
- logger.info(f"Validation features shape: {X_scaled.shape}")
- logger.info(f"Validation features type: {X_scaled.dtype}")
-
- return {
- 'X': X_scaled,
- 'y': None # 验证数据可能没有标签
- }
-
- # 从原始数据中提取特征和目标值
- if not feature_names:
- feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
-
- features = []
- targets = []
-
- for item in validation_data:
- # 提取特征值
- feature_values = []
- for name in feature_names:
- value = item.get(name)
- try:
- feature_values.append(float(value) if value is not None else 0.0)
- except (ValueError, TypeError):
- feature_values.append(0.0)
- features.append(feature_values)
-
- # 提取目标值(成本)并验证范围
- try:
- cost = float(item['actual_cost'])
- logger.info(f"Raw cost value: {cost}")
- if cost > 0: # 只使用正数成本值
- targets.append(cost)
- else:
- logger.warning(f"Skipping non-positive cost value: {cost}")
- except (ValueError, TypeError):
- logger.error(f"Invalid cost value: {item.get('actual_cost')}")
- continue
-
- # 转换为numpy数组
- X = np.array(features, dtype=float)
- y = np.array(targets, dtype=float)
-
- # 记录数据范围
- logger.info(f"Features range: min={X.min()}, max={X.max()}")
- logger.info(f"Targets range: min={y.min()}, max={y.max()}")
-
- # 处理无效值
- X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
-
- # 使用训练数据的标准化器
- if scalers and 'feature_scaler' in scalers:
- X_scaled = scalers['feature_scaler'].transform(X)
- if 'target_scaler' in scalers:
- y_scaled = scalers['target_scaler'].transform(y.reshape(-1, 1)).ravel()
- else:
- y_scaled = y
- else:
- X_scaled = X
- y_scaled = y
-
- logger.info(f"Preprocessed data shape: {X_scaled.shape}")
- logger.info(f"Validation features shape: {X_scaled.shape}")
- logger.info(f"Validation features type: {X_scaled.dtype}")
-
- # 记录标准化后的数据范围
- logger.info(f"Scaled validation X range: min={X_scaled.min()}, max={X_scaled.max()}")
- logger.info(f"Scaled validation y range: min={y_scaled.min()}, max={y_scaled.max()}")
-
- # 确保特征维度一致
- if not feature_names:
- feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
-
- logger.info(f"Expected features: {len(feature_names)}")
- logger.info(f"Actual features: {X_scaled.shape[1]}")
-
- if X_scaled.shape[1] != len(feature_names):
- raise ValueError(f"Feature dimension mismatch: expected {len(feature_names)}, got {X_scaled.shape[1]}")
-
- return {
- 'X': X_scaled,
- 'y': y_scaled # 返回标准化后的成本值
- }
-
- except Exception as e:
- logger.error(f"Error in validation data preparation: {str(e)}")
- logger.error(f"Feature names: {feature_names}")
- logger.error(f"Equipment type: {equipment_type}")
- raise Exception(f"Validation error: {str(e)}")
-
- def calculate_derived_features(self, data, equipment_type):
- """
- 计算衍生特征
- """
- try:
- return self.feature_analyzer.calculate_derived_features(data, equipment_type)
- except Exception as e:
- logger.error(f"Error calculating derived features: {str(e)}")
- raise Exception(f"Feature calculation error: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/database/__init__.py b/deploy/equipment_cost_system/src/database/__init__.py
deleted file mode 100644
index 6df49c9..0000000
--- a/deploy/equipment_cost_system/src/database/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .db_connection import get_db_connection
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/database/db_connection.py b/deploy/equipment_cost_system/src/database/db_connection.py
deleted file mode 100644
index 361a4d2..0000000
--- a/deploy/equipment_cost_system/src/database/db_connection.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import mysql.connector
-from mysql.connector import Error
-from contextlib import contextmanager
-import os
-from dotenv import load_dotenv
-from ..logger import setup_logger
-
-# 获取logger
-logger = setup_logger(__name__)
-
-# 加载环境变量
-load_dotenv()
-
-@contextmanager
-def get_db_connection():
- """
- 数据库连接上下文管理器
- """
- connection = None
- try:
- connection = mysql.connector.connect(
- host=os.getenv('MYSQL_HOST', 'localhost'),
- user=os.getenv('MYSQL_USER', 'root'),
- password=os.getenv('MYSQL_PASSWORD', '123456'),
- database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
- )
- logger.info("Database connection established")
- yield connection
-
- except Error as e:
- logger.error(f"Error connecting to MySQL: {str(e)}")
- raise
-
- finally:
- if connection and connection.is_connected():
- connection.close()
- logger.info("Database connection closed")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/feature_analysis.py b/deploy/equipment_cost_system/src/feature_analysis.py
deleted file mode 100644
index 5b87972..0000000
--- a/deploy/equipment_cost_system/src/feature_analysis.py
+++ /dev/null
@@ -1,269 +0,0 @@
-import numpy as np
-import pandas as pd
-from scipy import stats
-from sklearn.preprocessing import StandardScaler
-from sklearn.ensemble import RandomForestRegressor
-from sklearn.metrics import r2_score
-import logging
-from .logger import setup_logger
-
-logger = setup_logger(__name__)
-
-class FeatureAnalysis:
- def __init__(self):
- self.scaler = StandardScaler()
- self.important_features = []
-
- # 添加特征名称映射
- self.feature_names_map = {
- # 通用参数
- 'length_m': '总长(m)',
- 'width_m': '宽度(m)',
- 'height_m': '高度(m)',
- 'weight_kg': '重量(kg)',
- 'max_range_km': '最大射程(km)',
-
- # 火箭炮特有参数
- 'firing_angle_horizontal': '方向射界(度)',
- 'firing_angle_vertical': '高低射界(度)',
- 'rocket_length_m': '火箭弹长度(m)',
- 'rocket_diameter_mm': '口径(mm)',
- 'rocket_weight_kg': '火箭弹重量(kg)',
- 'rate_of_fire': '射速(发/分)',
- 'combat_weight_kg': '战斗重量(kg)',
- 'speed_kmh': '速度(km/h)',
- 'min_range_km': '最小射程(km)',
- 'power_hp': '功率(hp)',
-
- # 火箭炮衍生特征
- 'fire_density': '火力密度',
- 'mobility_index': '机动性指标',
- 'range_ratio': '射程比',
- 'power_weight_ratio': '功重比',
- 'volume_density': '体积密度',
-
- # 巡飞弹特有参数
- 'wingspan_m': '翼展(m)',
- 'warhead_weight_kg': '战斗部重量(kg)',
- 'max_speed_ms': '最大速度(m/s)',
- 'cruise_speed_kmh': '巡航速度(km/h)',
- 'flight_time_min': '巡飞时间(min)',
- 'folded_length_mm': '折叠长度(mm)',
- 'folded_width_mm': '折叠宽度(mm)',
- 'folded_height_mm': '折叠高度(mm)',
-
- # 巡飞弹衍生特征
- 'warhead_ratio': '战斗部比重',
- 'speed_ratio': '速度比',
- 'range_time_ratio': '射程时间比',
- 'aspect_ratio': '展弦比',
- 'volume_density': '体积密度'
- }
-
- def get_equipment_specific_features(self, equipment_type):
- """
- 获取特定装备类型的特征列表
- """
- # 通用参数
- common_features = [
- 'length_m', # 总长(m)
- 'width_m', # 宽度(m)
- 'height_m', # 高度(m)
- 'weight_kg', # 重量(kg)
- 'max_range_km' # 最大射程(km)
- ]
-
- if equipment_type == '火箭炮':
- # 火箭炮特有参数
- specific_features = [
- 'firing_angle_horizontal', # 方向射界(度)
- 'firing_angle_vertical', # 高低射界(度)
- 'rocket_length_m', # 火箭弹长度(m)
- 'rocket_diameter_mm', # 口径(mm)
- 'rocket_weight_kg', # 火箭弹重量(kg)
- 'rate_of_fire', # 射速(发/分)
- 'combat_weight_kg', # 战斗重量(kg)
- 'speed_kmh', # 速度(km/h)
- 'min_range_km', # 最小射程(km)
- 'power_hp' # 功率(hp)
- ]
-
- # 火箭炮衍生特征
- derived_features = [
- 'fire_density', # 火力密度 = 射速 * 火箭弹重量
- 'mobility_index', # 机动性指标 = 速度 / 战斗重量
- 'range_ratio', # 射程比 = 最大射程 / 最小射程
- 'power_weight_ratio', # 功重比 = 功率 / 战斗重量
- 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
- ]
-
- return common_features + specific_features + derived_features
-
- else: # 巡飞弹
- # 巡飞弹特有参数
- specific_features = [
- 'wingspan_m', # 翼展(m)
- 'warhead_weight_kg', # 战斗部重量(kg)
- 'max_speed_ms', # 最大速度(m/s)
- 'cruise_speed_kmh', # 巡航速度(km/h)
- 'flight_time_min', # 巡飞时间(min)
- 'folded_length_mm', # 折叠长度(mm)
- 'folded_width_mm', # 折叠宽度(mm)
- 'folded_height_mm' # 折叠高度(mm)
- ]
-
- # 巡飞弹衍生特征
- derived_features = [
- 'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
- 'speed_ratio', # 速度比 = 巡航速度 / 最大速度
- 'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
- 'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
- 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
- ]
-
- return common_features + specific_features + derived_features
-
- def calculate_derived_features(self, data, equipment_type):
- """
- 计算衍生特征
- """
- try:
- if equipment_type == '火箭炮':
- # 火箭炮衍生特征计算
- if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
- data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
- else:
- data['fire_density'] = 0 # 或者其他默认值
-
- if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
- data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
- else:
- data['mobility_index'] = 0
-
- if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
- data['range_ratio'] = data['max_range_km'] / data['min_range_km']
- else:
- data['range_ratio'] = 0
-
- if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
- data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
- else:
- data['power_weight_ratio'] = 0
-
- if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
- data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
- else:
- data['volume_density'] = 0
-
- else: # 巡飞弹
- # 巡飞弹衍生特征计算
- if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
- data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
- else:
- data['warhead_ratio'] = 0
-
- if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
- data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
- else:
- data['speed_ratio'] = 0
-
- if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
- data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
- else:
- data['range_time_ratio'] = 0
-
- if 'wingspan_m' in data.columns and 'length_m' in data.columns:
- data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
- else:
- data['aspect_ratio'] = 0
-
- if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
- data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
- else:
- data['volume_density'] = 0
-
- return data
-
- except Exception as e:
- logger.error(f"Error calculating derived features: {str(e)}")
- raise
-
- def analyze_features(self, features, target, feature_names):
- """
- 分析特征重要性和相关性
- """
- try:
- # 转换为numpy数组
- X = np.array(features)
- y = np.array(target)
-
- # 数据标准化
- X_scaled = self.scaler.fit_transform(X)
-
- # 特征重要性分析
- rf = RandomForestRegressor(n_estimators=100, random_state=42)
- rf.fit(X_scaled, y)
- importances = rf.feature_importances_
-
- # 按重要性排序,使用中文特征名
- importance_indices = np.argsort(importances)[::-1]
- important_features = [
- {
- 'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
- 'importance': float(importances[i])
- }
- for i in importance_indices
- ]
-
- # 相关性分析
- df = pd.DataFrame(X_scaled, columns=feature_names)
- correlation_matrix = df.corr().values
-
- # 生成相关性分析数据,保留2位小数
- correlation_data = []
- chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
- for i in range(len(feature_names)):
- for j in range(len(feature_names)):
- correlation_data.append([
- i, j,
- round(correlation_matrix[i][j], 2) # 修改为保留2位小数
- ])
-
- return {
- 'important_features': important_features,
- 'correlation_analysis': {
- 'features': chinese_feature_names, # 使用中文特征名
- 'matrix': correlation_data
- }
- }
-
- except Exception as e:
- logger.error(f"Error in feature analysis: {str(e)}")
- raise
-
- def preprocess_features(self, equipment_data, equipment_type):
- """
- 预处理特征数据
- """
- try:
- # 转换为 DataFrame
- df = pd.DataFrame(equipment_data)
-
- # 计算衍生特征
- df = self.calculate_derived_features(df, equipment_type)
-
- # 处理缺失值
- numeric_columns = df.select_dtypes(include=[np.number]).columns
- for col in numeric_columns:
- # 转换为数值类型
- df[col] = pd.to_numeric(df[col], errors='coerce')
- # 使用新的方式填充缺失值
- mean_value = df[col].mean()
- df[col] = df[col].fillna(mean_value)
-
- logger.info(f"Preprocessed data shape: {df.shape}")
- return df
-
- except Exception as e:
- logger.error(f"Error preprocessing features: {str(e)}")
- raise Exception(f"Feature preprocessing error: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/import_data.py b/deploy/equipment_cost_system/src/import_data.py
deleted file mode 100644
index 03dfa36..0000000
--- a/deploy/equipment_cost_system/src/import_data.py
+++ /dev/null
@@ -1,255 +0,0 @@
-import pandas as pd
-from .logger import setup_logger
-from src.database.db_connection import get_db_connection
-
-logger = setup_logger(__name__)
-
-def import_training_data(excel_file):
- """
- 从Excel导入训练数据到数据库
- """
- try:
- # 读取所有sheet
- rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮')
- missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹')
- special_df = pd.read_excel(excel_file, sheet_name='特殊参数')
-
- # 记录所有装备名称,用于后续检查
- equipment_names = set()
-
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 1. 先导入火箭炮数据
- logger.info("开始导入火箭炮数据...")
- for _, row in rocket_df.iterrows():
- equipment_names.add(row['名称'])
- # 检查是否已存在相同名称的装备
- cursor.execute("""
- SELECT id FROM equipment
- WHERE name = %s AND type = '火箭炮'
- """, (row['名称'],))
-
- existing_equipment = cursor.fetchone()
- if existing_equipment:
- logger.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入")
- continue
-
- # 插入基本信息
- cursor.execute("""
- INSERT INTO equipment (name, type, manufacturer)
- VALUES (%s, %s, %s)
- """, (row['名称'], '火箭炮', row['制造商']))
-
- equipment_id = cursor.lastrowid
-
- # 插入通用参数
- cursor.execute("""
- INSERT INTO common_params
- (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
- VALUES (%s, %s, %s, %s, %s, %s)
- """, (
- equipment_id,
- row['总长_m'] if pd.notna(row['总长_m']) else None,
- row['宽度_m'] if pd.notna(row['宽度_m']) else None,
- row['高度_m'] if pd.notna(row['高度_m']) else None,
- row['重量_kg'] if pd.notna(row['重量_kg']) else None,
- row['最大射程_km'] if pd.notna(row['最大射程_km']) else None
- ))
-
- # 插入火箭炮特有参数
- cursor.execute("""
- INSERT INTO rocket_artillery_params
- (equipment_id, firing_angle_horizontal, firing_angle_vertical,
- rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire,
- combat_weight_kg, speed_kmh, min_range_km, mobility_type,
- structure_layout, engine_model, engine_params, power_hp,
- travel_range_km)
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- equipment_id,
- row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
- row['高低射界_度'] if pd.notna(row['高低射界_度']) else None,
- row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None,
- row['口径_mm'] if pd.notna(row['口径_mm']) else None,
- row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None,
- row['射速_发'] if pd.notna(row['射速_发']) else None,
- row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None,
- row['速度_km/h'] if pd.notna(row['速度_km/h']) else None,
- row['最小射程_km'] if pd.notna(row['最小射程_km']) else None,
- row['行走方式'] if pd.notna(row['行走方式']) else None,
- row['结构布局'] if pd.notna(row['结构布局']) else None,
- row['发动机型号'] if pd.notna(row['发动机型号']) else None,
- row['发动机参数'] if pd.notna(row['发动机参数']) else None,
- row['功率_hp'] if pd.notna(row['功率_hp']) else None,
- row['行程_km'] if pd.notna(row['行程_km']) else None
- ))
-
- # 插入成本数据
- if pd.notna(row['成本_元']):
- cursor.execute("""
- INSERT INTO cost_data (equipment_id, actual_cost)
- VALUES (%s, %s)
- """, (equipment_id, row['成本_元']))
-
- logger.info("火箭炮数据导入完成")
-
- # 2. 导入巡飞弹数据
- logger.info("开始导入巡飞弹数据...")
- for index, row in missile_df.iterrows():
- # 记录每行数据的空值情况
- null_values = row[row.isna()].index.tolist()
- if null_values:
- logger.info(f"行 {index + 2} 中的空值字段: {null_values}")
-
- equipment_names.add(row['名称'])
- # 检查是否已存在相同名称的装备
- cursor.execute("""
- SELECT id FROM equipment
- WHERE name = %s AND type = '巡飞弹'
- """, (row['名称'],))
-
- existing_equipment = cursor.fetchone()
- if existing_equipment:
- logger.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入")
- continue
-
- # 插入基本信息
- cursor.execute("""
- INSERT INTO equipment (name, type, manufacturer)
- VALUES (%s, %s, %s)
- """, (
- row['名称'],
- '巡飞弹',
- row['制造商'] if pd.notna(row['制造商']) else None
- ))
-
- equipment_id = cursor.lastrowid
-
- # 插入通用参数
- cursor.execute("""
- INSERT INTO common_params
- (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
- VALUES (%s, %s, %s, %s, %s, %s)
- """, (
- equipment_id,
- float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
- float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
- float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
- float(row['重量_kg']) if pd.notna(row['重量_kg']) else None,
- float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None
- ))
-
- # 插入巡飞弹特有参数
- cursor.execute("""
- INSERT INTO loitering_munition_params
- (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms,
- cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
- folded_length_mm, folded_width_mm, folded_height_mm,
- power_system, guidance_system)
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- equipment_id,
- float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
- float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None,
- float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None,
- float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None,
- float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None,
- str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None,
- str(row['发射方式']) if pd.notna(row['发射方式']) else None,
- float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None,
- float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None,
- float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None,
- str(row['动力装置']) if pd.notna(row['动力装置']) else None,
- str(row['制导体制']) if pd.notna(row['制导体制']) else None
- ))
-
- # 插入成本数据
- if pd.notna(row['成本_元']):
- cursor.execute("""
- INSERT INTO cost_data (equipment_id, actual_cost)
- VALUES (%s, %s)
- """, (equipment_id, float(row['成本_元'])))
-
- logger.info("巡飞弹数据导入完成")
-
- # 提交之前的更改并关闭原有游标
- cursor.close()
- conn.commit()
-
- # 3. 导入特殊参数
- logger.info("开始导入特殊参数...")
- for index, row in special_df.iterrows():
- equipment_name = row['装备名称']
- param_name = row['参数名称']
- logger.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'")
-
- if equipment_name not in equipment_names:
- logger.warning(f"未找到装备: {equipment_name},请检查名称是否正确")
- continue
-
- # 获取装备ID - 使用新的游标
- logger.debug(f"查询装备ID: {equipment_name}")
- with conn.cursor() as id_cursor:
- id_cursor.execute("""
- SELECT id FROM equipment WHERE name = %s
- """, (equipment_name,))
- result = id_cursor.fetchone()
-
- if not result:
- logger.warning(f"未找到装备: {equipment_name}")
- continue
-
- equipment_id = result[0]
- logger.debug(f"找到装备ID: {equipment_id}")
-
- # 检查参数是否存在 - 使用新的游标
- logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
- with conn.cursor() as check_cursor:
- check_cursor.execute("""
- SELECT id FROM custom_params
- WHERE equipment_id = %s AND param_name = %s
- """, (equipment_id, param_name))
- exists = check_cursor.fetchone()
-
- if exists:
- logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
- continue
-
- # 插入新的参数 - 使用新的游标
- param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
- param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
- param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
-
- logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
- with conn.cursor() as insert_cursor:
- insert_cursor.execute("""
- INSERT INTO custom_params
- (equipment_id, param_name, param_value, param_unit, description)
- VALUES (%s, %s, %s, %s, %s)
- """, (
- equipment_id,
- param_name,
- param_value,
- param_unit,
- param_desc
- ))
- logger.debug(f"成功插入参数记录")
-
- # 最终提交
- conn.commit()
- logger.info("特殊参数导入完成")
- logger.info("所有数据导入成功")
- return True
-
- except Exception as e:
- logger.error(f"Error importing data: {str(e)}")
- raise
-
-if __name__ == "__main__":
- try:
- excel_file = 'data/equipment_data_20241108.xlsx'
- import_training_data(excel_file)
- logger.info("All data imported successfully")
- except Exception as e:
- logger.error(f"Import failed: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/init_data.sql b/deploy/equipment_cost_system/src/init_data.sql
deleted file mode 100644
index ee99db3..0000000
--- a/deploy/equipment_cost_system/src/init_data.sql
+++ /dev/null
@@ -1,319 +0,0 @@
-/*
-这是用于开发和测试环境的示例数据。
-生产环境请使用系统的数据导入功能添加实际数据。
-
-主要用途:
-1. 提供开发测试数据
-2. 作为数据格式参考
-3. 用于系统功能验证
-*/
-
--- 插入装备基本信息
-INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
-('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
-('胜利-2', '火箭炮', '伊朗', '地面固定目标');
-
--- 插入巡飞弹技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_speed_kmh,
- cruise_speed_kmh,
- max_range_km,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm
-) VALUES (
- 1, -- 终结者巡飞弹
- 0.56,
- 0.15,
- 0.20,
- 2.72,
- 160.93,
- 96.56,
- 24,
- 15,
- '破片杀伤战斗部',
- '凭自身动力起飞',
- 560,
- 150,
- 200
-);
-
--- 插入火箭炮技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_range_km
-) VALUES (
- 2, -- 胜利-2火箭炮
- 10,
- 2.5,
- 3.34,
- 15000,
- 23
-);
-
--- 插入成本数据(示例数据)
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(1, 1000000), -- 终结者巡飞弹成本
-(2, 5000000); -- 胜利-2火箭炮成本
-
--- 插入更多巡飞弹变体数据用于训练
-INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
-('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
-('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
-('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
-
--- 插入变体技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_speed_kmh,
- cruise_speed_kmh,
- max_range_km,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm
-) VALUES
--- 终结者-A(稍大型号)
-(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
--- 终结者-B(稍小型号)
-(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
--- 终结者-C(标准型号的改进版)
-(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
-
--- 插入变体成本数据
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(3, 1100000), -- 终结者-A成本(较高)
-(4, 900000), -- 终结者-B成本(较低)
-(5, 1050000); -- 终结者-C成本(中等)
-
--- 添加更多巡飞弹数据
-INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
-('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
-('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
-('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
-('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
-('彩虹-4', '巡飞弹', '中国', '地面固定目标');
-
--- 添加它们的技术参数
-INSERT INTO technical_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_speed_kmh,
- cruise_speed_kmh,
- max_range_km,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm
-) VALUES
--- 哈比
-(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
--- 游荡者
-(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
--- 凤凰
-(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
--- 弹簧刀
-(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
--- 彩虹-4
-(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
-
--- 添加成本数据
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
-(6, 800000), -- 哈比
-(7, 500000), -- 游荡者
-(8, 450000), -- 凤凰
-(9, 480000), -- 弹簧刀
-(10, 1500000); -- 彩虹-4
-
--- 火箭炮数据
-INSERT INTO equipment (name, type, manufacturer) VALUES
-('BM-21', '火箭炮', '俄罗斯'),
-('SR5', '火箭炮', '中国'),
-('HIMARS', '火箭炮', '美国'),
-('LAR-160', '火箭炮', '以色列'),
-('T-122', '火箭炮', '土耳其'),
-('RM-70', '火箭炮', '捷克'),
-('ASTROS II', '火箭炮', '巴西');
-
--- 火箭炮通用参数
-INSERT INTO common_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_range_km
-) VALUES
--- BM-21
-(1, 7.35, 2.4, 3.1, 13700, 20.4),
--- SR5
-(2, 10.2, 2.8, 3.2, 28500, 70),
--- HIMARS
-(3, 7.0, 2.4, 3.2, 16250, 70),
--- LAR-160
-(4, 6.7, 2.5, 2.8, 15000, 45),
--- T-122
-(5, 7.2, 2.5, 2.9, 18000, 40),
--- RM-70
-(6, 7.5, 2.5, 3.0, 17200, 20.3),
--- ASTROS II
-(7, 8.0, 2.7, 3.1, 24500, 90);
-
--- 火箭炮特有参数
-INSERT INTO rocket_artillery_params (
- equipment_id,
- firing_angle_horizontal,
- firing_angle_vertical,
- rocket_length_m,
- rocket_diameter_mm,
- rocket_weight_kg,
- rate_of_fire,
- combat_weight_kg,
- speed_kmh,
- min_range_km,
- mobility_type,
- structure_layout,
- engine_model,
- engine_params,
- power_hp,
- travel_range_km
-) VALUES
--- BM-21
-(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
--- SR5
-(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
--- HIMARS
-(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
--- LAR-160
-(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
--- T-122
-(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
--- RM-70
-(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
--- ASTROS II
-(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
-
--- 巡飞弹数据
-INSERT INTO equipment (name, type, manufacturer) VALUES
-('Hero-120', '巡飞弹', '以色列'),
-('Switchblade 600', '巡飞弹', '美国'),
-('Warmate', '巡飞弹', '波兰'),
-('CH-901', '巡飞弹', '中国'),
-('HAROP', '巡飞弹', '以色列'),
-('Coyote', '巡飞弹', '美国'),
-('WS-43', '巡飞弹', '中国');
-
--- 巡飞弹通用参数
-INSERT INTO common_params (
- equipment_id,
- length_m,
- width_m,
- height_m,
- weight_kg,
- max_range_km
-) VALUES
--- Hero-120
-(8, 1.3, 0.23, 0.23, 12.5, 40),
--- Switchblade 600
-(9, 1.3, 0.22, 0.22, 15.0, 40),
--- Warmate
-(10, 1.1, 0.15, 0.15, 5.7, 15),
--- CH-901
-(11, 1.2, 0.18, 0.18, 9.0, 20),
--- HAROP
-(12, 2.5, 0.43, 0.43, 135, 1000),
--- Coyote
-(13, 0.9, 0.12, 0.12, 5.9, 20),
--- WS-43
-(14, 1.8, 0.35, 0.35, 20, 60);
-
--- 巡飞弹特有参数
-INSERT INTO loitering_munition_params (
- equipment_id,
- wingspan_m,
- warhead_weight_kg,
- max_speed_ms,
- cruise_speed_kmh,
- flight_time_min,
- warhead_type,
- launch_mode,
- folded_length_mm,
- folded_width_mm,
- folded_height_mm,
- power_system,
- guidance_system
-) VALUES
--- Hero-120
-(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
--- Switchblade 600
-(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
--- Warmate
-(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
--- CH-901
-(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
--- HAROP
-(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
--- Coyote
-(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
--- WS-43
-(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
-
--- 插入成本数据(示例成本)
-INSERT INTO cost_data (equipment_id, actual_cost) VALUES
--- 火箭炮
-(1, 800000), -- BM-21
-(2, 4500000), -- SR5
-(3, 5500000), -- HIMARS
-(4, 3500000), -- LAR-160
-(5, 2800000), -- T-122
-(6, 1500000), -- RM-70
-(7, 4800000), -- ASTROS II
--- 巡飞弹
-(8, 150000), -- Hero-120
-(9, 180000), -- Switchblade 600
-(10, 80000), -- Warmate
-(11, 100000), -- CH-901
-(12, 850000), -- HAROP
-(13, 75000), -- Coyote
-(14, 120000); -- WS-43
-
--- 创建初始数据集
-INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
-('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
-('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
-('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
-('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
-
--- 关联装备到数据集
-INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
--- 火箭炮训练集
-(1, 1), (1, 2), (1, 3), (1, 4),
--- 巡飞弹训练集
-(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
--- 火箭炮验证集
-(3, 5), (3, 6), (3, 7),
--- 巡飞弹验证集
-(4, 13), (4, 14);
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/logger.py b/deploy/equipment_cost_system/src/logger.py
deleted file mode 100644
index 52bb879..0000000
--- a/deploy/equipment_cost_system/src/logger.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import logging
-import os
-from datetime import datetime
-
-def setup_logger(name):
- """
- 创建并配置logger
- """
- # 创建logger
- logger = logging.getLogger(name)
-
- # 如果logger已经有处理器,直接返回
- if logger.handlers:
- return logger
-
- # 设置日志级别
- logger.setLevel(logging.INFO)
-
- # 确保日志目录存在
- os.makedirs('logs', exist_ok=True)
-
- # 创建文件处理器
- file_handler = logging.FileHandler('logs/api.log')
- file_handler.setLevel(logging.INFO)
-
- # 创建格式化器
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- file_handler.setFormatter(formatter)
-
- # 添加处理器
- logger.addHandler(file_handler)
-
- return logger
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/model_trainer.py b/deploy/equipment_cost_system/src/model_trainer.py
deleted file mode 100644
index 660c38e..0000000
--- a/deploy/equipment_cost_system/src/model_trainer.py
+++ /dev/null
@@ -1,612 +0,0 @@
-import numpy as np
-import pandas as pd
-from sklearn.preprocessing import StandardScaler
-from sklearn.model_selection import cross_val_score
-from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
-from sklearn.impute import SimpleImputer
-import xgboost as xgb
-import lightgbm as lgb
-import logging
-import joblib
-import os
-from src.feature_analysis import FeatureAnalysis
-from datetime import datetime
-import json
-from src.database import get_db_connection
-from src.data_preparation import DataPreparation
-from sklearn.cross_decomposition import PLSRegression
-from .logger import setup_logger
-
-logger = setup_logger(__name__)
-
-class ModelTrainer:
- def __init__(self):
- """
- 初始化 ModelTrainer
- """
- self.models = {
- 'xgboost': self._create_xgboost_model(),
- 'lightgbm': self._create_lightgbm_model(),
- 'gbm': self._create_gbm_model(),
- 'rf': self._create_rf_model(),
- 'pls': self._create_pls_model()
- }
- self.best_model = None
- self.imputer = SimpleImputer(strategy='mean')
- self.feature_scaler = None
- self.target_scaler = None
- self.equipment_type = None
- self.feature_analyzer = FeatureAnalysis()
-
- def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
- """
- 训练模型并返回评估结果
- """
- try:
- self.equipment_type = equipment_type
- logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
- logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
-
- results = {}
- best_score = -float('inf')
- best_model_info = None
-
- # 首先训练 PLS 模型
- logger.info("Training pls...")
- pls_model = self.models['pls']
- pls_model.fit(X_train, y_train)
- pls_metrics = self._calculate_metrics(
- pls_model,
- X_train, y_train,
- X_val, y_val
- )
- results['pls'] = pls_metrics
-
- # 训练其他机器学习模型
- for model_name in model_names:
- if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
- continue
-
- if model_name not in self.models:
- logger.warning(f"Unknown model: {model_name}")
- continue
-
- logger.info(f"Training {model_name}...")
- model = self.models[model_name]
-
- # 训练模型
- model.fit(X_train, y_train)
-
- # 计算评估指标
- metrics = self._calculate_metrics(
- model,
- X_train, y_train,
- X_val, y_val
- )
-
- results[model_name] = metrics
-
- # 更新最佳模型(只在机器学习模型中比较)
- if metrics['validation']['r2'] > best_score:
- best_score = metrics['validation']['r2']
- best_model_info = {
- 'type': model_name,
- 'r2': metrics['validation']['r2'],
- 'mae': metrics['validation']['mae'],
- 'rmse': metrics['validation']['rmse']
- }
- self.best_model = model
-
- # 保存最佳模型和 PLS 模型
- if equipment_type and best_model_info:
- self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
-
- return {
- 'metrics': results,
- 'best_model': best_model_info
- }
-
- except Exception as e:
- logger.error(f"Error in model training: {str(e)}")
- raise
-
- def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
- """
- 计算模型评估指标
- """
- from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
-
- # 训练集评估
- train_pred = model.predict(X_train)
-
- # 如果使用了标准化,需要转换回原始范围
- if hasattr(self, 'target_scaler'):
- train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
- y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
- else:
- y_train_orig = y_train
-
- # 记录预测范围
- logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
- logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
-
- train_metrics = {
- 'r2': r2_score(y_train_orig, train_pred),
- 'mae': mean_absolute_error(y_train_orig, train_pred),
- 'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
- }
-
- # 验证集评估
- if X_val is not None and y_val is not None:
- val_pred = model.predict(X_val)
-
- # 如果使用了标准化,需要转换回原始范围
- if hasattr(self, 'target_scaler'):
- val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
- y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
- else:
- y_val_orig = y_val
-
- # 记录预测范围
- logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
- logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
-
- val_metrics = {
- 'r2': r2_score(y_val_orig, val_pred),
- 'mae': mean_absolute_error(y_val_orig, val_pred),
- 'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
- }
- else:
- # 使用交叉验证
- cv_scores = cross_val_score(model, X_train, y_train, cv=5)
- val_metrics = {
- 'r2': cv_scores.mean(),
- 'mae': None,
- 'rmse': None
- }
-
- return {
- 'train': train_metrics,
- 'validation': val_metrics
- }
-
- def _create_xgboost_model(self):
- """
- 创建 XGBoost 模型,增强正则化
- """
- return xgb.XGBRegressor(
- n_estimators=50, # 减少树的数量
- learning_rate=0.05, # 学习率
- max_depth=3, # 减小树的深
- min_child_weight=3, # 增加节点权重
- subsample=0.7, # 减小样本采样比例
- colsample_bytree=0.7, # 减小特征采样比例
- reg_alpha=0.1, # L1 正则化
- reg_lambda=1, # L2 正则化
- random_state=42
- )
-
- def _create_lightgbm_model(self):
- """
- 创建 LightGBM 模型,增强正则化
- """
- return lgb.LGBMRegressor(
- n_estimators=50,
- learning_rate=0.05,
- max_depth=3,
- num_leaves=7,
- min_data_in_leaf=3,
- min_sum_hessian_in_leaf=1e-3,
- subsample=0.7,
- colsample_bytree=0.7,
- reg_alpha=0.1,
- reg_lambda=1,
- random_state=42,
- verbose=-1
- )
-
- def _create_gbm_model(self):
- """
- 创建 GBM 模型,增强正则化以减轻过拟合
- """
- return GradientBoostingRegressor(
- n_estimators=100,
- learning_rate=0.1,
- max_depth=3,
- random_state=42,
- subsample=0.8,
- min_samples_split=3,
- min_samples_leaf=2
- )
-
- def _create_rf_model(self):
- """
- 创建随机森林模型,针对小样本数据调整参数
- """
- return RandomForestRegressor(
- n_estimators=100,
- max_depth=3,
- random_state=42,
- min_samples_split=3,
- min_samples_leaf=2
- )
-
- def _create_pls_model(self):
- """
- 创建 PLS 模型,优化参数配置
- """
- return PLSRegression(
- n_components=2, # 减少主成分数量,从5减到2
- scale=True, # 保持数据标准化
- max_iter=500, # 减少最大迭代次数,避免过拟合
- tol=1e-6 # 降低收敛精度,避免过拟合
- )
-
- def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
- """
- 保存最佳模型和 PLS 模型
- """
- try:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- model_dir = 'models'
- os.makedirs(model_dir, exist_ok=True)
-
- # 1. 保存最佳机器学习模型
- model_path = f'{model_dir}/{equipment_type}_{timestamp}'
- if isinstance(self.best_model, xgb.XGBRegressor):
- self.best_model.save_model(f'{model_path}.json')
- model_format = 'json'
- else:
- joblib.dump(self.best_model, f'{model_path}.joblib')
- model_format = 'joblib'
-
- # 2. 保存 PLS 模型
- pls_model = self.models['pls']
- pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
- joblib.dump(pls_model, pls_path)
-
- # 3. 保存标准化器
- scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
- joblib.dump({
- 'feature_scaler': self.feature_scaler,
- 'target_scaler': self.target_scaler
- }, scaler_path)
-
- logger.info(f"Saved best model to {model_path}.{model_format}")
- logger.info(f"Saved PLS model to {pls_path}")
- logger.info(f"Saved scalers to {scaler_path}")
-
- # 4. 更新数据库中的模型记录
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 将所有模型设置为非激活
- cursor.execute("""
- UPDATE trained_models
- SET is_active = FALSE
- WHERE equipment_type = %s
- """, (equipment_type,))
-
- # 获取 PLS 模型的评估指标
- pls_metrics = self._calculate_metrics(
- self.models['pls'],
- X_train,
- y_train,
- X_val,
- y_val
- )
-
- # 保存最佳机器学习模型记录
- self.best_model.equipment_type = equipment_type # 设置装备类型
- ml_feature_importance = self._get_feature_importance(self.best_model)
-
- cursor.execute("""
- INSERT INTO trained_models (
- model_name, model_type, equipment_type, model_path, scaler_path,
- r2_score, mae, rmse, feature_importance, training_data_size,
- training_date, is_active, created_by
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
- """, (
- f"{equipment_type}_{timestamp}", # model_name
- best_model_info['type'], # model_type
- equipment_type, # equipment_type
- f"{model_path}.{model_format}", # model_path
- scaler_path, # scaler_path
- best_model_info['r2'], # r2_score
- best_model_info['mae'], # mae
- best_model_info['rmse'], # rmse
- json.dumps(ml_feature_importance), # feature_importance
- len(X_train), # training_data_size
- 'system' # created_by
- ))
-
- # 保存 PLS 模型记录
- pls_model.equipment_type = equipment_type # 设置装备类型
- pls_feature_importance = self._get_feature_importance(pls_model)
-
- cursor.execute("""
- INSERT INTO trained_models (
- model_name, model_type, equipment_type, model_path, scaler_path,
- r2_score, mae, rmse, feature_importance, training_data_size,
- training_date, is_active, created_by
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
- """, (
- f"{equipment_type}_{timestamp}_pls", # model_name
- 'pls', # model_type
- equipment_type, # equipment_type
- pls_path, # model_path
- scaler_path, # scaler_path
- float(pls_metrics['validation']['r2']), # r2_score
- float(pls_metrics['validation']['mae']), # mae
- float(pls_metrics['validation']['rmse']), # rmse
- json.dumps(pls_feature_importance), # feature_importance
- len(X_train), # training_data_size
- 'system' # created_by
- ))
-
- conn.commit()
-
- except Exception as e:
- logger.error(f"Error saving models: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- raise
-
- def load_model(self, equipment_type, model_type='ml'):
- """
- 加载已训练的模型
- """
- try:
- logger.info(f"Loading {model_type} model for {equipment_type}")
-
- # 从数据库获取激活的模型
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
-
- # 构建查询语句
- if model_type == 'pls':
- query = """
- SELECT * FROM trained_models
- WHERE equipment_type = %s
- AND model_type = 'pls'
- AND is_active = TRUE
- LIMIT 1
- """
- params = (equipment_type,)
- else:
- query = """
- SELECT * FROM trained_models
- WHERE equipment_type = %s
- AND model_type != 'pls'
- AND is_active = TRUE
- LIMIT 1
- """
- params = (equipment_type,)
-
- # 记录查询信息
- logger.info(f"Executing query: {query}")
- logger.info(f"Query params: {params}")
-
- cursor.execute(query, params)
- model_record = cursor.fetchone()
-
- # 记录查询结果
- if model_record:
- logger.info(f"Found model record: {model_record}")
- else:
- logger.warning(f"No active model found for type {model_type}")
- return False
-
- # 检查文件是否存在
- logger.info(f"Checking model file: {model_record['model_path']}")
- logger.info(f"Checking scaler file: {model_record['scaler_path']}")
-
- if not os.path.exists(model_record['model_path']):
- logger.error(f"Model file not found: {model_record['model_path']}")
- raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
-
- if not os.path.exists(model_record['scaler_path']):
- logger.error(f"Scaler file not found: {model_record['scaler_path']}")
- raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
-
- # 加载模型文件
- logger.info(f"Loading model from {model_record['model_path']}")
- if model_type == 'pls':
- self.best_model = joblib.load(model_record['model_path'])
- logger.info("Loaded PLS model")
- else:
- if model_record['model_type'] == 'xgboost':
- self.best_model = xgb.XGBRegressor()
- self.best_model.load_model(model_record['model_path'])
- logger.info("Loaded XGBoost model")
- else:
- self.best_model = joblib.load(model_record['model_path'])
- logger.info(f"Loaded {model_record['model_type']} model")
-
- # 加载标准化器
- logger.info(f"Loading scalers from {model_record['scaler_path']}")
- scalers = joblib.load(model_record['scaler_path'])
- self.feature_scaler = scalers['feature_scaler']
- self.target_scaler = scalers['target_scaler']
- logger.info("Loaded scalers successfully")
-
- return True
-
- except Exception as e:
- logger.error(f"Error loading model: {str(e)}")
- logger.error(f"Detailed traceback:", exc_info=True)
- return False
-
- def predict(self, features):
- """
- 使用加载的模型进行预测
- """
- try:
- if not self.best_model:
- raise ValueError("No model loaded")
-
- if not self.feature_scaler:
- raise ValueError("Feature scaler not loaded")
-
- if not self.target_scaler:
- raise ValueError("Target scaler not loaded")
-
- logger.info("Starting prediction")
- logger.info(f"Input features shape: {features.shape}")
- logger.info(f"Input features: \n{features}")
-
- # 处理缺失值
- features_filled = np.array(features, dtype=float)
- features_filled[np.isnan(features_filled)] = 0
- features_filled = np.nan_to_num(features_filled, 0)
-
- logger.info(f"Filled features: \n{features_filled}")
-
- # 标准化特征
- X = self.feature_scaler.transform(features_filled)
- logger.info(f"Transformed features shape: {X.shape}")
- logger.info(f"Transformed features: \n{X}")
-
- # 预测
- y_pred_scaled = self.best_model.predict(X)
- logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
- logger.info(f"Scaled prediction: {y_pred_scaled}")
-
- # ��标准化
- y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
- logger.info(f"Final prediction shape: {y_pred.shape}")
- logger.info(f"Final prediction: {y_pred}")
-
- # 记录标准化器的参数
- logger.info("Target scaler params:")
- logger.info(f"Mean: {self.target_scaler.mean_}")
- logger.info(f"Scale: {self.target_scaler.scale_}")
-
- return y_pred.ravel()
-
- except Exception as e:
- logger.error(f"Error in prediction: {str(e)}")
- raise
-
- def _get_feature_importance(self, model):
- """
- 获取特征重要性
- """
- try:
- if not model:
- return {}
-
- # 获取特征名称
- feature_analyzer = FeatureAnalysis()
- feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
-
- # 获取特���重要性
- if hasattr(model, 'feature_importances_'):
- importances = model.feature_importances_
- elif hasattr(model, 'coef_'):
- if len(model.coef_.shape) > 1: # 如果是二维数组
- importances = np.abs(model.coef_[0]) # 取第一行
- else:
- importances = np.abs(model.coef_)
- else:
- return {}
-
- # 创建特征重要性字典
- importance_dict = {}
- for name, importance in zip(feature_names, importances):
- importance_dict[name] = float(importance) # 确保转换为 Python 标量
-
- # 按重要性降序排序
- sorted_dict = dict(sorted(
- importance_dict.items(),
- key=lambda x: x[1],
- reverse=True
- ))
-
- # 过滤掉重要性为0的特征
- return {k: v for k, v in sorted_dict.items() if v > 0}
-
- except Exception as e:
- logger.error(f"Error getting feature importance: {str(e)}")
- return {}
-
- def _calculate_confidence_interval(self, prediction, confidence=0.95):
- """
- 计算预测值的置信区间
- """
- try:
- # 使用预测值的20%作为标准差(增加不确定性)
- std = abs(prediction) * 0.2
-
- # 计算置信区间
- from scipy import stats
- interval = stats.norm.interval(confidence, loc=prediction, scale=std)
-
- # 确保区间值为正数且合理
- lower = max(1000, interval[0]) # 最小值设为1000元
- upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
-
- logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
-
- return [lower, upper]
-
- except Exception as e:
- logger.error(f"Error calculating confidence interval: {str(e)}")
- # 如果计算失败,返回基于20%的简单区间
- lower = max(1000, prediction * 0.8)
- upper = prediction * 1.2
- return [lower, upper]
-
- def get_model_type(self):
- """
- 获取当前模型的类型
- """
- if isinstance(self.best_model, xgb.XGBRegressor):
- return 'xgboost'
- elif isinstance(self.best_model, lgb.LGBMRegressor):
- return 'lightgbm'
- elif isinstance(self.best_model, GradientBoostingRegressor):
- return 'gbm'
- elif isinstance(self.best_model, RandomForestRegressor):
- return 'rf'
- else:
- return 'unknown'
-
- def _get_pls_feature_importance(self):
- """
- 获取 PLS 模型的特征重要性
- """
- try:
- if not self.models['pls']:
- return {}
-
- # 获取特征名称
- feature_analyzer = FeatureAnalysis()
- feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
-
- # 获取 PLS 模型的系数作为特征重要性
- pls_model = self.models['pls']
- if hasattr(pls_model, 'coef_'):
- # 使用绝对值作为重要性指标
- importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
- else:
- return {}
-
- # 创建特征重要性字典
- importance_dict = {}
- for name, importance in zip(feature_names, importances):
- importance_dict[name] = float(importance) # 确保转换为 Python 标量
-
- # 按重要性降序排序
- sorted_dict = dict(sorted(
- importance_dict.items(),
- key=lambda x: x[1],
- reverse=True
- ))
-
- # 过滤掉重要性为0的特征
- return {k: v for k, v in sorted_dict.items() if v > 0}
-
- except Exception as e:
- logger.error(f"Error getting PLS feature importance: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- return {}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/routes.py b/deploy/equipment_cost_system/src/routes.py
deleted file mode 100644
index df4e825..0000000
--- a/deploy/equipment_cost_system/src/routes.py
+++ /dev/null
@@ -1,1254 +0,0 @@
-from flask import Blueprint, request, jsonify, send_file
-from .cost_prediction import CostPredictor
-from .feature_analysis import FeatureAnalysis
-import pandas as pd
-from datetime import datetime
-import numpy as np
-import mysql.connector
-from sklearn.metrics import mean_absolute_error
-from .create_template import create_excel_template
-import json
-import os
-from .data_preparation import DataPreparation
-from .model_trainer import ModelTrainer
-from .logger import setup_logger
-
-# 创建蓝图
-api_bp = Blueprint('api', __name__)
-
-# 获取logger
-logger = setup_logger(__name__)
-
-@api_bp.route('/', methods=['GET'])
-def index():
- """
- API根路由
- 返回API版本信息和可用端点列表
- """
- return jsonify({
- 'name': '装备成本估算系统 API',
- 'version': '1.0.0',
- 'endpoints': {
- 'predict': {
- 'url': '/api/predict',
- 'method': 'POST',
- 'description': '成本预测'
- },
- 'analyze-features': {
- 'url': '/api/analyze-features',
- 'method': 'POST',
- 'description': '特征分析'
- },
- 'train': {
- 'url': '/api/train',
- 'method': 'POST',
- 'description': '模型训练'
- },
- 'evaluate': {
- 'url': '/api/evaluate',
- 'method': 'POST',
- 'description': '模型评估'
- }
- }
- })
-
-@api_bp.route('/predict', methods=['POST'])
-def predict_cost():
- """
- 成本预测接口
- """
- try:
- data = request.get_json()
- logger.info(f"Received prediction request for equipment type: {data.get('type')}")
-
- # 验证装备类型
- if 'type' not in data:
- return jsonify({'error': 'Equipment type is required'}), 400
-
- # 预测成本
- predictor = CostPredictor()
- result = predictor.predict(data)
-
- # 获取当前使用的模型信息
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT model_type, model_name, r2_score, mae, rmse
- FROM trained_models
- WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE
- LIMIT 1
- """, (data['type'],))
- model_info = cursor.fetchone()
-
- # 在结果中添加模型信息
- result.update({
- 'model_info': {
- 'type': model_info['model_type'],
- 'name': model_info['model_name'],
- 'r2_score': float(model_info['r2_score']),
- 'mae': float(model_info['mae']),
- 'rmse': float(model_info['rmse'])
- }
- })
-
- logger.info(f"Prediction completed: {result}")
- return jsonify(result)
-
- except Exception as e:
- logger.error(f"Error in prediction: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/analyze-features', methods=['POST'])
-def analyze_features():
- """
- 基于数据集进行特征分析
- """
- try:
- data = request.get_json()
- dataset_id = data.get('dataset_id')
-
- logger.info(f"Starting feature analysis for dataset {dataset_id}")
-
- if not dataset_id:
- logger.warning("No dataset_id provided")
- return jsonify({'error': '请选择数据集'}), 400
-
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
-
- # 获取数据集信息
- cursor.execute("""
- SELECT d.*,
- e.type as equipment_type
- FROM datasets d
- JOIN dataset_equipment de ON d.id = de.dataset_id
- JOIN equipment e ON de.equipment_id = e.id
- WHERE d.id = %s
- LIMIT 1
- """, (dataset_id,))
- dataset = cursor.fetchone()
-
- if not dataset:
- logger.warning(f"Dataset {dataset_id} not found")
- return jsonify({'error': '数据集不存在'}), 404
-
- logger.info(f"Dataset info: {dataset}")
-
- # 创建特征分析实例
- from src.feature_analysis import FeatureAnalysis
- analyzer = FeatureAnalysis()
-
- # 获取特征列表
- feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
- logger.info(f"Feature names: {feature_names}")
-
- # 获取数据集中的装备数据
- if dataset['equipment_type'] == '火箭炮':
- cursor.execute("""
- SELECT e.*, cp.*, rap.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- AND cd.actual_cost IS NOT NULL
- """, (dataset_id,))
- else:
- cursor.execute("""
- SELECT e.*, cp.*, lmp.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- AND cd.actual_cost IS NOT NULL
- """, (dataset_id,))
-
- equipment_data = cursor.fetchall()
- logger.info(f"Found {len(equipment_data)} equipment records")
-
- if not equipment_data:
- logger.warning("No valid equipment data found in dataset")
- return jsonify({'error': '数据集没有有效的成本数据'}), 400
-
- # 统计每个特征的缺失率
- missing_rates = {}
- for name in feature_names:
- missing_count = sum(1 for item in equipment_data if item.get(name) is None)
- missing_rate = missing_count / len(equipment_data)
- missing_rates[name] = missing_rate
- logger.info(f"Feature {name} missing rate: {missing_rate:.2%}")
-
- # 过滤掉缺失率过高的特征
- valid_features = [name for name in feature_names if missing_rates[name] < 0.7]
- logger.info(f"Valid features after filtering: {valid_features}")
-
- if len(valid_features) < 3: # 至少需要3个特征
- return jsonify({'error': '有效特征数量不足'}), 400
-
- # 计算每个特征的均值
- feature_means = {}
- for name in valid_features:
- values = [float(item[name]) for item in equipment_data if item.get(name) is not None]
- feature_means[name] = sum(values) / len(values) if values else 0
- logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}")
-
- # 准备特征和目标值
- features = []
- target = []
-
- # 提取特征和目标值,使用均值填充缺失值
- for item in equipment_data:
- feature_values = []
- for name in valid_features:
- value = item.get(name)
- try:
- # 确保数值类型转换正确
- feature_values.append(float(value) if value is not None else feature_means[name])
- except (ValueError, TypeError) as e:
- logger.error(f"Error converting value for feature {name}: {value}")
- logger.error(f"Error details: {str(e)}")
- return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400
- features.append(feature_values)
-
- # 确保成本值是值类型
- try:
- target.append(float(item['actual_cost']))
- except (ValueError, TypeError) as e:
- logger.error(f"Error converting actual_cost: {item['actual_cost']}")
- logger.error(f"Error details: {str(e)}")
- return jsonify({'error': '成本值无法换为数值'}), 400
-
- logger.info(f"Prepared {len(features)} feature vectors")
- logger.info(f"First feature vector: {features[0] if features else None}")
- logger.info(f"First target value: {target[0] if target else None}")
-
- # 调用特征分析方法
- result = analyzer.analyze_features(features, target, valid_features)
- logger.info("Analysis completed successfully")
-
- return jsonify(result)
-
- except Exception as e:
- logger.error(f"Error analyzing features: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/train', methods=['POST'])
-def train_model():
- """
- 训练模型
- """
- try:
- data = request.get_json()
- logger.info(f"Starting model training for {data.get('type')}")
- equipment_type = data.get('type')
- train_dataset_id = data.get('train_dataset_id')
- validation_dataset_id = data.get('validation_dataset_id')
- models = data.get('models', [])
-
- logger.info(f"Training dataset: {train_dataset_id}")
- logger.info(f"Validation dataset: {validation_dataset_id}")
- logger.info(f"Selected models: {models}")
-
- # 获取训练数据
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
-
- # 获取训练集数据
- if equipment_type == '火箭炮':
- cursor.execute("""
- SELECT e.*, cp.*, rap.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- AND cd.actual_cost IS NOT NULL
- """, (train_dataset_id,))
- else:
- cursor.execute("""
- SELECT e.*, cp.*, lmp.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- AND cd.actual_cost IS NOT NULL
- """, (train_dataset_id,))
-
- train_data = cursor.fetchall()
-
- # 获取验证集数据(如果有)
- validation_data = None
- if validation_dataset_id:
- if equipment_type == '火箭炮':
- cursor.execute("""
- SELECT e.*, cp.*, rap.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- AND cd.actual_cost IS NOT NULL
- """, (validation_dataset_id,))
- else:
- cursor.execute("""
- SELECT e.*, cp.*, lmp.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- AND cd.actual_cost IS NOT NULL
- """, (validation_dataset_id,))
- validation_data = cursor.fetchall()
-
- if not train_data:
- return jsonify({'error': '训练数据集为空'}), 400
-
- # 1. 准备数据
- data_processor = DataPreparation()
-
- # 准备训练数据
- train_prepared = data_processor.prepare_training_data(train_data, equipment_type)
-
- # 准备验证数据(如果有)
- validation_prepared = None
- if validation_data:
- validation_prepared = data_processor.prepare_validation_data(
- validation_data,
- equipment_type,
- train_prepared['feature_names'],
- {
- 'feature_scaler': train_prepared['feature_scaler'],
- 'target_scaler': train_prepared['target_scaler']
- }
- )
-
- # 2. 训练模型
- model_trainer = ModelTrainer()
- model_trainer.feature_scaler = train_prepared['feature_scaler']
- model_trainer.target_scaler = train_prepared['target_scaler']
-
- # 执行训练,传入 equipment_type 参数
- training_result = model_trainer.fit_model(
- train_prepared['X'],
- train_prepared['y'],
- models,
- validation_prepared['X'] if validation_prepared else None,
- validation_prepared['y'] if validation_prepared else None,
- equipment_type=equipment_type
- )
-
- return jsonify(training_result)
-
- except Exception as e:
- logger.error(f"Error in model training: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/evaluate', methods=['POST'])
-def evaluate_model():
- """
- 模型评估接口
- """
- try:
- data = request.get_json()
- logger.info("Received model evaluation request")
-
- if 'test_data' not in data:
- return jsonify({'error': 'Test data is required'}), 400
-
- predictor = CostPredictor()
- evaluation_result = predictor.evaluate(
- data['test_data']['actual'],
- data['test_data']['predicted']
- )
-
- logger.info("Model evaluation completed")
- return jsonify(evaluation_result)
-
- except Exception as e:
- logger.error(f"Error in model evaluation: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-def get_required_params(equipment_type):
- """
- 根据装备类型获取必要参数
- """
- common_params = [
- 'length_m',
- 'width_m',
- 'height_m',
- 'weight_kg',
- 'max_range_km'
- ]
-
- if equipment_type == '火箭炮':
- return common_params + [
- 'firing_angle_horizontal',
- 'firing_angle_vertical',
- 'rocket_length_m',
- 'rocket_diameter_mm',
- 'rocket_weight_kg'
- ]
- elif equipment_type == '巡飞弹':
- return common_params + [
- 'max_speed_kmh',
- 'cruise_speed_kmh',
- 'flight_time_min',
- 'folded_length_mm',
- 'folded_width_mm',
- 'folded_height_mm'
- ]
-
- return common_params
-
-@api_bp.errorhandler(404)
-def not_found(error):
- return jsonify({'error': 'Not found'}), 404
-
-@api_bp.errorhandler(500)
-def internal_error(error):
- logger.error(f"Internal server error: {str(error)}")
- return jsonify({'error': 'Internal server error'}), 500
-
-@api_bp.route('/data', methods=['GET'])
-def get_equipment_data():
- """
- 获取装备数据
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute('SET SESSION group_concat_max_len = 1000000')
-
- # 先测试特殊参数查询
- cursor.execute("""
- SELECT equipment_id, param_name, param_value, param_unit
- FROM custom_params
- WHERE param_name IS NOT NULL
- AND param_value IS NOT NULL
- LIMIT 5
- """)
- test_params = cursor.fetchall()
- logger.info(f"Test custom params: {test_params}")
-
- # 获取火箭炮数据
- logger.info("Fetching rocket artillery data...")
- cursor.execute("""
- SELECT
- e.id,
- e.name,
- e.type,
- e.manufacturer,
- e.created_at,
- cp.length_m,
- cp.width_m,
- cp.height_m,
- cp.weight_kg,
- cp.max_range_km,
- rap.firing_angle_horizontal,
- rap.firing_angle_vertical,
- rap.rocket_length_m,
- rap.rocket_diameter_mm,
- rap.rocket_weight_kg,
- rap.rate_of_fire,
- rap.combat_weight_kg,
- rap.speed_kmh,
- rap.min_range_km,
- rap.mobility_type,
- rap.structure_layout,
- rap.engine_model,
- rap.engine_params,
- rap.power_hp,
- rap.travel_range_km,
- cd.actual_cost,
- (
- SELECT COALESCE(
- JSON_ARRAYAGG(
- JSON_OBJECT(
- 'id', csp.id,
- 'param_name', csp.param_name,
- 'param_value', csp.param_value,
- 'param_unit', csp.param_unit,
- 'description', csp.description
- )
- ),
- '[]'
- )
- FROM custom_params csp
- WHERE csp.equipment_id = e.id
- AND csp.param_name IS NOT NULL
- AND csp.param_value IS NOT NULL
- ) as custom_params
- FROM equipment e
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE e.type = '火箭炮'
- """)
- rocket_artillery = cursor.fetchall()
- logger.info(f"Found {len(rocket_artillery)} rocket artillery records")
- if rocket_artillery:
- logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}")
- logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}")
-
- # 获取巡飞弹数据
- logger.info("Fetching missile data...")
- cursor.execute("""
- SELECT
- e.id,
- e.name,
- e.type,
- e.manufacturer,
- e.created_at,
- cp.length_m,
- cp.width_m,
- cp.height_m,
- cp.weight_kg,
- cp.max_range_km,
- lmp.wingspan_m,
- lmp.warhead_weight_kg,
- lmp.max_speed_ms,
- lmp.cruise_speed_kmh,
- lmp.flight_time_min,
- lmp.warhead_type,
- lmp.launch_mode,
- lmp.folded_length_mm,
- lmp.folded_width_mm,
- lmp.folded_height_mm,
- lmp.power_system,
- lmp.guidance_system,
- cd.actual_cost,
- (
- SELECT COALESCE(
- JSON_ARRAYAGG(
- JSON_OBJECT(
- 'id', csp.id,
- 'param_name', csp.param_name,
- 'param_value', csp.param_value,
- 'param_unit', csp.param_unit,
- 'description', csp.description
- )
- ),
- '[]'
- )
- FROM custom_params csp
- WHERE csp.equipment_id = e.id
- AND csp.param_name IS NOT NULL
- AND csp.param_value IS NOT NULL
- ) as custom_params
- FROM equipment e
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE e.type = '巡飞弹'
- """)
- loitering_munition = cursor.fetchall()
- logger.info(f"Found {len(loitering_munition)} missile records")
- if loitering_munition:
- logger.info(f"First missile: {loitering_munition[0]['name']}")
- logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}")
-
- # 处理 custom_params,保为 NULL
- for item in rocket_artillery + loitering_munition:
- if item['custom_params'] is None:
- item['custom_params'] = []
- logger.debug(f"Set empty custom_params for equipment {item['id']}")
- else:
- logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params")
-
- logger.info("Data fetching completed")
- return jsonify({
- 'rocket_artillery': rocket_artillery,
- 'loitering_munition': loitering_munition
- })
-
- except Exception as e:
- logger.error(f"Error getting equipment data: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/data/', methods=['DELETE'])
-def delete_equipment(id):
- """
- 删除装备数据
- """
- try:
- db = get_db_connection()
- cursor = db.cursor()
-
- # 删除相关数据
- cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
- cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
- cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
- cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
- cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
-
- db.commit()
- cursor.close()
- db.close()
-
- return jsonify({'status': 'success'})
-
- except Exception as e:
- logger.error(f"Error deleting equipment: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/data/template', methods=['GET'])
-def download_template():
- """
- 下载数据模板
- """
- try:
- # 创建模板文件
- from .create_template import create_excel_template
- template_path = create_excel_template()
-
- # 检查文件是否存
- if not os.path.exists(template_path):
- raise FileNotFoundError("模板文件不存在")
-
- # 返回文件
- return send_file(
- template_path,
- as_attachment=True,
- download_name='equipment_data_template.xlsx',
- mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
- )
-
- except Exception as e:
- logger.error(f"Error creating template: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-def get_db_connection():
- """
- 获取数据库连接
- """
- return mysql.connector.connect(
- host="localhost",
- user="root",
- password="123456",
- database="equipment_cost_db"
- )
-
-@api_bp.route('/pls/predict', methods=['POST'])
-def pls_predict():
- """
- PLS回归预测接口
- """
- try:
- data = request.get_json()
- logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}")
-
- # 验证装备类型
- if 'type' not in data:
- return jsonify({'error': 'Equipment type is required'}), 400
-
- # 使用 ModelTrainer 中的 PLS 模型进行预测
- trainer = ModelTrainer()
- if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型
- return jsonify({'error': '未找到可用的模型'}), 404
-
- # 准备特征数据
- feature_analyzer = FeatureAnalysis()
- features = feature_analyzer.get_equipment_specific_features(data['type'])
- X = np.array([[data.get(feature) for feature in features]])
-
- # 预测
- result = trainer.predict(X)
-
- # 计算置信区间
- confidence_interval = trainer._calculate_confidence_interval(result[0])
-
- # 获取模型信息
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT model_type, model_name, r2_score, mae, rmse
- FROM trained_models
- WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE
- LIMIT 1
- """, (data['type'],))
- model_info = cursor.fetchone()
-
- # 确保返回的数据可以序列化为JSON
- response = {
- 'predicted_cost': float(result[0]),
- 'model_info': {
- 'type': model_info['model_type'],
- 'name': model_info['model_name'],
- 'r2_score': model_info['r2_score'],
- 'mae': model_info['mae'],
- 'rmse': model_info['rmse']
- },
- 'confidence_interval': {
- 'lower': float(confidence_interval[0]),
- 'upper': float(confidence_interval[1])
- }
- }
-
- logger.info(f"PLS prediction completed: {response}")
- return jsonify(response)
-
- except Exception as e:
- logger.error(f"Error in PLS prediction: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/data/import', methods=['POST'])
-def import_data():
- """
- 导入数据接口
- """
- try:
- if 'file' not in request.files:
- return jsonify({'error': '没有上传文件'}), 400
-
- file = request.files['file']
- if not file.filename.endswith(('.xls', '.xlsx')):
- return jsonify({'error': '请上传Excel文件'}), 400
-
- # 保存上的文件
- upload_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
- os.makedirs(upload_dir, exist_ok=True)
- file_path = os.path.join(upload_dir, file.filename)
- file.save(file_path)
-
- # 导入数据
- from .import_data import import_training_data
- import_training_data(file_path)
-
- return jsonify({
- 'success': True,
- 'message': '数据导入成功'
- })
-
- except Exception as e:
- logger.error(f"Error importing data: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/data/', methods=['PUT'])
-def update_equipment(id):
- """
- 更新装备数据
- """
- try:
- data = request.get_json()
- logger.info(f"Updating equipment ID: {id}")
- logger.info(f"Update data: {data}")
-
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 更新基本信息
- cursor.execute("""
- UPDATE equipment
- SET name = %s, manufacturer = %s
- WHERE id = %s
- """, (data['name'], data['manufacturer'], id))
- logger.info("Basic info updated")
-
- # 更新通用参数
- cursor.execute("""
- UPDATE common_params
- SET length_m = %s, width_m = %s, height_m = %s,
- weight_kg = %s, max_range_km = %s
- WHERE equipment_id = %s
- """, (
- data['length_m'], data['width_m'], data['height_m'],
- data['weight_kg'], data['max_range_km'], id
- ))
- logger.info("Common params updated")
-
- # 根据备类型更新特有参数
- if data['type'] == '火箭炮':
- cursor.execute("""
- UPDATE rocket_artillery_params
- SET firing_angle_horizontal = %s, firing_angle_vertical = %s,
- rocket_length_m = %s, rocket_diameter_mm = %s,
- rocket_weight_kg = %s, rate_of_fire = %s
- WHERE equipment_id = %s
- """, (
- data['firing_angle_horizontal'], data['firing_angle_vertical'],
- data['rocket_length_m'], data['rocket_diameter_mm'],
- data['rocket_weight_kg'], data['rate_of_fire'], id
- ))
- logger.info("Rocket artillery params updated")
- else:
- cursor.execute("""
- UPDATE loitering_munition_params
- SET max_speed_ms = %s, cruise_speed_kmh = %s,
- flight_time_min = %s, warhead_type = %s,
- launch_mode = %s, folded_length_mm = %s,
- folded_width_mm = %s, folded_height_mm = %s
- WHERE equipment_id = %s
- """, (
- data['max_speed_ms'], data['cruise_speed_kmh'],
- data['flight_time_min'], data['warhead_type'],
- data['launch_mode'], data['folded_length_mm'],
- data['folded_width_mm'], data['folded_height_mm'], id
- ))
- logger.info("Missile params updated")
-
- # 更新成本数据
- if 'actual_cost' in data:
- cursor.execute("""
- UPDATE cost_data
- SET actual_cost = %s
- WHERE equipment_id = %s
- """, (data['actual_cost'], id))
- logger.info("Cost data updated")
-
- # 更新特殊参数
- if 'custom_params' in data and data['custom_params']:
- logger.info(f"Updating custom params: {data['custom_params']}")
- for param in data['custom_params']:
- cursor.execute("""
- UPDATE custom_params
- SET param_value = %s
- WHERE id = %s AND equipment_id = %s
- """, (param['param_value'], param['id'], id))
- logger.info("Custom params updated")
-
- conn.commit()
- logger.info("All updates committed successfully")
-
- return jsonify({'success': True})
-
- except Exception as e:
- logger.error(f"Error updating equipment: {str(e)}")
- logger.error("Detailed traceback:", exc_info=True)
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/data/details/', methods=['GET'])
-def get_equipment_details(id):
- """
- 获取装备详数据
- """
- try:
- logger.info(f"Getting details for equipment ID: {id}")
-
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
-
- # 先获取装备类型
- cursor.execute("SELECT type FROM equipment WHERE id = %s", (id,))
- equipment = cursor.fetchone()
-
- if not equipment:
- logger.warning(f"Equipment not found: {id}")
- return jsonify({'error': 'Equipment not found'}), 404
-
- equipment_type = equipment['type']
- logger.info(f"Equipment type: {equipment_type}")
-
- # 根据装备类型选择查询
- if equipment_type == '火箭炮':
- query = """
- SELECT
- e.*,
- cp.*,
- rap.*,
- cd.actual_cost,
- cd.prediction_date as cost_estimate_date,
- cd.predicted_cost,
- (
- SELECT JSON_ARRAYAGG(
- CASE
- WHEN csp.id IS NOT NULL THEN
- JSON_OBJECT(
- 'id', csp.id,
- 'param_name', csp.param_name,
- 'param_value', csp.param_value,
- 'param_unit', csp.param_unit,
- 'description', csp.description
- )
- END
- )
- FROM custom_params csp
- WHERE csp.equipment_id = e.id
- AND csp.param_name IS NOT NULL
- AND csp.param_value IS NOT NULL
- ) as custom_params
- FROM equipment e
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE e.id = %s
- """
- else:
- query = """
- SELECT
- e.*,
- cp.*,
- lmp.*,
- cd.actual_cost,
- cd.prediction_date as cost_estimate_date,
- cd.predicted_cost,
- (
- SELECT JSON_ARRAYAGG(
- CASE
- WHEN csp.id IS NOT NULL THEN
- JSON_OBJECT(
- 'id', csp.id,
- 'param_name', csp.param_name,
- 'param_value', csp.param_value,
- 'param_unit', csp.param_unit,
- 'description', csp.description
- )
- END
- )
- FROM custom_params csp
- WHERE csp.equipment_id = e.id
- AND csp.param_name IS NOT NULL
- AND csp.param_value IS NOT NULL
- ) as custom_params
- FROM equipment e
- LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE e.id = %s
- """
-
- cursor.execute(query, (id,))
- result = cursor.fetchone()
-
- if result:
- logger.info(f"Found equipment details: {result['name']}")
- logger.info(f"Custom params: {result.get('custom_params')}")
-
- return jsonify(result)
-
- except Exception as e:
- logger.error(f"Error getting equipment details: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-# 添加数据集相关的路由
-@api_bp.route('/datasets', methods=['GET'])
-def get_datasets():
- """
- 获取数据集列表
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT d.*,
- COUNT(de.equipment_id) as equipment_count,
- GROUP_CONCAT(e.name) as equipment_names
- FROM datasets d
- LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
- LEFT JOIN equipment e ON de.equipment_id = e.id
- GROUP BY d.id
- """)
- datasets = cursor.fetchall()
-
- # 理装备名称列表
- for dataset in datasets:
- if dataset['equipment_names']:
- dataset['equipment_names'] = dataset['equipment_names'].split(',')
- else:
- dataset['equipment_names'] = []
-
- return jsonify(datasets)
- except Exception as e:
- logger.error(f"Error getting datasets: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/datasets/', methods=['GET'])
-def get_dataset(id):
- """
- 获取数据集详情
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- # 获取数据集基本信息
- cursor.execute("""
- SELECT d.*,
- COUNT(de.equipment_id) as equipment_count
- FROM datasets d
- LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
- WHERE d.id = %s
- GROUP BY d.id
- """, (id,))
- dataset = cursor.fetchone()
-
- if not dataset:
- return jsonify({'error': 'Dataset not found'}), 404
-
- # 获取数据集中的装备
- cursor.execute("""
- SELECT e.*, cd.actual_cost
- FROM equipment e
- JOIN dataset_equipment de ON e.id = de.equipment_id
- LEFT JOIN cost_data cd ON e.id = cd.equipment_id
- WHERE de.dataset_id = %s
- """, (id,))
- equipment = cursor.fetchall()
-
- # 计算统计信息
- if equipment:
- total_cost = sum(item['actual_cost'] or 0 for item in equipment)
- avg_cost = total_cost / len(equipment)
- dataset['statistics'] = {
- 'equipment_count': len(equipment),
- 'total_cost': total_cost,
- 'average_cost': avg_cost
- }
- else:
- dataset['statistics'] = {
- 'equipment_count': 0,
- 'total_cost': 0,
- 'average_cost': 0
- }
-
- dataset['equipment'] = equipment
- return jsonify(dataset)
- except Exception as e:
- logger.error(f"Error getting dataset: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/datasets', methods=['POST'])
-def create_dataset():
- """
- 建数据集
- """
- try:
- data = request.get_json()
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 创建数据集
- cursor.execute("""
- INSERT INTO datasets (name, description, equipment_type, purpose)
- VALUES (%s, %s, %s, %s)
- """, (data['name'], data['description'], data['equipment_type'], data['purpose']))
-
- dataset_id = cursor.lastrowid
-
- # 添加装备关联
- if 'equipment_ids' in data and data['equipment_ids']:
- values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']]
- cursor.executemany("""
- INSERT INTO dataset_equipment (dataset_id, equipment_id)
- VALUES (%s, %s)
- """, values)
-
- conn.commit()
- return jsonify({'id': dataset_id, 'message': '数据集创建成功'})
- except Exception as e:
- logger.error(f"Error creating dataset: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/datasets/', methods=['PUT'])
-def update_dataset(id):
- """
- 更新数据集
- """
- try:
- data = request.get_json()
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 更新数据集基本信息
- cursor.execute("""
- UPDATE datasets
- SET name = %s, description = %s, equipment_type = %s, purpose = %s
- WHERE id = %s
- """, (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
-
- # 删除旧的装备关联
- cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
-
- # 加新的装备关联
- if 'equipment_ids' in data:
- for equipment_id in data['equipment_ids']:
- cursor.execute("""
- INSERT INTO dataset_equipment (dataset_id, equipment_id)
- VALUES (%s, %s)
- """, (id, equipment_id))
-
- conn.commit()
- return jsonify({'success': True})
- except Exception as e:
- logger.error(f"Error updating dataset: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/datasets/', methods=['DELETE'])
-def delete_dataset(id):
- """
- 删除数据集
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 删除装备关联
- cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
-
- # 删除数据集
- cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
-
- conn.commit()
- return jsonify({'success': True})
- except Exception as e:
- logger.error(f"Error deleting dataset: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/models//latest', methods=['GET'])
-def get_latest_model(equipment_type):
- """
- 获取最新训练的型信息
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT * FROM trained_models
- WHERE equipment_type = %s AND is_active = TRUE
- ORDER BY training_date DESC LIMIT 1
- """, (equipment_type,))
-
- model = cursor.fetchone()
- return jsonify(model)
-
- except Exception as e:
- logger.error(f"Error getting latest model: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/models', methods=['GET'])
-def get_models():
- """
- 获取模型列表
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor(dictionary=True)
- cursor.execute("""
- SELECT * FROM trained_models
- ORDER BY training_date DESC
- """)
-
- models = cursor.fetchall()
-
- # 确保数值类型字段是 float
- for model in models:
- if model['r2_score'] is not None:
- model['r2_score'] = float(model['r2_score'])
- if model['mae'] is not None:
- model['mae'] = float(model['mae'])
- if model['rmse'] is not None:
- model['rmse'] = float(model['rmse'])
-
- # 解析特征重要性
- if model['feature_importance']:
- model['feature_importance'] = json.loads(model['feature_importance'])
-
- return jsonify(models)
-
- except Exception as e:
- logger.error(f"Error getting models: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/models//activate', methods=['POST'])
-def activate_model(id):
- """
- 激活指定的模型
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 获取模型信息
- cursor.execute("""
- SELECT equipment_type FROM trained_models
- WHERE id = %s
- """, (id,))
- model = cursor.fetchone()
-
- if not model:
- return jsonify({'error': 'Model not found'}), 404
-
- # 将同类型的其他模型设置为非激活
- cursor.execute("""
- UPDATE trained_models
- SET is_active = FALSE
- WHERE equipment_type = %s
- """, (model[0],))
-
- # 激活指定模型
- cursor.execute("""
- UPDATE trained_models
- SET is_active = TRUE
- WHERE id = %s
- """, (id,))
-
- conn.commit()
- return jsonify({'success': True})
-
- except Exception as e:
- logger.error(f"Error activating model: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/models/', methods=['DELETE'])
-def delete_model(id):
- """
- 删除指定的模型
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 获取模型文件路径
- cursor.execute("""
- SELECT model_path, scaler_path
- FROM trained_models
- WHERE id = %s
- """, (id,))
- model = cursor.fetchone()
-
- if not model:
- return jsonify({'error': 'Model not found'}), 404
-
- # 删除模型文件
- if os.path.exists(model[0]):
- os.remove(model[0])
- if os.path.exists(model[1]):
- os.remove(model[1])
-
- # 删除数据库记录
- cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
- conn.commit()
-
- return jsonify({'success': True})
-
- except Exception as e:
- logger.error(f"Error deleting model: {str(e)}")
- return jsonify({'error': str(e)}), 500
-
-@api_bp.route('/predict/all', methods=['POST'])
-def predict_all():
- """
- 获取所有机器学习模型的预测结果
- """
- try:
- data = request.get_json()
- logger.info(f"Received prediction request for all models, equipment type: {data.get('type')}")
-
- predictor = CostPredictor()
- results = predictor.predict_all(data)
-
- return jsonify(results)
-
- except Exception as e:
- logger.error(f"Error in prediction: {str(e)}")
- return jsonify({'error': str(e)}), 500
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/schema.sql b/deploy/equipment_cost_system/src/schema.sql
deleted file mode 100644
index a0abd10..0000000
--- a/deploy/equipment_cost_system/src/schema.sql
+++ /dev/null
@@ -1,140 +0,0 @@
--- 如果数据库已存在则删除
-DROP DATABASE IF EXISTS equipment_cost_db;
-
--- 创建数据库
-CREATE DATABASE equipment_cost_db
-DEFAULT CHARACTER SET utf8mb4
-COLLATE utf8mb4_unicode_ci;
-
--- 使用数据库
-USE equipment_cost_db;
-
--- 装备基本信息表
-CREATE TABLE equipment (
- id INT AUTO_INCREMENT PRIMARY KEY,
- name VARCHAR(100), -- 名称
- type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
- manufacturer VARCHAR(100), -- 制造商
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 通用参数表
-CREATE TABLE common_params (
- id INT AUTO_INCREMENT PRIMARY KEY,
- equipment_id INT,
- length_m FLOAT, -- 总长(m)
- width_m FLOAT, -- 宽度(m)
- height_m FLOAT, -- 高度(m)
- weight_kg FLOAT, -- 重量(kg)
- max_range_km FLOAT, -- 最大射程(km)
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 火箭炮特有参数表
-CREATE TABLE rocket_artillery_params (
- id INT AUTO_INCREMENT PRIMARY KEY,
- equipment_id INT,
- firing_angle_horizontal FLOAT, -- 方向射界(度)
- firing_angle_vertical FLOAT, -- 高低射界(度)
- rocket_length_m FLOAT, -- 火箭弹长度(m)
- rocket_diameter_mm FLOAT, -- 弹体直径(mm)
- rocket_weight_kg FLOAT, -- 火箭弹重量(kg)
- rate_of_fire FLOAT, -- 射速(发/分钟)
- combat_weight_kg FLOAT, -- 战斗重量(kg)
- speed_kmh FLOAT, -- 速度(km/h)
- min_range_km FLOAT, -- 最小射程(km)
- mobility_type VARCHAR(50), -- 行走方式
- structure_layout VARCHAR(100), -- 结构布局
- engine_model VARCHAR(100), -- 发动机型号
- engine_params TEXT, -- 发动机参数
- power_hp FLOAT, -- 功率(hp)
- travel_range_km FLOAT, -- 行程(km)
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 巡飞弹特有参数表
-CREATE TABLE loitering_munition_params (
- id INT AUTO_INCREMENT PRIMARY KEY,
- equipment_id INT,
- wingspan_m FLOAT, -- 翼展(m)
- warhead_weight_kg FLOAT, -- 战斗部重量(kg)
- max_speed_ms FLOAT, -- 最大速度(m/s)
- cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
- flight_time_min FLOAT, -- 巡飞时间(min)
- warhead_type VARCHAR(50), -- 战斗部类型
- launch_mode VARCHAR(50), -- 发射方式
- folded_length_mm FLOAT, -- 折叠长度(mm)
- folded_width_mm FLOAT, -- 折叠宽度(mm)
- folded_height_mm FLOAT, -- 折叠高度(mm)
- power_system VARCHAR(100), -- 动力装置
- guidance_system VARCHAR(100), -- 制导体制
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 成本数据表
-CREATE TABLE cost_data (
- id INT AUTO_INCREMENT PRIMARY KEY,
- equipment_id INT,
- actual_cost DECIMAL(15,2), -- 实际成本(元)
- predicted_cost DECIMAL(15,2), -- 预测成本(元)
- prediction_date TIMESTAMP, -- 预测日期
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 特殊参数表
-CREATE TABLE custom_params (
- id INT AUTO_INCREMENT PRIMARY KEY,
- equipment_id INT,
- param_name VARCHAR(100), -- 参数名称
- param_value VARCHAR(255), -- 参数值
- param_unit VARCHAR(50), -- 参数单位
- description TEXT, -- 参数说明
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 添加索引
-CREATE INDEX idx_equipment_type ON equipment(type);
-CREATE INDEX idx_equipment_name ON equipment(name);
-CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
-
--- 数据集表
-CREATE TABLE datasets (
- id INT AUTO_INCREMENT PRIMARY KEY,
- name VARCHAR(100) NOT NULL, -- 数据集名称
- description TEXT, -- 数据集描述
- equipment_type VARCHAR(50) NOT NULL, -- 装备类型
- purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证)
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
-);
-
--- 数据集-装备关联表
-CREATE TABLE dataset_equipment (
- dataset_id INT NOT NULL,
- equipment_id INT NOT NULL,
- PRIMARY KEY (dataset_id, equipment_id),
- FOREIGN KEY (dataset_id) REFERENCES datasets(id),
- FOREIGN KEY (equipment_id) REFERENCES equipment(id)
-);
-
--- 训练模型表
-CREATE TABLE trained_models (
- id INT AUTO_INCREMENT PRIMARY KEY,
- model_name VARCHAR(100) NOT NULL, -- 模型名称
- model_type VARCHAR(50) NOT NULL, -- 模型类型
- equipment_type VARCHAR(50) NOT NULL, -- 装备类型
- model_path VARCHAR(255) NOT NULL, -- 模型文件路径
- scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径
- r2_score FLOAT, -- R²分数
- mae FLOAT, -- 平均绝对误差
- rmse FLOAT, -- 均方根误差
- feature_importance JSON, -- 特征重要性
- training_data_size INT, -- 训练数据量
- training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
- is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
- created_by VARCHAR(50) -- 创建者
-) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-
--- 添加索引
-CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
-CREATE INDEX idx_model_active ON trained_models(is_active);
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/test_api.py b/deploy/equipment_cost_system/src/test_api.py
deleted file mode 100644
index f68d88d..0000000
--- a/deploy/equipment_cost_system/src/test_api.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import requests
-import json
-import logging
-from datetime import datetime
-import os
-import sys
-
-# 添加项目根目录到 Python 路径
-current_dir = os.path.dirname(os.path.abspath(__file__))
-project_root = os.path.dirname(current_dir)
-if project_root not in sys.path:
- sys.path.insert(0, project_root)
-
-# 配置基本日志
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('logs/test_api.log'),
- logging.StreamHandler()
- ]
-)
-
-logger = logging.getLogger(__name__)
-
-def test_api_endpoints():
- """
- 测试 API 各个端点
- """
- base_url = 'http://localhost:5001/api'
-
- try:
- # 1. 测试根路由
- logger.info("\n1. 测试 API 根路由")
- response = requests.get(f'{base_url}/')
- print_response(response, "API 根路由")
-
- # 2. 测试机器学习预测接口
- logger.info("\n2. 测试机器学习预测接口")
- predict_data = {
- "type": "巡飞弹",
- "length_m": 1.3,
- "width_m": 0.23,
- "height_m": 0.23,
- "weight_kg": 12.5,
- "max_range_km": 40,
- "max_speed_ms": 50,
- "cruise_speed_kmh": 100,
- "flight_time_min": 60,
- "folded_length_mm": 1300,
- "folded_width_mm": 230,
- "folded_height_mm": 230,
- "warhead_type": "破片杀伤战斗部",
- "launch_mode": "凭自身动力起飞"
- }
-
- response = requests.post(
- f'{base_url}/predict',
- json=predict_data
- )
- print_response(response, "机器学习预测")
-
- # 3. 测试 PLS 预测接口
- logger.info("\n3. 测试 PLS 预测接口")
- response = requests.post(
- f'{base_url}/pls/predict',
- json=predict_data
- )
- print_response(response, "PLS 预测")
-
- # 4. 测试特征分析接口
- logger.info("\n4. 测试特征分析接口")
- analysis_data = {
- "dataset_id": 1, # 假设数据集 ID 为 1
- "equipment_type": "巡飞弹"
- }
-
- response = requests.post(
- f'{base_url}/analyze-features',
- json=analysis_data
- )
- print_response(response, "特征分析")
-
- # 5. 测试机器学习模型训练接口
- logger.info("\n5. 测试机器学习模型训练接口")
- training_data = {
- "type": "巡飞弹",
- "train_dataset_id": 1, # 训练数据集 ID
- "validation_dataset_id": 2, # 验证数据集 ID
- "models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型
- }
-
- response = requests.post(
- f'{base_url}/train',
- json=training_data
- )
- print_response(response, "模型训练")
-
- # 6. 测试数据集相关接口
- logger.info("\n6. 测试数据集相关接口")
-
- # 6.1 获取数据集列表
- response = requests.get(f'{base_url}/datasets')
- print_response(response, "获取数据集列表")
-
- # 6.2 获取可用的装备列表
- response = requests.get(f'{base_url}/data')
- equipment_data = response.json()
-
- # 获取巡飞弹类型的装备ID
- available_equipment_ids = []
- if 'loitering_munition' in equipment_data:
- available_equipment_ids = [
- item['id']
- for item in equipment_data['loitering_munition']
- if item['id'] is not None
- ][:3] # 取前3个可用的ID
-
- if not available_equipment_ids:
- logger.warning("没有找到可用的装备ID,跳过创建数据集测试")
- else:
- # 6.3 创建新数据集
- new_dataset = {
- "name": "测试数据集",
- "description": "用于测试的数据集",
- "equipment_type": "巡飞弹",
- "purpose": "训练",
- "equipment_ids": available_equipment_ids
- }
-
- logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}")
-
- response = requests.post(
- f'{base_url}/datasets',
- json=new_dataset
- )
- print_response(response, "创建数据集")
-
- # 7. 测试模型相关接口
- logger.info("\n7. 测试模型相关接口")
-
- # 7.1 获取模型列表
- response = requests.get(f'{base_url}/models')
- print_response(response, "获取模型列表")
-
- # 7.2 获取最新模型
- response = requests.get(f'{base_url}/models/巡飞弹/latest')
- print_response(response, "获取最新模型")
-
- # 8. 测试多模型预测接口
- logger.info("\n8. 测试多模型预测接口")
- response = requests.post(
- f'{base_url}/predict/all',
- json=predict_data
- )
- print_response(response, "多模型预测")
-
- logger.info("所有测试完成")
-
- except requests.exceptions.RequestException as e:
- logger.error(f"API 请求错误: {str(e)}")
- except Exception as e:
- logger.error(f"测试过程中出现错误: {str(e)}")
-
-def print_response(response, endpoint_name):
- """
- 打印响应结果
- """
- try:
- logger.info(f"\n=== {endpoint_name} 测试结果 ===")
- logger.info(f"状态码: {response.status_code}")
-
- if response.status_code == 200:
- result = response.json()
- logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}")
- else:
- logger.error(f"错误响应:\n{response.text}")
-
- except Exception as e:
- logger.error(f"处理响应时出错: {str(e)}")
-
-if __name__ == "__main__":
- try:
- # 确保日志目录存在
- os.makedirs('logs', exist_ok=True)
-
- logger.info(f"=== API 测试开始 - {datetime.now()} ===")
- test_api_endpoints()
- logger.info(f"=== API 测试结束 - {datetime.now()} ===")
- except Exception as e:
- logger.error(f"测试执行失败: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/venv b/deploy/equipment_cost_system/venv
deleted file mode 100644
index e69de29..0000000
diff --git a/deploy/setup.md b/deploy/setup.md
deleted file mode 100644
index b905ab0..0000000
--- a/deploy/setup.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# 安装说明
-
-## 1. 系统要求
-
-- Linux 服务器 (推荐 Ubuntu 20.04+)
-- Python 3.8+
-- MySQL 8.0+
-
-## 2. 安装部署
-
-```bash
-# 解压部署包
-tar -xzf equipment_cost_prediction.tar.gz
-cd equipment_cost_prediction
-
-# 运行安装脚本
-bash scripts/install.sh
-
-# 修改配置文件
-vim config/.env
-
-# 启动服务
-bash scripts/start.sh
-```
-
-## 3. 验证部署
-
-```bash
-# 检查服务状态
-curl http://localhost:5001/api/
-
-# 检查日志
-tail -f logs/api.log
-```
diff --git a/docs/dev/debug.md b/docs/dev/debug.md
index ce13431..13c274c 100644
--- a/docs/dev/debug.md
+++ b/docs/dev/debug.md
@@ -643,3 +643,11 @@ trainingResult.value = null
- Feature max_speed_ms missing rate: 77.78%
- Feature cruise_speed_kmh missing rate: 61.11%
- Feature flight_time_min missing rate: 33.33%
+
+## 前端特征分析页面未正确显示相关性分析数据(常见问题)
+
+- 确保相关性分析数据的正确格式化和返回
+- 添加了详细的日志记录
+- 增加了数据验证步骤
+- 处理了可能的NaN值
+- 确保所有数值都转换为Python原生类型(使用float())
diff --git a/frontend/src/views/AnalysisPage.vue b/frontend/src/views/AnalysisPage.vue
index 385b995..6e47559 100644
--- a/frontend/src/views/AnalysisPage.vue
+++ b/frontend/src/views/AnalysisPage.vue
@@ -11,14 +11,23 @@
-
+
-
+
特征重要性
相关性分析
+
+
+
+
+ 特征工程参数分析
+
+
+
+ 发动机性能与作战参数分析
+
+
@@ -82,6 +106,10 @@ const analyzing = ref(false)
const analysisResult = ref(null)
const importanceChartRef = ref(null)
const correlationChartRef = ref(null)
+const newFeatureChartRef = ref(null)
+const engineChartRef = ref(null)
+const newFeatureChart = ref(null)
+const engineChart = ref(null)
// 图表实例引用
const importanceChart = ref(null)
@@ -103,31 +131,74 @@ watch(() => analysisResult.value, async (newResult) => {
// 加载可用数据集
const loadDatasets = async (type) => {
try {
+ // 如果没有传入type,直接返回
+ if (!type) {
+ availableDatasets.value = []
+ return
+ }
+
+ console.log('Loading datasets for type:', type)
const response = await axios.get(`${API_BASE_URL}/datasets`, {
params: { equipment_type: type, purpose: '训练' }
})
+
+ // 验证响应数据
+ if (!response.data) {
+ console.warn('No datasets returned from API')
+ availableDatasets.value = []
+ return
+ }
+
+ console.log('Datasets loaded:', response.data)
availableDatasets.value = response.data
} catch (error) {
+ console.error('Error loading datasets:', error)
ElMessage.error('获取数据集列表失败')
+ availableDatasets.value = []
}
}
// 处理装备类型变化
const handleEquipmentTypeChange = () => {
+ console.log('Equipment type changed to:', analysisForm.value.equipment_type)
+
+ // 重置相关状态
analysisForm.value.dataset_id = null
selectedDataset.value = null
analysisResult.value = null
- loadDatasets(analysisForm.value.equipment_type)
+
+ // 只有当有装备类型时才加载数据集
+ if (analysisForm.value.equipment_type) {
+ loadDatasets(analysisForm.value.equipment_type)
+ } else {
+ availableDatasets.value = []
+ }
}
// 处理数据集选择变化
const handleDatasetChange = async () => {
try {
+ // 如果没有选择数据集,直接返回
+ if (!analysisForm.value.dataset_id) {
+ selectedDataset.value = null
+ analysisResult.value = null
+ return
+ }
+
+ console.log('Dataset changed to:', analysisForm.value.dataset_id)
const response = await axios.get(`${API_BASE_URL}/datasets/${analysisForm.value.dataset_id}`)
+
+ // 验证响应数据
+ if (!response.data) {
+ throw new Error('获取数据集详情失败:服务器返回空数据')
+ }
+
selectedDataset.value = response.data
analysisResult.value = null
} catch (error) {
- ElMessage.error('获取数据集详情失败')
+ console.error('Error getting dataset details:', error)
+ ElMessage.error(error.message || '获取数据集详情失败')
+ selectedDataset.value = null
}
}
@@ -140,14 +211,69 @@ const startAnalysis = async () => {
analyzing.value = true
try {
+ // 打印请求参数
+ console.log('Analysis request params:', {
+ dataset_id: analysisForm.value.dataset_id,
+ equipment_type: analysisForm.value.equipment_type
+ })
+
const response = await axios.post(`${API_BASE_URL}/analyze-features`, {
dataset_id: analysisForm.value.dataset_id
})
+
+ // 打印原始响应数据
+ console.log('Raw API response:', response)
+ console.log('Response data type:', typeof response.data)
+ console.log('Response data:', response.data)
+
+ // 检查响应数据的结构
+ if (!response.data) {
+ throw new Error('API返回的数据为空')
+ }
+
+ // 确保数据正确赋值
analysisResult.value = response.data
- console.log('Analysis completed, result:', analysisResult.value)
+
+ // 验证数赋值是否成功
+ console.log('Analysis result after assignment:', {
+ value: analysisResult.value,
+ important_features: analysisResult.value?.important_features,
+ correlation_analysis: analysisResult.value?.correlation_analysis,
+ equipment_names: analysisResult.value?.equipment_names,
+ length_width_ratio: analysisResult.value?.length_width_ratio
+ })
+
+ // 如果是巡飞弹类型,检查特定数据
+ if (analysisForm.value.equipment_type === '巡飞弹') {
+ const missileData = {
+ equipment_names: analysisResult.value?.equipment_names || [],
+ length_width_ratio: analysisResult.value?.length_width_ratio || [],
+ engine_power_kw: analysisResult.value?.engine_power_kw || [],
+ guidance_system_score: analysisResult.value?.guidance_system_score || [],
+ warhead_power_score: analysisResult.value?.warhead_power_score || []
+ }
+
+ console.log('Missile specific data:', missileData)
+
+ // 验证数据完整性
+ const missingFields = Object.entries(missileData)
+ .filter(([key, value]) => !Array.isArray(value) || value.length === 0)
+ .map(([key]) => key)
+
+ if (missingFields.length > 0) {
+ console.warn('Missing or empty missile data fields:', missingFields)
+ ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`)
+ }
+ }
+
} catch (error) {
- ElMessage.error('特征分析失败')
console.error('Analysis error:', error)
+ console.error('Error details:', {
+ message: error.message,
+ response: error.response?.data,
+ status: error.response?.status
+ })
+ ElMessage.error(error.message || '特征析失败')
} finally {
analyzing.value = false
}
@@ -166,6 +292,12 @@ const createResizeHandler = () => {
if (correlationChart.value && !correlationChart.value.isDisposed()) {
correlationChart.value.resize()
}
+ if (newFeatureChart.value && !newFeatureChart.value.isDisposed()) {
+ newFeatureChart.value.resize()
+ }
+ if (engineChart.value && !engineChart.value.isDisposed()) {
+ engineChart.value.resize()
+ }
} catch (error) {
console.error('Error in resize handler:', error)
}
@@ -180,6 +312,11 @@ onMounted(() => {
// 创建并保存 resize 处理函数的引用
resizeHandler.value = createResizeHandler()
window.addEventListener('resize', resizeHandler.value)
+
+ // 如果已经选择了装备类型,加载对应的数据集
+ if (analysisForm.value.equipment_type) {
+ loadDatasets(analysisForm.value.equipment_type)
+ }
})
// 组件卸载时
@@ -190,131 +327,353 @@ onUnmounted(() => {
resizeHandler.value = null
}
- // 销毁图表实例
- try {
- if (importanceChart.value) {
- importanceChart.value.dispose()
- importanceChart.value = null
+ // 销毁所有图表实例
+ [importanceChart, correlationChart, newFeatureChart, engineChart].forEach(chart => {
+ if (chart.value && !chart.value.isDisposed()) {
+ try {
+ chart.value.dispose()
+ } catch (e) {
+ console.error('Error disposing chart:', e)
+ }
+ chart.value = null
}
- if (correlationChart.value) {
- correlationChart.value.dispose()
- correlationChart.value = null
- }
- } catch (error) {
- console.error('Error disposing charts:', error)
- }
+ })
})
// 渲染图表
const renderCharts = () => {
console.log('Starting to render charts')
- // 检查分析结果是否存在且包含必要数据
- if (!analysisResult.value ||
- !analysisResult.value.important_features ||
- !analysisResult.value.correlation_analysis) {
- console.log('Analysis result not ready')
+ if (!analysisResult.value) {
+ console.error('No analysis result available')
return
}
-
- // 检查DOM元素
- if (!importanceChartRef.value || !correlationChartRef.value) {
- console.log('Chart DOM elements not ready')
- return
- }
-
+
try {
- // 销毁旧的图表实例
- if (importanceChart.value) {
- importanceChart.value.dispose()
- importanceChart.value = null
- }
- if (correlationChart.value) {
- correlationChart.value.dispose()
- correlationChart.value = null
- }
-
- // 创建新的图表实例
- importanceChart.value = echarts.init(importanceChartRef.value)
- correlationChart.value = echarts.init(correlationChartRef.value)
-
- // 设置图表选项
- const importanceOption = {
- title: { text: '特征重要性排序' },
- tooltip: {},
- xAxis: {
- type: 'value',
- name: '重要性得分'
- },
- yAxis: {
- type: 'category',
- data: analysisResult.value.important_features.map(f => f.name)
- },
- series: [{
- name: '重要性',
- type: 'bar',
- data: analysisResult.value.important_features.map(f => f.importance)
- }]
- }
-
- const correlationOption = {
- title: { text: '特征相关性热力图' },
- tooltip: {
- position: 'top',
- formatter: function (params) {
- const value = params.data[2].toFixed(2)
- const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]]
- const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]]
- return `${feature1} 与 ${feature2} 的相关性: ${value}`
- }
- },
- grid: {
- height: '50%',
- top: '10%'
- },
- xAxis: {
- type: 'category',
- data: analysisResult.value.correlation_analysis.features,
- splitArea: { show: true },
- axisLabel: {
- interval: 0,
- rotate: 45
- }
- },
- yAxis: {
- type: 'category',
- data: analysisResult.value.correlation_analysis.features,
- splitArea: { show: true }
- },
- visualMap: {
- min: -1,
- max: 1,
- calculable: true,
- orient: 'horizontal',
- left: 'center',
- bottom: '15%',
- color: ['#cc3333', '#eeeeee', '#00007f']
- },
- series: [{
- name: '相关性',
- type: 'heatmap',
- data: analysisResult.value.correlation_analysis.matrix,
- label: {
- show: true,
- formatter: function(params) {
- return params.data[2].toFixed(2)
+ // 先销毁所有现有的图表实例
+ [importanceChart, correlationChart, newFeatureChart, engineChart].forEach(chart => {
+ if (chart.value && !chart.value.isDisposed()) {
+ chart.value.dispose()
+ chart.value = null
+ }
+ })
+
+ // 等待 DOM 更新
+ nextTick(() => {
+ try {
+ // 重新初始化基本图表
+ if (importanceChartRef.value && correlationChartRef.value) {
+ importanceChart.value = echarts.init(importanceChartRef.value)
+ correlationChart.value = echarts.init(correlationChartRef.value)
+
+ // 设置基本图表的选项
+ const importanceOption = {
+ title: { text: '特征重要性排序' },
+ tooltip: {
+ trigger: 'axis',
+ axisPointer: {
+ type: 'shadow'
+ },
+ formatter: function(params) {
+ const data = params[0]
+ return `${data.name}: ${data.value.toFixed(4)}`
+ }
+ },
+ xAxis: {
+ type: 'value',
+ name: '重要性得分'
+ },
+ yAxis: {
+ type: 'category',
+ data: analysisResult.value.important_features.map(f => f.name)
+ },
+ series: [{
+ name: '重要性',
+ type: 'bar',
+ data: analysisResult.value.important_features.map(f => f.importance),
+ itemStyle: {
+ color: '#3a5fcd'
+ }
+ }]
}
+
+ const correlationOption = {
+ title: { text: '特征相关性热力图' },
+ tooltip: {
+ position: 'top',
+ trigger: 'item',
+ formatter: function (params) {
+ if (!params.data) return ''
+ const value = params.data[2].toFixed(2)
+ const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]]
+ const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]]
+ return `${feature1} 与 ${feature2}
相关性: ${value}`
+ }
+ },
+ grid: {
+ height: '75%',
+ top: '10%',
+ bottom: '15%',
+ left: '10%',
+ right: '10%',
+ containLabel: true
+ },
+ xAxis: {
+ type: 'category',
+ data: analysisResult.value.correlation_analysis.features,
+ splitArea: { show: true },
+ axisLabel: {
+ interval: 0,
+ rotate: 45,
+ margin: 15
+ }
+ },
+ yAxis: {
+ type: 'category',
+ data: analysisResult.value.correlation_analysis.features,
+ splitArea: { show: true },
+ axisLabel: {
+ interval: 0,
+ margin: 15
+ }
+ },
+ visualMap: {
+ min: -1,
+ max: 1,
+ calculable: true,
+ orient: 'horizontal',
+ left: 'center',
+ bottom: '5%',
+ inRange: {
+ color: ['#cc3333', '#eeeeee', '#00007f']
+ }
+ },
+ series: [{
+ name: '相关性',
+ type: 'heatmap',
+ data: analysisResult.value.correlation_analysis.matrix,
+ emphasis: {
+ itemStyle: {
+ shadowBlur: 10,
+ shadowColor: 'rgba(0, 0, 0, 0.5)'
+ }
+ },
+ label: {
+ show: true,
+ formatter: function(params) {
+ return params.data[2].toFixed(2)
+ }
+ },
+ itemStyle: {
+ borderWidth: 1,
+ borderColor: '#fff'
+ }
+ }]
+ }
+
+ // 使用 clear 方法清除旧的图表内容
+ importanceChart.value.clear()
+ correlationChart.value.clear()
+
+ // 设置新的选项
+ importanceChart.value.setOption(importanceOption, { notMerge: true })
+ correlationChart.value.setOption(correlationOption, { notMerge: true })
}
- }]
- }
-
- // 设置图表选项
- importanceChart.value.setOption(importanceOption)
- correlationChart.value.setOption(correlationOption)
-
- console.log('Charts rendered successfully')
+
+ // 如果是巡飞弹类型,渲染额外的图表
+ if (analysisForm.value.equipment_type === '巡飞弹' &&
+ newFeatureChartRef.value &&
+ engineChartRef.value) {
+
+ // 初始化巡飞弹特有图表
+ newFeatureChart.value = echarts.init(newFeatureChartRef.value)
+ engineChart.value = echarts.init(engineChartRef.value)
+
+ // 准备数据
+ const chartData = {
+ names: analysisResult.value.equipment_names || [],
+ lengthWidthRatio: analysisResult.value.length_width_ratio || [],
+ weightRangeRatio: analysisResult.value.weight_range_ratio || [],
+ speedWeightRatio: analysisResult.value.speed_weight_ratio || [],
+ guidanceSystemScore: analysisResult.value.guidance_system_score || [],
+ warheadPowerScore: analysisResult.value.warhead_power_score || [],
+ enginePowerKw: analysisResult.value.engine_power_kw || [],
+ engineThrustN: analysisResult.value.engine_thrust_n || [],
+ minAltitudeM: analysisResult.value.min_altitude_m || [],
+ maxAltitudeM: analysisResult.value.max_altitude_m || []
+ }
+
+ // 特征工程参数分析图表配置
+ const newFeatureOption = {
+ animation: false, // 禁用动画
+ title: {
+ text: '特征工程参数分析',
+ left: 'center'
+ },
+ tooltip: {
+ trigger: 'axis',
+ axisPointer: {
+ type: 'cross'
+ },
+ formatter: function(params) {
+ // 获取当前数据点对应的装备名称
+ const equipmentName = chartData.names[params[0].dataIndex]
+ let result = `${equipmentName}
`
+ // 添加每个系列的数据
+ params.forEach(param => {
+ result += `${param.seriesName}: ${param.value.toFixed(2)}
`
+ })
+ return result
+ }
+ },
+ legend: {
+ top: 30,
+ data: ['长宽比', '重量射程比', '速度重量比', '制导系统评分', '战斗部威力评分']
+ },
+ grid: {
+ top: 80,
+ bottom: 50,
+ containLabel: true
+ },
+ xAxis: {
+ type: 'category',
+ data: chartData.names
+ },
+ yAxis: [
+ {
+ type: 'value',
+ name: '比率',
+ position: 'left'
+ },
+ {
+ type: 'value',
+ name: '评分',
+ position: 'right',
+ min: 0,
+ max: 10
+ }
+ ],
+ series: [
+ {
+ name: '长宽比',
+ type: 'line',
+ data: chartData.lengthWidthRatio
+ },
+ {
+ name: '重量射程比',
+ type: 'line',
+ data: chartData.weightRangeRatio
+ },
+ {
+ name: '速度重量比',
+ type: 'line',
+ data: chartData.speedWeightRatio
+ },
+ {
+ name: '制导系统评分',
+ type: 'bar',
+ yAxisIndex: 1,
+ data: chartData.guidanceSystemScore
+ },
+ {
+ name: '战斗部威力评分',
+ type: 'bar',
+ yAxisIndex: 1,
+ data: chartData.warheadPowerScore
+ }
+ ]
+ }
+
+ // 发动机性能分析图表配置
+ const engineOption = {
+ animation: false, // 禁用动画
+ title: {
+ text: '发动机性能与作战参数分析',
+ left: 'center'
+ },
+ tooltip: {
+ trigger: 'axis',
+ axisPointer: {
+ type: 'cross'
+ },
+ formatter: function(params) {
+ // 获取当前数据点对应的装备名称
+ const equipmentName = chartData.names[params[0].dataIndex]
+ let result = `${equipmentName}
`
+ // 添加每个系列的数据
+ params.forEach(param => {
+ result += `${param.seriesName}: ${param.value.toFixed(2)}
`
+ })
+ return result
+ }
+ },
+ legend: {
+ top: 30,
+ data: ['发动机功率(kw)', '发动机推力(N)', '最小作战高度(m)', '最大作战高度(m)']
+ },
+ grid: {
+ top: 80,
+ bottom: 50,
+ containLabel: true
+ },
+ xAxis: {
+ type: 'category',
+ data: chartData.names
+ },
+ yAxis: [
+ {
+ type: 'value',
+ name: '功率/推力',
+ position: 'left'
+ },
+ {
+ type: 'value',
+ name: '高度(m)',
+ position: 'right'
+ }
+ ],
+ series: [
+ {
+ name: '发动机功率(kw)',
+ type: 'bar',
+ data: chartData.enginePowerKw
+ },
+ {
+ name: '发动机推力(N)',
+ type: 'bar',
+ data: chartData.engineThrustN
+ },
+ {
+ name: '最小作战高度(m)',
+ type: 'line',
+ yAxisIndex: 1,
+ data: chartData.minAltitudeM
+ },
+ {
+ name: '最大作战高度(m)',
+ type: 'line',
+ yAxisIndex: 1,
+ data: chartData.maxAltitudeM
+ }
+ ]
+ }
+
+ // 清除旧的内容
+ newFeatureChart.value.clear()
+ engineChart.value.clear()
+
+ // 设置新的选项
+ newFeatureChart.value.setOption(newFeatureOption, { notMerge: true })
+ engineChart.value.setOption(engineOption, { notMerge: true })
+ }
+
+ console.log('Charts rendered successfully')
+ } catch (error) {
+ console.error('Error in chart rendering:', error)
+ }
+ })
} catch (error) {
- console.error('Error rendering charts:', error)
+ console.error('Error in renderCharts:', error)
}
}
diff --git a/frontend/src/views/DataPage.vue b/frontend/src/views/DataPage.vue
index 4689f70..0a02aa8 100644
--- a/frontend/src/views/DataPage.vue
+++ b/frontend/src/views/DataPage.vue
@@ -85,23 +85,24 @@
-
+
-
+
+
+
+
+
+ {{ formatMoney(scope.row.actual_cost) }}
+
+
-
- 详情
-
-
- 编辑
-
-
- 删除
-
+ 详情
+ 编辑
+ 删除
@@ -162,7 +163,7 @@
{{ formatNumber(selectedData?.max_speed_ms) }}
{{ formatNumber(selectedData?.cruise_speed_kmh) }}
- {{ formatNumber(selectedData?.flight_time_min) }}
+ {{ formatNumber(selectedData?.endurance_min) }}
{{ selectedData?.warhead_type }}
{{ selectedData?.launch_mode }}
{{ selectedData?.power_system }}
@@ -291,8 +292,8 @@
-
-
+
+
diff --git a/frontend/src/views/PredictPage.vue b/frontend/src/views/PredictPage.vue
index a42fa9f..9b3992c 100644
--- a/frontend/src/views/PredictPage.vue
+++ b/frontend/src/views/PredictPage.vue
@@ -55,36 +55,56 @@
-
-
+
+
+
+
+
+
+
+
-
+
-
-
+
+
+
+
-
+
+
+
+
+
-
-
+
+
+
+
+
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
@@ -154,7 +174,16 @@ const formData = reactive({
width_m: null,
height_m: null,
weight_kg: null,
- max_range_km: null
+ max_range_km: null,
+ wingspan_m: null,
+ warhead_weight_kg: null,
+ max_speed_ms: null,
+ cruise_speed_kmh: null,
+ endurance_min: null,
+ warhead_type: '',
+ launch_mode: '',
+ power_system: '',
+ guidance_system: ''
})
const predictionResults = ref(null)
@@ -171,14 +200,15 @@ const handleTypeChange = () => {
formData.rocket_weight_kg = null
formData.rate_of_fire = null
} else if (formData.type === '巡飞弹') {
- formData.max_speed_kmh = null
+ formData.wingspan_m = null
+ formData.warhead_weight_kg = null
+ formData.max_speed_ms = null
formData.cruise_speed_kmh = null
- formData.flight_time_min = null
+ formData.endurance_min = null
formData.warhead_type = ''
formData.launch_mode = ''
- formData.folded_length_mm = null
- formData.folded_width_mm = null
- formData.folded_height_mm = null
+ formData.power_system = ''
+ formData.guidance_system = ''
}
}
@@ -210,8 +240,8 @@ const submitForm = async () => {
}
} else if (formData.type === '巡飞弹') {
const missileFields = [
- 'max_speed_kmh', 'cruise_speed_kmh', 'flight_time_min',
- 'folded_length_mm', 'folded_width_mm', 'folded_height_mm'
+ 'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh',
+ 'endurance_min', 'warhead_type', 'launch_mode', 'power_system', 'guidance_system'
]
for (const field of missileFields) {
if (!formData[field]) {
diff --git a/src/feature_analysis.py b/src/feature_analysis.py
index 5b87972..51939ca 100644
--- a/src/feature_analysis.py
+++ b/src/feature_analysis.py
@@ -1,269 +1,192 @@
import numpy as np
-import pandas as pd
-from scipy import stats
+from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.preprocessing import StandardScaler
-from sklearn.ensemble import RandomForestRegressor
-from sklearn.metrics import r2_score
import logging
-from .logger import setup_logger
-logger = setup_logger(__name__)
+logger = logging.getLogger(__name__)
class FeatureAnalysis:
def __init__(self):
self.scaler = StandardScaler()
- self.important_features = []
-
# 添加特征名称映射
- self.feature_names_map = {
- # 通用参数
- 'length_m': '总长(m)',
+ self.feature_name_map = {
+ 'length_m': '长度(m)',
'width_m': '宽度(m)',
'height_m': '高度(m)',
'weight_kg': '重量(kg)',
'max_range_km': '最大射程(km)',
-
- # 火箭炮特有参数
- 'firing_angle_horizontal': '方向射界(度)',
- 'firing_angle_vertical': '高低射界(度)',
- 'rocket_length_m': '火箭弹长度(m)',
- 'rocket_diameter_mm': '口径(mm)',
- 'rocket_weight_kg': '火箭弹重量(kg)',
- 'rate_of_fire': '射速(发/分)',
- 'combat_weight_kg': '战斗重量(kg)',
- 'speed_kmh': '速度(km/h)',
- 'min_range_km': '最小射程(km)',
- 'power_hp': '功率(hp)',
-
- # 火箭炮衍生特征
- 'fire_density': '火力密度',
- 'mobility_index': '机动性指标',
- 'range_ratio': '射程比',
- 'power_weight_ratio': '功重比',
- 'volume_density': '体积密度',
-
- # 巡飞弹特有参数
'wingspan_m': '翼展(m)',
- 'warhead_weight_kg': '战斗部重量(kg)',
+ 'warhead_weight_kg': '弹头重量(kg)',
'max_speed_ms': '最大速度(m/s)',
'cruise_speed_kmh': '巡航速度(km/h)',
- 'flight_time_min': '巡飞时间(min)',
- 'folded_length_mm': '折叠长度(mm)',
- 'folded_width_mm': '折叠宽度(mm)',
- 'folded_height_mm': '折叠高度(mm)',
-
- # 巡飞弹衍生特征
- 'warhead_ratio': '战斗部比重',
- 'speed_ratio': '速度比',
- 'range_time_ratio': '射程时间比',
- 'aspect_ratio': '展弦比',
- 'volume_density': '体积密度'
+ 'endurance_min': '续航时间(min)',
+ 'payload_weight_kg': '载荷重量(kg)',
+ 'min_combat_radius_km': '最小作战半径(km)',
+ 'engine_power_kw': '发动机功率(kw)',
+ 'engine_thrust_n': '发动机推力(N)',
+ 'datalink_range_km': '数据链距离(km)',
+ 'guidance_accuracy_m': '制导精度(m)',
+ 'min_altitude_m': '最小飞行高度(m)',
+ 'max_altitude_m': '最大飞行高度(m)',
+ 'length_width_ratio': '长宽比',
+ 'weight_range_ratio': '重量射程比',
+ 'speed_weight_ratio': '速度重量比',
+ 'guidance_system_score': '制导系统评分',
+ 'warhead_power_score': '战斗部威力评分',
+ 'firing_angle_horizontal': '水平射角(°)',
+ 'firing_angle_vertical': '垂直射角(°)',
+ 'rocket_length_m': '火箭长度(m)',
+ 'rocket_diameter_mm': '火箭直径(mm)',
+ 'rocket_weight_kg': '火箭重量(kg)'
}
-
+
def get_equipment_specific_features(self, equipment_type):
- """
- 获取特定装备类型的特征列表
- """
- # 通用参数
+ """获取特定装备类型的特征列表"""
common_features = [
- 'length_m', # 总长(m)
- 'width_m', # 宽度(m)
- 'height_m', # 高度(m)
- 'weight_kg', # 重量(kg)
- 'max_range_km' # 最大射程(km)
+ 'length_m', 'width_m', 'height_m',
+ 'weight_kg', 'max_range_km'
]
if equipment_type == '火箭炮':
- # 火箭炮特有参数
- specific_features = [
- 'firing_angle_horizontal', # 方向射界(度)
- 'firing_angle_vertical', # 高低射界(度)
- 'rocket_length_m', # 火箭弹长度(m)
- 'rocket_diameter_mm', # 口径(mm)
- 'rocket_weight_kg', # 火箭弹重量(kg)
- 'rate_of_fire', # 射速(发/分)
- 'combat_weight_kg', # 战斗重量(kg)
- 'speed_kmh', # 速度(km/h)
- 'min_range_km', # 最小射程(km)
- 'power_hp' # 功率(hp)
+ return common_features + [
+ 'firing_angle_horizontal',
+ 'firing_angle_vertical',
+ 'rocket_length_m',
+ 'rocket_diameter_mm',
+ 'rocket_weight_kg'
]
-
- # 火箭炮衍生特征
- derived_features = [
- 'fire_density', # 火力密度 = 射速 * 火箭弹重量
- 'mobility_index', # 机动性指标 = 速度 / 战斗重量
- 'range_ratio', # 射程比 = 最大射程 / 最小射程
- 'power_weight_ratio', # 功重比 = 功率 / 战斗重量
- 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
+ elif equipment_type == '巡飞弹':
+ return common_features + [
+ 'wingspan_m',
+ 'warhead_weight_kg',
+ 'max_speed_ms',
+ 'cruise_speed_kmh',
+ 'endurance_min',
+ 'payload_weight_kg',
+ 'min_combat_radius_km',
+ 'engine_power_kw',
+ 'engine_thrust_n',
+ 'datalink_range_km',
+ 'guidance_accuracy_m',
+ 'min_altitude_m',
+ 'max_altitude_m',
+ 'length_width_ratio',
+ 'weight_range_ratio',
+ 'speed_weight_ratio',
+ 'guidance_system_score',
+ 'warhead_power_score'
]
-
- return common_features + specific_features + derived_features
-
- else: # 巡飞弹
- # 巡飞弹特有参数
- specific_features = [
- 'wingspan_m', # 翼展(m)
- 'warhead_weight_kg', # 战斗部重量(kg)
- 'max_speed_ms', # 最大速度(m/s)
- 'cruise_speed_kmh', # 巡航速度(km/h)
- 'flight_time_min', # 巡飞时间(min)
- 'folded_length_mm', # 折叠长度(mm)
- 'folded_width_mm', # 折叠宽度(mm)
- 'folded_height_mm' # 折叠高度(mm)
- ]
-
- # 巡飞弹衍生特征
- derived_features = [
- 'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
- 'speed_ratio', # 速度比 = 巡航速度 / 最大速度
- 'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
- 'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
- 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
- ]
-
- return common_features + specific_features + derived_features
-
- def calculate_derived_features(self, data, equipment_type):
- """
- 计算衍生特征
- """
- try:
- if equipment_type == '火箭炮':
- # 火箭炮衍生特征计算
- if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
- data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
- else:
- data['fire_density'] = 0 # 或者其他默认值
-
- if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
- data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
- else:
- data['mobility_index'] = 0
-
- if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
- data['range_ratio'] = data['max_range_km'] / data['min_range_km']
- else:
- data['range_ratio'] = 0
-
- if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
- data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
- else:
- data['power_weight_ratio'] = 0
-
- if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
- data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
- else:
- data['volume_density'] = 0
-
- else: # 巡飞弹
- # 巡飞弹衍生特征计算
- if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
- data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
- else:
- data['warhead_ratio'] = 0
-
- if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
- data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
- else:
- data['speed_ratio'] = 0
-
- if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
- data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
- else:
- data['range_time_ratio'] = 0
-
- if 'wingspan_m' in data.columns and 'length_m' in data.columns:
- data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
- else:
- data['aspect_ratio'] = 0
-
- if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
- data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
- else:
- data['volume_density'] = 0
-
- return data
-
- except Exception as e:
- logger.error(f"Error calculating derived features: {str(e)}")
- raise
+ return common_features
def analyze_features(self, features, target, feature_names):
- """
- 分析特征重要性和相关性
- """
+ """分析特征重要性和相关性"""
try:
# 转换为numpy数组
- X = np.array(features)
- y = np.array(target)
+ X = np.array(features, dtype=np.float64) # 明确指定数据类型
+ y = np.array(target, dtype=np.float64)
- # 数据标准化
+ # 打印原始数据的统计信息
+ logger.info("Feature statistics before scaling:")
+ for i, name in enumerate(feature_names):
+ feature_data = X[:, i]
+ logger.info(f"{self.feature_name_map.get(name, name)}: "
+ f"min={np.min(feature_data)}, "
+ f"max={np.max(feature_data)}, "
+ f"mean={np.mean(feature_data)}, "
+ f"null_count={np.sum(np.isnan(feature_data))}")
+
+ # 处理可能的无穷大和NaN值
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # 标准化特征
X_scaled = self.scaler.fit_transform(X)
# 特征重要性分析
- rf = RandomForestRegressor(n_estimators=100, random_state=42)
- rf.fit(X_scaled, y)
- importances = rf.feature_importances_
+ selector = SelectKBest(score_func=f_regression, k='all')
+ selector.fit(X_scaled, y)
+ importance_scores = selector.scores_
- # 按重要性排序,使用中文特征名
- importance_indices = np.argsort(importances)[::-1]
- important_features = [
- {
- 'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
- 'importance': float(importances[i])
- }
- for i in importance_indices
- ]
+ # 计算相关性矩阵前检查特征的方差
+ X = np.array(features, dtype=np.float64)
+ feature_std = np.std(X, axis=0)
+ constant_features = []
- # 相关性分析
- df = pd.DataFrame(X_scaled, columns=feature_names)
- correlation_matrix = df.corr().values
+ # 记录标准差为0的特征
+ for i, (name, std) in enumerate(zip(feature_names, feature_std)):
+ if std == 0:
+ logger.warning(f"Feature '{self.feature_name_map.get(name, name)}' has zero standard deviation "
+ f"(constant value: {X[0, i]})")
+ constant_features.append(name)
- # 生成相关性分析数据,保留2位小数
+ # 计算相关性矩阵
+ correlation_matrix = np.corrcoef(X.T)
+
+ # 处理相关性矩阵中的无效值
correlation_data = []
- chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
+ chinese_feature_names = [self.feature_name_map.get(name, name) for name in feature_names]
+
for i in range(len(feature_names)):
for j in range(len(feature_names)):
- correlation_data.append([
- i, j,
- round(correlation_matrix[i][j], 2) # 修改为保留2位小数
- ])
+ corr_value = correlation_matrix[i, j]
+ if np.isnan(corr_value):
+ # 如果是常量特征,设置相关系数
+ if feature_names[i] in constant_features or feature_names[j] in constant_features:
+ if i == j:
+ # 自身相关性设为1
+ corr_value = 1.0
+ else:
+ # 与其他特征的相关性设为0
+ corr_value = 0.0
+ logger.info(f"Setting correlation for constant feature: "
+ f"{chinese_feature_names[i]} vs {chinese_feature_names[j]} = {corr_value}")
+ correlation_data.append([i, j, float(corr_value)])
- return {
+ # 记录数据形状
+ logger.info(f"Features shape: {X.shape}")
+ logger.info(f"Target shape: {y.shape}")
+ logger.info(f"Correlation matrix shape: {correlation_matrix.shape}")
+
+ # 创建特征重要性列表(使用中文名称)
+ important_features = []
+ for idx, (name, score) in enumerate(zip(feature_names, importance_scores)):
+ if not np.isnan(score):
+ important_features.append({
+ 'name': self.feature_name_map.get(name, name), # 使用中文名称
+ 'importance': float(score)
+ })
+
+ # 按重要性排序
+ important_features.sort(key=lambda x: x['importance'], reverse=True)
+
+ # 返回结果
+ result = {
'important_features': important_features,
'correlation_analysis': {
- 'features': chinese_feature_names, # 使用中文特征名
+ 'features': chinese_feature_names, # 使用中文名称
'matrix': correlation_data
}
}
- except Exception as e:
- logger.error(f"Error in feature analysis: {str(e)}")
- raise
-
- def preprocess_features(self, equipment_data, equipment_type):
- """
- 预处理特征数据
- """
- try:
- # 转换为 DataFrame
- df = pd.DataFrame(equipment_data)
+ # 添加数据验证
+ logger.info("Correlation data validation:")
+ expected_pairs = len(feature_names) * len(feature_names)
+ actual_pairs = len(correlation_data)
+ logger.info(f"Expected correlation pairs: {expected_pairs}")
+ logger.info(f"Actual correlation pairs: {actual_pairs}")
+ if expected_pairs != actual_pairs:
+ logger.warning("Missing correlation pairs detected!")
- # 计算衍生特征
- df = self.calculate_derived_features(df, equipment_type)
+ # 验证返回的数据
+ logger.info("Validation of return data:")
+ logger.info(f"Has important_features: {bool(result['important_features'])}")
+ logger.info(f"Important features count: {len(result['important_features'])}")
+ logger.info(f"Has correlation_analysis: {bool(result['correlation_analysis'])}")
+ logger.info(f"Correlation features count: {len(result['correlation_analysis']['features'])}")
+ logger.info(f"Correlation matrix size: {len(result['correlation_analysis']['matrix'])}")
- # 处理缺失值
- numeric_columns = df.select_dtypes(include=[np.number]).columns
- for col in numeric_columns:
- # 转换为数值类型
- df[col] = pd.to_numeric(df[col], errors='coerce')
- # 使用新的方式填充缺失值
- mean_value = df[col].mean()
- df[col] = df[col].fillna(mean_value)
-
- logger.info(f"Preprocessed data shape: {df.shape}")
- return df
+ return result
except Exception as e:
- logger.error(f"Error preprocessing features: {str(e)}")
- raise Exception(f"Feature preprocessing error: {str(e)}")
\ No newline at end of file
+ logger.error(f"Error in analyze_features: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ raise
\ No newline at end of file
diff --git a/src/real_data.sql b/src/real_data.sql
new file mode 100644
index 0000000..d00ab59
--- /dev/null
+++ b/src/real_data.sql
@@ -0,0 +1,485 @@
+-- 清空现有数据
+SET FOREIGN_KEY_CHECKS=0;
+TRUNCATE TABLE dataset_equipment;
+TRUNCATE TABLE datasets;
+TRUNCATE TABLE cost_data;
+TRUNCATE TABLE loitering_munition_params;
+TRUNCATE TABLE common_params;
+TRUNCATE TABLE equipment;
+SET FOREIGN_KEY_CHECKS=1;
+
+-- 按系列插入装备数据,确保ID连续
+-- 1. HAROP/Harpy 系列 (ID: 1-3)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(1, 'IAI Harop', '巡飞弹', '以色列'),
+(2, 'IAI Harpy', '巡飞弹', '以色列'),
+(3, 'IAI Mini Harpy', '巡飞弹', '以色列');
+
+-- 2. Hero 系列 (ID: 4-9)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(4, 'Hero-30', '巡飞弹', '以色列 UVision'),
+(5, 'Hero-70', '巡飞弹', '以色列 UVision'),
+(6, 'Hero-120', '巡飞弹', '以色列 UVision'),
+(7, 'Hero-250', '巡飞弹', '以色列 UVision'),
+(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'),
+(9, 'Hero-900', '巡飞弹', '以色列 UVision');
+
+-- 3. Switchblade 系列 (ID: 10-13)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'),
+(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'),
+(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'),
+(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment');
+
+-- 4. Warmate 系列 (ID: 14-18)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'),
+(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'),
+(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'),
+(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'),
+(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics');
+
+-- 5. CH-901/902 系列 (ID: 19-23)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(19, 'CH-901', '巡飞弹', '中国航天科工'),
+(20, 'CH-901A', '巡飞弹', '中国航天科工'),
+(21, 'CH-901H', '巡飞弹', '中国航天科工'),
+(22, 'CH-902', '巡飞弹', '中国航天科工'),
+(23, 'CH-902A', '巡飞弹', '中国航天科工');
+
+-- 6. WS-43/61 系列 (ID: 24-28)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(24, 'WS-43', '巡飞弹', '中国航天科工'),
+(25, 'WS-43A', '巡飞弹', '中国航天科工'),
+(26, 'WS-43B', '巡飞弹', '中国航天科工'),
+(27, 'WS-61', '巡飞弹', '中国航天科工'),
+(28, 'WS-61A', '巡飞弹', '中国航天科工');
+
+-- 7. Kargu/Alpagu 系列 (ID: 29-33)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(29, 'Kargu', '巡飞弹', '土耳其 STM'),
+(30, 'Kargu-2', '巡飞弹', '土耳其 STM'),
+(31, 'Alpagu', '巡飞弹', '土耳其 STM'),
+(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'),
+(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM');
+
+-- 8. Shahed 系列 (ID: 34-38)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(34, 'Shahed-131', '巡飞弹', '伊朗'),
+(35, 'Shahed-131B', '巡飞弹', '伊朗'),
+(36, 'Shahed-136', '巡飞弹', '伊朗'),
+(37, 'Shahed-136B', '巡飞弹', '伊朗'),
+(38, 'Shahed-136C', '巡飞弹', '伊朗');
+
+-- 9. Green Dragon 系列 (ID: 39-43)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(39, 'Green Dragon', '巡飞弹', '以色列 IAI'),
+(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'),
+(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'),
+(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'),
+(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI');
+
+-- 10. Phoenix Ghost 系列 (ID: 44-48)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'),
+(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'),
+(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'),
+(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'),
+(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace');
+
+-- 11. ZALA Lancet 系列 (ID: 49-52)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'),
+(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'),
+(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'),
+(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA');
+
+-- 12. Rotem L 系列 (ID: 53-56)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(53, 'Rotem L', '巡飞弹', '以色列 IAI'),
+(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'),
+(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'),
+(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI');
+
+-- 13. KUB-BLA 系列 (ID: 57-60)
+INSERT INTO equipment (id, name, type, manufacturer) VALUES
+(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'),
+(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'),
+(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'),
+(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA');
+
+-- 插入通用参数
+INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES
+(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop
+(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy
+(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy
+(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30
+(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70
+(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120
+(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250
+(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC
+(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900
+(10, 0.58, 0.12, 0.12, 2.5, 10),
+(11, 1.30, 0.22, 0.22, 15.0, 40),
+(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10
+(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range
+(14, 0.68, 0.12, 0.12, 2.5, 10),
+(15, 1.30, 0.22, 0.22, 15.0, 40),
+(16, 0.68, 0.12, 0.12, 2.5, 10),
+(17, 1.30, 0.22, 0.22, 15.0, 40),
+(18, 0.68, 0.12, 0.12, 2.5, 10),
+(19, 1.2, 0.18, 0.18, 9.0, 20),
+(20, 1.2, 0.18, 0.18, 9.3, 25),
+(21, 1.2, 0.18, 0.18, 9.5, 20),
+(22, 1.4, 0.22, 0.22, 15.0, 30),
+(23, 1.4, 0.22, 0.22, 15.5, 35),
+(24, 1.8, 0.35, 0.35, 20, 60),
+(25, 1.8, 0.35, 0.35, 21, 70),
+(26, 1.9, 0.35, 0.35, 22, 80),
+(27, 2.2, 0.40, 0.40, 35, 100),
+(28, 2.2, 0.40, 0.40, 37, 120),
+(29, 0.6, 0.35, 0.35, 7.0, 10),
+(30, 0.6, 0.35, 0.35, 7.2, 15),
+(31, 1.0, 0.23, 0.23, 3.7, 5),
+(32, 1.0, 0.23, 0.23, 3.9, 8),
+(33, 0.6, 0.35, 0.35, 7.5, 15),
+(34, 2.6, 0.34, 0.34, 135, 900),
+(35, 2.6, 0.34, 0.34, 140, 1000),
+(36, 3.5, 0.42, 0.42, 200, 2000),
+(37, 3.5, 0.42, 0.42, 210, 2200),
+(38, 3.5, 0.42, 0.42, 215, 2500),
+(39, 1.5, 0.20, 0.20, 15, 40),
+(40, 1.6, 0.20, 0.20, 16, 50),
+(41, 1.5, 0.20, 0.20, 15.5, 45),
+(42, 1.5, 0.20, 0.20, 15.8, 40),
+(43, 1.2, 0.18, 0.18, 12, 30),
+(44, 1.5, 0.25, 0.25, 14.0, 30),
+(45, 1.5, 0.25, 0.25, 14.5, 35),
+(46, 1.6, 0.26, 0.26, 15.0, 40),
+(47, 1.5, 0.25, 0.25, 14.8, 30),
+(48, 1.7, 0.27, 0.27, 16.0, 50),
+(49, 1.0, 0.20, 0.20, 5.0, 40),
+(50, 1.65, 0.35, 0.35, 12.0, 70),
+(51, 1.65, 0.35, 0.35, 12.5, 80),
+(52, 1.80, 0.40, 0.40, 15.0, 100),
+(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L
+(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X
+(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M
+(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER
+(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA
+(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E
+(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M
+(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER
+
+-- 插入特有参数
+INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh,
+ endurance_min,
+ warhead_type,
+ launch_mode,
+ power_system,
+ guidance_system
+) VALUES
+-- HAROP/Harpy系列
+(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'),
+(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'),
+(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'),
+
+-- Hero系列
+(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'),
+(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'),
+(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
+(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
+(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
+(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'),
+
+-- Switchblade系列
+(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
+(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
+(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
+(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+
+-- Warmate系列
+(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'),
+(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
+(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
+(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'),
+(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'),
+
+-- CH-901/902系列
+(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
+(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
+(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
+(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
+(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
+(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
+(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
+(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
+(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'),
+(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'),
+(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'),
+(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'),
+(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
+(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'),
+(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'),
+(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
+(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
+(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
+(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
+(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'),
+(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'),
+(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'),
+(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'),
+(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'),
+(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'),
+(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
+(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
+(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'),
+(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'),
+(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'),
+(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
+(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'),
+(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'),
+(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'),
+(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'),
+(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'),
+(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外');
+
+-- 插入成本数据
+INSERT INTO cost_data (equipment_id, actual_cost) VALUES
+(1, 800000), -- IAI Harop
+(2, 700000), -- IAI Harpy
+(3, 350000), -- IAI Mini Harpy
+(4, 70000), -- Hero-30
+(5, 120000), -- Hero-70
+(6, 150000), -- Hero-120
+(7, 300000), -- Hero-250
+(8, 400000), -- Hero-400EC
+(9, 650000), -- Hero-900
+(10, 60000), -- Switchblade 300
+(11, 180000), -- Switchblade 600
+(12, 75000), -- Switchblade 300 Block 10
+(13, 200000), -- Switchblade 600 Extended Range
+(14, 60000), -- Warmate 1.0
+(15, 180000), -- Warmate 2.0
+(16, 60000), -- Warmate-V
+(17, 180000), -- Warmate-L
+(18, 60000), -- Warmate 3.0
+(19, 100000), -- CH-901
+(20, 120000), -- CH-901A
+(21, 130000), -- CH-901H
+(22, 180000), -- CH-902
+(23, 200000), -- CH-902A
+(24, 120000), -- WS-43
+(25, 150000), -- WS-43A
+(26, 180000), -- WS-43B
+(27, 300000), -- WS-61
+(28, 350000), -- WS-61A
+(29, 70000), -- Kargu
+(30, 85000), -- Kargu-2
+(31, 45000), -- Alpagu
+(32, 55000), -- Alpagu Block-II
+(33, 95000), -- Kargu Autonomous
+(34, 20000), -- Shahed-131
+(35, 25000), -- Shahed-131B
+(36, 40000), -- Shahed-136
+(37, 45000), -- Shahed-136B
+(38, 50000), -- Shahed-136C
+(39, 160000), -- Green Dragon
+(40, 200000), -- Green Dragon Extended Range
+(41, 180000), -- Green Dragon Block 2
+(42, 190000), -- Green Dragon Maritime
+(43, 140000), -- Green Dragon-S
+(44, 150000), -- Phoenix Ghost
+(45, 180000), -- Phoenix Ghost Block I
+(46, 220000), -- Phoenix Ghost Block II
+(47, 190000), -- Phoenix Ghost Maritime
+(48, 250000), -- Phoenix Ghost-ER
+(49, 80000), -- Lancet-1
+(50, 150000), -- Lancet-3
+(51, 180000), -- Lancet-3M
+(52, 250000), -- Lancet-4
+(53, 65000), -- Rotem L
+(54, 85000), -- Rotem L-X
+(55, 75000), -- Rotem L-M
+(56, 95000), -- Rotem L-ER
+(57, 95000), -- KUB-BLA
+(58, 120000), -- KUB-BLA-E
+(59, 110000), -- KUB-BLA-M
+(60, 150000); -- KUB-BLA-ER
+
+-- 创建数据集
+INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES
+(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
+(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证');
+
+-- 关联装备到数据集(按照制造商和型号分配)
+INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
+-- 训练集(约80%的数据,48个型号)
+-- 以色列系列
+(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列
+(1, 4), (1, 5), (1, 6), -- Hero系列基础型号
+(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列
+(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列
+
+-- 美国系列
+(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列
+(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列
+
+-- 中国系列
+(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列
+(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列
+
+-- 波兰和土耳其系列
+(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列
+(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列
+
+-- 俄罗斯系列
+(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列
+
+-- 验证集(约20%的数据,12个型号)
+-- 混合系列
+(2, 7), (2, 8), (2, 9), -- Hero系列高级型号
+(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列
+(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列
+
+-- 添加分类特征编码
+INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES
+-- 战斗部类型编码
+('warhead_type', '破片杀伤战斗部', 1),
+('warhead_type', '破甲战斗部', 2),
+('warhead_type', '高爆战斗部', 3),
+('warhead_type', '破片杀伤/破甲双用战斗部', 4),
+('warhead_type', '模块化战斗部', 5),
+
+-- 发射方式编码
+('launch_mode', '箱式发射', 1),
+('launch_mode', '弹射式发射', 2),
+('launch_mode', '垂直起降', 3),
+('launch_mode', '单兵发射管', 4),
+('launch_mode', '箱式发射/弹射式', 5),
+('launch_mode', '箱式发射/空中发射', 6),
+
+-- 动力装置编码(按复杂度递增)
+('power_system', '电动机', 1),
+('power_system', '活塞发动机', 2),
+
+-- 制导系统编码(按复杂度递增)
+('guidance_system', 'GPS/INS', 1),
+('guidance_system', 'GPS/INS/光电', 2),
+('guidance_system', 'GPS/INS/光电/数据链', 3),
+('guidance_system', 'GPS/INS/光电/AI识别', 4),
+('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5),
+('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6),
+('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7);
+
+-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段
+UPDATE loitering_munition_params l
+JOIN common_params c ON l.equipment_id = c.equipment_id
+SET
+ -- 新增关键参数
+ l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20%
+ l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10%
+ l.engine_power_kw =
+ CASE
+ WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15
+ WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25
+ END,
+ l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30%
+ l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80%
+ l.guidance_accuracy_m =
+ CASE
+ WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0
+ WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0
+ ELSE 3.0
+ END,
+ l.min_altitude_m = -- 最小作战高度
+ CASE
+ -- 大型巡飞弹(体型大、重量大)
+ WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列
+
+ -- 中型巡飞弹
+ WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列
+
+ -- 中小型巡飞弹
+ WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等
+
+ -- 小型巡飞弹
+ WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等
+
+ -- 超小型巡飞弹
+ WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列
+
+ -- 其他型号使用默认值
+ ELSE 50
+ END,
+ l.max_altitude_m =
+ CASE
+ WHEN c.max_range_km > 500 THEN 5000
+ WHEN c.max_range_km > 100 THEN 3000
+ ELSE 1500
+ END,
+
+ -- 特征工程字段
+ l.length_width_ratio = c.length_m / c.width_m,
+ l.weight_range_ratio = c.weight_kg / c.max_range_km,
+ l.speed_weight_ratio = l.max_speed_ms / c.weight_kg,
+ l.guidance_system_score =
+ CASE
+ WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10
+ WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8
+ WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6
+ WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4
+ ELSE 2
+ END,
+ l.warhead_power_score =
+ CASE
+ WHEN l.warhead_type = '模块化战斗部' THEN 10
+ WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8
+ WHEN l.warhead_type = '高爆战斗部' THEN 7
+ WHEN l.warhead_type = '破甲战斗部' THEN 6
+ WHEN l.warhead_type = '破片杀伤战斗部' THEN 5
+ ELSE 4
+ END,
+
+ -- 分类特征编码
+ l.warhead_type_code =
+ CASE
+ WHEN l.warhead_type = '破片杀伤战斗部' THEN 1
+ WHEN l.warhead_type = '破甲战斗部' THEN 2
+ WHEN l.warhead_type = '高爆战斗部' THEN 3
+ WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4
+ WHEN l.warhead_type = '模块化战斗部' THEN 5
+ ELSE 0
+ END,
+ l.launch_mode_code =
+ CASE
+ WHEN l.launch_mode = '箱式发射' THEN 1
+ WHEN l.launch_mode = '弹射式发射' THEN 2
+ WHEN l.launch_mode = '垂直起降' THEN 3
+ WHEN l.launch_mode = '单兵发射管' THEN 4
+ WHEN l.launch_mode = '箱式发射/弹射式' THEN 5
+ WHEN l.launch_mode = '箱式发射/空中发射' THEN 6
+ ELSE 0
+ END,
+ l.power_system_code =
+ CASE
+ WHEN l.power_system = '电动机' THEN 1
+ WHEN l.power_system = '活塞发动机' THEN 2
+ ELSE 0
+ END,
+ l.guidance_system_code =
+ CASE
+ WHEN l.guidance_system = 'GPS/INS' THEN 1
+ WHEN l.guidance_system = 'GPS/INS/光电' THEN 2
+ WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3
+ WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4
+ WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5
+ WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6
+ WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7
+ ELSE 0
+ END;
diff --git a/src/routes.py b/src/routes.py
index df4e825..ba30c45 100644
--- a/src/routes.py
+++ b/src/routes.py
@@ -135,7 +135,6 @@ def analyze_features():
logger.info(f"Dataset info: {dataset}")
# 创建特征分析实例
- from src.feature_analysis import FeatureAnalysis
analyzer = FeatureAnalysis()
# 获取特征列表
@@ -143,20 +142,46 @@ def analyze_features():
logger.info(f"Feature names: {feature_names}")
# 获取数据集中的装备数据
- if dataset['equipment_type'] == '火箭炮':
+ if dataset['equipment_type'] == '巡飞弹':
cursor.execute("""
- SELECT e.*, cp.*, rap.*, cd.actual_cost
+ SELECT
+ e.name,
+ e.*,
+ cp.*,
+ lmp.*,
+ cd.actual_cost,
+ lmp.length_width_ratio,
+ lmp.weight_range_ratio,
+ lmp.speed_weight_ratio,
+ lmp.guidance_system_score,
+ lmp.warhead_power_score,
+ lmp.engine_power_kw,
+ lmp.engine_thrust_n,
+ lmp.min_altitude_m,
+ lmp.max_altitude_m
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
- LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
+ LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (dataset_id,))
else:
cursor.execute("""
- SELECT e.*, cp.*, lmp.*, cd.actual_cost
+ SELECT e.name,
+ e.*, cp.*, lmp.*,
+ cp.max_range_km,
+ lmp.length_width_ratio,
+ lmp.weight_range_ratio,
+ lmp.speed_weight_ratio,
+ lmp.guidance_system_score,
+ lmp.warhead_power_score,
+ lmp.engine_power_kw,
+ lmp.engine_thrust_n,
+ lmp.min_altitude_m,
+ lmp.max_altitude_m,
+ cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
@@ -173,61 +198,52 @@ def analyze_features():
logger.warning("No valid equipment data found in dataset")
return jsonify({'error': '数据集没有有效的成本数据'}), 400
- # 统计每个特征的缺失率
- missing_rates = {}
- for name in feature_names:
- missing_count = sum(1 for item in equipment_data if item.get(name) is None)
- missing_rate = missing_count / len(equipment_data)
- missing_rates[name] = missing_rate
- logger.info(f"Feature {name} missing rate: {missing_rate:.2%}")
-
- # 过滤掉缺失率过高的特征
- valid_features = [name for name in feature_names if missing_rates[name] < 0.7]
- logger.info(f"Valid features after filtering: {valid_features}")
-
- if len(valid_features) < 3: # 至少需要3个特征
- return jsonify({'error': '有效特征数量不足'}), 400
-
- # 计算每个特征的均值
- feature_means = {}
- for name in valid_features:
- values = [float(item[name]) for item in equipment_data if item.get(name) is not None]
- feature_means[name] = sum(values) / len(values) if values else 0
- logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}")
-
# 准备特征和目标值
features = []
target = []
+ equipment_names = [] # 新增:存储装备名称
- # 提取特征和目标值,使用均值填充缺失值
+ # 提取特征和目标值
for item in equipment_data:
feature_values = []
- for name in valid_features:
+ equipment_names.append(item['name']) # 保存装备名称
+
+ for name in feature_names:
value = item.get(name)
try:
- # 确保数值类型转换正确
- feature_values.append(float(value) if value is not None else feature_means[name])
+ feature_values.append(float(value) if value is not None else 0)
except (ValueError, TypeError) as e:
logger.error(f"Error converting value for feature {name}: {value}")
- logger.error(f"Error details: {str(e)}")
return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400
- features.append(feature_values)
- # 确保成本值是值类型
- try:
- target.append(float(item['actual_cost']))
- except (ValueError, TypeError) as e:
- logger.error(f"Error converting actual_cost: {item['actual_cost']}")
- logger.error(f"Error details: {str(e)}")
- return jsonify({'error': '成本值无法换为数值'}), 400
-
- logger.info(f"Prepared {len(features)} feature vectors")
- logger.info(f"First feature vector: {features[0] if features else None}")
- logger.info(f"First target value: {target[0] if target else None}")
+ features.append(feature_values)
+ target.append(float(item['actual_cost']))
# 调用特征分析方法
- result = analyzer.analyze_features(features, target, valid_features)
- logger.info("Analysis completed successfully")
+ result = analyzer.analyze_features(features, target, feature_names)
+
+ # 如果是巡飞弹类型,添加额外的数据
+ if dataset['equipment_type'] == '巡飞弹':
+ missile_data = {
+ 'equipment_names': equipment_names,
+ 'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data],
+ 'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data],
+ 'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data],
+ 'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data],
+ 'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data],
+ 'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data],
+ 'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data],
+ 'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data],
+ 'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data]
+ }
+
+ # 验证数据完整性
+ for key, value in missile_data.items():
+ logger.info(f"{key} data length: {len(value)}")
+ logger.info(f"{key} sample data: {value[:3]}")
+
+ # 更新结果
+ result.update(missile_data)
return jsonify(result)
@@ -428,148 +444,39 @@ def get_equipment_data():
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
- cursor.execute('SET SESSION group_concat_max_len = 1000000')
-
- # 先测试特殊参数查询
- cursor.execute("""
- SELECT equipment_id, param_name, param_value, param_unit
- FROM custom_params
- WHERE param_name IS NOT NULL
- AND param_value IS NOT NULL
- LIMIT 5
- """)
- test_params = cursor.fetchall()
- logger.info(f"Test custom params: {test_params}")
# 获取火箭炮数据
logger.info("Fetching rocket artillery data...")
cursor.execute("""
- SELECT
- e.id,
- e.name,
- e.type,
- e.manufacturer,
- e.created_at,
- cp.length_m,
- cp.width_m,
- cp.height_m,
- cp.weight_kg,
- cp.max_range_km,
- rap.firing_angle_horizontal,
- rap.firing_angle_vertical,
- rap.rocket_length_m,
- rap.rocket_diameter_mm,
- rap.rocket_weight_kg,
- rap.rate_of_fire,
- rap.combat_weight_kg,
- rap.speed_kmh,
- rap.min_range_km,
- rap.mobility_type,
- rap.structure_layout,
- rap.engine_model,
- rap.engine_params,
- rap.power_hp,
- rap.travel_range_km,
- cd.actual_cost,
- (
- SELECT COALESCE(
- JSON_ARRAYAGG(
- JSON_OBJECT(
- 'id', csp.id,
- 'param_name', csp.param_name,
- 'param_value', csp.param_value,
- 'param_unit', csp.param_unit,
- 'description', csp.description
- )
- ),
- '[]'
- )
- FROM custom_params csp
- WHERE csp.equipment_id = e.id
- AND csp.param_name IS NOT NULL
- AND csp.param_value IS NOT NULL
- ) as custom_params
+ SELECT e.*, cp.*, rap.*, cd.actual_cost, cd.predicted_cost
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.type = '火箭炮'
""")
- rocket_artillery = cursor.fetchall()
- logger.info(f"Found {len(rocket_artillery)} rocket artillery records")
- if rocket_artillery:
- logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}")
- logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}")
+ rocket_data = cursor.fetchall()
+ logger.info(f"Found {len(rocket_data)} rocket artillery records")
# 获取巡飞弹数据
logger.info("Fetching missile data...")
cursor.execute("""
- SELECT
- e.id,
- e.name,
- e.type,
- e.manufacturer,
- e.created_at,
- cp.length_m,
- cp.width_m,
- cp.height_m,
- cp.weight_kg,
- cp.max_range_km,
- lmp.wingspan_m,
- lmp.warhead_weight_kg,
- lmp.max_speed_ms,
- lmp.cruise_speed_kmh,
- lmp.flight_time_min,
- lmp.warhead_type,
- lmp.launch_mode,
- lmp.folded_length_mm,
- lmp.folded_width_mm,
- lmp.folded_height_mm,
- lmp.power_system,
- lmp.guidance_system,
- cd.actual_cost,
- (
- SELECT COALESCE(
- JSON_ARRAYAGG(
- JSON_OBJECT(
- 'id', csp.id,
- 'param_name', csp.param_name,
- 'param_value', csp.param_value,
- 'param_unit', csp.param_unit,
- 'description', csp.description
- )
- ),
- '[]'
- )
- FROM custom_params csp
- WHERE csp.equipment_id = e.id
- AND csp.param_name IS NOT NULL
- AND csp.param_value IS NOT NULL
- ) as custom_params
+ SELECT e.*, cp.*, lmp.*, cd.actual_cost, cd.predicted_cost,
+ lmp.wingspan_m, lmp.warhead_weight_kg, lmp.max_speed_ms,
+ lmp.cruise_speed_kmh, lmp.endurance_min, lmp.warhead_type,
+ lmp.launch_mode, lmp.power_system, lmp.guidance_system
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.type = '巡飞弹'
""")
- loitering_munition = cursor.fetchall()
- logger.info(f"Found {len(loitering_munition)} missile records")
- if loitering_munition:
- logger.info(f"First missile: {loitering_munition[0]['name']}")
- logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}")
+ missile_data = cursor.fetchall()
+ logger.info(f"Found {len(missile_data)} missile records")
- # 处理 custom_params,保为 NULL
- for item in rocket_artillery + loitering_munition:
- if item['custom_params'] is None:
- item['custom_params'] = []
- logger.debug(f"Set empty custom_params for equipment {item['id']}")
- else:
- logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params")
-
- logger.info("Data fetching completed")
return jsonify({
- 'rocket_artillery': rocket_artillery,
- 'loitering_munition': loitering_munition
+ 'rocket_artillery': rocket_data,
+ 'loitering_munition': missile_data
})
except Exception as e:
diff --git a/src/schema.sql b/src/schema.sql
index a0abd10..84066eb 100644
--- a/src/schema.sql
+++ b/src/schema.sql
@@ -60,17 +60,47 @@ CREATE TABLE loitering_munition_params (
warhead_weight_kg FLOAT, -- 战斗部重量(kg)
max_speed_ms FLOAT, -- 最大速度(m/s)
cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
- flight_time_min FLOAT, -- 巡飞时间(min)
+ endurance_min FLOAT, -- 续航时间(min)
warhead_type VARCHAR(50), -- 战斗部类型
launch_mode VARCHAR(50), -- 发射方式
- folded_length_mm FLOAT, -- 折叠长度(mm)
- folded_width_mm FLOAT, -- 折叠宽度(mm)
- folded_height_mm FLOAT, -- 折叠高度(mm)
power_system VARCHAR(100), -- 动力装置
guidance_system VARCHAR(100), -- 制导体制
+
+ -- 新增关键参数
+ payload_weight_kg FLOAT, -- 有效载荷重量(kg)
+ min_combat_radius_km FLOAT, -- 最小作战半径(km)
+ engine_power_kw FLOAT, -- 发动机功率(kw)
+ engine_thrust_n FLOAT, -- 发动机推力(N)
+ datalink_range_km FLOAT, -- 通信链路距离(km)
+ guidance_accuracy_m FLOAT, -- 制导精度(m)
+ min_altitude_m FLOAT, -- 最小作战高度(m)
+ max_altitude_m FLOAT, -- 最大作战高度(m)
+
+ -- 特征工程字段
+ length_width_ratio FLOAT, -- 长宽比
+ weight_range_ratio FLOAT, -- 重量/射程比
+ speed_weight_ratio FLOAT, -- 速度/重量比
+ guidance_system_score INT, -- 制导系统复杂度评分(1-10)
+ warhead_power_score INT, -- 战斗部威力评分(1-10)
+
+ -- 分类特征编码
+ warhead_type_code INT, -- 战斗部类型编码
+ launch_mode_code INT, -- 发射方式编码
+ power_system_code INT, -- 动力装置编码
+ guidance_system_code INT, -- 制导系统编码
+
FOREIGN KEY (equipment_id) REFERENCES equipment(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+-- 分类特征编码表
+CREATE TABLE feature_encoding (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ feature_type VARCHAR(50), -- 特征类型(warhead_type/launch_mode/power_system/guidance_system)
+ feature_value VARCHAR(100), -- 特征值
+ code INT, -- 编码值
+ UNIQUE KEY unique_feature (feature_type, feature_value)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
-- 成本数据表
CREATE TABLE cost_data (
id INT AUTO_INCREMENT PRIMARY KEY,