diff --git a/deploy/equipment_cost_system.tar.gz b/deploy/equipment_cost_system.tar.gz deleted file mode 100644 index c906e30..0000000 Binary files a/deploy/equipment_cost_system.tar.gz and /dev/null differ diff --git a/deploy/equipment_cost_system/config/.env.template b/deploy/equipment_cost_system/config/.env.template deleted file mode 100644 index 195987f..0000000 --- a/deploy/equipment_cost_system/config/.env.template +++ /dev/null @@ -1,25 +0,0 @@ -# 数据库配置 -MYSQL_HOST=localhost -MYSQL_USER=root -MYSQL_PASSWORD=your_password_here -MYSQL_DATABASE=equipment_cost_db - -# 服务配置 -PORT=5001 -DEBUG=False - -# 日志配置 -LOG_LEVEL=INFO -LOG_DIR=logs - -# 模型配置 -MODEL_DIR=models -DATA_DIR=data - -# 安全配置 -SECRET_KEY=your_secret_key_here -ALLOWED_HOSTS=localhost,127.0.0.1 - -# 其他配置 -UPLOAD_MAX_SIZE=10485760 # 10MB in bytes -ALLOWED_FILE_TYPES=.xlsx,.xls \ No newline at end of file diff --git a/deploy/equipment_cost_system/docs/api.md b/deploy/equipment_cost_system/docs/api.md deleted file mode 100644 index 000bd35..0000000 --- a/deploy/equipment_cost_system/docs/api.md +++ /dev/null @@ -1,663 +0,0 @@ -# 装备成本估算系统 API 文档 - -这个 API 文档提供了完整的接口说明,包括: - -- 每个端点的详细描述 -- 请求和响应的具体示例 -- 清晰的参数格式要求 -- 统一的错误处理说明 -- 重要的注意事项 - -文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。 - -## 基本信息 - -- 基础URL: `http://localhost:5001/api` -- 版本: 1.0.0 -- 响应格式: JSON - -## API 端点列表 - -### 1. 获取 API 信息 - -获取 API 版本信息和可用端点列表。 - -- **URL**: `/` -- **方法**: `GET` -- **响应示例**: -json -{ -"name": "装备成本估算系统 API", -"version": "1.0.0", -"endpoints": { -"predict": { -"url": "/api/predict", -"method": "POST", -"description": "成本预测" -}, -"analyze-features": { -"url": "/api/analyze-features", -"method": "POST", -"description": "特征分析" -}, -"train": { -"url": "/api/train", -"method": "POST", -"description": "模型训练" -}, -"evaluate": { -"url": "/api/evaluate", -"method": "POST", -"description": "模型评估" -} -} -} - -### 2. 单模型预测 - -使用当前激活的最优模型进行成本预测。 - -- **URL**: `/predict` -- **方法**: `POST` -- **请求体示例** (巡飞弹): - -```json -{ - "type": "巡飞弹", - "length_m": 1.3, - "width_m": 0.23, - "height_m": 0.23, - "weight_kg": 12.5, - "max_range_km": 40, - "max_speed_ms": 50, - "cruise_speed_kmh": 100, - "flight_time_min": 60, - "folded_length_mm": 1300, - "folded_width_mm": 230, - "folded_height_mm": 230, - "warhead_type": "破片杀伤战斗部", - "launch_mode": "凭自身动力起飞" -} -``` - -- **响应示例**: - -```json -{ - "predicted_cost": 150000.0, - "model_info": { - "type": "xgboost", - "name": "巡飞弹_20241111_model", - "r2_score": 0.95, - "mae": 5000.0, - "rmse": 7500.0 - }, - "confidence_interval": { - "lower": 135000.0, - "upper": 165000.0 - } -} -``` - -### 3. PLS 模型预测 - -使用 PLS 回归模型进行预测。 - -- **URL**: `/pls/predict` -- **方法**: `POST` -- **请求体**: 与单模型预测相同 -- **响应示例**: - -```json -{ - "predicted_cost": 148000.0, - "confidence_interval": { - "lower": 133000.0, - "upper": 163000.0 - } -} -``` - -### 4. 多模型预测 - -使用所有激活的模型进行预测并返回综合结果。 - -- **URL**: `/predict/all` -- **方法**: `POST` -- **请求体**: 与单模型预测相同 -- **响应示例**: - -```json -{ - "individual_predictions": { - "xgboost": { - "predicted_cost": 150000.0, - "model_info": { - "name": "巡飞弹_xgboost_model", - "type": "xgboost", - "r2_score": 0.95, - "mae": 5000.0, - "rmse": 7500.0 - }, - "confidence_interval": { - "lower": 135000.0, - "upper": 165000.0 - } - }, - "pls": { - "predicted_cost": 148000.0, - "model_info": { - "name": "巡飞弹_pls_model", - "type": "pls", - "r2_score": 0.92, - "mae": 5500.0, - "rmse": 8000.0 - }, - "confidence_interval": { - "lower": 133000.0, - "upper": 163000.0 - } - } - }, - "ensemble_prediction": { - "predicted_cost": 149000.0, - "standard_deviation": 1414.21, - "confidence_interval": { - "lower": 146228.15, - "upper": 151771.85 - } - } -} -``` - -### 5. 特征分析 - -分析数据集中特征的重要性和相关性。 - -- **URL**: `/analyze-features` -- **方法**: `POST` -- **请求体示例**: - -```json -{ - "dataset_id": 1, - "equipment_type": "巡飞弹" -} -``` - -- **响应示例**: - -```json -{ - "important_features": [ - { - "name": "最大射程(km)", - "importance": 0.35 - }, - { - "name": "重量(kg)", - "importance": 0.25 - } - ], - "correlation_analysis": { - "features": ["最大射程(km)", "重量(kg)"], - "matrix": [[1.0, 0.8], [0.8, 1.0]] - } -} -``` - -### 6. 模型训练 - -训练新的模型。 - -- **URL**: `/train` -- **方法**: `POST` -- **请求体示例**: - -```json -{ - "type": "巡飞弹", - "train_dataset_id": 1, - "validation_dataset_id": 2, - "models": ["xgboost", "lightgbm", "rf"] -} -``` - -- **响应示例**: - -```json -{ - "metrics": { - "xgboost": { - "train": { - "r2": 0.95, - "mae": 5000.0, - "rmse": 7500.0 - }, - "validation": { - "r2": 0.92, - "mae": 5500.0, - "rmse": 8000.0 - } - } - }, - "best_model": { - "type": "xgboost", - "r2": 0.92, - "mae": 5500.0, - "rmse": 8000.0 - } -} -``` - -### 7. 数据集管理 - -#### 7.1 获取数据集列表 - -- **URL**: `/datasets` -- **方法**: `GET` -- **响应示例**: - -```json -[ - { - "id": 1, - "name": "训练数据集", - "description": "用于训练的数据集", - "equipment_type": "巡飞弹", - "equipment_count": 10, - "equipment_names": ["设备1", "设备2"], - "purpose": "训练", - "created_at": "2024-11-11T10:00:00" - } -] -``` - -#### 7.2 获取数据集详情 - -- **URL**: `/datasets/{id}` -- **方法**: `GET` -- **响应示例**: - -```json -{ - "id": 1, - "name": "训练数据集", - "description": "用于训练的数据集", - "equipment_type": "巡飞弹", - "purpose": "训练", - "created_at": "2024-11-11T10:00:00", - "equipment": [ - { - "id": 1, - "name": "设备1", - "type": "巡飞弹", - "manufacturer": "制造商1", - "actual_cost": 150000 - } - ], - "statistics": { - "equipment_count": 10, - "total_cost": 1500000, - "average_cost": 150000 - } -} -``` - -#### 7.3 创建数据集 - -- **URL**: `/datasets` -- **方法**: `POST` -- **请求体示例**: - -```json -{ - "name": "测试数据集", - "description": "用于测试的数据集", - "equipment_type": "巡飞弹", - "purpose": "训练", - "equipment_ids": [1, 2, 3] -} -``` - -- **响应示例**: - -```json -{ - "id": 2, - "message": "数据集创建成功" -} -``` - -#### 7.4 更新数据集 - -- **URL**: `/datasets/{id}` -- **方法**: `PUT` -- **请求体示例**: - -```json -{ - "name": "更新后的数据集名称", - "description": "更新后的描述", - "equipment_type": "巡飞弹", - "purpose": "验证", - "equipment_ids": [1, 2, 3, 4] -} -``` - -- **响应示例**: - -```json -{ - "success": true, - "message": "数据集更新成功" -} -``` - -#### 7.5 删除数据集 - -- **URL**: `/datasets/{id}` -- **方法**: `DELETE` -- **描述**: 删除指定的数据集及其关联关系 -- **响应示例**: - -```json -{ - "success": true, - "message": "数据集删除成功" -} -``` - -注意事项: - -1. 数据集删除后不会删除关联的装备数据 -2. 不能删除正在被模型使用的数据集 -3. 更新数据集时会重新计算统计信息 -4. 数据集的装备类型一旦创建后不能更改 - -### 8. 模型管理 - -#### 8.1 获取模型列表 - -- **URL**: `/models` -- **方法**: `GET` -- **响应示例**: - -```json -[ - { - "id": 1, - "model_name": "巡飞弹_xgboost_model", - "model_type": "xgboost", - "equipment_type": "巡飞弹", - "r2_score": 0.95, - "mae": 5000.0, - "rmse": 7500.0, - "is_active": true, - "training_date": "2024-11-11T10:00:00" - } -] -``` - -#### 8.2 获取最新模型 - -- **URL**: `/models/{equipment_type}/latest` -- **方法**: `GET` -- **响应示例**: 与模型列表的单个模型格式相同 - -#### 8.3 获取模型详情 - -- **URL**: `/models/{id}` -- **方法**: `GET` -- **响应示例**: - -```json -{ - "id": 1, - "model_name": "巡飞弹_xgboost_model", - "model_type": "xgboost", - "equipment_type": "巡飞弹", - "r2_score": 0.95, - "mae": 5000.0, - "rmse": 7500.0, - "is_active": true, - "training_date": "2024-11-11T10:00:00", - "feature_importance": { - "max_range_km": 0.35, - "weight_kg": 0.25, - "length_m": 0.20 - }, - "training_data_size": 100, - "created_by": "system" -} -``` - -#### 8.4 激活模型 - -- **URL**: `/models/{id}/activate` -- **方法**: `POST` -- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态 -- **响应示例**: - -```json -{ - "success": true, - "message": "模型已激活" -} -``` - -#### 8.5 删除模型 - -- **URL**: `/models/{id}` -- **方法**: `DELETE` -- **描述**: 删除指定模型,包括模型文件和数据库记录 -- **响应示例**: - -```json -{ - "success": true, - "message": "模型已删除" -} -``` - -注意事项: - -1. 删除模型时会同时删除相关的文件和数据库记录 -2. 不能删除当前正在使用(已激活)的模型 -3. 激活模型时会自动取消同类型其他模型的激活状态 -4. 模型详情包含了更多的训练相关信息,如特征重要性等 - -### 9. 数据管理 - -#### 9.1 获取装备数据列表 - -- **URL**: `/data` -- **方法**: `GET` -- **响应示例**: - -```json -{ - "rocket_artillery": [ - { - "id": 1, - "name": "BM-21", - "type": "火箭炮", - "manufacturer": "俄罗斯", - "length_m": 7.35, - "width_m": 2.4, - "height_m": 3.1, - "weight_kg": 13700, - "max_range_km": 20.4, - "actual_cost": 800000 - } - ], - "loitering_munition": [ - { - "id": 8, - "name": "Hero-120", - "type": "巡飞弹", - "manufacturer": "以色列", - "length_m": 1.3, - "width_m": 0.23, - "height_m": 0.23, - "weight_kg": 12.5, - "max_range_km": 40, - "actual_cost": 150000 - } - ] -} -``` - -#### 9.2 获取装备详情 - -- **URL**: `/data/details/{id}` -- **方法**: `GET` -- **响应示例**: - -```json -{ - "id": 8, - "name": "Hero-120", - "type": "巡飞弹", - "manufacturer": "以色列", - "common_params": { - "length_m": 1.3, - "width_m": 0.23, - "height_m": 0.23, - "weight_kg": 12.5, - "max_range_km": 40 - }, - "specific_params": { - "wingspan_m": 2.1, - "warhead_weight_kg": 3.5, - "max_speed_ms": 50, - "cruise_speed_kmh": 100, - "flight_time_min": 60, - "warhead_type": "破片杀伤战斗部", - "launch_mode": "箱式发射", - "power_system": "电动机", - "guidance_system": "GPS/INS" - }, - "cost_data": { - "actual_cost": 150000, - "prediction_date": "2024-11-11T10:00:00", - "predicted_cost": 148000 - }, - "custom_params": [ - { - "id": 1, - "param_name": "续航时间", - "param_value": "2小时", - "param_unit": "小时", - "description": "最大续航时间" - } - ] -} -``` - -#### 9.3 更新装备数据 - -- **URL**: `/data/{id}` -- **方法**: `PUT` -- **请求体示例**: - -```json -{ - "name": "Hero-120", - "manufacturer": "以色列", - "length_m": 1.3, - "width_m": 0.23, - "height_m": 0.23, - "weight_kg": 12.5, - "max_range_km": 40, - "wingspan_m": 2.1, - "warhead_weight_kg": 3.5, - "max_speed_ms": 50, - "cruise_speed_kmh": 100, - "flight_time_min": 60, - "actual_cost": 150000, - "custom_params": [ - { - "id": 1, - "param_value": "2.5小时" - } - ] -} -``` - -- **响应示例**: - -```json -{ - "success": true, - "message": "装备数据更新成功" -} -``` - -#### 9.4 删除装备数据 - -- **URL**: `/data/{id}` -- **方法**: `DELETE` -- **响应示例**: - -```json -{ - "success": true, - "message": "装备数据删除成功" -} -``` - -#### 9.5 下载数据模板 - -- **URL**: `/data/template` -- **方法**: `GET` -- **描述**: 下载Excel格式的数据导入模板 -- **响应**: Excel文件下载 - -#### 9.6 导入数据 - -- **URL**: `/data/import` -- **方法**: `POST` -- **请求体**: - - Content-Type: multipart/form-data - - 参数名: file - - 文件类型: .xlsx 或 .xls -- **响应示例**: - -```json -{ - "success": true, - "message": "数据导入成功", - "imported_count": { - "rocket_artillery": 3, - "loitering_munition": 5 - } -} -``` - -注意事项: - -1. 导入数据时必须使用系统提供的模板 -2. 更新装备数据时会同时更新关联的参数表 -3. 删除装备数据会同时删除相关的参数和成本数据 -4. 导入的Excel文件大小不应超过10MB -5. 所有数值字段必须符合指定的单位和范围要求 -6. 特殊参数的值必须包含单位信息 - -## 错误响应 - -所有接口在发生错误时都会返回以下格式的响应: - -```json -{ - "error": "错误描述信息" -} -``` - -## 注意事项 - -1. 所有数值参数必须大于0 -2. 所有单位必须按照参数名称中指定的单位提供 -3. 预测结果中的成本单位为元 -4. 置信区间表示预测结果的95%置信水平范围 -5. 所有请求和响应的编码均为 UTF-8 diff --git a/deploy/equipment_cost_system/docs/deploy.md b/deploy/equipment_cost_system/docs/deploy.md deleted file mode 100644 index 1f52359..0000000 --- a/deploy/equipment_cost_system/docs/deploy.md +++ /dev/null @@ -1,120 +0,0 @@ -# 装备成本估算系统部署指南 - -## 一、系统要求 - -### 1. 基础软件 - -- Linux 操作系统 (推荐 Ubuntu 20.04+) -- Python 3.8+ 及相关组件 - - ```bash - sudo apt update - sudo apt install python3 python3-pip python3-venv - sudo apt install python3-dev build-essential - ``` - -- Node.js 14+ 及 npm - - ```bash - # 使用 nvm 安装 Node.js - curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash - source ~/.bashrc - nvm install 14 - nvm use 14 - ``` - -### 2. 数据库 - -- MySQL 8.0+ - - ```bash - sudo apt install mysql-server mysql-client - sudo apt install libmysqlclient-dev - ``` - -### 3. Python包依赖 - -```bash -# 科学计算相关 -sudo apt install libatlas-base-dev # numpy依赖 -sudo apt install libopenblas-dev # 线性代数库 -sudo apt install liblapack-dev # 线性代数包 -sudo apt install gfortran # Fortran编译器(scipy依赖) - -# XML处理相关(用于Excel文件处理) -sudo apt install libxml2-dev -sudo apt install libxslt1-dev -``` - -## 二、部署运行 - -### 1. 安装服务 - -```bash -sh scripts/install.sh -``` - -### 2. 启动服务 - -```bash -sh scripts/start.sh -``` - -### 3. 停止服务 - -```bash -sh scripts/stop.sh -``` - -## 三、维护说明 - -### 1. 日志管理 - -```bash -# 后端日志 -tail -f logs/api.log - -# 数据库日志 -tail -f /var/log/mysql/error.log -``` - -## 四、安全建议 - -1. 系统安全 - - 使用防火墙限制端口访问 - - 定期更新系统和依赖包 - -2. 数据安全 - - 定期备份数据库 - - 加密敏感信息 - - 限制数据库远程访问 - -3. 访问控制 - - 使用强密码 - - 配置适当的文件权限 - - 使用非root用户运行服务 - -## 五、监控方案 - -### 1. 系统监控 - -```bash -# 资源使用 -top -b -n 1 -df -h -free -m - -# 服务状态 -ps aux | grep gunicorn -ps aux | grep node -``` - -### 2. 应用监控 - -```bash -# API 响应时间 -curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/" - -# 错误日志 -grep "ERROR" logs/api.log -``` diff --git a/deploy/equipment_cost_system/frontend/babel.config.js b/deploy/equipment_cost_system/frontend/babel.config.js deleted file mode 100644 index e955840..0000000 --- a/deploy/equipment_cost_system/frontend/babel.config.js +++ /dev/null @@ -1,5 +0,0 @@ -module.exports = { - presets: [ - '@vue/cli-plugin-babel/preset' - ] -} diff --git a/deploy/equipment_cost_system/frontend/package.json b/deploy/equipment_cost_system/frontend/package.json deleted file mode 100644 index 5e6409e..0000000 --- a/deploy/equipment_cost_system/frontend/package.json +++ /dev/null @@ -1,61 +0,0 @@ -{ - "name": "frontend", - "version": "0.1.0", - "private": true, - "engines": { - "node": ">=16", - "npm": ">=8" - }, - "scripts": { - "serve": "vue-cli-service serve", - "build": "vue-cli-service build", - "lint": "vue-cli-service lint" - }, - "dependencies": { - "axios": "^1.6.0", - "core-js": "^3.8.3", - "echarts": "^5.4.3", - "element-plus": "^2.4.2", - "vue": "^3.2.13", - "vue-router": "^4.0.3", - "vuex": "^4.0.0" - }, - "devDependencies": { - "@babel/core": "^7.12.16", - "@babel/eslint-parser": "^7.12.16", - "@element-plus/icons-vue": "^2.3.1", - "@vue/cli-plugin-babel": "~5.0.0", - "@vue/cli-plugin-eslint": "~5.0.0", - "@vue/cli-plugin-router": "~5.0.0", - "@vue/cli-plugin-vuex": "~5.0.0", - "@vue/cli-service": "~5.0.0", - "@vue/compiler-sfc": "^3.2.13", - "eslint": "^7.32.0", - "eslint-plugin-vue": "^8.0.3", - "sass": "^1.32.7", - "sass-loader": "^12.0.0" - }, - "eslintConfig": { - "root": true, - "env": { - "node": true - }, - "extends": [ - "plugin:vue/vue3-essential", - "eslint:recommended" - ], - "parserOptions": { - "parser": "@babel/eslint-parser" - }, - "rules": { - "vue/multi-word-component-names": "off", - "no-unused-vars": "warn" - } - }, - "browserslist": [ - "> 1%", - "last 2 versions", - "not dead", - "not ie 11" - ] -} diff --git a/deploy/equipment_cost_system/frontend/public/favicon.ico b/deploy/equipment_cost_system/frontend/public/favicon.ico deleted file mode 100644 index df36fcf..0000000 Binary files a/deploy/equipment_cost_system/frontend/public/favicon.ico and /dev/null differ diff --git a/deploy/equipment_cost_system/frontend/public/index.html b/deploy/equipment_cost_system/frontend/public/index.html deleted file mode 100644 index 3e5a139..0000000 --- a/deploy/equipment_cost_system/frontend/public/index.html +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - <%= htmlWebpackPlugin.options.title %> - - - -
- - - diff --git a/deploy/equipment_cost_system/frontend/src/App.vue b/deploy/equipment_cost_system/frontend/src/App.vue deleted file mode 100644 index c4278c5..0000000 --- a/deploy/equipment_cost_system/frontend/src/App.vue +++ /dev/null @@ -1,42 +0,0 @@ - - - diff --git a/deploy/equipment_cost_system/frontend/src/api/index.js b/deploy/equipment_cost_system/frontend/src/api/index.js deleted file mode 100644 index 1090d1a..0000000 --- a/deploy/equipment_cost_system/frontend/src/api/index.js +++ /dev/null @@ -1,43 +0,0 @@ -import axios from 'axios' -import { API_BASE_URL } from '@/config' - -const api = axios.create({ - baseURL: API_BASE_URL, - timeout: 10000 -}) - -export const predict = (data) => { - return api.post('/predict', data) -} - -export const analyzeFeatures = (data) => { - return api.post('/analyze-features', data) -} - -export const trainModel = (data) => { - return api.post('/train', data) -} - -export const evaluateModel = (data) => { - return api.post('/evaluate', data) -} - -export const importData = (formData) => { - return api.post('/data/import', formData, { - headers: { - 'Content-Type': 'multipart/form-data' - } - }) -} - -export const getEquipmentData = () => { - return api.get('/data') -} - -export const updateEquipment = (id, data) => { - return api.put(`/data/${id}`, data) -} - -export const deleteEquipment = (id) => { - return api.delete(`/data/${id}`) -} \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/assets/logo.png b/deploy/equipment_cost_system/frontend/src/assets/logo.png deleted file mode 100644 index f3d2503..0000000 Binary files a/deploy/equipment_cost_system/frontend/src/assets/logo.png and /dev/null differ diff --git a/deploy/equipment_cost_system/frontend/src/assets/styles/global.css b/deploy/equipment_cost_system/frontend/src/assets/styles/global.css deleted file mode 100644 index a9e81e1..0000000 --- a/deploy/equipment_cost_system/frontend/src/assets/styles/global.css +++ /dev/null @@ -1,39 +0,0 @@ -/* 禁用 ResizeObserver 警告 */ -iframe[style*="position: fixed; top: 0px; left: 0px; width: 100%; height: 100%; border: none; z-index: 2147483647;"] { - display: none !important; -} - -.el-overlay { - overflow: hidden !important; -} - -/* 添加全局样式 */ -body { - margin: 0; - padding: 0; - overflow-x: hidden; -} - -/* 修复 Element Plus 的一些已知问题 */ -.el-dialog__wrapper { - overflow: hidden !important; -} - -.el-select-dropdown { - overflow: hidden !important; -} - -/* 禁用 ResizeObserver 相关的警告样式 */ -.resize-observer-warning { - display: none !important; -} - -/* 优化滚动行为 */ -* { - scroll-behavior: smooth; -} - -/* 防止页面抖动 */ -.el-main { - overflow-x: hidden; -} \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/config.js b/deploy/equipment_cost_system/frontend/src/config.js deleted file mode 100644 index 0a697ee..0000000 --- a/deploy/equipment_cost_system/frontend/src/config.js +++ /dev/null @@ -1,8 +0,0 @@ -export const API_BASE_URL = 'http://localhost:5001/api'; - -export const DB_CONFIG = { - host: 'localhost', - user: 'root', - password: '123456', - database: 'equipment_cost_db' -}; \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/main.js b/deploy/equipment_cost_system/frontend/src/main.js deleted file mode 100644 index 360a4a4..0000000 --- a/deploy/equipment_cost_system/frontend/src/main.js +++ /dev/null @@ -1,55 +0,0 @@ -import { createApp } from 'vue' -import App from './App.vue' -import router from './router' -import store from './store' -import ElementPlus from 'element-plus' -import 'element-plus/dist/index.css' -import './assets/styles/global.css' -import * as ElementPlusIconsVue from '@element-plus/icons-vue' - -// 创建应用实例 -const app = createApp(App) - -// 注册插件 -app.use(ElementPlus, { - size: 'default' -}) -app.use(router) -app.use(store) - -// 注册图标 -for (const [key, component] of Object.entries(ElementPlusIconsVue)) { - app.component(key, component) -} - -// 全局错误处理 -app.config.errorHandler = (err) => { - if (err.message && err.message.includes('ResizeObserver')) { - return - } - console.error(err) -} - -// 全局警告处理 -app.config.warnHandler = (msg, trace) => { - if (msg.includes('ResizeObserver')) { - return - } - console.warn(msg, trace) -} - -// 挂载应用 -app.mount('#app') - -// 处理 ResizeObserver 错误 -const _ResizeObserver = window.ResizeObserver -window.ResizeObserver = class ResizeObserver extends _ResizeObserver { - constructor(callback) { - super((entries, observer) => { - requestAnimationFrame(() => { - if (!Array.isArray(entries)) return - callback(entries, observer) - }) - }) - } -} diff --git a/deploy/equipment_cost_system/frontend/src/router/index.js b/deploy/equipment_cost_system/frontend/src/router/index.js deleted file mode 100644 index 3dfcdee..0000000 --- a/deploy/equipment_cost_system/frontend/src/router/index.js +++ /dev/null @@ -1,52 +0,0 @@ -import { createRouter, createWebHistory } from 'vue-router' -import HomePage from '@/views/HomePage.vue' -import DataPage from '@/views/DataPage.vue' -import DatasetPage from '@/views/DatasetPage.vue' -import PredictPage from '@/views/PredictPage.vue' -import AnalysisPage from '@/views/AnalysisPage.vue' -import TrainingPage from '@/views/TrainingPage.vue' - -const routes = [ - { - path: '/', - name: 'Home', - component: HomePage - }, - { - path: '/data', - name: 'Data', - component: DataPage - }, - { - path: '/datasets', - name: 'Datasets', - component: DatasetPage - }, - { - path: '/predict', - name: 'Predict', - component: PredictPage - }, - { - path: '/analysis', - name: 'Analysis', - component: AnalysisPage - }, - { - path: '/training', - name: 'Training', - component: TrainingPage - }, - { - path: '/models', - name: 'Models', - component: () => import('../views/ModelPage.vue') - } -] - -const router = createRouter({ - history: createWebHistory(), - routes -}) - -export default router \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/store/index.js b/deploy/equipment_cost_system/frontend/src/store/index.js deleted file mode 100644 index 0faaec7..0000000 --- a/deploy/equipment_cost_system/frontend/src/store/index.js +++ /dev/null @@ -1,14 +0,0 @@ -import { createStore } from 'vuex' - -export default createStore({ - state: { - }, - getters: { - }, - mutations: { - }, - actions: { - }, - modules: { - } -}) \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js b/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js deleted file mode 100644 index de54055..0000000 --- a/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js +++ /dev/null @@ -1,73 +0,0 @@ -// 处理 ResizeObserver 错误 -const resizeHandler = () => { - const resizeObserverErrors = [ - "ResizeObserver loop completed with undelivered notifications.", - "ResizeObserver loop limit exceeded", - "ResizeObserver loop completed" - ] - - // 添加全局错误处理 - const handler = (event) => { - if (event && event.message && resizeObserverErrors.includes(event.message)) { - event.stopPropagation() - event.preventDefault() - event.stopImmediatePropagation() - return false - } - } - - // 添加多个事件监听器 - window.addEventListener('error', handler, true) - window.addEventListener('unhandledrejection', handler, true) - - // 添加 ResizeObserver 错误处理 - if (window.ResizeObserver) { - const resizeObserverPrototype = ResizeObserver.prototype - const originalObserve = resizeObserverPrototype.observe - - resizeObserverPrototype.observe = function (...args) { - try { - return originalObserve.apply(this, args) - } catch (e) { - if (resizeObserverErrors.includes(e.message)) { - return null - } - throw e - } - } - } -} - -export default { - install: (app) => { - resizeHandler() - - // 添加全局错误处理器 - app.config.errorHandler = (err, vm, info) => { - const resizeObserverErrors = [ - "ResizeObserver loop completed with undelivered notifications.", - "ResizeObserver loop limit exceeded", - "ResizeObserver loop completed" - ] - - if (err && err.message && resizeObserverErrors.includes(err.message)) { - return - } - console.error('Vue Error:', err, info) - } - - // 添加全局警告处理器 - app.config.warnHandler = (msg, vm, trace) => { - const resizeObserverErrors = [ - "ResizeObserver loop completed with undelivered notifications.", - "ResizeObserver loop limit exceeded", - "ResizeObserver loop completed" - ] - - if (resizeObserverErrors.includes(msg)) { - return - } - console.warn('Vue Warning:', msg, trace) - } - } -} \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue b/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue deleted file mode 100644 index f8f8ba1..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue +++ /dev/null @@ -1,304 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/DataPage.vue b/deploy/equipment_cost_system/frontend/src/views/DataPage.vue deleted file mode 100644 index 4689f70..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/DataPage.vue +++ /dev/null @@ -1,725 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue b/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue deleted file mode 100644 index 1079611..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue +++ /dev/null @@ -1,322 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/HomePage.vue b/deploy/equipment_cost_system/frontend/src/views/HomePage.vue deleted file mode 100644 index 60c0676..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/HomePage.vue +++ /dev/null @@ -1,101 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue b/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue deleted file mode 100644 index e9632f9..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue +++ /dev/null @@ -1,279 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue b/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue deleted file mode 100644 index a42fa9f..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue +++ /dev/null @@ -1,321 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue b/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue deleted file mode 100644 index edd20c1..0000000 --- a/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue +++ /dev/null @@ -1,370 +0,0 @@ - - - - - \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/vue.config.js b/deploy/equipment_cost_system/frontend/vue.config.js deleted file mode 100644 index b2a2970..0000000 --- a/deploy/equipment_cost_system/frontend/vue.config.js +++ /dev/null @@ -1,5 +0,0 @@ -module.exports = { - devServer: { - port: 8080 - } -} diff --git a/deploy/equipment_cost_system/requirements.txt b/deploy/equipment_cost_system/requirements.txt deleted file mode 100644 index 6eec783..0000000 --- a/deploy/equipment_cost_system/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -flask==2.0.1 -flask-cors==3.0.10 -sqlalchemy==1.4.23 -pymysql==1.0.2 -cryptography==3.4.7 # MySQL 8.0+ 认证需要 -numpy==1.21.2 -pandas==1.3.3 -scikit-learn==0.24.2 -tensorflow==2.6.0 -urllib3<2.0.0 # 添加这一行,限制 urllib3 版本 -openpyxl==3.1.2 # 用于读取 .xlsx 文件 -xlrd==2.0.1 # 用于读取 .xls 文件 \ No newline at end of file diff --git a/deploy/equipment_cost_system/scripts/install.sh b/deploy/equipment_cost_system/scripts/install.sh deleted file mode 100644 index a7a80d6..0000000 --- a/deploy/equipment_cost_system/scripts/install.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash - -echo "开始安装装备成本估算系统..." - -# 检查Python版本 -python3 -V || { - echo "错误: 需要 Python 3.8+" - exit 1 -} - -# 检查Node.js版本 -node -v || { - echo "错误: 需要 Node.js 14+" - exit 1 -} - -# 创建必要的目录 -echo "创建系统目录..." -mkdir -p {logs,data,models} - -# 安装后端依赖 -echo "安装后端依赖..." -python3 -m venv venv -source venv/bin/activate -pip install -r requirements.txt - -# 安装前端依赖 -echo "安装前端依赖..." -cd frontend -npm install -npm run build -cd .. - -# 配置文件 -if [ ! -f config/.env ]; then - echo "创建配置文件..." - cp config/.env.template config/.env - echo "请修改 config/.env 中的配置" -fi - -# 初始化数据库 -echo "初始化数据库..." -read -p "请输入MySQL root密码: " mysqlpass -mysql -u root -p$mysqlpass < src/schema.sql - -# 导入测试数据(可选) -read -p "是否导入测试数据?(y/n) " import_test_data -if [ "$import_test_data" = "y" ]; then - mysql -u root -p$mysqlpass equipment_cost_db < src/init_data.sql -fi - -# 设置权限 -echo "设置文件权限..." -chmod +x scripts/*.sh -chmod 755 logs models data -chmod 600 config/.env - -echo "安装完成!" -echo "请检查并修改 config/.env 中的配置" -echo "使用 ./scripts/start.sh 启动服务" \ No newline at end of file diff --git a/deploy/equipment_cost_system/scripts/start.sh b/deploy/equipment_cost_system/scripts/start.sh deleted file mode 100644 index 5b28ac4..0000000 --- a/deploy/equipment_cost_system/scripts/start.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -echo "启动装备成本估算系统..." - -# 检查配置文件 -if [ ! -f config/.env ]; then - echo "错误: 配置文件不存在" - echo "请先运行 install.sh" - exit 1 -fi - -# 检查日志目录 -if [ ! -d logs ]; then - mkdir -p logs -fi - -# 激活虚拟环境 -source venv/bin/activate - -# 导出环境变量 -export $(cat config/.env | xargs) - -# 启动后端服务 -echo "启动后端服务..." -gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" \ - --daemon \ - --pid gunicorn.pid \ - --access-logfile logs/access.log \ - --error-logfile logs/error.log - -# 等待后端服务启动 -sleep 2 - -# 检查后端服务是否成功启动 -if ! curl -s http://localhost:5001/api/ > /dev/null; then - echo "错误: 后端服务启动失败" - exit 1 -fi - -# 启动前端服务 -echo "启动前端服务..." -cd frontend -npm run serve -- --port 8080 --host 0.0.0.0 & -echo $! > ../frontend.pid -cd .. - -echo "服务已启动!" -echo "后端API: http://localhost:5001" -echo "前端界面: http://localhost:8080" -echo "查看后端日志: tail -f logs/access.log" -echo "查看前端日志: tail -f logs/frontend.log" diff --git a/deploy/equipment_cost_system/scripts/stop.sh b/deploy/equipment_cost_system/scripts/stop.sh deleted file mode 100644 index 085d432..0000000 --- a/deploy/equipment_cost_system/scripts/stop.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -echo "停止装备成本估算系统..." - -# 停止后端服务 -if [ -f gunicorn.pid ]; then - echo "停止后端服务..." - pid=$(cat gunicorn.pid) - kill -TERM $pid - rm gunicorn.pid - echo "后端服务已停止 (PID: $pid)" -else - echo "未找到后端服务PID文件" - # 尝试查找并停止所有gunicorn进程 - pkill -f gunicorn - echo "已尝试停止所有gunicorn进程" -fi - -# 停止前端服务 -if [ -f frontend.pid ]; then - echo "停止前端服务..." - pid=$(cat frontend.pid) - kill -TERM $pid - rm frontend.pid - echo "前端服务已停止 (PID: $pid)" -else - echo "未找到前端服务PID文件" - # 尝试查找并停止前端服务进程 - pkill -f "vite preview" - echo "已尝试停止所有前端服务进程" -fi - -# 检查是否还有相关进程在运行 -if pgrep -f gunicorn > /dev/null; then - echo "警告: 仍有gunicorn进程在运行" - ps aux | grep gunicorn | grep -v grep -fi - -if pgrep -f "vite preview" > /dev/null; then - echo "警告: 仍有前端服务进程在运行" - ps aux | grep "vite preview" | grep -v grep -fi - -echo "所有服务已停止" diff --git a/deploy/equipment_cost_system/src/__init__.py b/deploy/equipment_cost_system/src/__init__.py deleted file mode 100644 index 497b4a4..0000000 --- a/deploy/equipment_cost_system/src/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# 这个文件可以为空,但必须存在 diff --git a/deploy/equipment_cost_system/src/app.py b/deploy/equipment_cost_system/src/app.py deleted file mode 100644 index 037fb79..0000000 --- a/deploy/equipment_cost_system/src/app.py +++ /dev/null @@ -1,50 +0,0 @@ -from flask import Flask -from flask_cors import CORS -from .routes import api_bp -from .logger import setup_logger -import os - -# 获取logger -logger = setup_logger(__name__) - -def create_app(): - """ - 创建并配置Flask应用 - """ - try: - # 创建必要的目录 - os.makedirs('logs', exist_ok=True) - os.makedirs('data', exist_ok=True) - os.makedirs('models', exist_ok=True) - - logger.info("=== Server Starting ===") - logger.info("Initializing directories...") - - # 创建Flask应用 - app = Flask(__name__) - - # 配置CORS - CORS(app) - logger.info("CORS enabled") - - # 注册API蓝图 - app.register_blueprint(api_bp, url_prefix='/api') - logger.info("API blueprint registered") - - # 配置数据库连接 - app.config['MYSQL_HOST'] = 'localhost' - app.config['MYSQL_USER'] = 'root' - app.config['MYSQL_PASSWORD'] = '123456' - app.config['MYSQL_DB'] = 'equipment_cost_db' - - logger.info("Starting server...") - - return app - - except Exception as e: - logger.error(f"Error creating app: {str(e)}") - raise - -if __name__ == '__main__': - app = create_app() - app.run(host='localhost', port=5001) \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/cost_prediction.py b/deploy/equipment_cost_system/src/cost_prediction.py deleted file mode 100644 index 23ccce5..0000000 --- a/deploy/equipment_cost_system/src/cost_prediction.py +++ /dev/null @@ -1,342 +0,0 @@ -import numpy as np -from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score -from sklearn.preprocessing import StandardScaler -import tensorflow as tf -from scipy import stats -import joblib -import os -import pandas as pd -from .feature_analysis import FeatureAnalysis -import logging -from src.model_trainer import ModelTrainer -from src.database import get_db_connection -from .logger import setup_logger - -logger = setup_logger(__name__) - -class CostPredictor: - def __init__(self): - self.scaler_X = StandardScaler() - self.scaler_y = StandardScaler() - self.model = None - self.feature_analyzer = FeatureAnalysis() - self.equipment_type = None - - # 添加 TensorFlow 配置 - tf.config.run_functions_eagerly(False) # 启用图执行模式 - - # 创建预测函数 - @tf.function(reduce_retracing=True, jit_compile=True) - def predict_fn(x): - return self.model(x, training=False) - - self._predict_fn = predict_fn - - self.load_model() - - def load_model(self): - """ - 加载预训练型和标准化器 - """ - try: - model_dir = 'models' - os.makedirs(model_dir, exist_ok=True) - - # 创建默认模型 - self._create_default_model() - - # 创建预测函数 - @tf.function(reduce_retracing=True, jit_compile=True) - def predict_fn(x): - return self.model(x, training=False) - - self._predict_fn = predict_fn - - except Exception as e: - logging.error(f"Error loading model: {str(e)}") - self._create_default_model() - - def _create_default_model(self): - """ - 创建默认模型并进行初始化训练 - """ - # 创建输入层 - inputs = tf.keras.Input(shape=(11,)) - - # 创建隐藏层 - x = tf.keras.layers.Dense(64, activation='relu')(inputs) - x = tf.keras.layers.Dense(32, activation='relu')(x) - - # 创建输出层 - outputs = tf.keras.layers.Dense(1)(x) - - # 创建模型 - self.model = tf.keras.Model(inputs=inputs, outputs=outputs) - - # 编译模型 - self.model.compile( - optimizer='adam', - loss=tf.keras.losses.mean_squared_error, - metrics=[tf.keras.metrics.mean_absolute_error] - ) - - # 创建示例数据 - example_data = pd.DataFrame({ - 'length_m': [7.35, 10.2], - 'width_m': [2.4, 2.8], - 'height_m': [3.1, 3.2], - 'weight_kg': [13700, 28500], - 'max_range_km': [20.4, 70], - 'firing_angle_horizontal': [102, 110], - 'firing_angle_vertical': [55, 60], - 'rocket_length_m': [2.87, 4.1], - 'rocket_diameter_mm': [122, 220], - 'rocket_weight_kg': [66.6, 150], - 'rate_of_fire': [40, 60] - }) - - # 训练标准化器 - self.scaler_X.fit(example_data) - self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围 - - # 设置默认装备类型 - self.equipment_type = '火箭炮' - - def _create_example_data(self): - """ - 创建示例数据来训练标准化器 - """ - # 火箭炮示例数据 - rocket_data = pd.DataFrame({ - 'length_m': [7.35, 10.2], - 'width_m': [2.4, 2.8], - 'height_m': [3.1, 3.2], - 'weight_kg': [13700, 28500], - 'max_range_km': [20.4, 70], - 'firing_angle_horizontal': [102, 110], - 'firing_angle_vertical': [55, 60], - 'rocket_length_m': [2.87, 4.1], - 'rocket_diameter_mm': [122, 220], - 'rocket_weight_kg': [66.6, 150], - 'rate_of_fire': [40, 60] - }) - - # 巡飞弹示例数据 - missile_data = pd.DataFrame({ - 'length_m': [1.3, 2.5], - 'width_m': [0.23, 0.6], - 'height_m': [0.23, 0.6], - 'weight_kg': [12.5, 135], - 'max_range_km': [40, 250], - 'max_speed_kmh': [180, 185], - 'cruise_speed_kmh': [100, 110], - 'flight_time_min': [60, 120], - 'folded_length_mm': [1300, 2500], - 'folded_width_mm': [230, 600], - 'folded_height_mm': [230, 600] - }) - - # 训练标准化器 - self.scaler_X.fit(rocket_data) # 使用火箭炮数据 - self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据 - - # 设置默认装备类型 - self.equipment_type = '火箭炮' - - def predict(self, data): - """ - 使用训练好的最优模型进行预测 - """ - try: - logger.info(f"Starting prediction for {data.get('type')}") - equipment_type = data.get('type') - - # 加载已训练的最优模型 - trainer = ModelTrainer() - if not trainer.load_model(equipment_type): - raise ValueError(f"No trained model found for {equipment_type}") - - # 准备特征数据 - features = self.feature_analyzer.get_equipment_specific_features(equipment_type) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - y_pred = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) - - # 获取模型类型 - model_type = trainer.get_model_type() - - return { - 'predicted_cost': float(y_pred[0]), - 'model_type': model_type, # 返回使用的模型类型 - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } - - except Exception as e: - logger.error(f"Prediction error: {str(e)}") - raise - - def _calculate_confidence_interval(self, prediction, confidence=0.95): - """ - 计算预测值的置信区间 - """ - try: - # 使用预测值的20%作为标准差(增加不确定性) - std = abs(prediction) * 0.2 - - # 计算置信区间 - from scipy import stats - interval = stats.norm.interval(confidence, loc=prediction, scale=std) - - # 确保区间值为正数且合理 - lower = max(1000, interval[0]) # 最小值设为1000元 - upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20% - - logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]") - - return [lower, upper] - - except Exception as e: - logging.error(f"Error calculating confidence interval: {str(e)}") - # 如果计算失败,返回基于20%的简单区间 - lower = max(1000, prediction * 0.8) - upper = prediction * 1.2 - return [lower, upper] - - def evaluate(self, y_true, y_pred): - """ - 模型评估 - """ - return { - 'mae': float(mean_absolute_error(y_true, y_pred)), - 'mse': float(mean_squared_error(y_true, y_pred)), - 'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))), - 'r2': float(r2_score(y_true, y_pred)) - } - - def predict_pls(self, data): - """ - 使用 PLS ���型预测成本 - """ - try: - logger.info(f"Starting PLS prediction for {data.get('type')}") - equipment_type = data.get('type') - - # 加载 PLS 模型 - trainer = ModelTrainer() - if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型 - raise ValueError(f"No trained PLS model found for {equipment_type}") - - # 准备特征数据 - features = self.feature_analyzer.get_equipment_specific_features(equipment_type) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - y_pred = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) - - return { - 'predicted_cost': float(y_pred[0]), - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } - - except Exception as e: - logger.error(f"PLS prediction error: {str(e)}") - raise - - def predict_all(self, data): - """ - 使用所有可用模型进行预测 - """ - try: - logger.info(f"Starting multi-model prediction for {data.get('type')}") - equipment_type = data.get('type') - results = {} - - # 1. 获取所有激活的模型 - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT id, model_type, model_name, r2_score, mae, rmse - FROM trained_models - WHERE equipment_type = %s AND is_active = TRUE - """, (equipment_type,)) - active_models = cursor.fetchall() - - if not active_models: - raise ValueError(f"No active models found for {equipment_type}") - - # 2. 使用每个模型进行预测 - trainer = ModelTrainer() - for model_info in active_models: - try: - # 加载特定模型 - if not trainer.load_model(equipment_type, model_type=model_info['model_type']): - logger.warning(f"Failed to load model: {model_info['model_name']}") - continue - - # 准备特征数据 - features = self.feature_analyzer.get_equipment_specific_features(equipment_type) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - y_pred = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) - - # 保存结果 - results[model_info['model_type']] = { - 'predicted_cost': float(y_pred[0]), - 'model_info': { - 'name': model_info['model_name'], - 'type': model_info['model_type'], - 'r2_score': float(model_info['r2_score']), - 'mae': float(model_info['mae']), - 'rmse': float(model_info['rmse']) - }, - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } - - except Exception as e: - logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}") - continue - - if not results: - raise ValueError("No successful predictions from any model") - - # 3. 计算综合预测结果 - all_predictions = [result['predicted_cost'] for result in results.values()] - ensemble_prediction = float(np.mean(all_predictions)) - prediction_std = float(np.std(all_predictions)) - - # 4. 返回所有结果 - return { - 'individual_predictions': results, - 'ensemble_prediction': { - 'predicted_cost': ensemble_prediction, - 'standard_deviation': prediction_std, - 'confidence_interval': { - 'lower': float(ensemble_prediction - 1.96 * prediction_std), - 'upper': float(ensemble_prediction + 1.96 * prediction_std) - } - } - } - - except Exception as e: - logger.error(f"Error in multi-model prediction: {str(e)}") - raise \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/create_template.py b/deploy/equipment_cost_system/src/create_template.py deleted file mode 100644 index 0f7d545..0000000 --- a/deploy/equipment_cost_system/src/create_template.py +++ /dev/null @@ -1,155 +0,0 @@ -import pandas as pd -import openpyxl -from openpyxl.styles import PatternFill, Font, Alignment -from openpyxl.worksheet.datavalidation import DataValidation -import os -from .logger import setup_logger - -logger = setup_logger(__name__) - -def create_excel_template(): - """ - 创建数据模板 - """ - try: - # 确保data目录存在 - data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') - os.makedirs(data_dir, exist_ok=True) - - # 创建完整的文件路径 - template_path = os.path.join(data_dir, 'equipment_data_template.xlsx') - - # 创建 Excel 写入器 - writer = pd.ExcelWriter(template_path, engine='openpyxl') - - # 火箭炮基本参数表 - rocket_artillery_columns = [ - '名称', '类型', '制造商', '口径_mm', - '反射管数量', '乘员数', '总长_m', - '宽度_m', '高度_m', '重量_kg', - '战斗重_kg', '速度_km/h', '最大射程_km', - '最小射程_km', '方向射界_度', '高低射界_度', - '火箭弹长度_m', '火箭弹重量_kg', - '火箭弹最大速度_m/s', '射速_发', - '战斗部重量_kg', '行走方式', - '结构布局', '发动机型号', '发动机参数', - '功率_hp', '行程_km', '成本_元' - ] - - # 巡飞弹基本参数表 - loitering_munition_columns = [ - '名称', '类型', '制造商', '目标类型', - '弹长_m', '弹径_mm', '翼展_m', - '重量_kg', '战斗部重量_kg', - '最大射程_km', '最大速度_m/s', - '巡航速度_kmh', '巡飞时间_min', - '战斗部类型', '发射方式', - '折叠长度_mm', '折叠宽度_mm', - '折叠高度_mm', '动力装置', - '制导体制', '成本_元' - ] - - # 特殊参数表 - special_params_columns = [ - '装备名称', # 关联字段 - '参数名称', - '参数值', - '参数单位', - '参数说明' - ] - - # 创建工作表 - pd.DataFrame(columns=rocket_artillery_columns).to_excel( - writer, sheet_name='火箭炮', index=False - ) - pd.DataFrame(columns=loitering_munition_columns).to_excel( - writer, sheet_name='巡飞弹', index=False - ) - pd.DataFrame(columns=special_params_columns).to_excel( - writer, sheet_name='特殊参数', index=False - ) - - # 获取工作簿 - workbook = writer.book - - # 设置火箭炮工作表格式 - rocket_sheet = workbook['火箭炮'] - for col in range(1, len(rocket_artillery_columns) + 1): - cell = rocket_sheet.cell(row=1, column=col) - cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid') - cell.font = Font(bold=True) - cell.alignment = Alignment(horizontal='center') - - # 设置巡飞弹工作表格式 - missile_sheet = workbook['巡飞弹'] - for col in range(1, len(loitering_munition_columns) + 1): - cell = missile_sheet.cell(row=1, column=col) - cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid') - cell.font = Font(bold=True) - cell.alignment = Alignment(horizontal='center') - - # 设置特殊参数工作表格式 - special_sheet = workbook['特殊参数'] - for col in range(1, len(special_params_columns) + 1): - cell = special_sheet.cell(row=1, column=col) - cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid') - cell.font = Font(bold=True) - cell.alignment = Alignment(horizontal='center') - - # 添加数据验证 - for sheet in [rocket_sheet, missile_sheet]: - # 数值验证 - number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0") - number_validation.error = "请输入大于0的数值" - number_validation.errorTitle = "输入错误" - - # 应用到相应列 - for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']: - number_validation.add(f"{col}2:{col}1000") - - sheet.add_data_validation(number_validation) - - # 添加说明 - rocket_sheet['AD1'] = "填写说明:" - rocket_sheet['AD2'] = "1. 所有数值必须大于0" - rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写" - rocket_sheet['AD4'] = "3. 成本单位为元" - - missile_sheet['V1'] = "填写说明:" - missile_sheet['V2'] = "1. 所有数值必须大于0" - missile_sheet['V3'] = "2. 单位必须按照表头要求填写" - missile_sheet['V4'] = "3. 成本单位为元" - - special_sheet['G1'] = "填写说明:" - special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致" - special_sheet['G3'] = "2. 参数值必须包含单位" - special_sheet['G4'] = "3. 参数说明应简明扼要" - - # 调整列宽 - for sheet in [rocket_sheet, missile_sheet, special_sheet]: - for col in sheet.columns: - max_length = 0 - column = col[0].column_letter - for cell in col: - try: - if len(str(cell.value)) > max_length: - max_length = len(str(cell.value)) - except: - pass - adjusted_width = (max_length + 2) - sheet.column_dimensions[column].width = adjusted_width - - # 保存文件 - writer.close() - - return template_path - - except Exception as e: - raise Exception(f"创建模板文件失败: {str(e)}") - -if __name__ == "__main__": - try: - template_path = create_excel_template() - print(f"模板文件已创建: {template_path}") - except Exception as e: - print(f"错误: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/data_preparation.py b/deploy/equipment_cost_system/src/data_preparation.py deleted file mode 100644 index 6380647..0000000 --- a/deploy/equipment_cost_system/src/data_preparation.py +++ /dev/null @@ -1,233 +0,0 @@ -from sklearn.preprocessing import StandardScaler -from datetime import datetime -import os -import joblib -import pandas as pd -import numpy as np -from src.feature_analysis import FeatureAnalysis -from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor -from xgboost import XGBRegressor -from lightgbm import LGBMRegressor -from sklearn.model_selection import cross_val_score, LeaveOneOut -import json -import logging -from src.database.db_connection import get_db_connection -from sklearn.metrics import mean_absolute_error, mean_squared_error -from .logger import setup_logger - -logger = setup_logger(__name__) - -class DataPreparation: - def __init__(self): - self.feature_analyzer = FeatureAnalysis() - self.feature_scaler = StandardScaler() - self.target_scaler = StandardScaler() # 添加目标值标准化器 - - def prepare_training_data(self, equipment_data, equipment_type): - """ - 准备训练数据 - """ - try: - logger.info(f"Preparing training data for {equipment_type}") - logger.info(f"Raw data size: {len(equipment_data)}") - - # 如果输入已经是 numpy 数组,直接返回 - if isinstance(equipment_data, np.ndarray): - X = equipment_data - logger.info(f"Input is already numpy array with shape: {X.shape}") - - # 处理无效值 - X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) - - return { - 'X': X, - 'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type), - 'feature_scaler': self.feature_scaler, - 'target_scaler': self.target_scaler - } - - # 从原始数据中提取特征和目标值 - feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) - features = [] - targets = [] - - for item in equipment_data: - # 提取特征值 - feature_values = [] - for name in feature_names: - value = item.get(name) - try: - feature_values.append(float(value) if value is not None else 0.0) - except (ValueError, TypeError): - feature_values.append(0.0) - features.append(feature_values) - - # 提取目标值(成本) - try: - cost = float(item['actual_cost']) - if cost > 0: # 只使用正数成本值 - targets.append(cost) - else: - logger.warning(f"Skipping non-positive cost value: {cost}") - except (ValueError, TypeError, KeyError): - logger.error(f"Invalid cost value: {item.get('actual_cost')}") - continue - - # 转换为numpy数组 - X = np.array(features, dtype=float) - y = np.array(targets, dtype=float) - - # 记录原始数据范围 - logger.info(f"Raw X range: min={X.min()}, max={X.max()}") - logger.info(f"Raw y range: min={y.min()}, max={y.max()}") - - # 标准化特征和目标值 - X_scaled = self.feature_scaler.fit_transform(X) - y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel() - - # 记录标准化后的数据范围 - logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}") - logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}") - - # 记录标准化器参数 - logger.info("Feature scaler params:") - logger.info(f"Mean: {self.feature_scaler.mean_}") - logger.info(f"Scale: {self.feature_scaler.scale_}") - - logger.info("Target scaler params:") - logger.info(f"Mean: {self.target_scaler.mean_}") - logger.info(f"Scale: {self.target_scaler.scale_}") - - return { - 'X': X_scaled, - 'y': y_scaled, - 'feature_names': feature_names, - 'feature_scaler': self.feature_scaler, - 'target_scaler': self.target_scaler - } - - except Exception as e: - logger.error(f"Error in data preparation: {str(e)}") - raise Exception(f"Training error: {str(e)}") - - def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None): - """ - 准备验证数据 - """ - try: - logger.info(f"Preparing validation data for {equipment_type}") - logger.info(f"Raw validation data size: {len(validation_data)}") - - # 如果输入已经是 numpy 数组,直接使用 - if isinstance(validation_data, np.ndarray): - X = validation_data - logger.info(f"Input is already numpy array with shape: {X.shape}") - - # 处理无效值 - X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) - - # 使用训练数据的标准化器 - if scalers and 'feature_scaler' in scalers: - X_scaled = scalers['feature_scaler'].transform(X) - else: - X_scaled = X - - logger.info(f"Preprocessed data shape: {X_scaled.shape}") - logger.info(f"Validation features shape: {X_scaled.shape}") - logger.info(f"Validation features type: {X_scaled.dtype}") - - return { - 'X': X_scaled, - 'y': None # 验证数据可能没有标签 - } - - # 从原始数据中提取特征和目标值 - if not feature_names: - feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) - - features = [] - targets = [] - - for item in validation_data: - # 提取特征值 - feature_values = [] - for name in feature_names: - value = item.get(name) - try: - feature_values.append(float(value) if value is not None else 0.0) - except (ValueError, TypeError): - feature_values.append(0.0) - features.append(feature_values) - - # 提取目标值(成本)并验证范围 - try: - cost = float(item['actual_cost']) - logger.info(f"Raw cost value: {cost}") - if cost > 0: # 只使用正数成本值 - targets.append(cost) - else: - logger.warning(f"Skipping non-positive cost value: {cost}") - except (ValueError, TypeError): - logger.error(f"Invalid cost value: {item.get('actual_cost')}") - continue - - # 转换为numpy数组 - X = np.array(features, dtype=float) - y = np.array(targets, dtype=float) - - # 记录数据范围 - logger.info(f"Features range: min={X.min()}, max={X.max()}") - logger.info(f"Targets range: min={y.min()}, max={y.max()}") - - # 处理无效值 - X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) - - # 使用训练数据的标准化器 - if scalers and 'feature_scaler' in scalers: - X_scaled = scalers['feature_scaler'].transform(X) - if 'target_scaler' in scalers: - y_scaled = scalers['target_scaler'].transform(y.reshape(-1, 1)).ravel() - else: - y_scaled = y - else: - X_scaled = X - y_scaled = y - - logger.info(f"Preprocessed data shape: {X_scaled.shape}") - logger.info(f"Validation features shape: {X_scaled.shape}") - logger.info(f"Validation features type: {X_scaled.dtype}") - - # 记录标准化后的数据范围 - logger.info(f"Scaled validation X range: min={X_scaled.min()}, max={X_scaled.max()}") - logger.info(f"Scaled validation y range: min={y_scaled.min()}, max={y_scaled.max()}") - - # 确保特征维度一致 - if not feature_names: - feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) - - logger.info(f"Expected features: {len(feature_names)}") - logger.info(f"Actual features: {X_scaled.shape[1]}") - - if X_scaled.shape[1] != len(feature_names): - raise ValueError(f"Feature dimension mismatch: expected {len(feature_names)}, got {X_scaled.shape[1]}") - - return { - 'X': X_scaled, - 'y': y_scaled # 返回标准化后的成本值 - } - - except Exception as e: - logger.error(f"Error in validation data preparation: {str(e)}") - logger.error(f"Feature names: {feature_names}") - logger.error(f"Equipment type: {equipment_type}") - raise Exception(f"Validation error: {str(e)}") - - def calculate_derived_features(self, data, equipment_type): - """ - 计算衍生特征 - """ - try: - return self.feature_analyzer.calculate_derived_features(data, equipment_type) - except Exception as e: - logger.error(f"Error calculating derived features: {str(e)}") - raise Exception(f"Feature calculation error: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/database/__init__.py b/deploy/equipment_cost_system/src/database/__init__.py deleted file mode 100644 index 6df49c9..0000000 --- a/deploy/equipment_cost_system/src/database/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .db_connection import get_db_connection \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/database/db_connection.py b/deploy/equipment_cost_system/src/database/db_connection.py deleted file mode 100644 index 361a4d2..0000000 --- a/deploy/equipment_cost_system/src/database/db_connection.py +++ /dev/null @@ -1,37 +0,0 @@ -import mysql.connector -from mysql.connector import Error -from contextlib import contextmanager -import os -from dotenv import load_dotenv -from ..logger import setup_logger - -# 获取logger -logger = setup_logger(__name__) - -# 加载环境变量 -load_dotenv() - -@contextmanager -def get_db_connection(): - """ - 数据库连接上下文管理器 - """ - connection = None - try: - connection = mysql.connector.connect( - host=os.getenv('MYSQL_HOST', 'localhost'), - user=os.getenv('MYSQL_USER', 'root'), - password=os.getenv('MYSQL_PASSWORD', '123456'), - database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db') - ) - logger.info("Database connection established") - yield connection - - except Error as e: - logger.error(f"Error connecting to MySQL: {str(e)}") - raise - - finally: - if connection and connection.is_connected(): - connection.close() - logger.info("Database connection closed") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/feature_analysis.py b/deploy/equipment_cost_system/src/feature_analysis.py deleted file mode 100644 index 5b87972..0000000 --- a/deploy/equipment_cost_system/src/feature_analysis.py +++ /dev/null @@ -1,269 +0,0 @@ -import numpy as np -import pandas as pd -from scipy import stats -from sklearn.preprocessing import StandardScaler -from sklearn.ensemble import RandomForestRegressor -from sklearn.metrics import r2_score -import logging -from .logger import setup_logger - -logger = setup_logger(__name__) - -class FeatureAnalysis: - def __init__(self): - self.scaler = StandardScaler() - self.important_features = [] - - # 添加特征名称映射 - self.feature_names_map = { - # 通用参数 - 'length_m': '总长(m)', - 'width_m': '宽度(m)', - 'height_m': '高度(m)', - 'weight_kg': '重量(kg)', - 'max_range_km': '最大射程(km)', - - # 火箭炮特有参数 - 'firing_angle_horizontal': '方向射界(度)', - 'firing_angle_vertical': '高低射界(度)', - 'rocket_length_m': '火箭弹长度(m)', - 'rocket_diameter_mm': '口径(mm)', - 'rocket_weight_kg': '火箭弹重量(kg)', - 'rate_of_fire': '射速(发/分)', - 'combat_weight_kg': '战斗重量(kg)', - 'speed_kmh': '速度(km/h)', - 'min_range_km': '最小射程(km)', - 'power_hp': '功率(hp)', - - # 火箭炮衍生特征 - 'fire_density': '火力密度', - 'mobility_index': '机动性指标', - 'range_ratio': '射程比', - 'power_weight_ratio': '功重比', - 'volume_density': '体积密度', - - # 巡飞弹特有参数 - 'wingspan_m': '翼展(m)', - 'warhead_weight_kg': '战斗部重量(kg)', - 'max_speed_ms': '最大速度(m/s)', - 'cruise_speed_kmh': '巡航速度(km/h)', - 'flight_time_min': '巡飞时间(min)', - 'folded_length_mm': '折叠长度(mm)', - 'folded_width_mm': '折叠宽度(mm)', - 'folded_height_mm': '折叠高度(mm)', - - # 巡飞弹衍生特征 - 'warhead_ratio': '战斗部比重', - 'speed_ratio': '速度比', - 'range_time_ratio': '射程时间比', - 'aspect_ratio': '展弦比', - 'volume_density': '体积密度' - } - - def get_equipment_specific_features(self, equipment_type): - """ - 获取特定装备类型的特征列表 - """ - # 通用参数 - common_features = [ - 'length_m', # 总长(m) - 'width_m', # 宽度(m) - 'height_m', # 高度(m) - 'weight_kg', # 重量(kg) - 'max_range_km' # 最大射程(km) - ] - - if equipment_type == '火箭炮': - # 火箭炮特有参数 - specific_features = [ - 'firing_angle_horizontal', # 方向射界(度) - 'firing_angle_vertical', # 高低射界(度) - 'rocket_length_m', # 火箭弹长度(m) - 'rocket_diameter_mm', # 口径(mm) - 'rocket_weight_kg', # 火箭弹重量(kg) - 'rate_of_fire', # 射速(发/分) - 'combat_weight_kg', # 战斗重量(kg) - 'speed_kmh', # 速度(km/h) - 'min_range_km', # 最小射程(km) - 'power_hp' # 功率(hp) - ] - - # 火箭炮衍生特征 - derived_features = [ - 'fire_density', # 火力密度 = 射速 * 火箭弹重量 - 'mobility_index', # 机动性指标 = 速度 / 战斗重量 - 'range_ratio', # 射程比 = 最大射程 / 最小射程 - 'power_weight_ratio', # 功重比 = 功率 / 战斗重量 - 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高) - ] - - return common_features + specific_features + derived_features - - else: # 巡飞弹 - # 巡飞弹特有参数 - specific_features = [ - 'wingspan_m', # 翼展(m) - 'warhead_weight_kg', # 战斗部重量(kg) - 'max_speed_ms', # 最大速度(m/s) - 'cruise_speed_kmh', # 巡航速度(km/h) - 'flight_time_min', # 巡飞时间(min) - 'folded_length_mm', # 折叠长度(mm) - 'folded_width_mm', # 折叠宽度(mm) - 'folded_height_mm' # 折叠高度(mm) - ] - - # 巡飞弹衍生特征 - derived_features = [ - 'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量 - 'speed_ratio', # 速度比 = 巡航速度 / 最大速度 - 'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间 - 'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积 - 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高) - ] - - return common_features + specific_features + derived_features - - def calculate_derived_features(self, data, equipment_type): - """ - 计算衍生特征 - """ - try: - if equipment_type == '火箭炮': - # 火箭炮衍生特征计算 - if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns: - data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg'] - else: - data['fire_density'] = 0 # 或者其他默认值 - - if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns: - data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg'] - else: - data['mobility_index'] = 0 - - if 'max_range_km' in data.columns and 'min_range_km' in data.columns: - data['range_ratio'] = data['max_range_km'] / data['min_range_km'] - else: - data['range_ratio'] = 0 - - if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns: - data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg'] - else: - data['power_weight_ratio'] = 0 - - if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']): - data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m']) - else: - data['volume_density'] = 0 - - else: # 巡飞弹 - # 巡飞弹衍生特征计算 - if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns: - data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg'] - else: - data['warhead_ratio'] = 0 - - if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns: - data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6) - else: - data['speed_ratio'] = 0 - - if 'max_range_km' in data.columns and 'flight_time_min' in data.columns: - data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min'] - else: - data['range_time_ratio'] = 0 - - if 'wingspan_m' in data.columns and 'length_m' in data.columns: - data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m'] - else: - data['aspect_ratio'] = 0 - - if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']): - data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m']) - else: - data['volume_density'] = 0 - - return data - - except Exception as e: - logger.error(f"Error calculating derived features: {str(e)}") - raise - - def analyze_features(self, features, target, feature_names): - """ - 分析特征重要性和相关性 - """ - try: - # 转换为numpy数组 - X = np.array(features) - y = np.array(target) - - # 数据标准化 - X_scaled = self.scaler.fit_transform(X) - - # 特征重要性分析 - rf = RandomForestRegressor(n_estimators=100, random_state=42) - rf.fit(X_scaled, y) - importances = rf.feature_importances_ - - # 按重要性排序,使用中文特征名 - importance_indices = np.argsort(importances)[::-1] - important_features = [ - { - 'name': self.feature_names_map.get(feature_names[i], feature_names[i]), - 'importance': float(importances[i]) - } - for i in importance_indices - ] - - # 相关性分析 - df = pd.DataFrame(X_scaled, columns=feature_names) - correlation_matrix = df.corr().values - - # 生成相关性分析数据,保留2位小数 - correlation_data = [] - chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names] - for i in range(len(feature_names)): - for j in range(len(feature_names)): - correlation_data.append([ - i, j, - round(correlation_matrix[i][j], 2) # 修改为保留2位小数 - ]) - - return { - 'important_features': important_features, - 'correlation_analysis': { - 'features': chinese_feature_names, # 使用中文特征名 - 'matrix': correlation_data - } - } - - except Exception as e: - logger.error(f"Error in feature analysis: {str(e)}") - raise - - def preprocess_features(self, equipment_data, equipment_type): - """ - 预处理特征数据 - """ - try: - # 转换为 DataFrame - df = pd.DataFrame(equipment_data) - - # 计算衍生特征 - df = self.calculate_derived_features(df, equipment_type) - - # 处理缺失值 - numeric_columns = df.select_dtypes(include=[np.number]).columns - for col in numeric_columns: - # 转换为数值类型 - df[col] = pd.to_numeric(df[col], errors='coerce') - # 使用新的方式填充缺失值 - mean_value = df[col].mean() - df[col] = df[col].fillna(mean_value) - - logger.info(f"Preprocessed data shape: {df.shape}") - return df - - except Exception as e: - logger.error(f"Error preprocessing features: {str(e)}") - raise Exception(f"Feature preprocessing error: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/import_data.py b/deploy/equipment_cost_system/src/import_data.py deleted file mode 100644 index 03dfa36..0000000 --- a/deploy/equipment_cost_system/src/import_data.py +++ /dev/null @@ -1,255 +0,0 @@ -import pandas as pd -from .logger import setup_logger -from src.database.db_connection import get_db_connection - -logger = setup_logger(__name__) - -def import_training_data(excel_file): - """ - 从Excel导入训练数据到数据库 - """ - try: - # 读取所有sheet - rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮') - missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹') - special_df = pd.read_excel(excel_file, sheet_name='特殊参数') - - # 记录所有装备名称,用于后续检查 - equipment_names = set() - - with get_db_connection() as conn: - cursor = conn.cursor() - - # 1. 先导入火箭炮数据 - logger.info("开始导入火箭炮数据...") - for _, row in rocket_df.iterrows(): - equipment_names.add(row['名称']) - # 检查是否已存在相同名称的装备 - cursor.execute(""" - SELECT id FROM equipment - WHERE name = %s AND type = '火箭炮' - """, (row['名称'],)) - - existing_equipment = cursor.fetchone() - if existing_equipment: - logger.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入") - continue - - # 插入基本信息 - cursor.execute(""" - INSERT INTO equipment (name, type, manufacturer) - VALUES (%s, %s, %s) - """, (row['名称'], '火箭炮', row['制造商'])) - - equipment_id = cursor.lastrowid - - # 插入通用参数 - cursor.execute(""" - INSERT INTO common_params - (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) - VALUES (%s, %s, %s, %s, %s, %s) - """, ( - equipment_id, - row['总长_m'] if pd.notna(row['总长_m']) else None, - row['宽度_m'] if pd.notna(row['宽度_m']) else None, - row['高度_m'] if pd.notna(row['高度_m']) else None, - row['重量_kg'] if pd.notna(row['重量_kg']) else None, - row['最大射程_km'] if pd.notna(row['最大射程_km']) else None - )) - - # 插入火箭炮特有参数 - cursor.execute(""" - INSERT INTO rocket_artillery_params - (equipment_id, firing_angle_horizontal, firing_angle_vertical, - rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire, - combat_weight_kg, speed_kmh, min_range_km, mobility_type, - structure_layout, engine_model, engine_params, power_hp, - travel_range_km) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, ( - equipment_id, - row['方向射界_度'] if pd.notna(row['方向射界_度']) else None, - row['高低射界_度'] if pd.notna(row['高低射界_度']) else None, - row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None, - row['口径_mm'] if pd.notna(row['口径_mm']) else None, - row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None, - row['射速_发'] if pd.notna(row['射速_发']) else None, - row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None, - row['速度_km/h'] if pd.notna(row['速度_km/h']) else None, - row['最小射程_km'] if pd.notna(row['最小射程_km']) else None, - row['行走方式'] if pd.notna(row['行走方式']) else None, - row['结构布局'] if pd.notna(row['结构布局']) else None, - row['发动机型号'] if pd.notna(row['发动机型号']) else None, - row['发动机参数'] if pd.notna(row['发动机参数']) else None, - row['功率_hp'] if pd.notna(row['功率_hp']) else None, - row['行程_km'] if pd.notna(row['行程_km']) else None - )) - - # 插入成本数据 - if pd.notna(row['成本_元']): - cursor.execute(""" - INSERT INTO cost_data (equipment_id, actual_cost) - VALUES (%s, %s) - """, (equipment_id, row['成本_元'])) - - logger.info("火箭炮数据导入完成") - - # 2. 导入巡飞弹数据 - logger.info("开始导入巡飞弹数据...") - for index, row in missile_df.iterrows(): - # 记录每行数据的空值情况 - null_values = row[row.isna()].index.tolist() - if null_values: - logger.info(f"行 {index + 2} 中的空值字段: {null_values}") - - equipment_names.add(row['名称']) - # 检查是否已存在相同名称的装备 - cursor.execute(""" - SELECT id FROM equipment - WHERE name = %s AND type = '巡飞弹' - """, (row['名称'],)) - - existing_equipment = cursor.fetchone() - if existing_equipment: - logger.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入") - continue - - # 插入基本信息 - cursor.execute(""" - INSERT INTO equipment (name, type, manufacturer) - VALUES (%s, %s, %s) - """, ( - row['名称'], - '巡飞弹', - row['制造商'] if pd.notna(row['制造商']) else None - )) - - equipment_id = cursor.lastrowid - - # 插入通用参数 - cursor.execute(""" - INSERT INTO common_params - (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) - VALUES (%s, %s, %s, %s, %s, %s) - """, ( - equipment_id, - float(row['弹长_m']) if pd.notna(row['弹长_m']) else None, - float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米 - float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米 - float(row['重量_kg']) if pd.notna(row['重量_kg']) else None, - float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None - )) - - # 插入巡飞弹特有参数 - cursor.execute(""" - INSERT INTO loitering_munition_params - (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, - cruise_speed_kmh, flight_time_min, warhead_type, launch_mode, - folded_length_mm, folded_width_mm, folded_height_mm, - power_system, guidance_system) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, ( - equipment_id, - float(row['翼展_m']) if pd.notna(row['翼展_m']) else None, - float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None, - float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None, - float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None, - float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None, - str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None, - str(row['发射方式']) if pd.notna(row['发射方式']) else None, - float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None, - float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None, - float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None, - str(row['动力装置']) if pd.notna(row['动力装置']) else None, - str(row['制导体制']) if pd.notna(row['制导体制']) else None - )) - - # 插入成本数据 - if pd.notna(row['成本_元']): - cursor.execute(""" - INSERT INTO cost_data (equipment_id, actual_cost) - VALUES (%s, %s) - """, (equipment_id, float(row['成本_元']))) - - logger.info("巡飞弹数据导入完成") - - # 提交之前的更改并关闭原有游标 - cursor.close() - conn.commit() - - # 3. 导入特殊参数 - logger.info("开始导入特殊参数...") - for index, row in special_df.iterrows(): - equipment_name = row['装备名称'] - param_name = row['参数名称'] - logger.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'") - - if equipment_name not in equipment_names: - logger.warning(f"未找到装备: {equipment_name},请检查名称是否正确") - continue - - # 获取装备ID - 使用新的游标 - logger.debug(f"查询装备ID: {equipment_name}") - with conn.cursor() as id_cursor: - id_cursor.execute(""" - SELECT id FROM equipment WHERE name = %s - """, (equipment_name,)) - result = id_cursor.fetchone() - - if not result: - logger.warning(f"未找到装备: {equipment_name}") - continue - - equipment_id = result[0] - logger.debug(f"找到装备ID: {equipment_id}") - - # 检查参数是否存在 - 使用新的游标 - logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'") - with conn.cursor() as check_cursor: - check_cursor.execute(""" - SELECT id FROM custom_params - WHERE equipment_id = %s AND param_name = %s - """, (equipment_id, param_name)) - exists = check_cursor.fetchone() - - if exists: - logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入") - continue - - # 插入新的参数 - 使用新的游标 - param_value = str(row['参数值']) if pd.notna(row['参数值']) else None - param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None - param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None - - logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'") - with conn.cursor() as insert_cursor: - insert_cursor.execute(""" - INSERT INTO custom_params - (equipment_id, param_name, param_value, param_unit, description) - VALUES (%s, %s, %s, %s, %s) - """, ( - equipment_id, - param_name, - param_value, - param_unit, - param_desc - )) - logger.debug(f"成功插入参数记录") - - # 最终提交 - conn.commit() - logger.info("特殊参数导入完成") - logger.info("所有数据导入成功") - return True - - except Exception as e: - logger.error(f"Error importing data: {str(e)}") - raise - -if __name__ == "__main__": - try: - excel_file = 'data/equipment_data_20241108.xlsx' - import_training_data(excel_file) - logger.info("All data imported successfully") - except Exception as e: - logger.error(f"Import failed: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/init_data.sql b/deploy/equipment_cost_system/src/init_data.sql deleted file mode 100644 index ee99db3..0000000 --- a/deploy/equipment_cost_system/src/init_data.sql +++ /dev/null @@ -1,319 +0,0 @@ -/* -这是用于开发和测试环境的示例数据。 -生产环境请使用系统的数据导入功能添加实际数据。 - -主要用途: -1. 提供开发测试数据 -2. 作为数据格式参考 -3. 用于系统功能验证 -*/ - --- 插入装备基本信息 -INSERT INTO equipment (name, type, manufacturer, target_type) VALUES -('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), -('胜利-2', '火箭炮', '伊朗', '地面固定目标'); - --- 插入巡飞弹技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_speed_kmh, - cruise_speed_kmh, - max_range_km, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm -) VALUES ( - 1, -- 终结者巡飞弹 - 0.56, - 0.15, - 0.20, - 2.72, - 160.93, - 96.56, - 24, - 15, - '破片杀伤战斗部', - '凭自身动力起飞', - 560, - 150, - 200 -); - --- 插入火箭炮技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_range_km -) VALUES ( - 2, -- 胜利-2火箭炮 - 10, - 2.5, - 3.34, - 15000, - 23 -); - --- 插入成本数据(示例数据) -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(1, 1000000), -- 终结者巡飞弹成本 -(2, 5000000); -- 胜利-2火箭炮成本 - --- 插入更多巡飞弹变体数据用于训练 -INSERT INTO equipment (name, type, manufacturer, target_type) VALUES -('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), -('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), -('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'); - --- 插入变体技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_speed_kmh, - cruise_speed_kmh, - max_range_km, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm -) VALUES --- 终结者-A(稍大型号) -(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210), --- 终结者-B(稍小型号) -(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190), --- 终结者-C(标准型号的改进版) -(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200); - --- 插入变体成本数据 -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(3, 1100000), -- 终结者-A成本(较高) -(4, 900000), -- 终结者-B成本(较低) -(5, 1050000); -- 终结者-C成本(中等) - --- 添加更多巡飞弹数据 -INSERT INTO equipment (name, type, manufacturer, target_type) VALUES -('哈比', '巡飞弹', '以色列', '防空系统和雷达站'), -('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'), -('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'), -('弹簧刀', '巡飞弹', '波兰', '装甲目标'), -('彩虹-4', '巡飞弹', '中国', '地面固定目标'); - --- 添加它们的技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_speed_kmh, - cruise_speed_kmh, - max_range_km, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm -) VALUES --- 哈比 -(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600), --- 游荡者 -(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400), --- 凤凰 -(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300), --- 弹簧刀 -(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350), --- 彩虹-4 -(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800); - --- 添加成本数据 -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(6, 800000), -- 哈比 -(7, 500000), -- 游荡者 -(8, 450000), -- 凤凰 -(9, 480000), -- 弹簧刀 -(10, 1500000); -- 彩虹-4 - --- 火箭炮数据 -INSERT INTO equipment (name, type, manufacturer) VALUES -('BM-21', '火箭炮', '俄罗斯'), -('SR5', '火箭炮', '中国'), -('HIMARS', '火箭炮', '美国'), -('LAR-160', '火箭炮', '以色列'), -('T-122', '火箭炮', '土耳其'), -('RM-70', '火箭炮', '捷克'), -('ASTROS II', '火箭炮', '巴西'); - --- 火箭炮通用参数 -INSERT INTO common_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_range_km -) VALUES --- BM-21 -(1, 7.35, 2.4, 3.1, 13700, 20.4), --- SR5 -(2, 10.2, 2.8, 3.2, 28500, 70), --- HIMARS -(3, 7.0, 2.4, 3.2, 16250, 70), --- LAR-160 -(4, 6.7, 2.5, 2.8, 15000, 45), --- T-122 -(5, 7.2, 2.5, 2.9, 18000, 40), --- RM-70 -(6, 7.5, 2.5, 3.0, 17200, 20.3), --- ASTROS II -(7, 8.0, 2.7, 3.1, 24500, 90); - --- 火箭炮特有参数 -INSERT INTO rocket_artillery_params ( - equipment_id, - firing_angle_horizontal, - firing_angle_vertical, - rocket_length_m, - rocket_diameter_mm, - rocket_weight_kg, - rate_of_fire, - combat_weight_kg, - speed_kmh, - min_range_km, - mobility_type, - structure_layout, - engine_model, - engine_params, - power_hp, - travel_range_km -) VALUES --- BM-21 -(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500), --- SR5 -(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650), --- HIMARS -(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480), --- LAR-160 -(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550), --- T-122 -(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600), --- RM-70 -(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520), --- ASTROS II -(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700); - --- 巡飞弹数据 -INSERT INTO equipment (name, type, manufacturer) VALUES -('Hero-120', '巡飞弹', '以色列'), -('Switchblade 600', '巡飞弹', '美国'), -('Warmate', '巡飞弹', '波兰'), -('CH-901', '巡飞弹', '中国'), -('HAROP', '巡飞弹', '以色列'), -('Coyote', '巡飞弹', '美国'), -('WS-43', '巡飞弹', '中国'); - --- 巡飞弹通用参数 -INSERT INTO common_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_range_km -) VALUES --- Hero-120 -(8, 1.3, 0.23, 0.23, 12.5, 40), --- Switchblade 600 -(9, 1.3, 0.22, 0.22, 15.0, 40), --- Warmate -(10, 1.1, 0.15, 0.15, 5.7, 15), --- CH-901 -(11, 1.2, 0.18, 0.18, 9.0, 20), --- HAROP -(12, 2.5, 0.43, 0.43, 135, 1000), --- Coyote -(13, 0.9, 0.12, 0.12, 5.9, 20), --- WS-43 -(14, 1.8, 0.35, 0.35, 20, 60); - --- 巡飞弹特有参数 -INSERT INTO loitering_munition_params ( - equipment_id, - wingspan_m, - warhead_weight_kg, - max_speed_ms, - cruise_speed_kmh, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm, - power_system, - guidance_system -) VALUES --- Hero-120 -(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'), --- Switchblade 600 -(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'), --- Warmate -(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'), --- CH-901 -(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'), --- HAROP -(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'), --- Coyote -(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'), --- WS-43 -(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电'); - --- 插入成本数据(示例成本) -INSERT INTO cost_data (equipment_id, actual_cost) VALUES --- 火箭炮 -(1, 800000), -- BM-21 -(2, 4500000), -- SR5 -(3, 5500000), -- HIMARS -(4, 3500000), -- LAR-160 -(5, 2800000), -- T-122 -(6, 1500000), -- RM-70 -(7, 4800000), -- ASTROS II --- 巡飞弹 -(8, 150000), -- Hero-120 -(9, 180000), -- Switchblade 600 -(10, 80000), -- Warmate -(11, 100000), -- CH-901 -(12, 850000), -- HAROP -(13, 75000), -- Coyote -(14, 120000); -- WS-43 - --- 创建初始数据集 -INSERT INTO datasets (name, description, equipment_type, purpose) VALUES -('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'), -('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'), -('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'), -('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证'); - --- 关联装备到数据集 -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES --- 火箭炮训练集 -(1, 1), (1, 2), (1, 3), (1, 4), --- 巡飞弹训练集 -(2, 8), (2, 9), (2, 10), (2, 11), (2, 12), --- 火箭炮验证集 -(3, 5), (3, 6), (3, 7), --- 巡飞弹验证集 -(4, 13), (4, 14); \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/logger.py b/deploy/equipment_cost_system/src/logger.py deleted file mode 100644 index 52bb879..0000000 --- a/deploy/equipment_cost_system/src/logger.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging -import os -from datetime import datetime - -def setup_logger(name): - """ - 创建并配置logger - """ - # 创建logger - logger = logging.getLogger(name) - - # 如果logger已经有处理器,直接返回 - if logger.handlers: - return logger - - # 设置日志级别 - logger.setLevel(logging.INFO) - - # 确保日志目录存在 - os.makedirs('logs', exist_ok=True) - - # 创建文件处理器 - file_handler = logging.FileHandler('logs/api.log') - file_handler.setLevel(logging.INFO) - - # 创建格式化器 - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - file_handler.setFormatter(formatter) - - # 添加处理器 - logger.addHandler(file_handler) - - return logger \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/model_trainer.py b/deploy/equipment_cost_system/src/model_trainer.py deleted file mode 100644 index 660c38e..0000000 --- a/deploy/equipment_cost_system/src/model_trainer.py +++ /dev/null @@ -1,612 +0,0 @@ -import numpy as np -import pandas as pd -from sklearn.preprocessing import StandardScaler -from sklearn.model_selection import cross_val_score -from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor -from sklearn.impute import SimpleImputer -import xgboost as xgb -import lightgbm as lgb -import logging -import joblib -import os -from src.feature_analysis import FeatureAnalysis -from datetime import datetime -import json -from src.database import get_db_connection -from src.data_preparation import DataPreparation -from sklearn.cross_decomposition import PLSRegression -from .logger import setup_logger - -logger = setup_logger(__name__) - -class ModelTrainer: - def __init__(self): - """ - 初始化 ModelTrainer - """ - self.models = { - 'xgboost': self._create_xgboost_model(), - 'lightgbm': self._create_lightgbm_model(), - 'gbm': self._create_gbm_model(), - 'rf': self._create_rf_model(), - 'pls': self._create_pls_model() - } - self.best_model = None - self.imputer = SimpleImputer(strategy='mean') - self.feature_scaler = None - self.target_scaler = None - self.equipment_type = None - self.feature_analyzer = FeatureAnalysis() - - def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None): - """ - 训练模型并返回评估结果 - """ - try: - self.equipment_type = equipment_type - logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}") - logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}") - - results = {} - best_score = -float('inf') - best_model_info = None - - # 首先训练 PLS 模型 - logger.info("Training pls...") - pls_model = self.models['pls'] - pls_model.fit(X_train, y_train) - pls_metrics = self._calculate_metrics( - pls_model, - X_train, y_train, - X_val, y_val - ) - results['pls'] = pls_metrics - - # 训练其他机器学习模型 - for model_name in model_names: - if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了 - continue - - if model_name not in self.models: - logger.warning(f"Unknown model: {model_name}") - continue - - logger.info(f"Training {model_name}...") - model = self.models[model_name] - - # 训练模型 - model.fit(X_train, y_train) - - # 计算评估指标 - metrics = self._calculate_metrics( - model, - X_train, y_train, - X_val, y_val - ) - - results[model_name] = metrics - - # 更新最佳模型(只在机器学习模型中比较) - if metrics['validation']['r2'] > best_score: - best_score = metrics['validation']['r2'] - best_model_info = { - 'type': model_name, - 'r2': metrics['validation']['r2'], - 'mae': metrics['validation']['mae'], - 'rmse': metrics['validation']['rmse'] - } - self.best_model = model - - # 保存最佳模型和 PLS 模型 - if equipment_type and best_model_info: - self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val) - - return { - 'metrics': results, - 'best_model': best_model_info - } - - except Exception as e: - logger.error(f"Error in model training: {str(e)}") - raise - - def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None): - """ - 计算模型评估指标 - """ - from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error - - # 训练集评估 - train_pred = model.predict(X_train) - - # 如果使用了标准化,需要转换回原始范围 - if hasattr(self, 'target_scaler'): - train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel() - y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel() - else: - y_train_orig = y_train - - # 记录预测范围 - logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}") - logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}") - - train_metrics = { - 'r2': r2_score(y_train_orig, train_pred), - 'mae': mean_absolute_error(y_train_orig, train_pred), - 'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred)) - } - - # 验证集评估 - if X_val is not None and y_val is not None: - val_pred = model.predict(X_val) - - # 如果使用了标准化,需要转换回原始范围 - if hasattr(self, 'target_scaler'): - val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel() - y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel() - else: - y_val_orig = y_val - - # 记录预测范围 - logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}") - logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}") - - val_metrics = { - 'r2': r2_score(y_val_orig, val_pred), - 'mae': mean_absolute_error(y_val_orig, val_pred), - 'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred)) - } - else: - # 使用交叉验证 - cv_scores = cross_val_score(model, X_train, y_train, cv=5) - val_metrics = { - 'r2': cv_scores.mean(), - 'mae': None, - 'rmse': None - } - - return { - 'train': train_metrics, - 'validation': val_metrics - } - - def _create_xgboost_model(self): - """ - 创建 XGBoost 模型,增强正则化 - """ - return xgb.XGBRegressor( - n_estimators=50, # 减少树的数量 - learning_rate=0.05, # 学习率 - max_depth=3, # 减小树的深 - min_child_weight=3, # 增加节点权重 - subsample=0.7, # 减小样本采样比例 - colsample_bytree=0.7, # 减小特征采样比例 - reg_alpha=0.1, # L1 正则化 - reg_lambda=1, # L2 正则化 - random_state=42 - ) - - def _create_lightgbm_model(self): - """ - 创建 LightGBM 模型,增强正则化 - """ - return lgb.LGBMRegressor( - n_estimators=50, - learning_rate=0.05, - max_depth=3, - num_leaves=7, - min_data_in_leaf=3, - min_sum_hessian_in_leaf=1e-3, - subsample=0.7, - colsample_bytree=0.7, - reg_alpha=0.1, - reg_lambda=1, - random_state=42, - verbose=-1 - ) - - def _create_gbm_model(self): - """ - 创建 GBM 模型,增强正则化以减轻过拟合 - """ - return GradientBoostingRegressor( - n_estimators=100, - learning_rate=0.1, - max_depth=3, - random_state=42, - subsample=0.8, - min_samples_split=3, - min_samples_leaf=2 - ) - - def _create_rf_model(self): - """ - 创建随机森林模型,针对小样本数据调整参数 - """ - return RandomForestRegressor( - n_estimators=100, - max_depth=3, - random_state=42, - min_samples_split=3, - min_samples_leaf=2 - ) - - def _create_pls_model(self): - """ - 创建 PLS 模型,优化参数配置 - """ - return PLSRegression( - n_components=2, # 减少主成分数量,从5减到2 - scale=True, # 保持数据标准化 - max_iter=500, # 减少最大迭代次数,避免过拟合 - tol=1e-6 # 降低收敛精度,避免过拟合 - ) - - def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None): - """ - 保存最佳模型和 PLS 模型 - """ - try: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - model_dir = 'models' - os.makedirs(model_dir, exist_ok=True) - - # 1. 保存最佳机器学习模型 - model_path = f'{model_dir}/{equipment_type}_{timestamp}' - if isinstance(self.best_model, xgb.XGBRegressor): - self.best_model.save_model(f'{model_path}.json') - model_format = 'json' - else: - joblib.dump(self.best_model, f'{model_path}.joblib') - model_format = 'joblib' - - # 2. 保存 PLS 模型 - pls_model = self.models['pls'] - pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib' - joblib.dump(pls_model, pls_path) - - # 3. 保存标准化器 - scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib' - joblib.dump({ - 'feature_scaler': self.feature_scaler, - 'target_scaler': self.target_scaler - }, scaler_path) - - logger.info(f"Saved best model to {model_path}.{model_format}") - logger.info(f"Saved PLS model to {pls_path}") - logger.info(f"Saved scalers to {scaler_path}") - - # 4. 更新数据库中的模型记录 - with get_db_connection() as conn: - cursor = conn.cursor() - - # 将所有模型设置为非激活 - cursor.execute(""" - UPDATE trained_models - SET is_active = FALSE - WHERE equipment_type = %s - """, (equipment_type,)) - - # 获取 PLS 模型的评估指标 - pls_metrics = self._calculate_metrics( - self.models['pls'], - X_train, - y_train, - X_val, - y_val - ) - - # 保存最佳机器学习模型记录 - self.best_model.equipment_type = equipment_type # 设置装备类型 - ml_feature_importance = self._get_feature_importance(self.best_model) - - cursor.execute(""" - INSERT INTO trained_models ( - model_name, model_type, equipment_type, model_path, scaler_path, - r2_score, mae, rmse, feature_importance, training_data_size, - training_date, is_active, created_by - ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s) - """, ( - f"{equipment_type}_{timestamp}", # model_name - best_model_info['type'], # model_type - equipment_type, # equipment_type - f"{model_path}.{model_format}", # model_path - scaler_path, # scaler_path - best_model_info['r2'], # r2_score - best_model_info['mae'], # mae - best_model_info['rmse'], # rmse - json.dumps(ml_feature_importance), # feature_importance - len(X_train), # training_data_size - 'system' # created_by - )) - - # 保存 PLS 模型记录 - pls_model.equipment_type = equipment_type # 设置装备类型 - pls_feature_importance = self._get_feature_importance(pls_model) - - cursor.execute(""" - INSERT INTO trained_models ( - model_name, model_type, equipment_type, model_path, scaler_path, - r2_score, mae, rmse, feature_importance, training_data_size, - training_date, is_active, created_by - ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s) - """, ( - f"{equipment_type}_{timestamp}_pls", # model_name - 'pls', # model_type - equipment_type, # equipment_type - pls_path, # model_path - scaler_path, # scaler_path - float(pls_metrics['validation']['r2']), # r2_score - float(pls_metrics['validation']['mae']), # mae - float(pls_metrics['validation']['rmse']), # rmse - json.dumps(pls_feature_importance), # feature_importance - len(X_train), # training_data_size - 'system' # created_by - )) - - conn.commit() - - except Exception as e: - logger.error(f"Error saving models: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - raise - - def load_model(self, equipment_type, model_type='ml'): - """ - 加载已训练的模型 - """ - try: - logger.info(f"Loading {model_type} model for {equipment_type}") - - # 从数据库获取激活的模型 - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - - # 构建查询语句 - if model_type == 'pls': - query = """ - SELECT * FROM trained_models - WHERE equipment_type = %s - AND model_type = 'pls' - AND is_active = TRUE - LIMIT 1 - """ - params = (equipment_type,) - else: - query = """ - SELECT * FROM trained_models - WHERE equipment_type = %s - AND model_type != 'pls' - AND is_active = TRUE - LIMIT 1 - """ - params = (equipment_type,) - - # 记录查询信息 - logger.info(f"Executing query: {query}") - logger.info(f"Query params: {params}") - - cursor.execute(query, params) - model_record = cursor.fetchone() - - # 记录查询结果 - if model_record: - logger.info(f"Found model record: {model_record}") - else: - logger.warning(f"No active model found for type {model_type}") - return False - - # 检查文件是否存在 - logger.info(f"Checking model file: {model_record['model_path']}") - logger.info(f"Checking scaler file: {model_record['scaler_path']}") - - if not os.path.exists(model_record['model_path']): - logger.error(f"Model file not found: {model_record['model_path']}") - raise FileNotFoundError(f"Model file not found: {model_record['model_path']}") - - if not os.path.exists(model_record['scaler_path']): - logger.error(f"Scaler file not found: {model_record['scaler_path']}") - raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}") - - # 加载模型文件 - logger.info(f"Loading model from {model_record['model_path']}") - if model_type == 'pls': - self.best_model = joblib.load(model_record['model_path']) - logger.info("Loaded PLS model") - else: - if model_record['model_type'] == 'xgboost': - self.best_model = xgb.XGBRegressor() - self.best_model.load_model(model_record['model_path']) - logger.info("Loaded XGBoost model") - else: - self.best_model = joblib.load(model_record['model_path']) - logger.info(f"Loaded {model_record['model_type']} model") - - # 加载标准化器 - logger.info(f"Loading scalers from {model_record['scaler_path']}") - scalers = joblib.load(model_record['scaler_path']) - self.feature_scaler = scalers['feature_scaler'] - self.target_scaler = scalers['target_scaler'] - logger.info("Loaded scalers successfully") - - return True - - except Exception as e: - logger.error(f"Error loading model: {str(e)}") - logger.error(f"Detailed traceback:", exc_info=True) - return False - - def predict(self, features): - """ - 使用加载的模型进行预测 - """ - try: - if not self.best_model: - raise ValueError("No model loaded") - - if not self.feature_scaler: - raise ValueError("Feature scaler not loaded") - - if not self.target_scaler: - raise ValueError("Target scaler not loaded") - - logger.info("Starting prediction") - logger.info(f"Input features shape: {features.shape}") - logger.info(f"Input features: \n{features}") - - # 处理缺失值 - features_filled = np.array(features, dtype=float) - features_filled[np.isnan(features_filled)] = 0 - features_filled = np.nan_to_num(features_filled, 0) - - logger.info(f"Filled features: \n{features_filled}") - - # 标准化特征 - X = self.feature_scaler.transform(features_filled) - logger.info(f"Transformed features shape: {X.shape}") - logger.info(f"Transformed features: \n{X}") - - # 预测 - y_pred_scaled = self.best_model.predict(X) - logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}") - logger.info(f"Scaled prediction: {y_pred_scaled}") - - # ��标准化 - y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1)) - logger.info(f"Final prediction shape: {y_pred.shape}") - logger.info(f"Final prediction: {y_pred}") - - # 记录标准化器的参数 - logger.info("Target scaler params:") - logger.info(f"Mean: {self.target_scaler.mean_}") - logger.info(f"Scale: {self.target_scaler.scale_}") - - return y_pred.ravel() - - except Exception as e: - logger.error(f"Error in prediction: {str(e)}") - raise - - def _get_feature_importance(self, model): - """ - 获取特征重要性 - """ - try: - if not model: - return {} - - # 获取特征名称 - feature_analyzer = FeatureAnalysis() - feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type) - - # 获取特���重要性 - if hasattr(model, 'feature_importances_'): - importances = model.feature_importances_ - elif hasattr(model, 'coef_'): - if len(model.coef_.shape) > 1: # 如果是二维数组 - importances = np.abs(model.coef_[0]) # 取第一行 - else: - importances = np.abs(model.coef_) - else: - return {} - - # 创建特征重要性字典 - importance_dict = {} - for name, importance in zip(feature_names, importances): - importance_dict[name] = float(importance) # 确保转换为 Python 标量 - - # 按重要性降序排序 - sorted_dict = dict(sorted( - importance_dict.items(), - key=lambda x: x[1], - reverse=True - )) - - # 过滤掉重要性为0的特征 - return {k: v for k, v in sorted_dict.items() if v > 0} - - except Exception as e: - logger.error(f"Error getting feature importance: {str(e)}") - return {} - - def _calculate_confidence_interval(self, prediction, confidence=0.95): - """ - 计算预测值的置信区间 - """ - try: - # 使用预测值的20%作为标准差(增加不确定性) - std = abs(prediction) * 0.2 - - # 计算置信区间 - from scipy import stats - interval = stats.norm.interval(confidence, loc=prediction, scale=std) - - # 确保区间值为正数且合理 - lower = max(1000, interval[0]) # 最小值设为1000元 - upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20% - - logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]") - - return [lower, upper] - - except Exception as e: - logger.error(f"Error calculating confidence interval: {str(e)}") - # 如果计算失败,返回基于20%的简单区间 - lower = max(1000, prediction * 0.8) - upper = prediction * 1.2 - return [lower, upper] - - def get_model_type(self): - """ - 获取当前模型的类型 - """ - if isinstance(self.best_model, xgb.XGBRegressor): - return 'xgboost' - elif isinstance(self.best_model, lgb.LGBMRegressor): - return 'lightgbm' - elif isinstance(self.best_model, GradientBoostingRegressor): - return 'gbm' - elif isinstance(self.best_model, RandomForestRegressor): - return 'rf' - else: - return 'unknown' - - def _get_pls_feature_importance(self): - """ - 获取 PLS 模型的特征重要性 - """ - try: - if not self.models['pls']: - return {} - - # 获取特征名称 - feature_analyzer = FeatureAnalysis() - feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type) - - # 获取 PLS 模型的系数作为特征重要性 - pls_model = self.models['pls'] - if hasattr(pls_model, 'coef_'): - # 使用绝对值作为重要性指标 - importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组 - else: - return {} - - # 创建特征重要性字典 - importance_dict = {} - for name, importance in zip(feature_names, importances): - importance_dict[name] = float(importance) # 确保转换为 Python 标量 - - # 按重要性降序排序 - sorted_dict = dict(sorted( - importance_dict.items(), - key=lambda x: x[1], - reverse=True - )) - - # 过滤掉重要性为0的特征 - return {k: v for k, v in sorted_dict.items() if v > 0} - - except Exception as e: - logger.error(f"Error getting PLS feature importance: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - return {} \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/routes.py b/deploy/equipment_cost_system/src/routes.py deleted file mode 100644 index df4e825..0000000 --- a/deploy/equipment_cost_system/src/routes.py +++ /dev/null @@ -1,1254 +0,0 @@ -from flask import Blueprint, request, jsonify, send_file -from .cost_prediction import CostPredictor -from .feature_analysis import FeatureAnalysis -import pandas as pd -from datetime import datetime -import numpy as np -import mysql.connector -from sklearn.metrics import mean_absolute_error -from .create_template import create_excel_template -import json -import os -from .data_preparation import DataPreparation -from .model_trainer import ModelTrainer -from .logger import setup_logger - -# 创建蓝图 -api_bp = Blueprint('api', __name__) - -# 获取logger -logger = setup_logger(__name__) - -@api_bp.route('/', methods=['GET']) -def index(): - """ - API根路由 - 返回API版本信息和可用端点列表 - """ - return jsonify({ - 'name': '装备成本估算系统 API', - 'version': '1.0.0', - 'endpoints': { - 'predict': { - 'url': '/api/predict', - 'method': 'POST', - 'description': '成本预测' - }, - 'analyze-features': { - 'url': '/api/analyze-features', - 'method': 'POST', - 'description': '特征分析' - }, - 'train': { - 'url': '/api/train', - 'method': 'POST', - 'description': '模型训练' - }, - 'evaluate': { - 'url': '/api/evaluate', - 'method': 'POST', - 'description': '模型评估' - } - } - }) - -@api_bp.route('/predict', methods=['POST']) -def predict_cost(): - """ - 成本预测接口 - """ - try: - data = request.get_json() - logger.info(f"Received prediction request for equipment type: {data.get('type')}") - - # 验证装备类型 - if 'type' not in data: - return jsonify({'error': 'Equipment type is required'}), 400 - - # 预测成本 - predictor = CostPredictor() - result = predictor.predict(data) - - # 获取当前使用的模型信息 - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT model_type, model_name, r2_score, mae, rmse - FROM trained_models - WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE - LIMIT 1 - """, (data['type'],)) - model_info = cursor.fetchone() - - # 在结果中添加模型信息 - result.update({ - 'model_info': { - 'type': model_info['model_type'], - 'name': model_info['model_name'], - 'r2_score': float(model_info['r2_score']), - 'mae': float(model_info['mae']), - 'rmse': float(model_info['rmse']) - } - }) - - logger.info(f"Prediction completed: {result}") - return jsonify(result) - - except Exception as e: - logger.error(f"Error in prediction: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/analyze-features', methods=['POST']) -def analyze_features(): - """ - 基于数据集进行特征分析 - """ - try: - data = request.get_json() - dataset_id = data.get('dataset_id') - - logger.info(f"Starting feature analysis for dataset {dataset_id}") - - if not dataset_id: - logger.warning("No dataset_id provided") - return jsonify({'error': '请选择数据集'}), 400 - - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - - # 获取数据集信息 - cursor.execute(""" - SELECT d.*, - e.type as equipment_type - FROM datasets d - JOIN dataset_equipment de ON d.id = de.dataset_id - JOIN equipment e ON de.equipment_id = e.id - WHERE d.id = %s - LIMIT 1 - """, (dataset_id,)) - dataset = cursor.fetchone() - - if not dataset: - logger.warning(f"Dataset {dataset_id} not found") - return jsonify({'error': '数据集不存在'}), 404 - - logger.info(f"Dataset info: {dataset}") - - # 创建特征分析实例 - from src.feature_analysis import FeatureAnalysis - analyzer = FeatureAnalysis() - - # 获取特征列表 - feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type']) - logger.info(f"Feature names: {feature_names}") - - # 获取数据集中的装备数据 - if dataset['equipment_type'] == '火箭炮': - cursor.execute(""" - SELECT e.*, cp.*, rap.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - AND cd.actual_cost IS NOT NULL - """, (dataset_id,)) - else: - cursor.execute(""" - SELECT e.*, cp.*, lmp.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - AND cd.actual_cost IS NOT NULL - """, (dataset_id,)) - - equipment_data = cursor.fetchall() - logger.info(f"Found {len(equipment_data)} equipment records") - - if not equipment_data: - logger.warning("No valid equipment data found in dataset") - return jsonify({'error': '数据集没有有效的成本数据'}), 400 - - # 统计每个特征的缺失率 - missing_rates = {} - for name in feature_names: - missing_count = sum(1 for item in equipment_data if item.get(name) is None) - missing_rate = missing_count / len(equipment_data) - missing_rates[name] = missing_rate - logger.info(f"Feature {name} missing rate: {missing_rate:.2%}") - - # 过滤掉缺失率过高的特征 - valid_features = [name for name in feature_names if missing_rates[name] < 0.7] - logger.info(f"Valid features after filtering: {valid_features}") - - if len(valid_features) < 3: # 至少需要3个特征 - return jsonify({'error': '有效特征数量不足'}), 400 - - # 计算每个特征的均值 - feature_means = {} - for name in valid_features: - values = [float(item[name]) for item in equipment_data if item.get(name) is not None] - feature_means[name] = sum(values) / len(values) if values else 0 - logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}") - - # 准备特征和目标值 - features = [] - target = [] - - # 提取特征和目标值,使用均值填充缺失值 - for item in equipment_data: - feature_values = [] - for name in valid_features: - value = item.get(name) - try: - # 确保数值类型转换正确 - feature_values.append(float(value) if value is not None else feature_means[name]) - except (ValueError, TypeError) as e: - logger.error(f"Error converting value for feature {name}: {value}") - logger.error(f"Error details: {str(e)}") - return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400 - features.append(feature_values) - - # 确保成本值是值类型 - try: - target.append(float(item['actual_cost'])) - except (ValueError, TypeError) as e: - logger.error(f"Error converting actual_cost: {item['actual_cost']}") - logger.error(f"Error details: {str(e)}") - return jsonify({'error': '成本值无法换为数值'}), 400 - - logger.info(f"Prepared {len(features)} feature vectors") - logger.info(f"First feature vector: {features[0] if features else None}") - logger.info(f"First target value: {target[0] if target else None}") - - # 调用特征分析方法 - result = analyzer.analyze_features(features, target, valid_features) - logger.info("Analysis completed successfully") - - return jsonify(result) - - except Exception as e: - logger.error(f"Error analyzing features: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/train', methods=['POST']) -def train_model(): - """ - 训练模型 - """ - try: - data = request.get_json() - logger.info(f"Starting model training for {data.get('type')}") - equipment_type = data.get('type') - train_dataset_id = data.get('train_dataset_id') - validation_dataset_id = data.get('validation_dataset_id') - models = data.get('models', []) - - logger.info(f"Training dataset: {train_dataset_id}") - logger.info(f"Validation dataset: {validation_dataset_id}") - logger.info(f"Selected models: {models}") - - # 获取训练数据 - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - - # 获取训练集数据 - if equipment_type == '火箭炮': - cursor.execute(""" - SELECT e.*, cp.*, rap.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - AND cd.actual_cost IS NOT NULL - """, (train_dataset_id,)) - else: - cursor.execute(""" - SELECT e.*, cp.*, lmp.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - AND cd.actual_cost IS NOT NULL - """, (train_dataset_id,)) - - train_data = cursor.fetchall() - - # 获取验证集数据(如果有) - validation_data = None - if validation_dataset_id: - if equipment_type == '火箭炮': - cursor.execute(""" - SELECT e.*, cp.*, rap.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - AND cd.actual_cost IS NOT NULL - """, (validation_dataset_id,)) - else: - cursor.execute(""" - SELECT e.*, cp.*, lmp.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - AND cd.actual_cost IS NOT NULL - """, (validation_dataset_id,)) - validation_data = cursor.fetchall() - - if not train_data: - return jsonify({'error': '训练数据集为空'}), 400 - - # 1. 准备数据 - data_processor = DataPreparation() - - # 准备训练数据 - train_prepared = data_processor.prepare_training_data(train_data, equipment_type) - - # 准备验证数据(如果有) - validation_prepared = None - if validation_data: - validation_prepared = data_processor.prepare_validation_data( - validation_data, - equipment_type, - train_prepared['feature_names'], - { - 'feature_scaler': train_prepared['feature_scaler'], - 'target_scaler': train_prepared['target_scaler'] - } - ) - - # 2. 训练模型 - model_trainer = ModelTrainer() - model_trainer.feature_scaler = train_prepared['feature_scaler'] - model_trainer.target_scaler = train_prepared['target_scaler'] - - # 执行训练,传入 equipment_type 参数 - training_result = model_trainer.fit_model( - train_prepared['X'], - train_prepared['y'], - models, - validation_prepared['X'] if validation_prepared else None, - validation_prepared['y'] if validation_prepared else None, - equipment_type=equipment_type - ) - - return jsonify(training_result) - - except Exception as e: - logger.error(f"Error in model training: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/evaluate', methods=['POST']) -def evaluate_model(): - """ - 模型评估接口 - """ - try: - data = request.get_json() - logger.info("Received model evaluation request") - - if 'test_data' not in data: - return jsonify({'error': 'Test data is required'}), 400 - - predictor = CostPredictor() - evaluation_result = predictor.evaluate( - data['test_data']['actual'], - data['test_data']['predicted'] - ) - - logger.info("Model evaluation completed") - return jsonify(evaluation_result) - - except Exception as e: - logger.error(f"Error in model evaluation: {str(e)}") - return jsonify({'error': str(e)}), 500 - -def get_required_params(equipment_type): - """ - 根据装备类型获取必要参数 - """ - common_params = [ - 'length_m', - 'width_m', - 'height_m', - 'weight_kg', - 'max_range_km' - ] - - if equipment_type == '火箭炮': - return common_params + [ - 'firing_angle_horizontal', - 'firing_angle_vertical', - 'rocket_length_m', - 'rocket_diameter_mm', - 'rocket_weight_kg' - ] - elif equipment_type == '巡飞弹': - return common_params + [ - 'max_speed_kmh', - 'cruise_speed_kmh', - 'flight_time_min', - 'folded_length_mm', - 'folded_width_mm', - 'folded_height_mm' - ] - - return common_params - -@api_bp.errorhandler(404) -def not_found(error): - return jsonify({'error': 'Not found'}), 404 - -@api_bp.errorhandler(500) -def internal_error(error): - logger.error(f"Internal server error: {str(error)}") - return jsonify({'error': 'Internal server error'}), 500 - -@api_bp.route('/data', methods=['GET']) -def get_equipment_data(): - """ - 获取装备数据 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute('SET SESSION group_concat_max_len = 1000000') - - # 先测试特殊参数查询 - cursor.execute(""" - SELECT equipment_id, param_name, param_value, param_unit - FROM custom_params - WHERE param_name IS NOT NULL - AND param_value IS NOT NULL - LIMIT 5 - """) - test_params = cursor.fetchall() - logger.info(f"Test custom params: {test_params}") - - # 获取火箭炮数据 - logger.info("Fetching rocket artillery data...") - cursor.execute(""" - SELECT - e.id, - e.name, - e.type, - e.manufacturer, - e.created_at, - cp.length_m, - cp.width_m, - cp.height_m, - cp.weight_kg, - cp.max_range_km, - rap.firing_angle_horizontal, - rap.firing_angle_vertical, - rap.rocket_length_m, - rap.rocket_diameter_mm, - rap.rocket_weight_kg, - rap.rate_of_fire, - rap.combat_weight_kg, - rap.speed_kmh, - rap.min_range_km, - rap.mobility_type, - rap.structure_layout, - rap.engine_model, - rap.engine_params, - rap.power_hp, - rap.travel_range_km, - cd.actual_cost, - ( - SELECT COALESCE( - JSON_ARRAYAGG( - JSON_OBJECT( - 'id', csp.id, - 'param_name', csp.param_name, - 'param_value', csp.param_value, - 'param_unit', csp.param_unit, - 'description', csp.description - ) - ), - '[]' - ) - FROM custom_params csp - WHERE csp.equipment_id = e.id - AND csp.param_name IS NOT NULL - AND csp.param_value IS NOT NULL - ) as custom_params - FROM equipment e - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE e.type = '火箭炮' - """) - rocket_artillery = cursor.fetchall() - logger.info(f"Found {len(rocket_artillery)} rocket artillery records") - if rocket_artillery: - logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}") - logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}") - - # 获取巡飞弹数据 - logger.info("Fetching missile data...") - cursor.execute(""" - SELECT - e.id, - e.name, - e.type, - e.manufacturer, - e.created_at, - cp.length_m, - cp.width_m, - cp.height_m, - cp.weight_kg, - cp.max_range_km, - lmp.wingspan_m, - lmp.warhead_weight_kg, - lmp.max_speed_ms, - lmp.cruise_speed_kmh, - lmp.flight_time_min, - lmp.warhead_type, - lmp.launch_mode, - lmp.folded_length_mm, - lmp.folded_width_mm, - lmp.folded_height_mm, - lmp.power_system, - lmp.guidance_system, - cd.actual_cost, - ( - SELECT COALESCE( - JSON_ARRAYAGG( - JSON_OBJECT( - 'id', csp.id, - 'param_name', csp.param_name, - 'param_value', csp.param_value, - 'param_unit', csp.param_unit, - 'description', csp.description - ) - ), - '[]' - ) - FROM custom_params csp - WHERE csp.equipment_id = e.id - AND csp.param_name IS NOT NULL - AND csp.param_value IS NOT NULL - ) as custom_params - FROM equipment e - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE e.type = '巡飞弹' - """) - loitering_munition = cursor.fetchall() - logger.info(f"Found {len(loitering_munition)} missile records") - if loitering_munition: - logger.info(f"First missile: {loitering_munition[0]['name']}") - logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}") - - # 处理 custom_params,保为 NULL - for item in rocket_artillery + loitering_munition: - if item['custom_params'] is None: - item['custom_params'] = [] - logger.debug(f"Set empty custom_params for equipment {item['id']}") - else: - logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params") - - logger.info("Data fetching completed") - return jsonify({ - 'rocket_artillery': rocket_artillery, - 'loitering_munition': loitering_munition - }) - - except Exception as e: - logger.error(f"Error getting equipment data: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/data/', methods=['DELETE']) -def delete_equipment(id): - """ - 删除装备数据 - """ - try: - db = get_db_connection() - cursor = db.cursor() - - # 删除相关数据 - cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,)) - cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,)) - cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,)) - cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,)) - cursor.execute("DELETE FROM equipment WHERE id = %s", (id,)) - - db.commit() - cursor.close() - db.close() - - return jsonify({'status': 'success'}) - - except Exception as e: - logger.error(f"Error deleting equipment: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/data/template', methods=['GET']) -def download_template(): - """ - 下载数据模板 - """ - try: - # 创建模板文件 - from .create_template import create_excel_template - template_path = create_excel_template() - - # 检查文件是否存 - if not os.path.exists(template_path): - raise FileNotFoundError("模板文件不存在") - - # 返回文件 - return send_file( - template_path, - as_attachment=True, - download_name='equipment_data_template.xlsx', - mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' - ) - - except Exception as e: - logger.error(f"Error creating template: {str(e)}") - return jsonify({'error': str(e)}), 500 - -def get_db_connection(): - """ - 获取数据库连接 - """ - return mysql.connector.connect( - host="localhost", - user="root", - password="123456", - database="equipment_cost_db" - ) - -@api_bp.route('/pls/predict', methods=['POST']) -def pls_predict(): - """ - PLS回归预测接口 - """ - try: - data = request.get_json() - logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}") - - # 验证装备类型 - if 'type' not in data: - return jsonify({'error': 'Equipment type is required'}), 400 - - # 使用 ModelTrainer 中的 PLS 模型进行预测 - trainer = ModelTrainer() - if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型 - return jsonify({'error': '未找到可用的模型'}), 404 - - # 准备特征数据 - feature_analyzer = FeatureAnalysis() - features = feature_analyzer.get_equipment_specific_features(data['type']) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - result = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(result[0]) - - # 获取模型信息 - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT model_type, model_name, r2_score, mae, rmse - FROM trained_models - WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE - LIMIT 1 - """, (data['type'],)) - model_info = cursor.fetchone() - - # 确保返回的数据可以序列化为JSON - response = { - 'predicted_cost': float(result[0]), - 'model_info': { - 'type': model_info['model_type'], - 'name': model_info['model_name'], - 'r2_score': model_info['r2_score'], - 'mae': model_info['mae'], - 'rmse': model_info['rmse'] - }, - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } - - logger.info(f"PLS prediction completed: {response}") - return jsonify(response) - - except Exception as e: - logger.error(f"Error in PLS prediction: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/data/import', methods=['POST']) -def import_data(): - """ - 导入数据接口 - """ - try: - if 'file' not in request.files: - return jsonify({'error': '没有上传文件'}), 400 - - file = request.files['file'] - if not file.filename.endswith(('.xls', '.xlsx')): - return jsonify({'error': '请上传Excel文件'}), 400 - - # 保存上的文件 - upload_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') - os.makedirs(upload_dir, exist_ok=True) - file_path = os.path.join(upload_dir, file.filename) - file.save(file_path) - - # 导入数据 - from .import_data import import_training_data - import_training_data(file_path) - - return jsonify({ - 'success': True, - 'message': '数据导入成功' - }) - - except Exception as e: - logger.error(f"Error importing data: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/data/', methods=['PUT']) -def update_equipment(id): - """ - 更新装备数据 - """ - try: - data = request.get_json() - logger.info(f"Updating equipment ID: {id}") - logger.info(f"Update data: {data}") - - with get_db_connection() as conn: - cursor = conn.cursor() - - # 更新基本信息 - cursor.execute(""" - UPDATE equipment - SET name = %s, manufacturer = %s - WHERE id = %s - """, (data['name'], data['manufacturer'], id)) - logger.info("Basic info updated") - - # 更新通用参数 - cursor.execute(""" - UPDATE common_params - SET length_m = %s, width_m = %s, height_m = %s, - weight_kg = %s, max_range_km = %s - WHERE equipment_id = %s - """, ( - data['length_m'], data['width_m'], data['height_m'], - data['weight_kg'], data['max_range_km'], id - )) - logger.info("Common params updated") - - # 根据备类型更新特有参数 - if data['type'] == '火箭炮': - cursor.execute(""" - UPDATE rocket_artillery_params - SET firing_angle_horizontal = %s, firing_angle_vertical = %s, - rocket_length_m = %s, rocket_diameter_mm = %s, - rocket_weight_kg = %s, rate_of_fire = %s - WHERE equipment_id = %s - """, ( - data['firing_angle_horizontal'], data['firing_angle_vertical'], - data['rocket_length_m'], data['rocket_diameter_mm'], - data['rocket_weight_kg'], data['rate_of_fire'], id - )) - logger.info("Rocket artillery params updated") - else: - cursor.execute(""" - UPDATE loitering_munition_params - SET max_speed_ms = %s, cruise_speed_kmh = %s, - flight_time_min = %s, warhead_type = %s, - launch_mode = %s, folded_length_mm = %s, - folded_width_mm = %s, folded_height_mm = %s - WHERE equipment_id = %s - """, ( - data['max_speed_ms'], data['cruise_speed_kmh'], - data['flight_time_min'], data['warhead_type'], - data['launch_mode'], data['folded_length_mm'], - data['folded_width_mm'], data['folded_height_mm'], id - )) - logger.info("Missile params updated") - - # 更新成本数据 - if 'actual_cost' in data: - cursor.execute(""" - UPDATE cost_data - SET actual_cost = %s - WHERE equipment_id = %s - """, (data['actual_cost'], id)) - logger.info("Cost data updated") - - # 更新特殊参数 - if 'custom_params' in data and data['custom_params']: - logger.info(f"Updating custom params: {data['custom_params']}") - for param in data['custom_params']: - cursor.execute(""" - UPDATE custom_params - SET param_value = %s - WHERE id = %s AND equipment_id = %s - """, (param['param_value'], param['id'], id)) - logger.info("Custom params updated") - - conn.commit() - logger.info("All updates committed successfully") - - return jsonify({'success': True}) - - except Exception as e: - logger.error(f"Error updating equipment: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/data/details/', methods=['GET']) -def get_equipment_details(id): - """ - 获取装备详数据 - """ - try: - logger.info(f"Getting details for equipment ID: {id}") - - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - - # 先获取装备类型 - cursor.execute("SELECT type FROM equipment WHERE id = %s", (id,)) - equipment = cursor.fetchone() - - if not equipment: - logger.warning(f"Equipment not found: {id}") - return jsonify({'error': 'Equipment not found'}), 404 - - equipment_type = equipment['type'] - logger.info(f"Equipment type: {equipment_type}") - - # 根据装备类型选择查询 - if equipment_type == '火箭炮': - query = """ - SELECT - e.*, - cp.*, - rap.*, - cd.actual_cost, - cd.prediction_date as cost_estimate_date, - cd.predicted_cost, - ( - SELECT JSON_ARRAYAGG( - CASE - WHEN csp.id IS NOT NULL THEN - JSON_OBJECT( - 'id', csp.id, - 'param_name', csp.param_name, - 'param_value', csp.param_value, - 'param_unit', csp.param_unit, - 'description', csp.description - ) - END - ) - FROM custom_params csp - WHERE csp.equipment_id = e.id - AND csp.param_name IS NOT NULL - AND csp.param_value IS NOT NULL - ) as custom_params - FROM equipment e - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE e.id = %s - """ - else: - query = """ - SELECT - e.*, - cp.*, - lmp.*, - cd.actual_cost, - cd.prediction_date as cost_estimate_date, - cd.predicted_cost, - ( - SELECT JSON_ARRAYAGG( - CASE - WHEN csp.id IS NOT NULL THEN - JSON_OBJECT( - 'id', csp.id, - 'param_name', csp.param_name, - 'param_value', csp.param_value, - 'param_unit', csp.param_unit, - 'description', csp.description - ) - END - ) - FROM custom_params csp - WHERE csp.equipment_id = e.id - AND csp.param_name IS NOT NULL - AND csp.param_value IS NOT NULL - ) as custom_params - FROM equipment e - LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE e.id = %s - """ - - cursor.execute(query, (id,)) - result = cursor.fetchone() - - if result: - logger.info(f"Found equipment details: {result['name']}") - logger.info(f"Custom params: {result.get('custom_params')}") - - return jsonify(result) - - except Exception as e: - logger.error(f"Error getting equipment details: {str(e)}") - return jsonify({'error': str(e)}), 500 - -# 添加数据集相关的路由 -@api_bp.route('/datasets', methods=['GET']) -def get_datasets(): - """ - 获取数据集列表 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT d.*, - COUNT(de.equipment_id) as equipment_count, - GROUP_CONCAT(e.name) as equipment_names - FROM datasets d - LEFT JOIN dataset_equipment de ON d.id = de.dataset_id - LEFT JOIN equipment e ON de.equipment_id = e.id - GROUP BY d.id - """) - datasets = cursor.fetchall() - - # 理装备名称列表 - for dataset in datasets: - if dataset['equipment_names']: - dataset['equipment_names'] = dataset['equipment_names'].split(',') - else: - dataset['equipment_names'] = [] - - return jsonify(datasets) - except Exception as e: - logger.error(f"Error getting datasets: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/datasets/', methods=['GET']) -def get_dataset(id): - """ - 获取数据集详情 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - # 获取数据集基本信息 - cursor.execute(""" - SELECT d.*, - COUNT(de.equipment_id) as equipment_count - FROM datasets d - LEFT JOIN dataset_equipment de ON d.id = de.dataset_id - WHERE d.id = %s - GROUP BY d.id - """, (id,)) - dataset = cursor.fetchone() - - if not dataset: - return jsonify({'error': 'Dataset not found'}), 404 - - # 获取数据集中的装备 - cursor.execute(""" - SELECT e.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id - LEFT JOIN cost_data cd ON e.id = cd.equipment_id - WHERE de.dataset_id = %s - """, (id,)) - equipment = cursor.fetchall() - - # 计算统计信息 - if equipment: - total_cost = sum(item['actual_cost'] or 0 for item in equipment) - avg_cost = total_cost / len(equipment) - dataset['statistics'] = { - 'equipment_count': len(equipment), - 'total_cost': total_cost, - 'average_cost': avg_cost - } - else: - dataset['statistics'] = { - 'equipment_count': 0, - 'total_cost': 0, - 'average_cost': 0 - } - - dataset['equipment'] = equipment - return jsonify(dataset) - except Exception as e: - logger.error(f"Error getting dataset: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/datasets', methods=['POST']) -def create_dataset(): - """ - 建数据集 - """ - try: - data = request.get_json() - with get_db_connection() as conn: - cursor = conn.cursor() - - # 创建数据集 - cursor.execute(""" - INSERT INTO datasets (name, description, equipment_type, purpose) - VALUES (%s, %s, %s, %s) - """, (data['name'], data['description'], data['equipment_type'], data['purpose'])) - - dataset_id = cursor.lastrowid - - # 添加装备关联 - if 'equipment_ids' in data and data['equipment_ids']: - values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']] - cursor.executemany(""" - INSERT INTO dataset_equipment (dataset_id, equipment_id) - VALUES (%s, %s) - """, values) - - conn.commit() - return jsonify({'id': dataset_id, 'message': '数据集创建成功'}) - except Exception as e: - logger.error(f"Error creating dataset: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/datasets/', methods=['PUT']) -def update_dataset(id): - """ - 更新数据集 - """ - try: - data = request.get_json() - with get_db_connection() as conn: - cursor = conn.cursor() - - # 更新数据集基本信息 - cursor.execute(""" - UPDATE datasets - SET name = %s, description = %s, equipment_type = %s, purpose = %s - WHERE id = %s - """, (data['name'], data['description'], data['equipment_type'], data['purpose'], id)) - - # 删除旧的装备关联 - cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) - - # 加新的装备关联 - if 'equipment_ids' in data: - for equipment_id in data['equipment_ids']: - cursor.execute(""" - INSERT INTO dataset_equipment (dataset_id, equipment_id) - VALUES (%s, %s) - """, (id, equipment_id)) - - conn.commit() - return jsonify({'success': True}) - except Exception as e: - logger.error(f"Error updating dataset: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/datasets/', methods=['DELETE']) -def delete_dataset(id): - """ - 删除数据集 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor() - - # 删除装备关联 - cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) - - # 删除数据集 - cursor.execute("DELETE FROM datasets WHERE id = %s", (id,)) - - conn.commit() - return jsonify({'success': True}) - except Exception as e: - logger.error(f"Error deleting dataset: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/models//latest', methods=['GET']) -def get_latest_model(equipment_type): - """ - 获取最新训练的型信息 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT * FROM trained_models - WHERE equipment_type = %s AND is_active = TRUE - ORDER BY training_date DESC LIMIT 1 - """, (equipment_type,)) - - model = cursor.fetchone() - return jsonify(model) - - except Exception as e: - logger.error(f"Error getting latest model: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/models', methods=['GET']) -def get_models(): - """ - 获取模型列表 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT * FROM trained_models - ORDER BY training_date DESC - """) - - models = cursor.fetchall() - - # 确保数值类型字段是 float - for model in models: - if model['r2_score'] is not None: - model['r2_score'] = float(model['r2_score']) - if model['mae'] is not None: - model['mae'] = float(model['mae']) - if model['rmse'] is not None: - model['rmse'] = float(model['rmse']) - - # 解析特征重要性 - if model['feature_importance']: - model['feature_importance'] = json.loads(model['feature_importance']) - - return jsonify(models) - - except Exception as e: - logger.error(f"Error getting models: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/models//activate', methods=['POST']) -def activate_model(id): - """ - 激活指定的模型 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor() - - # 获取模型信息 - cursor.execute(""" - SELECT equipment_type FROM trained_models - WHERE id = %s - """, (id,)) - model = cursor.fetchone() - - if not model: - return jsonify({'error': 'Model not found'}), 404 - - # 将同类型的其他模型设置为非激活 - cursor.execute(""" - UPDATE trained_models - SET is_active = FALSE - WHERE equipment_type = %s - """, (model[0],)) - - # 激活指定模型 - cursor.execute(""" - UPDATE trained_models - SET is_active = TRUE - WHERE id = %s - """, (id,)) - - conn.commit() - return jsonify({'success': True}) - - except Exception as e: - logger.error(f"Error activating model: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/models/', methods=['DELETE']) -def delete_model(id): - """ - 删除指定的模型 - """ - try: - with get_db_connection() as conn: - cursor = conn.cursor() - - # 获取模型文件路径 - cursor.execute(""" - SELECT model_path, scaler_path - FROM trained_models - WHERE id = %s - """, (id,)) - model = cursor.fetchone() - - if not model: - return jsonify({'error': 'Model not found'}), 404 - - # 删除模型文件 - if os.path.exists(model[0]): - os.remove(model[0]) - if os.path.exists(model[1]): - os.remove(model[1]) - - # 删除数据库记录 - cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,)) - conn.commit() - - return jsonify({'success': True}) - - except Exception as e: - logger.error(f"Error deleting model: {str(e)}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route('/predict/all', methods=['POST']) -def predict_all(): - """ - 获取所有机器学习模型的预测结果 - """ - try: - data = request.get_json() - logger.info(f"Received prediction request for all models, equipment type: {data.get('type')}") - - predictor = CostPredictor() - results = predictor.predict_all(data) - - return jsonify(results) - - except Exception as e: - logger.error(f"Error in prediction: {str(e)}") - return jsonify({'error': str(e)}), 500 \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/schema.sql b/deploy/equipment_cost_system/src/schema.sql deleted file mode 100644 index a0abd10..0000000 --- a/deploy/equipment_cost_system/src/schema.sql +++ /dev/null @@ -1,140 +0,0 @@ --- 如果数据库已存在则删除 -DROP DATABASE IF EXISTS equipment_cost_db; - --- 创建数据库 -CREATE DATABASE equipment_cost_db -DEFAULT CHARACTER SET utf8mb4 -COLLATE utf8mb4_unicode_ci; - --- 使用数据库 -USE equipment_cost_db; - --- 装备基本信息表 -CREATE TABLE equipment ( - id INT AUTO_INCREMENT PRIMARY KEY, - name VARCHAR(100), -- 名称 - type VARCHAR(50), -- 类型(火箭炮/巡飞弹) - manufacturer VARCHAR(100), -- 制造商 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 通用参数表 -CREATE TABLE common_params ( - id INT AUTO_INCREMENT PRIMARY KEY, - equipment_id INT, - length_m FLOAT, -- 总长(m) - width_m FLOAT, -- 宽度(m) - height_m FLOAT, -- 高度(m) - weight_kg FLOAT, -- 重量(kg) - max_range_km FLOAT, -- 最大射程(km) - FOREIGN KEY (equipment_id) REFERENCES equipment(id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 火箭炮特有参数表 -CREATE TABLE rocket_artillery_params ( - id INT AUTO_INCREMENT PRIMARY KEY, - equipment_id INT, - firing_angle_horizontal FLOAT, -- 方向射界(度) - firing_angle_vertical FLOAT, -- 高低射界(度) - rocket_length_m FLOAT, -- 火箭弹长度(m) - rocket_diameter_mm FLOAT, -- 弹体直径(mm) - rocket_weight_kg FLOAT, -- 火箭弹重量(kg) - rate_of_fire FLOAT, -- 射速(发/分钟) - combat_weight_kg FLOAT, -- 战斗重量(kg) - speed_kmh FLOAT, -- 速度(km/h) - min_range_km FLOAT, -- 最小射程(km) - mobility_type VARCHAR(50), -- 行走方式 - structure_layout VARCHAR(100), -- 结构布局 - engine_model VARCHAR(100), -- 发动机型号 - engine_params TEXT, -- 发动机参数 - power_hp FLOAT, -- 功率(hp) - travel_range_km FLOAT, -- 行程(km) - FOREIGN KEY (equipment_id) REFERENCES equipment(id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 巡飞弹特有参数表 -CREATE TABLE loitering_munition_params ( - id INT AUTO_INCREMENT PRIMARY KEY, - equipment_id INT, - wingspan_m FLOAT, -- 翼展(m) - warhead_weight_kg FLOAT, -- 战斗部重量(kg) - max_speed_ms FLOAT, -- 最大速度(m/s) - cruise_speed_kmh FLOAT, -- 巡航速度(km/h) - flight_time_min FLOAT, -- 巡飞时间(min) - warhead_type VARCHAR(50), -- 战斗部类型 - launch_mode VARCHAR(50), -- 发射方式 - folded_length_mm FLOAT, -- 折叠长度(mm) - folded_width_mm FLOAT, -- 折叠宽度(mm) - folded_height_mm FLOAT, -- 折叠高度(mm) - power_system VARCHAR(100), -- 动力装置 - guidance_system VARCHAR(100), -- 制导体制 - FOREIGN KEY (equipment_id) REFERENCES equipment(id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 成本数据表 -CREATE TABLE cost_data ( - id INT AUTO_INCREMENT PRIMARY KEY, - equipment_id INT, - actual_cost DECIMAL(15,2), -- 实际成本(元) - predicted_cost DECIMAL(15,2), -- 预测成本(元) - prediction_date TIMESTAMP, -- 预测日期 - FOREIGN KEY (equipment_id) REFERENCES equipment(id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 特殊参数表 -CREATE TABLE custom_params ( - id INT AUTO_INCREMENT PRIMARY KEY, - equipment_id INT, - param_name VARCHAR(100), -- 参数名称 - param_value VARCHAR(255), -- 参数值 - param_unit VARCHAR(50), -- 参数单位 - description TEXT, -- 参数说明 - FOREIGN KEY (equipment_id) REFERENCES equipment(id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 添加索引 -CREATE INDEX idx_equipment_type ON equipment(type); -CREATE INDEX idx_equipment_name ON equipment(name); -CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id); - --- 数据集表 -CREATE TABLE datasets ( - id INT AUTO_INCREMENT PRIMARY KEY, - name VARCHAR(100) NOT NULL, -- 数据集名称 - description TEXT, -- 数据集描述 - equipment_type VARCHAR(50) NOT NULL, -- 装备类型 - purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证) - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP -); - --- 数据集-装备关联表 -CREATE TABLE dataset_equipment ( - dataset_id INT NOT NULL, - equipment_id INT NOT NULL, - PRIMARY KEY (dataset_id, equipment_id), - FOREIGN KEY (dataset_id) REFERENCES datasets(id), - FOREIGN KEY (equipment_id) REFERENCES equipment(id) -); - --- 训练模型表 -CREATE TABLE trained_models ( - id INT AUTO_INCREMENT PRIMARY KEY, - model_name VARCHAR(100) NOT NULL, -- 模型名称 - model_type VARCHAR(50) NOT NULL, -- 模型类型 - equipment_type VARCHAR(50) NOT NULL, -- 装备类型 - model_path VARCHAR(255) NOT NULL, -- 模型文件路径 - scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径 - r2_score FLOAT, -- R²分数 - mae FLOAT, -- 平均绝对误差 - rmse FLOAT, -- 均方根误差 - feature_importance JSON, -- 特征重要性 - training_data_size INT, -- 训练数据量 - training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间 - is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型 - created_by VARCHAR(50) -- 创建者 -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; - --- 添加索引 -CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type); -CREATE INDEX idx_model_active ON trained_models(is_active); \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/test_api.py b/deploy/equipment_cost_system/src/test_api.py deleted file mode 100644 index f68d88d..0000000 --- a/deploy/equipment_cost_system/src/test_api.py +++ /dev/null @@ -1,191 +0,0 @@ -import requests -import json -import logging -from datetime import datetime -import os -import sys - -# 添加项目根目录到 Python 路径 -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.dirname(current_dir) -if project_root not in sys.path: - sys.path.insert(0, project_root) - -# 配置基本日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('logs/test_api.log'), - logging.StreamHandler() - ] -) - -logger = logging.getLogger(__name__) - -def test_api_endpoints(): - """ - 测试 API 各个端点 - """ - base_url = 'http://localhost:5001/api' - - try: - # 1. 测试根路由 - logger.info("\n1. 测试 API 根路由") - response = requests.get(f'{base_url}/') - print_response(response, "API 根路由") - - # 2. 测试机器学习预测接口 - logger.info("\n2. 测试机器学习预测接口") - predict_data = { - "type": "巡飞弹", - "length_m": 1.3, - "width_m": 0.23, - "height_m": 0.23, - "weight_kg": 12.5, - "max_range_km": 40, - "max_speed_ms": 50, - "cruise_speed_kmh": 100, - "flight_time_min": 60, - "folded_length_mm": 1300, - "folded_width_mm": 230, - "folded_height_mm": 230, - "warhead_type": "破片杀伤战斗部", - "launch_mode": "凭自身动力起飞" - } - - response = requests.post( - f'{base_url}/predict', - json=predict_data - ) - print_response(response, "机器学习预测") - - # 3. 测试 PLS 预测接口 - logger.info("\n3. 测试 PLS 预测接口") - response = requests.post( - f'{base_url}/pls/predict', - json=predict_data - ) - print_response(response, "PLS 预测") - - # 4. 测试特征分析接口 - logger.info("\n4. 测试特征分析接口") - analysis_data = { - "dataset_id": 1, # 假设数据集 ID 为 1 - "equipment_type": "巡飞弹" - } - - response = requests.post( - f'{base_url}/analyze-features', - json=analysis_data - ) - print_response(response, "特征分析") - - # 5. 测试机器学习模型训练接口 - logger.info("\n5. 测试机器学习模型训练接口") - training_data = { - "type": "巡飞弹", - "train_dataset_id": 1, # 训练数据集 ID - "validation_dataset_id": 2, # 验证数据集 ID - "models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型 - } - - response = requests.post( - f'{base_url}/train', - json=training_data - ) - print_response(response, "模型训练") - - # 6. 测试数据集相关接口 - logger.info("\n6. 测试数据集相关接口") - - # 6.1 获取数据集列表 - response = requests.get(f'{base_url}/datasets') - print_response(response, "获取数据集列表") - - # 6.2 获取可用的装备列表 - response = requests.get(f'{base_url}/data') - equipment_data = response.json() - - # 获取巡飞弹类型的装备ID - available_equipment_ids = [] - if 'loitering_munition' in equipment_data: - available_equipment_ids = [ - item['id'] - for item in equipment_data['loitering_munition'] - if item['id'] is not None - ][:3] # 取前3个可用的ID - - if not available_equipment_ids: - logger.warning("没有找到可用的装备ID,跳过创建数据集测试") - else: - # 6.3 创建新数据集 - new_dataset = { - "name": "测试数据集", - "description": "用于测试的数据集", - "equipment_type": "巡飞弹", - "purpose": "训练", - "equipment_ids": available_equipment_ids - } - - logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}") - - response = requests.post( - f'{base_url}/datasets', - json=new_dataset - ) - print_response(response, "创建数据集") - - # 7. 测试模型相关接口 - logger.info("\n7. 测试模型相关接口") - - # 7.1 获取模型列表 - response = requests.get(f'{base_url}/models') - print_response(response, "获取模型列表") - - # 7.2 获取最新模型 - response = requests.get(f'{base_url}/models/巡飞弹/latest') - print_response(response, "获取最新模型") - - # 8. 测试多模型预测接口 - logger.info("\n8. 测试多模型预测接口") - response = requests.post( - f'{base_url}/predict/all', - json=predict_data - ) - print_response(response, "多模型预测") - - logger.info("所有测试完成") - - except requests.exceptions.RequestException as e: - logger.error(f"API 请求错误: {str(e)}") - except Exception as e: - logger.error(f"测试过程中出现错误: {str(e)}") - -def print_response(response, endpoint_name): - """ - 打印响应结果 - """ - try: - logger.info(f"\n=== {endpoint_name} 测试结果 ===") - logger.info(f"状态码: {response.status_code}") - - if response.status_code == 200: - result = response.json() - logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}") - else: - logger.error(f"错误响应:\n{response.text}") - - except Exception as e: - logger.error(f"处理响应时出错: {str(e)}") - -if __name__ == "__main__": - try: - # 确保日志目录存在 - os.makedirs('logs', exist_ok=True) - - logger.info(f"=== API 测试开始 - {datetime.now()} ===") - test_api_endpoints() - logger.info(f"=== API 测试结束 - {datetime.now()} ===") - except Exception as e: - logger.error(f"测试执行失败: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/venv b/deploy/equipment_cost_system/venv deleted file mode 100644 index e69de29..0000000 diff --git a/deploy/setup.md b/deploy/setup.md deleted file mode 100644 index b905ab0..0000000 --- a/deploy/setup.md +++ /dev/null @@ -1,34 +0,0 @@ -# 安装说明 - -## 1. 系统要求 - -- Linux 服务器 (推荐 Ubuntu 20.04+) -- Python 3.8+ -- MySQL 8.0+ - -## 2. 安装部署 - -```bash -# 解压部署包 -tar -xzf equipment_cost_prediction.tar.gz -cd equipment_cost_prediction - -# 运行安装脚本 -bash scripts/install.sh - -# 修改配置文件 -vim config/.env - -# 启动服务 -bash scripts/start.sh -``` - -## 3. 验证部署 - -```bash -# 检查服务状态 -curl http://localhost:5001/api/ - -# 检查日志 -tail -f logs/api.log -``` diff --git a/docs/dev/debug.md b/docs/dev/debug.md index ce13431..13c274c 100644 --- a/docs/dev/debug.md +++ b/docs/dev/debug.md @@ -643,3 +643,11 @@ trainingResult.value = null - Feature max_speed_ms missing rate: 77.78% - Feature cruise_speed_kmh missing rate: 61.11% - Feature flight_time_min missing rate: 33.33% + +## 前端特征分析页面未正确显示相关性分析数据(常见问题) + +- 确保相关性分析数据的正确格式化和返回 +- 添加了详细的日志记录 +- 增加了数据验证步骤 +- 处理了可能的NaN值 +- 确保所有数值都转换为Python原生类型(使用float()) diff --git a/frontend/src/views/AnalysisPage.vue b/frontend/src/views/AnalysisPage.vue index 385b995..6e47559 100644 --- a/frontend/src/views/AnalysisPage.vue +++ b/frontend/src/views/AnalysisPage.vue @@ -11,14 +11,23 @@
- + - +

特征重要性

-
+

相关性分析

-
+
+ + +
@@ -82,6 +106,10 @@ const analyzing = ref(false) const analysisResult = ref(null) const importanceChartRef = ref(null) const correlationChartRef = ref(null) +const newFeatureChartRef = ref(null) +const engineChartRef = ref(null) +const newFeatureChart = ref(null) +const engineChart = ref(null) // 图表实例引用 const importanceChart = ref(null) @@ -103,31 +131,74 @@ watch(() => analysisResult.value, async (newResult) => { // 加载可用数据集 const loadDatasets = async (type) => { try { + // 如果没有传入type,直接返回 + if (!type) { + availableDatasets.value = [] + return + } + + console.log('Loading datasets for type:', type) const response = await axios.get(`${API_BASE_URL}/datasets`, { params: { equipment_type: type, purpose: '训练' } }) + + // 验证响应数据 + if (!response.data) { + console.warn('No datasets returned from API') + availableDatasets.value = [] + return + } + + console.log('Datasets loaded:', response.data) availableDatasets.value = response.data } catch (error) { + console.error('Error loading datasets:', error) ElMessage.error('获取数据集列表失败') + availableDatasets.value = [] } } // 处理装备类型变化 const handleEquipmentTypeChange = () => { + console.log('Equipment type changed to:', analysisForm.value.equipment_type) + + // 重置相关状态 analysisForm.value.dataset_id = null selectedDataset.value = null analysisResult.value = null - loadDatasets(analysisForm.value.equipment_type) + + // 只有当有装备类型时才加载数据集 + if (analysisForm.value.equipment_type) { + loadDatasets(analysisForm.value.equipment_type) + } else { + availableDatasets.value = [] + } } // 处理数据集选择变化 const handleDatasetChange = async () => { try { + // 如果没有选择数据集,直接返回 + if (!analysisForm.value.dataset_id) { + selectedDataset.value = null + analysisResult.value = null + return + } + + console.log('Dataset changed to:', analysisForm.value.dataset_id) const response = await axios.get(`${API_BASE_URL}/datasets/${analysisForm.value.dataset_id}`) + + // 验证响应数据 + if (!response.data) { + throw new Error('获取数据集详情失败:服务器返回空数据') + } + selectedDataset.value = response.data analysisResult.value = null } catch (error) { - ElMessage.error('获取数据集详情失败') + console.error('Error getting dataset details:', error) + ElMessage.error(error.message || '获取数据集详情失败') + selectedDataset.value = null } } @@ -140,14 +211,69 @@ const startAnalysis = async () => { analyzing.value = true try { + // 打印请求参数 + console.log('Analysis request params:', { + dataset_id: analysisForm.value.dataset_id, + equipment_type: analysisForm.value.equipment_type + }) + const response = await axios.post(`${API_BASE_URL}/analyze-features`, { dataset_id: analysisForm.value.dataset_id }) + + // 打印原始响应数据 + console.log('Raw API response:', response) + console.log('Response data type:', typeof response.data) + console.log('Response data:', response.data) + + // 检查响应数据的结构 + if (!response.data) { + throw new Error('API返回的数据为空') + } + + // 确保数据正确赋值 analysisResult.value = response.data - console.log('Analysis completed, result:', analysisResult.value) + + // 验证数赋值是否成功 + console.log('Analysis result after assignment:', { + value: analysisResult.value, + important_features: analysisResult.value?.important_features, + correlation_analysis: analysisResult.value?.correlation_analysis, + equipment_names: analysisResult.value?.equipment_names, + length_width_ratio: analysisResult.value?.length_width_ratio + }) + + // 如果是巡飞弹类型,检查特定数据 + if (analysisForm.value.equipment_type === '巡飞弹') { + const missileData = { + equipment_names: analysisResult.value?.equipment_names || [], + length_width_ratio: analysisResult.value?.length_width_ratio || [], + engine_power_kw: analysisResult.value?.engine_power_kw || [], + guidance_system_score: analysisResult.value?.guidance_system_score || [], + warhead_power_score: analysisResult.value?.warhead_power_score || [] + } + + console.log('Missile specific data:', missileData) + + // 验证数据完整性 + const missingFields = Object.entries(missileData) + .filter(([key, value]) => !Array.isArray(value) || value.length === 0) + .map(([key]) => key) + + if (missingFields.length > 0) { + console.warn('Missing or empty missile data fields:', missingFields) + ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`) + } + } + } catch (error) { - ElMessage.error('特征分析失败') console.error('Analysis error:', error) + console.error('Error details:', { + message: error.message, + response: error.response?.data, + status: error.response?.status + }) + ElMessage.error(error.message || '特征析失败') } finally { analyzing.value = false } @@ -166,6 +292,12 @@ const createResizeHandler = () => { if (correlationChart.value && !correlationChart.value.isDisposed()) { correlationChart.value.resize() } + if (newFeatureChart.value && !newFeatureChart.value.isDisposed()) { + newFeatureChart.value.resize() + } + if (engineChart.value && !engineChart.value.isDisposed()) { + engineChart.value.resize() + } } catch (error) { console.error('Error in resize handler:', error) } @@ -180,6 +312,11 @@ onMounted(() => { // 创建并保存 resize 处理函数的引用 resizeHandler.value = createResizeHandler() window.addEventListener('resize', resizeHandler.value) + + // 如果已经选择了装备类型,加载对应的数据集 + if (analysisForm.value.equipment_type) { + loadDatasets(analysisForm.value.equipment_type) + } }) // 组件卸载时 @@ -190,131 +327,353 @@ onUnmounted(() => { resizeHandler.value = null } - // 销毁图表实例 - try { - if (importanceChart.value) { - importanceChart.value.dispose() - importanceChart.value = null + // 销毁所有图表实例 + [importanceChart, correlationChart, newFeatureChart, engineChart].forEach(chart => { + if (chart.value && !chart.value.isDisposed()) { + try { + chart.value.dispose() + } catch (e) { + console.error('Error disposing chart:', e) + } + chart.value = null } - if (correlationChart.value) { - correlationChart.value.dispose() - correlationChart.value = null - } - } catch (error) { - console.error('Error disposing charts:', error) - } + }) }) // 渲染图表 const renderCharts = () => { console.log('Starting to render charts') - // 检查分析结果是否存在且包含必要数据 - if (!analysisResult.value || - !analysisResult.value.important_features || - !analysisResult.value.correlation_analysis) { - console.log('Analysis result not ready') + if (!analysisResult.value) { + console.error('No analysis result available') return } - - // 检查DOM元素 - if (!importanceChartRef.value || !correlationChartRef.value) { - console.log('Chart DOM elements not ready') - return - } - + try { - // 销毁旧的图表实例 - if (importanceChart.value) { - importanceChart.value.dispose() - importanceChart.value = null - } - if (correlationChart.value) { - correlationChart.value.dispose() - correlationChart.value = null - } - - // 创建新的图表实例 - importanceChart.value = echarts.init(importanceChartRef.value) - correlationChart.value = echarts.init(correlationChartRef.value) - - // 设置图表选项 - const importanceOption = { - title: { text: '特征重要性排序' }, - tooltip: {}, - xAxis: { - type: 'value', - name: '重要性得分' - }, - yAxis: { - type: 'category', - data: analysisResult.value.important_features.map(f => f.name) - }, - series: [{ - name: '重要性', - type: 'bar', - data: analysisResult.value.important_features.map(f => f.importance) - }] - } - - const correlationOption = { - title: { text: '特征相关性热力图' }, - tooltip: { - position: 'top', - formatter: function (params) { - const value = params.data[2].toFixed(2) - const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]] - const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]] - return `${feature1} 与 ${feature2} 的相关性: ${value}` - } - }, - grid: { - height: '50%', - top: '10%' - }, - xAxis: { - type: 'category', - data: analysisResult.value.correlation_analysis.features, - splitArea: { show: true }, - axisLabel: { - interval: 0, - rotate: 45 - } - }, - yAxis: { - type: 'category', - data: analysisResult.value.correlation_analysis.features, - splitArea: { show: true } - }, - visualMap: { - min: -1, - max: 1, - calculable: true, - orient: 'horizontal', - left: 'center', - bottom: '15%', - color: ['#cc3333', '#eeeeee', '#00007f'] - }, - series: [{ - name: '相关性', - type: 'heatmap', - data: analysisResult.value.correlation_analysis.matrix, - label: { - show: true, - formatter: function(params) { - return params.data[2].toFixed(2) + // 先销毁所有现有的图表实例 + [importanceChart, correlationChart, newFeatureChart, engineChart].forEach(chart => { + if (chart.value && !chart.value.isDisposed()) { + chart.value.dispose() + chart.value = null + } + }) + + // 等待 DOM 更新 + nextTick(() => { + try { + // 重新初始化基本图表 + if (importanceChartRef.value && correlationChartRef.value) { + importanceChart.value = echarts.init(importanceChartRef.value) + correlationChart.value = echarts.init(correlationChartRef.value) + + // 设置基本图表的选项 + const importanceOption = { + title: { text: '特征重要性排序' }, + tooltip: { + trigger: 'axis', + axisPointer: { + type: 'shadow' + }, + formatter: function(params) { + const data = params[0] + return `${data.name}: ${data.value.toFixed(4)}` + } + }, + xAxis: { + type: 'value', + name: '重要性得分' + }, + yAxis: { + type: 'category', + data: analysisResult.value.important_features.map(f => f.name) + }, + series: [{ + name: '重要性', + type: 'bar', + data: analysisResult.value.important_features.map(f => f.importance), + itemStyle: { + color: '#3a5fcd' + } + }] } + + const correlationOption = { + title: { text: '特征相关性热力图' }, + tooltip: { + position: 'top', + trigger: 'item', + formatter: function (params) { + if (!params.data) return '' + const value = params.data[2].toFixed(2) + const feature1 = analysisResult.value.correlation_analysis.features[params.data[0]] + const feature2 = analysisResult.value.correlation_analysis.features[params.data[1]] + return `${feature1} 与 ${feature2}
相关性: ${value}` + } + }, + grid: { + height: '75%', + top: '10%', + bottom: '15%', + left: '10%', + right: '10%', + containLabel: true + }, + xAxis: { + type: 'category', + data: analysisResult.value.correlation_analysis.features, + splitArea: { show: true }, + axisLabel: { + interval: 0, + rotate: 45, + margin: 15 + } + }, + yAxis: { + type: 'category', + data: analysisResult.value.correlation_analysis.features, + splitArea: { show: true }, + axisLabel: { + interval: 0, + margin: 15 + } + }, + visualMap: { + min: -1, + max: 1, + calculable: true, + orient: 'horizontal', + left: 'center', + bottom: '5%', + inRange: { + color: ['#cc3333', '#eeeeee', '#00007f'] + } + }, + series: [{ + name: '相关性', + type: 'heatmap', + data: analysisResult.value.correlation_analysis.matrix, + emphasis: { + itemStyle: { + shadowBlur: 10, + shadowColor: 'rgba(0, 0, 0, 0.5)' + } + }, + label: { + show: true, + formatter: function(params) { + return params.data[2].toFixed(2) + } + }, + itemStyle: { + borderWidth: 1, + borderColor: '#fff' + } + }] + } + + // 使用 clear 方法清除旧的图表内容 + importanceChart.value.clear() + correlationChart.value.clear() + + // 设置新的选项 + importanceChart.value.setOption(importanceOption, { notMerge: true }) + correlationChart.value.setOption(correlationOption, { notMerge: true }) } - }] - } - - // 设置图表选项 - importanceChart.value.setOption(importanceOption) - correlationChart.value.setOption(correlationOption) - - console.log('Charts rendered successfully') + + // 如果是巡飞弹类型,渲染额外的图表 + if (analysisForm.value.equipment_type === '巡飞弹' && + newFeatureChartRef.value && + engineChartRef.value) { + + // 初始化巡飞弹特有图表 + newFeatureChart.value = echarts.init(newFeatureChartRef.value) + engineChart.value = echarts.init(engineChartRef.value) + + // 准备数据 + const chartData = { + names: analysisResult.value.equipment_names || [], + lengthWidthRatio: analysisResult.value.length_width_ratio || [], + weightRangeRatio: analysisResult.value.weight_range_ratio || [], + speedWeightRatio: analysisResult.value.speed_weight_ratio || [], + guidanceSystemScore: analysisResult.value.guidance_system_score || [], + warheadPowerScore: analysisResult.value.warhead_power_score || [], + enginePowerKw: analysisResult.value.engine_power_kw || [], + engineThrustN: analysisResult.value.engine_thrust_n || [], + minAltitudeM: analysisResult.value.min_altitude_m || [], + maxAltitudeM: analysisResult.value.max_altitude_m || [] + } + + // 特征工程参数分析图表配置 + const newFeatureOption = { + animation: false, // 禁用动画 + title: { + text: '特征工程参数分析', + left: 'center' + }, + tooltip: { + trigger: 'axis', + axisPointer: { + type: 'cross' + }, + formatter: function(params) { + // 获取当前数据点对应的装备名称 + const equipmentName = chartData.names[params[0].dataIndex] + let result = `${equipmentName}
` + // 添加每个系列的数据 + params.forEach(param => { + result += `${param.seriesName}: ${param.value.toFixed(2)}
` + }) + return result + } + }, + legend: { + top: 30, + data: ['长宽比', '重量射程比', '速度重量比', '制导系统评分', '战斗部威力评分'] + }, + grid: { + top: 80, + bottom: 50, + containLabel: true + }, + xAxis: { + type: 'category', + data: chartData.names + }, + yAxis: [ + { + type: 'value', + name: '比率', + position: 'left' + }, + { + type: 'value', + name: '评分', + position: 'right', + min: 0, + max: 10 + } + ], + series: [ + { + name: '长宽比', + type: 'line', + data: chartData.lengthWidthRatio + }, + { + name: '重量射程比', + type: 'line', + data: chartData.weightRangeRatio + }, + { + name: '速度重量比', + type: 'line', + data: chartData.speedWeightRatio + }, + { + name: '制导系统评分', + type: 'bar', + yAxisIndex: 1, + data: chartData.guidanceSystemScore + }, + { + name: '战斗部威力评分', + type: 'bar', + yAxisIndex: 1, + data: chartData.warheadPowerScore + } + ] + } + + // 发动机性能分析图表配置 + const engineOption = { + animation: false, // 禁用动画 + title: { + text: '发动机性能与作战参数分析', + left: 'center' + }, + tooltip: { + trigger: 'axis', + axisPointer: { + type: 'cross' + }, + formatter: function(params) { + // 获取当前数据点对应的装备名称 + const equipmentName = chartData.names[params[0].dataIndex] + let result = `${equipmentName}
` + // 添加每个系列的数据 + params.forEach(param => { + result += `${param.seriesName}: ${param.value.toFixed(2)}
` + }) + return result + } + }, + legend: { + top: 30, + data: ['发动机功率(kw)', '发动机推力(N)', '最小作战高度(m)', '最大作战高度(m)'] + }, + grid: { + top: 80, + bottom: 50, + containLabel: true + }, + xAxis: { + type: 'category', + data: chartData.names + }, + yAxis: [ + { + type: 'value', + name: '功率/推力', + position: 'left' + }, + { + type: 'value', + name: '高度(m)', + position: 'right' + } + ], + series: [ + { + name: '发动机功率(kw)', + type: 'bar', + data: chartData.enginePowerKw + }, + { + name: '发动机推力(N)', + type: 'bar', + data: chartData.engineThrustN + }, + { + name: '最小作战高度(m)', + type: 'line', + yAxisIndex: 1, + data: chartData.minAltitudeM + }, + { + name: '最大作战高度(m)', + type: 'line', + yAxisIndex: 1, + data: chartData.maxAltitudeM + } + ] + } + + // 清除旧的内容 + newFeatureChart.value.clear() + engineChart.value.clear() + + // 设置新的选项 + newFeatureChart.value.setOption(newFeatureOption, { notMerge: true }) + engineChart.value.setOption(engineOption, { notMerge: true }) + } + + console.log('Charts rendered successfully') + } catch (error) { + console.error('Error in chart rendering:', error) + } + }) } catch (error) { - console.error('Error rendering charts:', error) + console.error('Error in renderCharts:', error) } } diff --git a/frontend/src/views/DataPage.vue b/frontend/src/views/DataPage.vue index 4689f70..0a02aa8 100644 --- a/frontend/src/views/DataPage.vue +++ b/frontend/src/views/DataPage.vue @@ -85,23 +85,24 @@ - + - + + + + + + @@ -162,7 +163,7 @@ {{ formatNumber(selectedData?.max_speed_ms) }} {{ formatNumber(selectedData?.cruise_speed_kmh) }} - {{ formatNumber(selectedData?.flight_time_min) }} + {{ formatNumber(selectedData?.endurance_min) }} {{ selectedData?.warhead_type }} {{ selectedData?.launch_mode }} {{ selectedData?.power_system }} @@ -291,8 +292,8 @@ - - + + diff --git a/frontend/src/views/PredictPage.vue b/frontend/src/views/PredictPage.vue index a42fa9f..9b3992c 100644 --- a/frontend/src/views/PredictPage.vue +++ b/frontend/src/views/PredictPage.vue @@ -55,36 +55,56 @@ @@ -154,7 +174,16 @@ const formData = reactive({ width_m: null, height_m: null, weight_kg: null, - max_range_km: null + max_range_km: null, + wingspan_m: null, + warhead_weight_kg: null, + max_speed_ms: null, + cruise_speed_kmh: null, + endurance_min: null, + warhead_type: '', + launch_mode: '', + power_system: '', + guidance_system: '' }) const predictionResults = ref(null) @@ -171,14 +200,15 @@ const handleTypeChange = () => { formData.rocket_weight_kg = null formData.rate_of_fire = null } else if (formData.type === '巡飞弹') { - formData.max_speed_kmh = null + formData.wingspan_m = null + formData.warhead_weight_kg = null + formData.max_speed_ms = null formData.cruise_speed_kmh = null - formData.flight_time_min = null + formData.endurance_min = null formData.warhead_type = '' formData.launch_mode = '' - formData.folded_length_mm = null - formData.folded_width_mm = null - formData.folded_height_mm = null + formData.power_system = '' + formData.guidance_system = '' } } @@ -210,8 +240,8 @@ const submitForm = async () => { } } else if (formData.type === '巡飞弹') { const missileFields = [ - 'max_speed_kmh', 'cruise_speed_kmh', 'flight_time_min', - 'folded_length_mm', 'folded_width_mm', 'folded_height_mm' + 'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh', + 'endurance_min', 'warhead_type', 'launch_mode', 'power_system', 'guidance_system' ] for (const field of missileFields) { if (!formData[field]) { diff --git a/src/feature_analysis.py b/src/feature_analysis.py index 5b87972..51939ca 100644 --- a/src/feature_analysis.py +++ b/src/feature_analysis.py @@ -1,269 +1,192 @@ import numpy as np -import pandas as pd -from scipy import stats +from sklearn.feature_selection import SelectKBest, f_regression from sklearn.preprocessing import StandardScaler -from sklearn.ensemble import RandomForestRegressor -from sklearn.metrics import r2_score import logging -from .logger import setup_logger -logger = setup_logger(__name__) +logger = logging.getLogger(__name__) class FeatureAnalysis: def __init__(self): self.scaler = StandardScaler() - self.important_features = [] - # 添加特征名称映射 - self.feature_names_map = { - # 通用参数 - 'length_m': '总长(m)', + self.feature_name_map = { + 'length_m': '长度(m)', 'width_m': '宽度(m)', 'height_m': '高度(m)', 'weight_kg': '重量(kg)', 'max_range_km': '最大射程(km)', - - # 火箭炮特有参数 - 'firing_angle_horizontal': '方向射界(度)', - 'firing_angle_vertical': '高低射界(度)', - 'rocket_length_m': '火箭弹长度(m)', - 'rocket_diameter_mm': '口径(mm)', - 'rocket_weight_kg': '火箭弹重量(kg)', - 'rate_of_fire': '射速(发/分)', - 'combat_weight_kg': '战斗重量(kg)', - 'speed_kmh': '速度(km/h)', - 'min_range_km': '最小射程(km)', - 'power_hp': '功率(hp)', - - # 火箭炮衍生特征 - 'fire_density': '火力密度', - 'mobility_index': '机动性指标', - 'range_ratio': '射程比', - 'power_weight_ratio': '功重比', - 'volume_density': '体积密度', - - # 巡飞弹特有参数 'wingspan_m': '翼展(m)', - 'warhead_weight_kg': '战斗部重量(kg)', + 'warhead_weight_kg': '弹头重量(kg)', 'max_speed_ms': '最大速度(m/s)', 'cruise_speed_kmh': '巡航速度(km/h)', - 'flight_time_min': '巡飞时间(min)', - 'folded_length_mm': '折叠长度(mm)', - 'folded_width_mm': '折叠宽度(mm)', - 'folded_height_mm': '折叠高度(mm)', - - # 巡飞弹衍生特征 - 'warhead_ratio': '战斗部比重', - 'speed_ratio': '速度比', - 'range_time_ratio': '射程时间比', - 'aspect_ratio': '展弦比', - 'volume_density': '体积密度' + 'endurance_min': '续航时间(min)', + 'payload_weight_kg': '载荷重量(kg)', + 'min_combat_radius_km': '最小作战半径(km)', + 'engine_power_kw': '发动机功率(kw)', + 'engine_thrust_n': '发动机推力(N)', + 'datalink_range_km': '数据链距离(km)', + 'guidance_accuracy_m': '制导精度(m)', + 'min_altitude_m': '最小飞行高度(m)', + 'max_altitude_m': '最大飞行高度(m)', + 'length_width_ratio': '长宽比', + 'weight_range_ratio': '重量射程比', + 'speed_weight_ratio': '速度重量比', + 'guidance_system_score': '制导系统评分', + 'warhead_power_score': '战斗部威力评分', + 'firing_angle_horizontal': '水平射角(°)', + 'firing_angle_vertical': '垂直射角(°)', + 'rocket_length_m': '火箭长度(m)', + 'rocket_diameter_mm': '火箭直径(mm)', + 'rocket_weight_kg': '火箭重量(kg)' } - + def get_equipment_specific_features(self, equipment_type): - """ - 获取特定装备类型的特征列表 - """ - # 通用参数 + """获取特定装备类型的特征列表""" common_features = [ - 'length_m', # 总长(m) - 'width_m', # 宽度(m) - 'height_m', # 高度(m) - 'weight_kg', # 重量(kg) - 'max_range_km' # 最大射程(km) + 'length_m', 'width_m', 'height_m', + 'weight_kg', 'max_range_km' ] if equipment_type == '火箭炮': - # 火箭炮特有参数 - specific_features = [ - 'firing_angle_horizontal', # 方向射界(度) - 'firing_angle_vertical', # 高低射界(度) - 'rocket_length_m', # 火箭弹长度(m) - 'rocket_diameter_mm', # 口径(mm) - 'rocket_weight_kg', # 火箭弹重量(kg) - 'rate_of_fire', # 射速(发/分) - 'combat_weight_kg', # 战斗重量(kg) - 'speed_kmh', # 速度(km/h) - 'min_range_km', # 最小射程(km) - 'power_hp' # 功率(hp) + return common_features + [ + 'firing_angle_horizontal', + 'firing_angle_vertical', + 'rocket_length_m', + 'rocket_diameter_mm', + 'rocket_weight_kg' ] - - # 火箭炮衍生特征 - derived_features = [ - 'fire_density', # 火力密度 = 射速 * 火箭弹重量 - 'mobility_index', # 机动性指标 = 速度 / 战斗重量 - 'range_ratio', # 射程比 = 最大射程 / 最小射程 - 'power_weight_ratio', # 功重比 = 功率 / 战斗重量 - 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高) + elif equipment_type == '巡飞弹': + return common_features + [ + 'wingspan_m', + 'warhead_weight_kg', + 'max_speed_ms', + 'cruise_speed_kmh', + 'endurance_min', + 'payload_weight_kg', + 'min_combat_radius_km', + 'engine_power_kw', + 'engine_thrust_n', + 'datalink_range_km', + 'guidance_accuracy_m', + 'min_altitude_m', + 'max_altitude_m', + 'length_width_ratio', + 'weight_range_ratio', + 'speed_weight_ratio', + 'guidance_system_score', + 'warhead_power_score' ] - - return common_features + specific_features + derived_features - - else: # 巡飞弹 - # 巡飞弹特有参数 - specific_features = [ - 'wingspan_m', # 翼展(m) - 'warhead_weight_kg', # 战斗部重量(kg) - 'max_speed_ms', # 最大速度(m/s) - 'cruise_speed_kmh', # 巡航速度(km/h) - 'flight_time_min', # 巡飞时间(min) - 'folded_length_mm', # 折叠长度(mm) - 'folded_width_mm', # 折叠宽度(mm) - 'folded_height_mm' # 折叠高度(mm) - ] - - # 巡飞弹衍生特征 - derived_features = [ - 'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量 - 'speed_ratio', # 速度比 = 巡航速度 / 最大速度 - 'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间 - 'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积 - 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高) - ] - - return common_features + specific_features + derived_features - - def calculate_derived_features(self, data, equipment_type): - """ - 计算衍生特征 - """ - try: - if equipment_type == '火箭炮': - # 火箭炮衍生特征计算 - if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns: - data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg'] - else: - data['fire_density'] = 0 # 或者其他默认值 - - if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns: - data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg'] - else: - data['mobility_index'] = 0 - - if 'max_range_km' in data.columns and 'min_range_km' in data.columns: - data['range_ratio'] = data['max_range_km'] / data['min_range_km'] - else: - data['range_ratio'] = 0 - - if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns: - data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg'] - else: - data['power_weight_ratio'] = 0 - - if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']): - data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m']) - else: - data['volume_density'] = 0 - - else: # 巡飞弹 - # 巡飞弹衍生特征计算 - if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns: - data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg'] - else: - data['warhead_ratio'] = 0 - - if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns: - data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6) - else: - data['speed_ratio'] = 0 - - if 'max_range_km' in data.columns and 'flight_time_min' in data.columns: - data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min'] - else: - data['range_time_ratio'] = 0 - - if 'wingspan_m' in data.columns and 'length_m' in data.columns: - data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m'] - else: - data['aspect_ratio'] = 0 - - if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']): - data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m']) - else: - data['volume_density'] = 0 - - return data - - except Exception as e: - logger.error(f"Error calculating derived features: {str(e)}") - raise + return common_features def analyze_features(self, features, target, feature_names): - """ - 分析特征重要性和相关性 - """ + """分析特征重要性和相关性""" try: # 转换为numpy数组 - X = np.array(features) - y = np.array(target) + X = np.array(features, dtype=np.float64) # 明确指定数据类型 + y = np.array(target, dtype=np.float64) - # 数据标准化 + # 打印原始数据的统计信息 + logger.info("Feature statistics before scaling:") + for i, name in enumerate(feature_names): + feature_data = X[:, i] + logger.info(f"{self.feature_name_map.get(name, name)}: " + f"min={np.min(feature_data)}, " + f"max={np.max(feature_data)}, " + f"mean={np.mean(feature_data)}, " + f"null_count={np.sum(np.isnan(feature_data))}") + + # 处理可能的无穷大和NaN值 + X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) + + # 标准化特征 X_scaled = self.scaler.fit_transform(X) # 特征重要性分析 - rf = RandomForestRegressor(n_estimators=100, random_state=42) - rf.fit(X_scaled, y) - importances = rf.feature_importances_ + selector = SelectKBest(score_func=f_regression, k='all') + selector.fit(X_scaled, y) + importance_scores = selector.scores_ - # 按重要性排序,使用中文特征名 - importance_indices = np.argsort(importances)[::-1] - important_features = [ - { - 'name': self.feature_names_map.get(feature_names[i], feature_names[i]), - 'importance': float(importances[i]) - } - for i in importance_indices - ] + # 计算相关性矩阵前检查特征的方差 + X = np.array(features, dtype=np.float64) + feature_std = np.std(X, axis=0) + constant_features = [] - # 相关性分析 - df = pd.DataFrame(X_scaled, columns=feature_names) - correlation_matrix = df.corr().values + # 记录标准差为0的特征 + for i, (name, std) in enumerate(zip(feature_names, feature_std)): + if std == 0: + logger.warning(f"Feature '{self.feature_name_map.get(name, name)}' has zero standard deviation " + f"(constant value: {X[0, i]})") + constant_features.append(name) - # 生成相关性分析数据,保留2位小数 + # 计算相关性矩阵 + correlation_matrix = np.corrcoef(X.T) + + # 处理相关性矩阵中的无效值 correlation_data = [] - chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names] + chinese_feature_names = [self.feature_name_map.get(name, name) for name in feature_names] + for i in range(len(feature_names)): for j in range(len(feature_names)): - correlation_data.append([ - i, j, - round(correlation_matrix[i][j], 2) # 修改为保留2位小数 - ]) + corr_value = correlation_matrix[i, j] + if np.isnan(corr_value): + # 如果是常量特征,设置相关系数 + if feature_names[i] in constant_features or feature_names[j] in constant_features: + if i == j: + # 自身相关性设为1 + corr_value = 1.0 + else: + # 与其他特征的相关性设为0 + corr_value = 0.0 + logger.info(f"Setting correlation for constant feature: " + f"{chinese_feature_names[i]} vs {chinese_feature_names[j]} = {corr_value}") + correlation_data.append([i, j, float(corr_value)]) - return { + # 记录数据形状 + logger.info(f"Features shape: {X.shape}") + logger.info(f"Target shape: {y.shape}") + logger.info(f"Correlation matrix shape: {correlation_matrix.shape}") + + # 创建特征重要性列表(使用中文名称) + important_features = [] + for idx, (name, score) in enumerate(zip(feature_names, importance_scores)): + if not np.isnan(score): + important_features.append({ + 'name': self.feature_name_map.get(name, name), # 使用中文名称 + 'importance': float(score) + }) + + # 按重要性排序 + important_features.sort(key=lambda x: x['importance'], reverse=True) + + # 返回结果 + result = { 'important_features': important_features, 'correlation_analysis': { - 'features': chinese_feature_names, # 使用中文特征名 + 'features': chinese_feature_names, # 使用中文名称 'matrix': correlation_data } } - except Exception as e: - logger.error(f"Error in feature analysis: {str(e)}") - raise - - def preprocess_features(self, equipment_data, equipment_type): - """ - 预处理特征数据 - """ - try: - # 转换为 DataFrame - df = pd.DataFrame(equipment_data) + # 添加数据验证 + logger.info("Correlation data validation:") + expected_pairs = len(feature_names) * len(feature_names) + actual_pairs = len(correlation_data) + logger.info(f"Expected correlation pairs: {expected_pairs}") + logger.info(f"Actual correlation pairs: {actual_pairs}") + if expected_pairs != actual_pairs: + logger.warning("Missing correlation pairs detected!") - # 计算衍生特征 - df = self.calculate_derived_features(df, equipment_type) + # 验证返回的数据 + logger.info("Validation of return data:") + logger.info(f"Has important_features: {bool(result['important_features'])}") + logger.info(f"Important features count: {len(result['important_features'])}") + logger.info(f"Has correlation_analysis: {bool(result['correlation_analysis'])}") + logger.info(f"Correlation features count: {len(result['correlation_analysis']['features'])}") + logger.info(f"Correlation matrix size: {len(result['correlation_analysis']['matrix'])}") - # 处理缺失值 - numeric_columns = df.select_dtypes(include=[np.number]).columns - for col in numeric_columns: - # 转换为数值类型 - df[col] = pd.to_numeric(df[col], errors='coerce') - # 使用新的方式填充缺失值 - mean_value = df[col].mean() - df[col] = df[col].fillna(mean_value) - - logger.info(f"Preprocessed data shape: {df.shape}") - return df + return result except Exception as e: - logger.error(f"Error preprocessing features: {str(e)}") - raise Exception(f"Feature preprocessing error: {str(e)}") \ No newline at end of file + logger.error(f"Error in analyze_features: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + raise \ No newline at end of file diff --git a/src/real_data.sql b/src/real_data.sql new file mode 100644 index 0000000..d00ab59 --- /dev/null +++ b/src/real_data.sql @@ -0,0 +1,485 @@ +-- 清空现有数据 +SET FOREIGN_KEY_CHECKS=0; +TRUNCATE TABLE dataset_equipment; +TRUNCATE TABLE datasets; +TRUNCATE TABLE cost_data; +TRUNCATE TABLE loitering_munition_params; +TRUNCATE TABLE common_params; +TRUNCATE TABLE equipment; +SET FOREIGN_KEY_CHECKS=1; + +-- 按系列插入装备数据,确保ID连续 +-- 1. HAROP/Harpy 系列 (ID: 1-3) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(1, 'IAI Harop', '巡飞弹', '以色列'), +(2, 'IAI Harpy', '巡飞弹', '以色列'), +(3, 'IAI Mini Harpy', '巡飞弹', '以色列'); + +-- 2. Hero 系列 (ID: 4-9) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(4, 'Hero-30', '巡飞弹', '以色列 UVision'), +(5, 'Hero-70', '巡飞弹', '以色列 UVision'), +(6, 'Hero-120', '巡飞弹', '以色列 UVision'), +(7, 'Hero-250', '巡飞弹', '以色列 UVision'), +(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'), +(9, 'Hero-900', '巡飞弹', '以色列 UVision'); + +-- 3. Switchblade 系列 (ID: 10-13) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'), +(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'), +(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'), +(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment'); + +-- 4. Warmate 系列 (ID: 14-18) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'), +(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'), +(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'), +(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'), +(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics'); + +-- 5. CH-901/902 系列 (ID: 19-23) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(19, 'CH-901', '巡飞弹', '中国航天科工'), +(20, 'CH-901A', '巡飞弹', '中国航天科工'), +(21, 'CH-901H', '巡飞弹', '中国航天科工'), +(22, 'CH-902', '巡飞弹', '中国航天科工'), +(23, 'CH-902A', '巡飞弹', '中国航天科工'); + +-- 6. WS-43/61 系列 (ID: 24-28) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(24, 'WS-43', '巡飞弹', '中国航天科工'), +(25, 'WS-43A', '巡飞弹', '中国航天科工'), +(26, 'WS-43B', '巡飞弹', '中国航天科工'), +(27, 'WS-61', '巡飞弹', '中国航天科工'), +(28, 'WS-61A', '巡飞弹', '中国航天科工'); + +-- 7. Kargu/Alpagu 系列 (ID: 29-33) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(29, 'Kargu', '巡飞弹', '土耳其 STM'), +(30, 'Kargu-2', '巡飞弹', '土耳其 STM'), +(31, 'Alpagu', '巡飞弹', '土耳其 STM'), +(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'), +(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'); + +-- 8. Shahed 系列 (ID: 34-38) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(34, 'Shahed-131', '巡飞弹', '伊朗'), +(35, 'Shahed-131B', '巡飞弹', '伊朗'), +(36, 'Shahed-136', '巡飞弹', '伊朗'), +(37, 'Shahed-136B', '巡飞弹', '伊朗'), +(38, 'Shahed-136C', '巡飞弹', '伊朗'); + +-- 9. Green Dragon 系列 (ID: 39-43) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(39, 'Green Dragon', '巡飞弹', '以色列 IAI'), +(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'), +(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'), +(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'), +(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI'); + +-- 10. Phoenix Ghost 系列 (ID: 44-48) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'), +(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'), +(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'), +(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'), +(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace'); + +-- 11. ZALA Lancet 系列 (ID: 49-52) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'), +(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'), +(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'), +(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA'); + +-- 12. Rotem L 系列 (ID: 53-56) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(53, 'Rotem L', '巡飞弹', '以色列 IAI'), +(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'), +(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'), +(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI'); + +-- 13. KUB-BLA 系列 (ID: 57-60) +INSERT INTO equipment (id, name, type, manufacturer) VALUES +(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'), +(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'), +(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'), +(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA'); + +-- 插入通用参数 +INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES +(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop +(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy +(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy +(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30 +(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70 +(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120 +(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250 +(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC +(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900 +(10, 0.58, 0.12, 0.12, 2.5, 10), +(11, 1.30, 0.22, 0.22, 15.0, 40), +(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10 +(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range +(14, 0.68, 0.12, 0.12, 2.5, 10), +(15, 1.30, 0.22, 0.22, 15.0, 40), +(16, 0.68, 0.12, 0.12, 2.5, 10), +(17, 1.30, 0.22, 0.22, 15.0, 40), +(18, 0.68, 0.12, 0.12, 2.5, 10), +(19, 1.2, 0.18, 0.18, 9.0, 20), +(20, 1.2, 0.18, 0.18, 9.3, 25), +(21, 1.2, 0.18, 0.18, 9.5, 20), +(22, 1.4, 0.22, 0.22, 15.0, 30), +(23, 1.4, 0.22, 0.22, 15.5, 35), +(24, 1.8, 0.35, 0.35, 20, 60), +(25, 1.8, 0.35, 0.35, 21, 70), +(26, 1.9, 0.35, 0.35, 22, 80), +(27, 2.2, 0.40, 0.40, 35, 100), +(28, 2.2, 0.40, 0.40, 37, 120), +(29, 0.6, 0.35, 0.35, 7.0, 10), +(30, 0.6, 0.35, 0.35, 7.2, 15), +(31, 1.0, 0.23, 0.23, 3.7, 5), +(32, 1.0, 0.23, 0.23, 3.9, 8), +(33, 0.6, 0.35, 0.35, 7.5, 15), +(34, 2.6, 0.34, 0.34, 135, 900), +(35, 2.6, 0.34, 0.34, 140, 1000), +(36, 3.5, 0.42, 0.42, 200, 2000), +(37, 3.5, 0.42, 0.42, 210, 2200), +(38, 3.5, 0.42, 0.42, 215, 2500), +(39, 1.5, 0.20, 0.20, 15, 40), +(40, 1.6, 0.20, 0.20, 16, 50), +(41, 1.5, 0.20, 0.20, 15.5, 45), +(42, 1.5, 0.20, 0.20, 15.8, 40), +(43, 1.2, 0.18, 0.18, 12, 30), +(44, 1.5, 0.25, 0.25, 14.0, 30), +(45, 1.5, 0.25, 0.25, 14.5, 35), +(46, 1.6, 0.26, 0.26, 15.0, 40), +(47, 1.5, 0.25, 0.25, 14.8, 30), +(48, 1.7, 0.27, 0.27, 16.0, 50), +(49, 1.0, 0.20, 0.20, 5.0, 40), +(50, 1.65, 0.35, 0.35, 12.0, 70), +(51, 1.65, 0.35, 0.35, 12.5, 80), +(52, 1.80, 0.40, 0.40, 15.0, 100), +(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L +(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X +(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M +(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER +(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA +(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E +(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M +(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER + +-- 插入特有参数 +INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh, + endurance_min, + warhead_type, + launch_mode, + power_system, + guidance_system +) VALUES +-- HAROP/Harpy系列 +(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'), +(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'), +(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'), + +-- Hero系列 +(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'), +(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'), +(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), +(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), +(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), +(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'), + +-- Switchblade系列 +(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'), +(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), +(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'), +(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), + +-- Warmate系列 +(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'), +(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), +(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'), +(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), +(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'), + +-- CH-901/902系列 +(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), +(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), +(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), +(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), +(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), +(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), +(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), +(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), +(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'), +(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), +(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'), +(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'), +(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'), +(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'), +(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'), +(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'), +(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'), +(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'), +(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'), +(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), +(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), +(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), +(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), +(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'), +(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'), +(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), +(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), +(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'), +(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'), +(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), +(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'), +(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'), +(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'), +(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'), +(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'), +(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'), +(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'), +(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'), +(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'), +(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'), +(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'), +(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'); + +-- 插入成本数据 +INSERT INTO cost_data (equipment_id, actual_cost) VALUES +(1, 800000), -- IAI Harop +(2, 700000), -- IAI Harpy +(3, 350000), -- IAI Mini Harpy +(4, 70000), -- Hero-30 +(5, 120000), -- Hero-70 +(6, 150000), -- Hero-120 +(7, 300000), -- Hero-250 +(8, 400000), -- Hero-400EC +(9, 650000), -- Hero-900 +(10, 60000), -- Switchblade 300 +(11, 180000), -- Switchblade 600 +(12, 75000), -- Switchblade 300 Block 10 +(13, 200000), -- Switchblade 600 Extended Range +(14, 60000), -- Warmate 1.0 +(15, 180000), -- Warmate 2.0 +(16, 60000), -- Warmate-V +(17, 180000), -- Warmate-L +(18, 60000), -- Warmate 3.0 +(19, 100000), -- CH-901 +(20, 120000), -- CH-901A +(21, 130000), -- CH-901H +(22, 180000), -- CH-902 +(23, 200000), -- CH-902A +(24, 120000), -- WS-43 +(25, 150000), -- WS-43A +(26, 180000), -- WS-43B +(27, 300000), -- WS-61 +(28, 350000), -- WS-61A +(29, 70000), -- Kargu +(30, 85000), -- Kargu-2 +(31, 45000), -- Alpagu +(32, 55000), -- Alpagu Block-II +(33, 95000), -- Kargu Autonomous +(34, 20000), -- Shahed-131 +(35, 25000), -- Shahed-131B +(36, 40000), -- Shahed-136 +(37, 45000), -- Shahed-136B +(38, 50000), -- Shahed-136C +(39, 160000), -- Green Dragon +(40, 200000), -- Green Dragon Extended Range +(41, 180000), -- Green Dragon Block 2 +(42, 190000), -- Green Dragon Maritime +(43, 140000), -- Green Dragon-S +(44, 150000), -- Phoenix Ghost +(45, 180000), -- Phoenix Ghost Block I +(46, 220000), -- Phoenix Ghost Block II +(47, 190000), -- Phoenix Ghost Maritime +(48, 250000), -- Phoenix Ghost-ER +(49, 80000), -- Lancet-1 +(50, 150000), -- Lancet-3 +(51, 180000), -- Lancet-3M +(52, 250000), -- Lancet-4 +(53, 65000), -- Rotem L +(54, 85000), -- Rotem L-X +(55, 75000), -- Rotem L-M +(56, 95000), -- Rotem L-ER +(57, 95000), -- KUB-BLA +(58, 120000), -- KUB-BLA-E +(59, 110000), -- KUB-BLA-M +(60, 150000); -- KUB-BLA-ER + +-- 创建数据集 +INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES +(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'), +(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证'); + +-- 关联装备到数据集(按照制造商和型号分配) +INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES +-- 训练集(约80%的数据,48个型号) +-- 以色列系列 +(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列 +(1, 4), (1, 5), (1, 6), -- Hero系列基础型号 +(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列 +(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列 + +-- 美国系列 +(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列 +(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列 + +-- 中国系列 +(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列 +(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列 + +-- 波兰和土耳其系列 +(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列 +(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列 + +-- 俄罗斯系列 +(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列 + +-- 验证集(约20%的数据,12个型号) +-- 混合系列 +(2, 7), (2, 8), (2, 9), -- Hero系列高级型号 +(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列 +(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列 + +-- 添加分类特征编码 +INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES +-- 战斗部类型编码 +('warhead_type', '破片杀伤战斗部', 1), +('warhead_type', '破甲战斗部', 2), +('warhead_type', '高爆战斗部', 3), +('warhead_type', '破片杀伤/破甲双用战斗部', 4), +('warhead_type', '模块化战斗部', 5), + +-- 发射方式编码 +('launch_mode', '箱式发射', 1), +('launch_mode', '弹射式发射', 2), +('launch_mode', '垂直起降', 3), +('launch_mode', '单兵发射管', 4), +('launch_mode', '箱式发射/弹射式', 5), +('launch_mode', '箱式发射/空中发射', 6), + +-- 动力装置编码(按复杂度递增) +('power_system', '电动机', 1), +('power_system', '活塞发动机', 2), + +-- 制导系统编码(按复杂度递增) +('guidance_system', 'GPS/INS', 1), +('guidance_system', 'GPS/INS/光电', 2), +('guidance_system', 'GPS/INS/光电/数据链', 3), +('guidance_system', 'GPS/INS/光电/AI识别', 4), +('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5), +('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6), +('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7); + +-- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段 +UPDATE loitering_munition_params l +JOIN common_params c ON l.equipment_id = c.equipment_id +SET + -- 新增关键参数 + l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20% + l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10% + l.engine_power_kw = + CASE + WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15 + WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25 + END, + l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30% + l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80% + l.guidance_accuracy_m = + CASE + WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0 + WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0 + ELSE 3.0 + END, + l.min_altitude_m = -- 最小作战高度 + CASE + -- 大型巡飞弹(体型大、重量大) + WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列 + + -- 中型巡飞弹 + WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列 + + -- 中小型巡飞弹 + WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等 + + -- 小型巡飞弹 + WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等 + + -- 超小型巡飞弹 + WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列 + + -- 其他型号使用默认值 + ELSE 50 + END, + l.max_altitude_m = + CASE + WHEN c.max_range_km > 500 THEN 5000 + WHEN c.max_range_km > 100 THEN 3000 + ELSE 1500 + END, + + -- 特征工程字段 + l.length_width_ratio = c.length_m / c.width_m, + l.weight_range_ratio = c.weight_kg / c.max_range_km, + l.speed_weight_ratio = l.max_speed_ms / c.weight_kg, + l.guidance_system_score = + CASE + WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10 + WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8 + WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6 + WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4 + ELSE 2 + END, + l.warhead_power_score = + CASE + WHEN l.warhead_type = '模块化战斗部' THEN 10 + WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8 + WHEN l.warhead_type = '高爆战斗部' THEN 7 + WHEN l.warhead_type = '破甲战斗部' THEN 6 + WHEN l.warhead_type = '破片杀伤战斗部' THEN 5 + ELSE 4 + END, + + -- 分类特征编码 + l.warhead_type_code = + CASE + WHEN l.warhead_type = '破片杀伤战斗部' THEN 1 + WHEN l.warhead_type = '破甲战斗部' THEN 2 + WHEN l.warhead_type = '高爆战斗部' THEN 3 + WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4 + WHEN l.warhead_type = '模块化战斗部' THEN 5 + ELSE 0 + END, + l.launch_mode_code = + CASE + WHEN l.launch_mode = '箱式发射' THEN 1 + WHEN l.launch_mode = '弹射式发射' THEN 2 + WHEN l.launch_mode = '垂直起降' THEN 3 + WHEN l.launch_mode = '单兵发射管' THEN 4 + WHEN l.launch_mode = '箱式发射/弹射式' THEN 5 + WHEN l.launch_mode = '箱式发射/空中发射' THEN 6 + ELSE 0 + END, + l.power_system_code = + CASE + WHEN l.power_system = '电动机' THEN 1 + WHEN l.power_system = '活塞发动机' THEN 2 + ELSE 0 + END, + l.guidance_system_code = + CASE + WHEN l.guidance_system = 'GPS/INS' THEN 1 + WHEN l.guidance_system = 'GPS/INS/光电' THEN 2 + WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3 + WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4 + WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5 + WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6 + WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7 + ELSE 0 + END; diff --git a/src/routes.py b/src/routes.py index df4e825..ba30c45 100644 --- a/src/routes.py +++ b/src/routes.py @@ -135,7 +135,6 @@ def analyze_features(): logger.info(f"Dataset info: {dataset}") # 创建特征分析实例 - from src.feature_analysis import FeatureAnalysis analyzer = FeatureAnalysis() # 获取特征列表 @@ -143,20 +142,46 @@ def analyze_features(): logger.info(f"Feature names: {feature_names}") # 获取数据集中的装备数据 - if dataset['equipment_type'] == '火箭炮': + if dataset['equipment_type'] == '巡飞弹': cursor.execute(""" - SELECT e.*, cp.*, rap.*, cd.actual_cost + SELECT + e.name, + e.*, + cp.*, + lmp.*, + cd.actual_cost, + lmp.length_width_ratio, + lmp.weight_range_ratio, + lmp.speed_weight_ratio, + lmp.guidance_system_score, + lmp.warhead_power_score, + lmp.engine_power_kw, + lmp.engine_thrust_n, + lmp.min_altitude_m, + lmp.max_altitude_m FROM equipment e JOIN dataset_equipment de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id - LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id + LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL """, (dataset_id,)) else: cursor.execute(""" - SELECT e.*, cp.*, lmp.*, cd.actual_cost + SELECT e.name, + e.*, cp.*, lmp.*, + cp.max_range_km, + lmp.length_width_ratio, + lmp.weight_range_ratio, + lmp.speed_weight_ratio, + lmp.guidance_system_score, + lmp.warhead_power_score, + lmp.engine_power_kw, + lmp.engine_thrust_n, + lmp.min_altitude_m, + lmp.max_altitude_m, + cd.actual_cost FROM equipment e JOIN dataset_equipment de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id @@ -173,61 +198,52 @@ def analyze_features(): logger.warning("No valid equipment data found in dataset") return jsonify({'error': '数据集没有有效的成本数据'}), 400 - # 统计每个特征的缺失率 - missing_rates = {} - for name in feature_names: - missing_count = sum(1 for item in equipment_data if item.get(name) is None) - missing_rate = missing_count / len(equipment_data) - missing_rates[name] = missing_rate - logger.info(f"Feature {name} missing rate: {missing_rate:.2%}") - - # 过滤掉缺失率过高的特征 - valid_features = [name for name in feature_names if missing_rates[name] < 0.7] - logger.info(f"Valid features after filtering: {valid_features}") - - if len(valid_features) < 3: # 至少需要3个特征 - return jsonify({'error': '有效特征数量不足'}), 400 - - # 计算每个特征的均值 - feature_means = {} - for name in valid_features: - values = [float(item[name]) for item in equipment_data if item.get(name) is not None] - feature_means[name] = sum(values) / len(values) if values else 0 - logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}") - # 准备特征和目标值 features = [] target = [] + equipment_names = [] # 新增:存储装备名称 - # 提取特征和目标值,使用均值填充缺失值 + # 提取特征和目标值 for item in equipment_data: feature_values = [] - for name in valid_features: + equipment_names.append(item['name']) # 保存装备名称 + + for name in feature_names: value = item.get(name) try: - # 确保数值类型转换正确 - feature_values.append(float(value) if value is not None else feature_means[name]) + feature_values.append(float(value) if value is not None else 0) except (ValueError, TypeError) as e: logger.error(f"Error converting value for feature {name}: {value}") - logger.error(f"Error details: {str(e)}") return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400 - features.append(feature_values) - # 确保成本值是值类型 - try: - target.append(float(item['actual_cost'])) - except (ValueError, TypeError) as e: - logger.error(f"Error converting actual_cost: {item['actual_cost']}") - logger.error(f"Error details: {str(e)}") - return jsonify({'error': '成本值无法换为数值'}), 400 - - logger.info(f"Prepared {len(features)} feature vectors") - logger.info(f"First feature vector: {features[0] if features else None}") - logger.info(f"First target value: {target[0] if target else None}") + features.append(feature_values) + target.append(float(item['actual_cost'])) # 调用特征分析方法 - result = analyzer.analyze_features(features, target, valid_features) - logger.info("Analysis completed successfully") + result = analyzer.analyze_features(features, target, feature_names) + + # 如果是巡飞弹类型,添加额外的数据 + if dataset['equipment_type'] == '巡飞弹': + missile_data = { + 'equipment_names': equipment_names, + 'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data], + 'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data], + 'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data], + 'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data], + 'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data], + 'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data], + 'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data], + 'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data], + 'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data] + } + + # 验证数据完整性 + for key, value in missile_data.items(): + logger.info(f"{key} data length: {len(value)}") + logger.info(f"{key} sample data: {value[:3]}") + + # 更新结果 + result.update(missile_data) return jsonify(result) @@ -428,148 +444,39 @@ def get_equipment_data(): try: with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) - cursor.execute('SET SESSION group_concat_max_len = 1000000') - - # 先测试特殊参数查询 - cursor.execute(""" - SELECT equipment_id, param_name, param_value, param_unit - FROM custom_params - WHERE param_name IS NOT NULL - AND param_value IS NOT NULL - LIMIT 5 - """) - test_params = cursor.fetchall() - logger.info(f"Test custom params: {test_params}") # 获取火箭炮数据 logger.info("Fetching rocket artillery data...") cursor.execute(""" - SELECT - e.id, - e.name, - e.type, - e.manufacturer, - e.created_at, - cp.length_m, - cp.width_m, - cp.height_m, - cp.weight_kg, - cp.max_range_km, - rap.firing_angle_horizontal, - rap.firing_angle_vertical, - rap.rocket_length_m, - rap.rocket_diameter_mm, - rap.rocket_weight_kg, - rap.rate_of_fire, - rap.combat_weight_kg, - rap.speed_kmh, - rap.min_range_km, - rap.mobility_type, - rap.structure_layout, - rap.engine_model, - rap.engine_params, - rap.power_hp, - rap.travel_range_km, - cd.actual_cost, - ( - SELECT COALESCE( - JSON_ARRAYAGG( - JSON_OBJECT( - 'id', csp.id, - 'param_name', csp.param_name, - 'param_value', csp.param_value, - 'param_unit', csp.param_unit, - 'description', csp.description - ) - ), - '[]' - ) - FROM custom_params csp - WHERE csp.equipment_id = e.id - AND csp.param_name IS NOT NULL - AND csp.param_value IS NOT NULL - ) as custom_params + SELECT e.*, cp.*, rap.*, cd.actual_cost, cd.predicted_cost FROM equipment e LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE e.type = '火箭炮' """) - rocket_artillery = cursor.fetchall() - logger.info(f"Found {len(rocket_artillery)} rocket artillery records") - if rocket_artillery: - logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}") - logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}") + rocket_data = cursor.fetchall() + logger.info(f"Found {len(rocket_data)} rocket artillery records") # 获取巡飞弹数据 logger.info("Fetching missile data...") cursor.execute(""" - SELECT - e.id, - e.name, - e.type, - e.manufacturer, - e.created_at, - cp.length_m, - cp.width_m, - cp.height_m, - cp.weight_kg, - cp.max_range_km, - lmp.wingspan_m, - lmp.warhead_weight_kg, - lmp.max_speed_ms, - lmp.cruise_speed_kmh, - lmp.flight_time_min, - lmp.warhead_type, - lmp.launch_mode, - lmp.folded_length_mm, - lmp.folded_width_mm, - lmp.folded_height_mm, - lmp.power_system, - lmp.guidance_system, - cd.actual_cost, - ( - SELECT COALESCE( - JSON_ARRAYAGG( - JSON_OBJECT( - 'id', csp.id, - 'param_name', csp.param_name, - 'param_value', csp.param_value, - 'param_unit', csp.param_unit, - 'description', csp.description - ) - ), - '[]' - ) - FROM custom_params csp - WHERE csp.equipment_id = e.id - AND csp.param_name IS NOT NULL - AND csp.param_value IS NOT NULL - ) as custom_params + SELECT e.*, cp.*, lmp.*, cd.actual_cost, cd.predicted_cost, + lmp.wingspan_m, lmp.warhead_weight_kg, lmp.max_speed_ms, + lmp.cruise_speed_kmh, lmp.endurance_min, lmp.warhead_type, + lmp.launch_mode, lmp.power_system, lmp.guidance_system FROM equipment e LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE e.type = '巡飞弹' """) - loitering_munition = cursor.fetchall() - logger.info(f"Found {len(loitering_munition)} missile records") - if loitering_munition: - logger.info(f"First missile: {loitering_munition[0]['name']}") - logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}") + missile_data = cursor.fetchall() + logger.info(f"Found {len(missile_data)} missile records") - # 处理 custom_params,保为 NULL - for item in rocket_artillery + loitering_munition: - if item['custom_params'] is None: - item['custom_params'] = [] - logger.debug(f"Set empty custom_params for equipment {item['id']}") - else: - logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params") - - logger.info("Data fetching completed") return jsonify({ - 'rocket_artillery': rocket_artillery, - 'loitering_munition': loitering_munition + 'rocket_artillery': rocket_data, + 'loitering_munition': missile_data }) except Exception as e: diff --git a/src/schema.sql b/src/schema.sql index a0abd10..84066eb 100644 --- a/src/schema.sql +++ b/src/schema.sql @@ -60,17 +60,47 @@ CREATE TABLE loitering_munition_params ( warhead_weight_kg FLOAT, -- 战斗部重量(kg) max_speed_ms FLOAT, -- 最大速度(m/s) cruise_speed_kmh FLOAT, -- 巡航速度(km/h) - flight_time_min FLOAT, -- 巡飞时间(min) + endurance_min FLOAT, -- 续航时间(min) warhead_type VARCHAR(50), -- 战斗部类型 launch_mode VARCHAR(50), -- 发射方式 - folded_length_mm FLOAT, -- 折叠长度(mm) - folded_width_mm FLOAT, -- 折叠宽度(mm) - folded_height_mm FLOAT, -- 折叠高度(mm) power_system VARCHAR(100), -- 动力装置 guidance_system VARCHAR(100), -- 制导体制 + + -- 新增关键参数 + payload_weight_kg FLOAT, -- 有效载荷重量(kg) + min_combat_radius_km FLOAT, -- 最小作战半径(km) + engine_power_kw FLOAT, -- 发动机功率(kw) + engine_thrust_n FLOAT, -- 发动机推力(N) + datalink_range_km FLOAT, -- 通信链路距离(km) + guidance_accuracy_m FLOAT, -- 制导精度(m) + min_altitude_m FLOAT, -- 最小作战高度(m) + max_altitude_m FLOAT, -- 最大作战高度(m) + + -- 特征工程字段 + length_width_ratio FLOAT, -- 长宽比 + weight_range_ratio FLOAT, -- 重量/射程比 + speed_weight_ratio FLOAT, -- 速度/重量比 + guidance_system_score INT, -- 制导系统复杂度评分(1-10) + warhead_power_score INT, -- 战斗部威力评分(1-10) + + -- 分类特征编码 + warhead_type_code INT, -- 战斗部类型编码 + launch_mode_code INT, -- 发射方式编码 + power_system_code INT, -- 动力装置编码 + guidance_system_code INT, -- 制导系统编码 + FOREIGN KEY (equipment_id) REFERENCES equipment(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +-- 分类特征编码表 +CREATE TABLE feature_encoding ( + id INT AUTO_INCREMENT PRIMARY KEY, + feature_type VARCHAR(50), -- 特征类型(warhead_type/launch_mode/power_system/guidance_system) + feature_value VARCHAR(100), -- 特征值 + code INT, -- 编码值 + UNIQUE KEY unique_feature (feature_type, feature_value) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + -- 成本数据表 CREATE TABLE cost_data ( id INT AUTO_INCREMENT PRIMARY KEY,