diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..49204df --- /dev/null +++ b/.env.example @@ -0,0 +1,25 @@ +# 数据库配置 +MYSQL_HOST=localhost +MYSQL_USER=root +MYSQL_PASSWORD=your_password_here +MYSQL_DATABASE=equipment_cost_db + +# 服务配置 +PORT=5001 +DEBUG=False + +# 日志配置 +LOG_LEVEL=INFO +LOG_DIR=logs + +# 模型配置 +MODEL_DIR=models +DATA_DIR=data + +# 安全配置 +SECRET_KEY=your_secret_key_here +ALLOWED_HOSTS=localhost,127.0.0.1 + +# 其他配置 +UPLOAD_MAX_SIZE=10485760 # 10MB in bytes +ALLOWED_FILE_TYPES=.xlsx,.xls \ No newline at end of file diff --git a/config.py b/config.py index ee008e5..ec92cd1 100644 --- a/config.py +++ b/config.py @@ -8,8 +8,8 @@ DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db" SECRET_KEY = secrets.token_hex(16) # 环境配置 -DEBUG = True -ENV = 'development' +DEBUG = False +ENV = 'production' # 文件上传配置 UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') diff --git a/data/equipment_data_20241108_training.xlsx b/data/equipment_data_20241108_training.xlsx deleted file mode 100644 index 30f7dcf..0000000 Binary files a/data/equipment_data_20241108_training.xlsx and /dev/null differ diff --git a/data/equipment_data_20241108_verify.xlsx b/data/equipment_data_20241108_verify.xlsx deleted file mode 100644 index 92a527a..0000000 Binary files a/data/equipment_data_20241108_verify.xlsx and /dev/null differ diff --git a/deploy/equipment_cost_system.tar.gz b/deploy/equipment_cost_system.tar.gz new file mode 100644 index 0000000..c906e30 Binary files /dev/null and b/deploy/equipment_cost_system.tar.gz differ diff --git a/deploy/equipment_cost_system/config/.env.template b/deploy/equipment_cost_system/config/.env.template new file mode 100644 index 0000000..195987f --- /dev/null +++ b/deploy/equipment_cost_system/config/.env.template @@ -0,0 +1,25 @@ +# 数据库配置 +MYSQL_HOST=localhost +MYSQL_USER=root +MYSQL_PASSWORD=your_password_here +MYSQL_DATABASE=equipment_cost_db + +# 服务配置 +PORT=5001 +DEBUG=False + +# 日志配置 +LOG_LEVEL=INFO +LOG_DIR=logs + +# 模型配置 +MODEL_DIR=models +DATA_DIR=data + +# 安全配置 +SECRET_KEY=your_secret_key_here +ALLOWED_HOSTS=localhost,127.0.0.1 + +# 其他配置 +UPLOAD_MAX_SIZE=10485760 # 10MB in bytes +ALLOWED_FILE_TYPES=.xlsx,.xls \ No newline at end of file diff --git a/deploy/equipment_cost_system/docs/api.md b/deploy/equipment_cost_system/docs/api.md new file mode 100644 index 0000000..000bd35 --- /dev/null +++ b/deploy/equipment_cost_system/docs/api.md @@ -0,0 +1,663 @@ +# 装备成本估算系统 API 文档 + +这个 API 文档提供了完整的接口说明,包括: + +- 每个端点的详细描述 +- 请求和响应的具体示例 +- 清晰的参数格式要求 +- 统一的错误处理说明 +- 重要的注意事项 + +文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。 + +## 基本信息 + +- 基础URL: `http://localhost:5001/api` +- 版本: 1.0.0 +- 响应格式: JSON + +## API 端点列表 + +### 1. 获取 API 信息 + +获取 API 版本信息和可用端点列表。 + +- **URL**: `/` +- **方法**: `GET` +- **响应示例**: +json +{ +"name": "装备成本估算系统 API", +"version": "1.0.0", +"endpoints": { +"predict": { +"url": "/api/predict", +"method": "POST", +"description": "成本预测" +}, +"analyze-features": { +"url": "/api/analyze-features", +"method": "POST", +"description": "特征分析" +}, +"train": { +"url": "/api/train", +"method": "POST", +"description": "模型训练" +}, +"evaluate": { +"url": "/api/evaluate", +"method": "POST", +"description": "模型评估" +} +} +} + +### 2. 单模型预测 + +使用当前激活的最优模型进行成本预测。 + +- **URL**: `/predict` +- **方法**: `POST` +- **请求体示例** (巡飞弹): + +```json +{ + "type": "巡飞弹", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "folded_length_mm": 1300, + "folded_width_mm": 230, + "folded_height_mm": 230, + "warhead_type": "破片杀伤战斗部", + "launch_mode": "凭自身动力起飞" +} +``` + +- **响应示例**: + +```json +{ + "predicted_cost": 150000.0, + "model_info": { + "type": "xgboost", + "name": "巡飞弹_20241111_model", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0 + }, + "confidence_interval": { + "lower": 135000.0, + "upper": 165000.0 + } +} +``` + +### 3. PLS 模型预测 + +使用 PLS 回归模型进行预测。 + +- **URL**: `/pls/predict` +- **方法**: `POST` +- **请求体**: 与单模型预测相同 +- **响应示例**: + +```json +{ + "predicted_cost": 148000.0, + "confidence_interval": { + "lower": 133000.0, + "upper": 163000.0 + } +} +``` + +### 4. 多模型预测 + +使用所有激活的模型进行预测并返回综合结果。 + +- **URL**: `/predict/all` +- **方法**: `POST` +- **请求体**: 与单模型预测相同 +- **响应示例**: + +```json +{ + "individual_predictions": { + "xgboost": { + "predicted_cost": 150000.0, + "model_info": { + "name": "巡飞弹_xgboost_model", + "type": "xgboost", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0 + }, + "confidence_interval": { + "lower": 135000.0, + "upper": 165000.0 + } + }, + "pls": { + "predicted_cost": 148000.0, + "model_info": { + "name": "巡飞弹_pls_model", + "type": "pls", + "r2_score": 0.92, + "mae": 5500.0, + "rmse": 8000.0 + }, + "confidence_interval": { + "lower": 133000.0, + "upper": 163000.0 + } + } + }, + "ensemble_prediction": { + "predicted_cost": 149000.0, + "standard_deviation": 1414.21, + "confidence_interval": { + "lower": 146228.15, + "upper": 151771.85 + } + } +} +``` + +### 5. 特征分析 + +分析数据集中特征的重要性和相关性。 + +- **URL**: `/analyze-features` +- **方法**: `POST` +- **请求体示例**: + +```json +{ + "dataset_id": 1, + "equipment_type": "巡飞弹" +} +``` + +- **响应示例**: + +```json +{ + "important_features": [ + { + "name": "最大射程(km)", + "importance": 0.35 + }, + { + "name": "重量(kg)", + "importance": 0.25 + } + ], + "correlation_analysis": { + "features": ["最大射程(km)", "重量(kg)"], + "matrix": [[1.0, 0.8], [0.8, 1.0]] + } +} +``` + +### 6. 模型训练 + +训练新的模型。 + +- **URL**: `/train` +- **方法**: `POST` +- **请求体示例**: + +```json +{ + "type": "巡飞弹", + "train_dataset_id": 1, + "validation_dataset_id": 2, + "models": ["xgboost", "lightgbm", "rf"] +} +``` + +- **响应示例**: + +```json +{ + "metrics": { + "xgboost": { + "train": { + "r2": 0.95, + "mae": 5000.0, + "rmse": 7500.0 + }, + "validation": { + "r2": 0.92, + "mae": 5500.0, + "rmse": 8000.0 + } + } + }, + "best_model": { + "type": "xgboost", + "r2": 0.92, + "mae": 5500.0, + "rmse": 8000.0 + } +} +``` + +### 7. 数据集管理 + +#### 7.1 获取数据集列表 + +- **URL**: `/datasets` +- **方法**: `GET` +- **响应示例**: + +```json +[ + { + "id": 1, + "name": "训练数据集", + "description": "用于训练的数据集", + "equipment_type": "巡飞弹", + "equipment_count": 10, + "equipment_names": ["设备1", "设备2"], + "purpose": "训练", + "created_at": "2024-11-11T10:00:00" + } +] +``` + +#### 7.2 获取数据集详情 + +- **URL**: `/datasets/{id}` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "id": 1, + "name": "训练数据集", + "description": "用于训练的数据集", + "equipment_type": "巡飞弹", + "purpose": "训练", + "created_at": "2024-11-11T10:00:00", + "equipment": [ + { + "id": 1, + "name": "设备1", + "type": "巡飞弹", + "manufacturer": "制造商1", + "actual_cost": 150000 + } + ], + "statistics": { + "equipment_count": 10, + "total_cost": 1500000, + "average_cost": 150000 + } +} +``` + +#### 7.3 创建数据集 + +- **URL**: `/datasets` +- **方法**: `POST` +- **请求体示例**: + +```json +{ + "name": "测试数据集", + "description": "用于测试的数据集", + "equipment_type": "巡飞弹", + "purpose": "训练", + "equipment_ids": [1, 2, 3] +} +``` + +- **响应示例**: + +```json +{ + "id": 2, + "message": "数据集创建成功" +} +``` + +#### 7.4 更新数据集 + +- **URL**: `/datasets/{id}` +- **方法**: `PUT` +- **请求体示例**: + +```json +{ + "name": "更新后的数据集名称", + "description": "更新后的描述", + "equipment_type": "巡飞弹", + "purpose": "验证", + "equipment_ids": [1, 2, 3, 4] +} +``` + +- **响应示例**: + +```json +{ + "success": true, + "message": "数据集更新成功" +} +``` + +#### 7.5 删除数据集 + +- **URL**: `/datasets/{id}` +- **方法**: `DELETE` +- **描述**: 删除指定的数据集及其关联关系 +- **响应示例**: + +```json +{ + "success": true, + "message": "数据集删除成功" +} +``` + +注意事项: + +1. 数据集删除后不会删除关联的装备数据 +2. 不能删除正在被模型使用的数据集 +3. 更新数据集时会重新计算统计信息 +4. 数据集的装备类型一旦创建后不能更改 + +### 8. 模型管理 + +#### 8.1 获取模型列表 + +- **URL**: `/models` +- **方法**: `GET` +- **响应示例**: + +```json +[ + { + "id": 1, + "model_name": "巡飞弹_xgboost_model", + "model_type": "xgboost", + "equipment_type": "巡飞弹", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0, + "is_active": true, + "training_date": "2024-11-11T10:00:00" + } +] +``` + +#### 8.2 获取最新模型 + +- **URL**: `/models/{equipment_type}/latest` +- **方法**: `GET` +- **响应示例**: 与模型列表的单个模型格式相同 + +#### 8.3 获取模型详情 + +- **URL**: `/models/{id}` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "id": 1, + "model_name": "巡飞弹_xgboost_model", + "model_type": "xgboost", + "equipment_type": "巡飞弹", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0, + "is_active": true, + "training_date": "2024-11-11T10:00:00", + "feature_importance": { + "max_range_km": 0.35, + "weight_kg": 0.25, + "length_m": 0.20 + }, + "training_data_size": 100, + "created_by": "system" +} +``` + +#### 8.4 激活模型 + +- **URL**: `/models/{id}/activate` +- **方法**: `POST` +- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态 +- **响应示例**: + +```json +{ + "success": true, + "message": "模型已激活" +} +``` + +#### 8.5 删除模型 + +- **URL**: `/models/{id}` +- **方法**: `DELETE` +- **描述**: 删除指定模型,包括模型文件和数据库记录 +- **响应示例**: + +```json +{ + "success": true, + "message": "模型已删除" +} +``` + +注意事项: + +1. 删除模型时会同时删除相关的文件和数据库记录 +2. 不能删除当前正在使用(已激活)的模型 +3. 激活模型时会自动取消同类型其他模型的激活状态 +4. 模型详情包含了更多的训练相关信息,如特征重要性等 + +### 9. 数据管理 + +#### 9.1 获取装备数据列表 + +- **URL**: `/data` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "rocket_artillery": [ + { + "id": 1, + "name": "BM-21", + "type": "火箭炮", + "manufacturer": "俄罗斯", + "length_m": 7.35, + "width_m": 2.4, + "height_m": 3.1, + "weight_kg": 13700, + "max_range_km": 20.4, + "actual_cost": 800000 + } + ], + "loitering_munition": [ + { + "id": 8, + "name": "Hero-120", + "type": "巡飞弹", + "manufacturer": "以色列", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "actual_cost": 150000 + } + ] +} +``` + +#### 9.2 获取装备详情 + +- **URL**: `/data/details/{id}` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "id": 8, + "name": "Hero-120", + "type": "巡飞弹", + "manufacturer": "以色列", + "common_params": { + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40 + }, + "specific_params": { + "wingspan_m": 2.1, + "warhead_weight_kg": 3.5, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "warhead_type": "破片杀伤战斗部", + "launch_mode": "箱式发射", + "power_system": "电动机", + "guidance_system": "GPS/INS" + }, + "cost_data": { + "actual_cost": 150000, + "prediction_date": "2024-11-11T10:00:00", + "predicted_cost": 148000 + }, + "custom_params": [ + { + "id": 1, + "param_name": "续航时间", + "param_value": "2小时", + "param_unit": "小时", + "description": "最大续航时间" + } + ] +} +``` + +#### 9.3 更新装备数据 + +- **URL**: `/data/{id}` +- **方法**: `PUT` +- **请求体示例**: + +```json +{ + "name": "Hero-120", + "manufacturer": "以色列", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "wingspan_m": 2.1, + "warhead_weight_kg": 3.5, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "actual_cost": 150000, + "custom_params": [ + { + "id": 1, + "param_value": "2.5小时" + } + ] +} +``` + +- **响应示例**: + +```json +{ + "success": true, + "message": "装备数据更新成功" +} +``` + +#### 9.4 删除装备数据 + +- **URL**: `/data/{id}` +- **方法**: `DELETE` +- **响应示例**: + +```json +{ + "success": true, + "message": "装备数据删除成功" +} +``` + +#### 9.5 下载数据模板 + +- **URL**: `/data/template` +- **方法**: `GET` +- **描述**: 下载Excel格式的数据导入模板 +- **响应**: Excel文件下载 + +#### 9.6 导入数据 + +- **URL**: `/data/import` +- **方法**: `POST` +- **请求体**: + - Content-Type: multipart/form-data + - 参数名: file + - 文件类型: .xlsx 或 .xls +- **响应示例**: + +```json +{ + "success": true, + "message": "数据导入成功", + "imported_count": { + "rocket_artillery": 3, + "loitering_munition": 5 + } +} +``` + +注意事项: + +1. 导入数据时必须使用系统提供的模板 +2. 更新装备数据时会同时更新关联的参数表 +3. 删除装备数据会同时删除相关的参数和成本数据 +4. 导入的Excel文件大小不应超过10MB +5. 所有数值字段必须符合指定的单位和范围要求 +6. 特殊参数的值必须包含单位信息 + +## 错误响应 + +所有接口在发生错误时都会返回以下格式的响应: + +```json +{ + "error": "错误描述信息" +} +``` + +## 注意事项 + +1. 所有数值参数必须大于0 +2. 所有单位必须按照参数名称中指定的单位提供 +3. 预测结果中的成本单位为元 +4. 置信区间表示预测结果的95%置信水平范围 +5. 所有请求和响应的编码均为 UTF-8 diff --git a/deploy/equipment_cost_system/docs/deploy.md b/deploy/equipment_cost_system/docs/deploy.md new file mode 100644 index 0000000..1f52359 --- /dev/null +++ b/deploy/equipment_cost_system/docs/deploy.md @@ -0,0 +1,120 @@ +# 装备成本估算系统部署指南 + +## 一、系统要求 + +### 1. 基础软件 + +- Linux 操作系统 (推荐 Ubuntu 20.04+) +- Python 3.8+ 及相关组件 + + ```bash + sudo apt update + sudo apt install python3 python3-pip python3-venv + sudo apt install python3-dev build-essential + ``` + +- Node.js 14+ 及 npm + + ```bash + # 使用 nvm 安装 Node.js + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash + source ~/.bashrc + nvm install 14 + nvm use 14 + ``` + +### 2. 数据库 + +- MySQL 8.0+ + + ```bash + sudo apt install mysql-server mysql-client + sudo apt install libmysqlclient-dev + ``` + +### 3. Python包依赖 + +```bash +# 科学计算相关 +sudo apt install libatlas-base-dev # numpy依赖 +sudo apt install libopenblas-dev # 线性代数库 +sudo apt install liblapack-dev # 线性代数包 +sudo apt install gfortran # Fortran编译器(scipy依赖) + +# XML处理相关(用于Excel文件处理) +sudo apt install libxml2-dev +sudo apt install libxslt1-dev +``` + +## 二、部署运行 + +### 1. 安装服务 + +```bash +sh scripts/install.sh +``` + +### 2. 启动服务 + +```bash +sh scripts/start.sh +``` + +### 3. 停止服务 + +```bash +sh scripts/stop.sh +``` + +## 三、维护说明 + +### 1. 日志管理 + +```bash +# 后端日志 +tail -f logs/api.log + +# 数据库日志 +tail -f /var/log/mysql/error.log +``` + +## 四、安全建议 + +1. 系统安全 + - 使用防火墙限制端口访问 + - 定期更新系统和依赖包 + +2. 数据安全 + - 定期备份数据库 + - 加密敏感信息 + - 限制数据库远程访问 + +3. 访问控制 + - 使用强密码 + - 配置适当的文件权限 + - 使用非root用户运行服务 + +## 五、监控方案 + +### 1. 系统监控 + +```bash +# 资源使用 +top -b -n 1 +df -h +free -m + +# 服务状态 +ps aux | grep gunicorn +ps aux | grep node +``` + +### 2. 应用监控 + +```bash +# API 响应时间 +curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/" + +# 错误日志 +grep "ERROR" logs/api.log +``` diff --git a/deploy/equipment_cost_system/frontend/babel.config.js b/deploy/equipment_cost_system/frontend/babel.config.js new file mode 100644 index 0000000..e955840 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/babel.config.js @@ -0,0 +1,5 @@ +module.exports = { + presets: [ + '@vue/cli-plugin-babel/preset' + ] +} diff --git a/deploy/equipment_cost_system/frontend/package.json b/deploy/equipment_cost_system/frontend/package.json new file mode 100644 index 0000000..5e6409e --- /dev/null +++ b/deploy/equipment_cost_system/frontend/package.json @@ -0,0 +1,61 @@ +{ + "name": "frontend", + "version": "0.1.0", + "private": true, + "engines": { + "node": ">=16", + "npm": ">=8" + }, + "scripts": { + "serve": "vue-cli-service serve", + "build": "vue-cli-service build", + "lint": "vue-cli-service lint" + }, + "dependencies": { + "axios": "^1.6.0", + "core-js": "^3.8.3", + "echarts": "^5.4.3", + "element-plus": "^2.4.2", + "vue": "^3.2.13", + "vue-router": "^4.0.3", + "vuex": "^4.0.0" + }, + "devDependencies": { + "@babel/core": "^7.12.16", + "@babel/eslint-parser": "^7.12.16", + "@element-plus/icons-vue": "^2.3.1", + "@vue/cli-plugin-babel": "~5.0.0", + "@vue/cli-plugin-eslint": "~5.0.0", + "@vue/cli-plugin-router": "~5.0.0", + "@vue/cli-plugin-vuex": "~5.0.0", + "@vue/cli-service": "~5.0.0", + "@vue/compiler-sfc": "^3.2.13", + "eslint": "^7.32.0", + "eslint-plugin-vue": "^8.0.3", + "sass": "^1.32.7", + "sass-loader": "^12.0.0" + }, + "eslintConfig": { + "root": true, + "env": { + "node": true + }, + "extends": [ + "plugin:vue/vue3-essential", + "eslint:recommended" + ], + "parserOptions": { + "parser": "@babel/eslint-parser" + }, + "rules": { + "vue/multi-word-component-names": "off", + "no-unused-vars": "warn" + } + }, + "browserslist": [ + "> 1%", + "last 2 versions", + "not dead", + "not ie 11" + ] +} diff --git a/deploy/equipment_cost_system/frontend/public/favicon.ico b/deploy/equipment_cost_system/frontend/public/favicon.ico new file mode 100644 index 0000000..df36fcf Binary files /dev/null and b/deploy/equipment_cost_system/frontend/public/favicon.ico differ diff --git a/deploy/equipment_cost_system/frontend/public/index.html b/deploy/equipment_cost_system/frontend/public/index.html new file mode 100644 index 0000000..3e5a139 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/public/index.html @@ -0,0 +1,17 @@ + + + + + + + + <%= htmlWebpackPlugin.options.title %> + + + +
+ + + diff --git a/deploy/equipment_cost_system/frontend/src/App.vue b/deploy/equipment_cost_system/frontend/src/App.vue new file mode 100644 index 0000000..c4278c5 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/App.vue @@ -0,0 +1,42 @@ + + + diff --git a/deploy/equipment_cost_system/frontend/src/api/index.js b/deploy/equipment_cost_system/frontend/src/api/index.js new file mode 100644 index 0000000..1090d1a --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/api/index.js @@ -0,0 +1,43 @@ +import axios from 'axios' +import { API_BASE_URL } from '@/config' + +const api = axios.create({ + baseURL: API_BASE_URL, + timeout: 10000 +}) + +export const predict = (data) => { + return api.post('/predict', data) +} + +export const analyzeFeatures = (data) => { + return api.post('/analyze-features', data) +} + +export const trainModel = (data) => { + return api.post('/train', data) +} + +export const evaluateModel = (data) => { + return api.post('/evaluate', data) +} + +export const importData = (formData) => { + return api.post('/data/import', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }) +} + +export const getEquipmentData = () => { + return api.get('/data') +} + +export const updateEquipment = (id, data) => { + return api.put(`/data/${id}`, data) +} + +export const deleteEquipment = (id) => { + return api.delete(`/data/${id}`) +} \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/assets/logo.png b/deploy/equipment_cost_system/frontend/src/assets/logo.png new file mode 100644 index 0000000..f3d2503 Binary files /dev/null and b/deploy/equipment_cost_system/frontend/src/assets/logo.png differ diff --git a/deploy/equipment_cost_system/frontend/src/assets/styles/global.css b/deploy/equipment_cost_system/frontend/src/assets/styles/global.css new file mode 100644 index 0000000..a9e81e1 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/assets/styles/global.css @@ -0,0 +1,39 @@ +/* 禁用 ResizeObserver 警告 */ +iframe[style*="position: fixed; top: 0px; left: 0px; width: 100%; height: 100%; border: none; z-index: 2147483647;"] { + display: none !important; +} + +.el-overlay { + overflow: hidden !important; +} + +/* 添加全局样式 */ +body { + margin: 0; + padding: 0; + overflow-x: hidden; +} + +/* 修复 Element Plus 的一些已知问题 */ +.el-dialog__wrapper { + overflow: hidden !important; +} + +.el-select-dropdown { + overflow: hidden !important; +} + +/* 禁用 ResizeObserver 相关的警告样式 */ +.resize-observer-warning { + display: none !important; +} + +/* 优化滚动行为 */ +* { + scroll-behavior: smooth; +} + +/* 防止页面抖动 */ +.el-main { + overflow-x: hidden; +} \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/config.js b/deploy/equipment_cost_system/frontend/src/config.js new file mode 100644 index 0000000..0a697ee --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/config.js @@ -0,0 +1,8 @@ +export const API_BASE_URL = 'http://localhost:5001/api'; + +export const DB_CONFIG = { + host: 'localhost', + user: 'root', + password: '123456', + database: 'equipment_cost_db' +}; \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/main.js b/deploy/equipment_cost_system/frontend/src/main.js new file mode 100644 index 0000000..360a4a4 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/main.js @@ -0,0 +1,55 @@ +import { createApp } from 'vue' +import App from './App.vue' +import router from './router' +import store from './store' +import ElementPlus from 'element-plus' +import 'element-plus/dist/index.css' +import './assets/styles/global.css' +import * as ElementPlusIconsVue from '@element-plus/icons-vue' + +// 创建应用实例 +const app = createApp(App) + +// 注册插件 +app.use(ElementPlus, { + size: 'default' +}) +app.use(router) +app.use(store) + +// 注册图标 +for (const [key, component] of Object.entries(ElementPlusIconsVue)) { + app.component(key, component) +} + +// 全局错误处理 +app.config.errorHandler = (err) => { + if (err.message && err.message.includes('ResizeObserver')) { + return + } + console.error(err) +} + +// 全局警告处理 +app.config.warnHandler = (msg, trace) => { + if (msg.includes('ResizeObserver')) { + return + } + console.warn(msg, trace) +} + +// 挂载应用 +app.mount('#app') + +// 处理 ResizeObserver 错误 +const _ResizeObserver = window.ResizeObserver +window.ResizeObserver = class ResizeObserver extends _ResizeObserver { + constructor(callback) { + super((entries, observer) => { + requestAnimationFrame(() => { + if (!Array.isArray(entries)) return + callback(entries, observer) + }) + }) + } +} diff --git a/deploy/equipment_cost_system/frontend/src/router/index.js b/deploy/equipment_cost_system/frontend/src/router/index.js new file mode 100644 index 0000000..3dfcdee --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/router/index.js @@ -0,0 +1,52 @@ +import { createRouter, createWebHistory } from 'vue-router' +import HomePage from '@/views/HomePage.vue' +import DataPage from '@/views/DataPage.vue' +import DatasetPage from '@/views/DatasetPage.vue' +import PredictPage from '@/views/PredictPage.vue' +import AnalysisPage from '@/views/AnalysisPage.vue' +import TrainingPage from '@/views/TrainingPage.vue' + +const routes = [ + { + path: '/', + name: 'Home', + component: HomePage + }, + { + path: '/data', + name: 'Data', + component: DataPage + }, + { + path: '/datasets', + name: 'Datasets', + component: DatasetPage + }, + { + path: '/predict', + name: 'Predict', + component: PredictPage + }, + { + path: '/analysis', + name: 'Analysis', + component: AnalysisPage + }, + { + path: '/training', + name: 'Training', + component: TrainingPage + }, + { + path: '/models', + name: 'Models', + component: () => import('../views/ModelPage.vue') + } +] + +const router = createRouter({ + history: createWebHistory(), + routes +}) + +export default router \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/store/index.js b/deploy/equipment_cost_system/frontend/src/store/index.js new file mode 100644 index 0000000..0faaec7 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/store/index.js @@ -0,0 +1,14 @@ +import { createStore } from 'vuex' + +export default createStore({ + state: { + }, + getters: { + }, + mutations: { + }, + actions: { + }, + modules: { + } +}) \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js b/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js new file mode 100644 index 0000000..de54055 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js @@ -0,0 +1,73 @@ +// 处理 ResizeObserver 错误 +const resizeHandler = () => { + const resizeObserverErrors = [ + "ResizeObserver loop completed with undelivered notifications.", + "ResizeObserver loop limit exceeded", + "ResizeObserver loop completed" + ] + + // 添加全局错误处理 + const handler = (event) => { + if (event && event.message && resizeObserverErrors.includes(event.message)) { + event.stopPropagation() + event.preventDefault() + event.stopImmediatePropagation() + return false + } + } + + // 添加多个事件监听器 + window.addEventListener('error', handler, true) + window.addEventListener('unhandledrejection', handler, true) + + // 添加 ResizeObserver 错误处理 + if (window.ResizeObserver) { + const resizeObserverPrototype = ResizeObserver.prototype + const originalObserve = resizeObserverPrototype.observe + + resizeObserverPrototype.observe = function (...args) { + try { + return originalObserve.apply(this, args) + } catch (e) { + if (resizeObserverErrors.includes(e.message)) { + return null + } + throw e + } + } + } +} + +export default { + install: (app) => { + resizeHandler() + + // 添加全局错误处理器 + app.config.errorHandler = (err, vm, info) => { + const resizeObserverErrors = [ + "ResizeObserver loop completed with undelivered notifications.", + "ResizeObserver loop limit exceeded", + "ResizeObserver loop completed" + ] + + if (err && err.message && resizeObserverErrors.includes(err.message)) { + return + } + console.error('Vue Error:', err, info) + } + + // 添加全局警告处理器 + app.config.warnHandler = (msg, vm, trace) => { + const resizeObserverErrors = [ + "ResizeObserver loop completed with undelivered notifications.", + "ResizeObserver loop limit exceeded", + "ResizeObserver loop completed" + ] + + if (resizeObserverErrors.includes(msg)) { + return + } + console.warn('Vue Warning:', msg, trace) + } + } +} \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue b/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue new file mode 100644 index 0000000..f8f8ba1 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue @@ -0,0 +1,304 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/DataPage.vue b/deploy/equipment_cost_system/frontend/src/views/DataPage.vue new file mode 100644 index 0000000..4689f70 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/DataPage.vue @@ -0,0 +1,725 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue b/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue new file mode 100644 index 0000000..1079611 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue @@ -0,0 +1,322 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/HomePage.vue b/deploy/equipment_cost_system/frontend/src/views/HomePage.vue new file mode 100644 index 0000000..60c0676 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/HomePage.vue @@ -0,0 +1,101 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue b/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue new file mode 100644 index 0000000..e9632f9 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue @@ -0,0 +1,279 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue b/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue new file mode 100644 index 0000000..a42fa9f --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue @@ -0,0 +1,321 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue b/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue new file mode 100644 index 0000000..edd20c1 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue @@ -0,0 +1,370 @@ + + + + + \ No newline at end of file diff --git a/deploy/equipment_cost_system/frontend/vue.config.js b/deploy/equipment_cost_system/frontend/vue.config.js new file mode 100644 index 0000000..b2a2970 --- /dev/null +++ b/deploy/equipment_cost_system/frontend/vue.config.js @@ -0,0 +1,5 @@ +module.exports = { + devServer: { + port: 8080 + } +} diff --git a/deploy/equipment_cost_system/requirements.txt b/deploy/equipment_cost_system/requirements.txt new file mode 100644 index 0000000..6eec783 --- /dev/null +++ b/deploy/equipment_cost_system/requirements.txt @@ -0,0 +1,12 @@ +flask==2.0.1 +flask-cors==3.0.10 +sqlalchemy==1.4.23 +pymysql==1.0.2 +cryptography==3.4.7 # MySQL 8.0+ 认证需要 +numpy==1.21.2 +pandas==1.3.3 +scikit-learn==0.24.2 +tensorflow==2.6.0 +urllib3<2.0.0 # 添加这一行,限制 urllib3 版本 +openpyxl==3.1.2 # 用于读取 .xlsx 文件 +xlrd==2.0.1 # 用于读取 .xls 文件 \ No newline at end of file diff --git a/deploy/equipment_cost_system/scripts/install.sh b/deploy/equipment_cost_system/scripts/install.sh new file mode 100644 index 0000000..a7a80d6 --- /dev/null +++ b/deploy/equipment_cost_system/scripts/install.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +echo "开始安装装备成本估算系统..." + +# 检查Python版本 +python3 -V || { + echo "错误: 需要 Python 3.8+" + exit 1 +} + +# 检查Node.js版本 +node -v || { + echo "错误: 需要 Node.js 14+" + exit 1 +} + +# 创建必要的目录 +echo "创建系统目录..." +mkdir -p {logs,data,models} + +# 安装后端依赖 +echo "安装后端依赖..." +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt + +# 安装前端依赖 +echo "安装前端依赖..." +cd frontend +npm install +npm run build +cd .. + +# 配置文件 +if [ ! -f config/.env ]; then + echo "创建配置文件..." + cp config/.env.template config/.env + echo "请修改 config/.env 中的配置" +fi + +# 初始化数据库 +echo "初始化数据库..." +read -p "请输入MySQL root密码: " mysqlpass +mysql -u root -p$mysqlpass < src/schema.sql + +# 导入测试数据(可选) +read -p "是否导入测试数据?(y/n) " import_test_data +if [ "$import_test_data" = "y" ]; then + mysql -u root -p$mysqlpass equipment_cost_db < src/init_data.sql +fi + +# 设置权限 +echo "设置文件权限..." +chmod +x scripts/*.sh +chmod 755 logs models data +chmod 600 config/.env + +echo "安装完成!" +echo "请检查并修改 config/.env 中的配置" +echo "使用 ./scripts/start.sh 启动服务" \ No newline at end of file diff --git a/deploy/equipment_cost_system/scripts/start.sh b/deploy/equipment_cost_system/scripts/start.sh new file mode 100644 index 0000000..5b28ac4 --- /dev/null +++ b/deploy/equipment_cost_system/scripts/start.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +echo "启动装备成本估算系统..." + +# 检查配置文件 +if [ ! -f config/.env ]; then + echo "错误: 配置文件不存在" + echo "请先运行 install.sh" + exit 1 +fi + +# 检查日志目录 +if [ ! -d logs ]; then + mkdir -p logs +fi + +# 激活虚拟环境 +source venv/bin/activate + +# 导出环境变量 +export $(cat config/.env | xargs) + +# 启动后端服务 +echo "启动后端服务..." +gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" \ + --daemon \ + --pid gunicorn.pid \ + --access-logfile logs/access.log \ + --error-logfile logs/error.log + +# 等待后端服务启动 +sleep 2 + +# 检查后端服务是否成功启动 +if ! curl -s http://localhost:5001/api/ > /dev/null; then + echo "错误: 后端服务启动失败" + exit 1 +fi + +# 启动前端服务 +echo "启动前端服务..." +cd frontend +npm run serve -- --port 8080 --host 0.0.0.0 & +echo $! > ../frontend.pid +cd .. + +echo "服务已启动!" +echo "后端API: http://localhost:5001" +echo "前端界面: http://localhost:8080" +echo "查看后端日志: tail -f logs/access.log" +echo "查看前端日志: tail -f logs/frontend.log" diff --git a/deploy/equipment_cost_system/scripts/stop.sh b/deploy/equipment_cost_system/scripts/stop.sh new file mode 100644 index 0000000..085d432 --- /dev/null +++ b/deploy/equipment_cost_system/scripts/stop.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +echo "停止装备成本估算系统..." + +# 停止后端服务 +if [ -f gunicorn.pid ]; then + echo "停止后端服务..." + pid=$(cat gunicorn.pid) + kill -TERM $pid + rm gunicorn.pid + echo "后端服务已停止 (PID: $pid)" +else + echo "未找到后端服务PID文件" + # 尝试查找并停止所有gunicorn进程 + pkill -f gunicorn + echo "已尝试停止所有gunicorn进程" +fi + +# 停止前端服务 +if [ -f frontend.pid ]; then + echo "停止前端服务..." + pid=$(cat frontend.pid) + kill -TERM $pid + rm frontend.pid + echo "前端服务已停止 (PID: $pid)" +else + echo "未找到前端服务PID文件" + # 尝试查找并停止前端服务进程 + pkill -f "vite preview" + echo "已尝试停止所有前端服务进程" +fi + +# 检查是否还有相关进程在运行 +if pgrep -f gunicorn > /dev/null; then + echo "警告: 仍有gunicorn进程在运行" + ps aux | grep gunicorn | grep -v grep +fi + +if pgrep -f "vite preview" > /dev/null; then + echo "警告: 仍有前端服务进程在运行" + ps aux | grep "vite preview" | grep -v grep +fi + +echo "所有服务已停止" diff --git a/deploy/equipment_cost_system/src/__init__.py b/deploy/equipment_cost_system/src/__init__.py new file mode 100644 index 0000000..497b4a4 --- /dev/null +++ b/deploy/equipment_cost_system/src/__init__.py @@ -0,0 +1 @@ +# 这个文件可以为空,但必须存在 diff --git a/deploy/equipment_cost_system/src/app.py b/deploy/equipment_cost_system/src/app.py new file mode 100644 index 0000000..037fb79 --- /dev/null +++ b/deploy/equipment_cost_system/src/app.py @@ -0,0 +1,50 @@ +from flask import Flask +from flask_cors import CORS +from .routes import api_bp +from .logger import setup_logger +import os + +# 获取logger +logger = setup_logger(__name__) + +def create_app(): + """ + 创建并配置Flask应用 + """ + try: + # 创建必要的目录 + os.makedirs('logs', exist_ok=True) + os.makedirs('data', exist_ok=True) + os.makedirs('models', exist_ok=True) + + logger.info("=== Server Starting ===") + logger.info("Initializing directories...") + + # 创建Flask应用 + app = Flask(__name__) + + # 配置CORS + CORS(app) + logger.info("CORS enabled") + + # 注册API蓝图 + app.register_blueprint(api_bp, url_prefix='/api') + logger.info("API blueprint registered") + + # 配置数据库连接 + app.config['MYSQL_HOST'] = 'localhost' + app.config['MYSQL_USER'] = 'root' + app.config['MYSQL_PASSWORD'] = '123456' + app.config['MYSQL_DB'] = 'equipment_cost_db' + + logger.info("Starting server...") + + return app + + except Exception as e: + logger.error(f"Error creating app: {str(e)}") + raise + +if __name__ == '__main__': + app = create_app() + app.run(host='localhost', port=5001) \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/cost_prediction.py b/deploy/equipment_cost_system/src/cost_prediction.py new file mode 100644 index 0000000..23ccce5 --- /dev/null +++ b/deploy/equipment_cost_system/src/cost_prediction.py @@ -0,0 +1,342 @@ +import numpy as np +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score +from sklearn.preprocessing import StandardScaler +import tensorflow as tf +from scipy import stats +import joblib +import os +import pandas as pd +from .feature_analysis import FeatureAnalysis +import logging +from src.model_trainer import ModelTrainer +from src.database import get_db_connection +from .logger import setup_logger + +logger = setup_logger(__name__) + +class CostPredictor: + def __init__(self): + self.scaler_X = StandardScaler() + self.scaler_y = StandardScaler() + self.model = None + self.feature_analyzer = FeatureAnalysis() + self.equipment_type = None + + # 添加 TensorFlow 配置 + tf.config.run_functions_eagerly(False) # 启用图执行模式 + + # 创建预测函数 + @tf.function(reduce_retracing=True, jit_compile=True) + def predict_fn(x): + return self.model(x, training=False) + + self._predict_fn = predict_fn + + self.load_model() + + def load_model(self): + """ + 加载预训练型和标准化器 + """ + try: + model_dir = 'models' + os.makedirs(model_dir, exist_ok=True) + + # 创建默认模型 + self._create_default_model() + + # 创建预测函数 + @tf.function(reduce_retracing=True, jit_compile=True) + def predict_fn(x): + return self.model(x, training=False) + + self._predict_fn = predict_fn + + except Exception as e: + logging.error(f"Error loading model: {str(e)}") + self._create_default_model() + + def _create_default_model(self): + """ + 创建默认模型并进行初始化训练 + """ + # 创建输入层 + inputs = tf.keras.Input(shape=(11,)) + + # 创建隐藏层 + x = tf.keras.layers.Dense(64, activation='relu')(inputs) + x = tf.keras.layers.Dense(32, activation='relu')(x) + + # 创建输出层 + outputs = tf.keras.layers.Dense(1)(x) + + # 创建模型 + self.model = tf.keras.Model(inputs=inputs, outputs=outputs) + + # 编译模型 + self.model.compile( + optimizer='adam', + loss=tf.keras.losses.mean_squared_error, + metrics=[tf.keras.metrics.mean_absolute_error] + ) + + # 创建示例数据 + example_data = pd.DataFrame({ + 'length_m': [7.35, 10.2], + 'width_m': [2.4, 2.8], + 'height_m': [3.1, 3.2], + 'weight_kg': [13700, 28500], + 'max_range_km': [20.4, 70], + 'firing_angle_horizontal': [102, 110], + 'firing_angle_vertical': [55, 60], + 'rocket_length_m': [2.87, 4.1], + 'rocket_diameter_mm': [122, 220], + 'rocket_weight_kg': [66.6, 150], + 'rate_of_fire': [40, 60] + }) + + # 训练标准化器 + self.scaler_X.fit(example_data) + self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围 + + # 设置默认装备类型 + self.equipment_type = '火箭炮' + + def _create_example_data(self): + """ + 创建示例数据来训练标准化器 + """ + # 火箭炮示例数据 + rocket_data = pd.DataFrame({ + 'length_m': [7.35, 10.2], + 'width_m': [2.4, 2.8], + 'height_m': [3.1, 3.2], + 'weight_kg': [13700, 28500], + 'max_range_km': [20.4, 70], + 'firing_angle_horizontal': [102, 110], + 'firing_angle_vertical': [55, 60], + 'rocket_length_m': [2.87, 4.1], + 'rocket_diameter_mm': [122, 220], + 'rocket_weight_kg': [66.6, 150], + 'rate_of_fire': [40, 60] + }) + + # 巡飞弹示例数据 + missile_data = pd.DataFrame({ + 'length_m': [1.3, 2.5], + 'width_m': [0.23, 0.6], + 'height_m': [0.23, 0.6], + 'weight_kg': [12.5, 135], + 'max_range_km': [40, 250], + 'max_speed_kmh': [180, 185], + 'cruise_speed_kmh': [100, 110], + 'flight_time_min': [60, 120], + 'folded_length_mm': [1300, 2500], + 'folded_width_mm': [230, 600], + 'folded_height_mm': [230, 600] + }) + + # 训练标准化器 + self.scaler_X.fit(rocket_data) # 使用火箭炮数据 + self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据 + + # 设置默认装备类型 + self.equipment_type = '火箭炮' + + def predict(self, data): + """ + 使用训练好的最优模型进行预测 + """ + try: + logger.info(f"Starting prediction for {data.get('type')}") + equipment_type = data.get('type') + + # 加载已训练的最优模型 + trainer = ModelTrainer() + if not trainer.load_model(equipment_type): + raise ValueError(f"No trained model found for {equipment_type}") + + # 准备特征数据 + features = self.feature_analyzer.get_equipment_specific_features(equipment_type) + X = np.array([[data.get(feature) for feature in features]]) + + # 预测 + y_pred = trainer.predict(X) + + # 计算置信区间 + confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) + + # 获取模型类型 + model_type = trainer.get_model_type() + + return { + 'predicted_cost': float(y_pred[0]), + 'model_type': model_type, # 返回使用的模型类型 + 'confidence_interval': { + 'lower': float(confidence_interval[0]), + 'upper': float(confidence_interval[1]) + } + } + + except Exception as e: + logger.error(f"Prediction error: {str(e)}") + raise + + def _calculate_confidence_interval(self, prediction, confidence=0.95): + """ + 计算预测值的置信区间 + """ + try: + # 使用预测值的20%作为标准差(增加不确定性) + std = abs(prediction) * 0.2 + + # 计算置信区间 + from scipy import stats + interval = stats.norm.interval(confidence, loc=prediction, scale=std) + + # 确保区间值为正数且合理 + lower = max(1000, interval[0]) # 最小值设为1000元 + upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20% + + logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]") + + return [lower, upper] + + except Exception as e: + logging.error(f"Error calculating confidence interval: {str(e)}") + # 如果计算失败,返回基于20%的简单区间 + lower = max(1000, prediction * 0.8) + upper = prediction * 1.2 + return [lower, upper] + + def evaluate(self, y_true, y_pred): + """ + 模型评估 + """ + return { + 'mae': float(mean_absolute_error(y_true, y_pred)), + 'mse': float(mean_squared_error(y_true, y_pred)), + 'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))), + 'r2': float(r2_score(y_true, y_pred)) + } + + def predict_pls(self, data): + """ + 使用 PLS ���型预测成本 + """ + try: + logger.info(f"Starting PLS prediction for {data.get('type')}") + equipment_type = data.get('type') + + # 加载 PLS 模型 + trainer = ModelTrainer() + if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型 + raise ValueError(f"No trained PLS model found for {equipment_type}") + + # 准备特征数据 + features = self.feature_analyzer.get_equipment_specific_features(equipment_type) + X = np.array([[data.get(feature) for feature in features]]) + + # 预测 + y_pred = trainer.predict(X) + + # 计算置信区间 + confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) + + return { + 'predicted_cost': float(y_pred[0]), + 'confidence_interval': { + 'lower': float(confidence_interval[0]), + 'upper': float(confidence_interval[1]) + } + } + + except Exception as e: + logger.error(f"PLS prediction error: {str(e)}") + raise + + def predict_all(self, data): + """ + 使用所有可用模型进行预测 + """ + try: + logger.info(f"Starting multi-model prediction for {data.get('type')}") + equipment_type = data.get('type') + results = {} + + # 1. 获取所有激活的模型 + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute(""" + SELECT id, model_type, model_name, r2_score, mae, rmse + FROM trained_models + WHERE equipment_type = %s AND is_active = TRUE + """, (equipment_type,)) + active_models = cursor.fetchall() + + if not active_models: + raise ValueError(f"No active models found for {equipment_type}") + + # 2. 使用每个模型进行预测 + trainer = ModelTrainer() + for model_info in active_models: + try: + # 加载特定模型 + if not trainer.load_model(equipment_type, model_type=model_info['model_type']): + logger.warning(f"Failed to load model: {model_info['model_name']}") + continue + + # 准备特征数据 + features = self.feature_analyzer.get_equipment_specific_features(equipment_type) + X = np.array([[data.get(feature) for feature in features]]) + + # 预测 + y_pred = trainer.predict(X) + + # 计算置信区间 + confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) + + # 保存结果 + results[model_info['model_type']] = { + 'predicted_cost': float(y_pred[0]), + 'model_info': { + 'name': model_info['model_name'], + 'type': model_info['model_type'], + 'r2_score': float(model_info['r2_score']), + 'mae': float(model_info['mae']), + 'rmse': float(model_info['rmse']) + }, + 'confidence_interval': { + 'lower': float(confidence_interval[0]), + 'upper': float(confidence_interval[1]) + } + } + + except Exception as e: + logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}") + continue + + if not results: + raise ValueError("No successful predictions from any model") + + # 3. 计算综合预测结果 + all_predictions = [result['predicted_cost'] for result in results.values()] + ensemble_prediction = float(np.mean(all_predictions)) + prediction_std = float(np.std(all_predictions)) + + # 4. 返回所有结果 + return { + 'individual_predictions': results, + 'ensemble_prediction': { + 'predicted_cost': ensemble_prediction, + 'standard_deviation': prediction_std, + 'confidence_interval': { + 'lower': float(ensemble_prediction - 1.96 * prediction_std), + 'upper': float(ensemble_prediction + 1.96 * prediction_std) + } + } + } + + except Exception as e: + logger.error(f"Error in multi-model prediction: {str(e)}") + raise \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/create_template.py b/deploy/equipment_cost_system/src/create_template.py new file mode 100644 index 0000000..0f7d545 --- /dev/null +++ b/deploy/equipment_cost_system/src/create_template.py @@ -0,0 +1,155 @@ +import pandas as pd +import openpyxl +from openpyxl.styles import PatternFill, Font, Alignment +from openpyxl.worksheet.datavalidation import DataValidation +import os +from .logger import setup_logger + +logger = setup_logger(__name__) + +def create_excel_template(): + """ + 创建数据模板 + """ + try: + # 确保data目录存在 + data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') + os.makedirs(data_dir, exist_ok=True) + + # 创建完整的文件路径 + template_path = os.path.join(data_dir, 'equipment_data_template.xlsx') + + # 创建 Excel 写入器 + writer = pd.ExcelWriter(template_path, engine='openpyxl') + + # 火箭炮基本参数表 + rocket_artillery_columns = [ + '名称', '类型', '制造商', '口径_mm', + '反射管数量', '乘员数', '总长_m', + '宽度_m', '高度_m', '重量_kg', + '战斗重_kg', '速度_km/h', '最大射程_km', + '最小射程_km', '方向射界_度', '高低射界_度', + '火箭弹长度_m', '火箭弹重量_kg', + '火箭弹最大速度_m/s', '射速_发', + '战斗部重量_kg', '行走方式', + '结构布局', '发动机型号', '发动机参数', + '功率_hp', '行程_km', '成本_元' + ] + + # 巡飞弹基本参数表 + loitering_munition_columns = [ + '名称', '类型', '制造商', '目标类型', + '弹长_m', '弹径_mm', '翼展_m', + '重量_kg', '战斗部重量_kg', + '最大射程_km', '最大速度_m/s', + '巡航速度_kmh', '巡飞时间_min', + '战斗部类型', '发射方式', + '折叠长度_mm', '折叠宽度_mm', + '折叠高度_mm', '动力装置', + '制导体制', '成本_元' + ] + + # 特殊参数表 + special_params_columns = [ + '装备名称', # 关联字段 + '参数名称', + '参数值', + '参数单位', + '参数说明' + ] + + # 创建工作表 + pd.DataFrame(columns=rocket_artillery_columns).to_excel( + writer, sheet_name='火箭炮', index=False + ) + pd.DataFrame(columns=loitering_munition_columns).to_excel( + writer, sheet_name='巡飞弹', index=False + ) + pd.DataFrame(columns=special_params_columns).to_excel( + writer, sheet_name='特殊参数', index=False + ) + + # 获取工作簿 + workbook = writer.book + + # 设置火箭炮工作表格式 + rocket_sheet = workbook['火箭炮'] + for col in range(1, len(rocket_artillery_columns) + 1): + cell = rocket_sheet.cell(row=1, column=col) + cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid') + cell.font = Font(bold=True) + cell.alignment = Alignment(horizontal='center') + + # 设置巡飞弹工作表格式 + missile_sheet = workbook['巡飞弹'] + for col in range(1, len(loitering_munition_columns) + 1): + cell = missile_sheet.cell(row=1, column=col) + cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid') + cell.font = Font(bold=True) + cell.alignment = Alignment(horizontal='center') + + # 设置特殊参数工作表格式 + special_sheet = workbook['特殊参数'] + for col in range(1, len(special_params_columns) + 1): + cell = special_sheet.cell(row=1, column=col) + cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid') + cell.font = Font(bold=True) + cell.alignment = Alignment(horizontal='center') + + # 添加数据验证 + for sheet in [rocket_sheet, missile_sheet]: + # 数值验证 + number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0") + number_validation.error = "请输入大于0的数值" + number_validation.errorTitle = "输入错误" + + # 应用到相应列 + for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']: + number_validation.add(f"{col}2:{col}1000") + + sheet.add_data_validation(number_validation) + + # 添加说明 + rocket_sheet['AD1'] = "填写说明:" + rocket_sheet['AD2'] = "1. 所有数值必须大于0" + rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写" + rocket_sheet['AD4'] = "3. 成本单位为元" + + missile_sheet['V1'] = "填写说明:" + missile_sheet['V2'] = "1. 所有数值必须大于0" + missile_sheet['V3'] = "2. 单位必须按照表头要求填写" + missile_sheet['V4'] = "3. 成本单位为元" + + special_sheet['G1'] = "填写说明:" + special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致" + special_sheet['G3'] = "2. 参数值必须包含单位" + special_sheet['G4'] = "3. 参数说明应简明扼要" + + # 调整列宽 + for sheet in [rocket_sheet, missile_sheet, special_sheet]: + for col in sheet.columns: + max_length = 0 + column = col[0].column_letter + for cell in col: + try: + if len(str(cell.value)) > max_length: + max_length = len(str(cell.value)) + except: + pass + adjusted_width = (max_length + 2) + sheet.column_dimensions[column].width = adjusted_width + + # 保存文件 + writer.close() + + return template_path + + except Exception as e: + raise Exception(f"创建模板文件失败: {str(e)}") + +if __name__ == "__main__": + try: + template_path = create_excel_template() + print(f"模板文件已创建: {template_path}") + except Exception as e: + print(f"错误: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/data_preparation.py b/deploy/equipment_cost_system/src/data_preparation.py new file mode 100644 index 0000000..6380647 --- /dev/null +++ b/deploy/equipment_cost_system/src/data_preparation.py @@ -0,0 +1,233 @@ +from sklearn.preprocessing import StandardScaler +from datetime import datetime +import os +import joblib +import pandas as pd +import numpy as np +from src.feature_analysis import FeatureAnalysis +from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor +from xgboost import XGBRegressor +from lightgbm import LGBMRegressor +from sklearn.model_selection import cross_val_score, LeaveOneOut +import json +import logging +from src.database.db_connection import get_db_connection +from sklearn.metrics import mean_absolute_error, mean_squared_error +from .logger import setup_logger + +logger = setup_logger(__name__) + +class DataPreparation: + def __init__(self): + self.feature_analyzer = FeatureAnalysis() + self.feature_scaler = StandardScaler() + self.target_scaler = StandardScaler() # 添加目标值标准化器 + + def prepare_training_data(self, equipment_data, equipment_type): + """ + 准备训练数据 + """ + try: + logger.info(f"Preparing training data for {equipment_type}") + logger.info(f"Raw data size: {len(equipment_data)}") + + # 如果输入已经是 numpy 数组,直接返回 + if isinstance(equipment_data, np.ndarray): + X = equipment_data + logger.info(f"Input is already numpy array with shape: {X.shape}") + + # 处理无效值 + X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) + + return { + 'X': X, + 'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type), + 'feature_scaler': self.feature_scaler, + 'target_scaler': self.target_scaler + } + + # 从原始数据中提取特征和目标值 + feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) + features = [] + targets = [] + + for item in equipment_data: + # 提取特征值 + feature_values = [] + for name in feature_names: + value = item.get(name) + try: + feature_values.append(float(value) if value is not None else 0.0) + except (ValueError, TypeError): + feature_values.append(0.0) + features.append(feature_values) + + # 提取目标值(成本) + try: + cost = float(item['actual_cost']) + if cost > 0: # 只使用正数成本值 + targets.append(cost) + else: + logger.warning(f"Skipping non-positive cost value: {cost}") + except (ValueError, TypeError, KeyError): + logger.error(f"Invalid cost value: {item.get('actual_cost')}") + continue + + # 转换为numpy数组 + X = np.array(features, dtype=float) + y = np.array(targets, dtype=float) + + # 记录原始数据范围 + logger.info(f"Raw X range: min={X.min()}, max={X.max()}") + logger.info(f"Raw y range: min={y.min()}, max={y.max()}") + + # 标准化特征和目标值 + X_scaled = self.feature_scaler.fit_transform(X) + y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel() + + # 记录标准化后的数据范围 + logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}") + logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}") + + # 记录标准化器参数 + logger.info("Feature scaler params:") + logger.info(f"Mean: {self.feature_scaler.mean_}") + logger.info(f"Scale: {self.feature_scaler.scale_}") + + logger.info("Target scaler params:") + logger.info(f"Mean: {self.target_scaler.mean_}") + logger.info(f"Scale: {self.target_scaler.scale_}") + + return { + 'X': X_scaled, + 'y': y_scaled, + 'feature_names': feature_names, + 'feature_scaler': self.feature_scaler, + 'target_scaler': self.target_scaler + } + + except Exception as e: + logger.error(f"Error in data preparation: {str(e)}") + raise Exception(f"Training error: {str(e)}") + + def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None): + """ + 准备验证数据 + """ + try: + logger.info(f"Preparing validation data for {equipment_type}") + logger.info(f"Raw validation data size: {len(validation_data)}") + + # 如果输入已经是 numpy 数组,直接使用 + if isinstance(validation_data, np.ndarray): + X = validation_data + logger.info(f"Input is already numpy array with shape: {X.shape}") + + # 处理无效值 + X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) + + # 使用训练数据的标准化器 + if scalers and 'feature_scaler' in scalers: + X_scaled = scalers['feature_scaler'].transform(X) + else: + X_scaled = X + + logger.info(f"Preprocessed data shape: {X_scaled.shape}") + logger.info(f"Validation features shape: {X_scaled.shape}") + logger.info(f"Validation features type: {X_scaled.dtype}") + + return { + 'X': X_scaled, + 'y': None # 验证数据可能没有标签 + } + + # 从原始数据中提取特征和目标值 + if not feature_names: + feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) + + features = [] + targets = [] + + for item in validation_data: + # 提取特征值 + feature_values = [] + for name in feature_names: + value = item.get(name) + try: + feature_values.append(float(value) if value is not None else 0.0) + except (ValueError, TypeError): + feature_values.append(0.0) + features.append(feature_values) + + # 提取目标值(成本)并验证范围 + try: + cost = float(item['actual_cost']) + logger.info(f"Raw cost value: {cost}") + if cost > 0: # 只使用正数成本值 + targets.append(cost) + else: + logger.warning(f"Skipping non-positive cost value: {cost}") + except (ValueError, TypeError): + logger.error(f"Invalid cost value: {item.get('actual_cost')}") + continue + + # 转换为numpy数组 + X = np.array(features, dtype=float) + y = np.array(targets, dtype=float) + + # 记录数据范围 + logger.info(f"Features range: min={X.min()}, max={X.max()}") + logger.info(f"Targets range: min={y.min()}, max={y.max()}") + + # 处理无效值 + X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) + + # 使用训练数据的标准化器 + if scalers and 'feature_scaler' in scalers: + X_scaled = scalers['feature_scaler'].transform(X) + if 'target_scaler' in scalers: + y_scaled = scalers['target_scaler'].transform(y.reshape(-1, 1)).ravel() + else: + y_scaled = y + else: + X_scaled = X + y_scaled = y + + logger.info(f"Preprocessed data shape: {X_scaled.shape}") + logger.info(f"Validation features shape: {X_scaled.shape}") + logger.info(f"Validation features type: {X_scaled.dtype}") + + # 记录标准化后的数据范围 + logger.info(f"Scaled validation X range: min={X_scaled.min()}, max={X_scaled.max()}") + logger.info(f"Scaled validation y range: min={y_scaled.min()}, max={y_scaled.max()}") + + # 确保特征维度一致 + if not feature_names: + feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) + + logger.info(f"Expected features: {len(feature_names)}") + logger.info(f"Actual features: {X_scaled.shape[1]}") + + if X_scaled.shape[1] != len(feature_names): + raise ValueError(f"Feature dimension mismatch: expected {len(feature_names)}, got {X_scaled.shape[1]}") + + return { + 'X': X_scaled, + 'y': y_scaled # 返回标准化后的成本值 + } + + except Exception as e: + logger.error(f"Error in validation data preparation: {str(e)}") + logger.error(f"Feature names: {feature_names}") + logger.error(f"Equipment type: {equipment_type}") + raise Exception(f"Validation error: {str(e)}") + + def calculate_derived_features(self, data, equipment_type): + """ + 计算衍生特征 + """ + try: + return self.feature_analyzer.calculate_derived_features(data, equipment_type) + except Exception as e: + logger.error(f"Error calculating derived features: {str(e)}") + raise Exception(f"Feature calculation error: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/database/__init__.py b/deploy/equipment_cost_system/src/database/__init__.py new file mode 100644 index 0000000..6df49c9 --- /dev/null +++ b/deploy/equipment_cost_system/src/database/__init__.py @@ -0,0 +1 @@ +from .db_connection import get_db_connection \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/database/db_connection.py b/deploy/equipment_cost_system/src/database/db_connection.py new file mode 100644 index 0000000..361a4d2 --- /dev/null +++ b/deploy/equipment_cost_system/src/database/db_connection.py @@ -0,0 +1,37 @@ +import mysql.connector +from mysql.connector import Error +from contextlib import contextmanager +import os +from dotenv import load_dotenv +from ..logger import setup_logger + +# 获取logger +logger = setup_logger(__name__) + +# 加载环境变量 +load_dotenv() + +@contextmanager +def get_db_connection(): + """ + 数据库连接上下文管理器 + """ + connection = None + try: + connection = mysql.connector.connect( + host=os.getenv('MYSQL_HOST', 'localhost'), + user=os.getenv('MYSQL_USER', 'root'), + password=os.getenv('MYSQL_PASSWORD', '123456'), + database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db') + ) + logger.info("Database connection established") + yield connection + + except Error as e: + logger.error(f"Error connecting to MySQL: {str(e)}") + raise + + finally: + if connection and connection.is_connected(): + connection.close() + logger.info("Database connection closed") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/feature_analysis.py b/deploy/equipment_cost_system/src/feature_analysis.py new file mode 100644 index 0000000..5b87972 --- /dev/null +++ b/deploy/equipment_cost_system/src/feature_analysis.py @@ -0,0 +1,269 @@ +import numpy as np +import pandas as pd +from scipy import stats +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import r2_score +import logging +from .logger import setup_logger + +logger = setup_logger(__name__) + +class FeatureAnalysis: + def __init__(self): + self.scaler = StandardScaler() + self.important_features = [] + + # 添加特征名称映射 + self.feature_names_map = { + # 通用参数 + 'length_m': '总长(m)', + 'width_m': '宽度(m)', + 'height_m': '高度(m)', + 'weight_kg': '重量(kg)', + 'max_range_km': '最大射程(km)', + + # 火箭炮特有参数 + 'firing_angle_horizontal': '方向射界(度)', + 'firing_angle_vertical': '高低射界(度)', + 'rocket_length_m': '火箭弹长度(m)', + 'rocket_diameter_mm': '口径(mm)', + 'rocket_weight_kg': '火箭弹重量(kg)', + 'rate_of_fire': '射速(发/分)', + 'combat_weight_kg': '战斗重量(kg)', + 'speed_kmh': '速度(km/h)', + 'min_range_km': '最小射程(km)', + 'power_hp': '功率(hp)', + + # 火箭炮衍生特征 + 'fire_density': '火力密度', + 'mobility_index': '机动性指标', + 'range_ratio': '射程比', + 'power_weight_ratio': '功重比', + 'volume_density': '体积密度', + + # 巡飞弹特有参数 + 'wingspan_m': '翼展(m)', + 'warhead_weight_kg': '战斗部重量(kg)', + 'max_speed_ms': '最大速度(m/s)', + 'cruise_speed_kmh': '巡航速度(km/h)', + 'flight_time_min': '巡飞时间(min)', + 'folded_length_mm': '折叠长度(mm)', + 'folded_width_mm': '折叠宽度(mm)', + 'folded_height_mm': '折叠高度(mm)', + + # 巡飞弹衍生特征 + 'warhead_ratio': '战斗部比重', + 'speed_ratio': '速度比', + 'range_time_ratio': '射程时间比', + 'aspect_ratio': '展弦比', + 'volume_density': '体积密度' + } + + def get_equipment_specific_features(self, equipment_type): + """ + 获取特定装备类型的特征列表 + """ + # 通用参数 + common_features = [ + 'length_m', # 总长(m) + 'width_m', # 宽度(m) + 'height_m', # 高度(m) + 'weight_kg', # 重量(kg) + 'max_range_km' # 最大射程(km) + ] + + if equipment_type == '火箭炮': + # 火箭炮特有参数 + specific_features = [ + 'firing_angle_horizontal', # 方向射界(度) + 'firing_angle_vertical', # 高低射界(度) + 'rocket_length_m', # 火箭弹长度(m) + 'rocket_diameter_mm', # 口径(mm) + 'rocket_weight_kg', # 火箭弹重量(kg) + 'rate_of_fire', # 射速(发/分) + 'combat_weight_kg', # 战斗重量(kg) + 'speed_kmh', # 速度(km/h) + 'min_range_km', # 最小射程(km) + 'power_hp' # 功率(hp) + ] + + # 火箭炮衍生特征 + derived_features = [ + 'fire_density', # 火力密度 = 射速 * 火箭弹重量 + 'mobility_index', # 机动性指标 = 速度 / 战斗重量 + 'range_ratio', # 射程比 = 最大射程 / 最小射程 + 'power_weight_ratio', # 功重比 = 功率 / 战斗重量 + 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高) + ] + + return common_features + specific_features + derived_features + + else: # 巡飞弹 + # 巡飞弹特有参数 + specific_features = [ + 'wingspan_m', # 翼展(m) + 'warhead_weight_kg', # 战斗部重量(kg) + 'max_speed_ms', # 最大速度(m/s) + 'cruise_speed_kmh', # 巡航速度(km/h) + 'flight_time_min', # 巡飞时间(min) + 'folded_length_mm', # 折叠长度(mm) + 'folded_width_mm', # 折叠宽度(mm) + 'folded_height_mm' # 折叠高度(mm) + ] + + # 巡飞弹衍生特征 + derived_features = [ + 'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量 + 'speed_ratio', # 速度比 = 巡航速度 / 最大速度 + 'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间 + 'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积 + 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高) + ] + + return common_features + specific_features + derived_features + + def calculate_derived_features(self, data, equipment_type): + """ + 计算衍生特征 + """ + try: + if equipment_type == '火箭炮': + # 火箭炮衍生特征计算 + if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns: + data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg'] + else: + data['fire_density'] = 0 # 或者其他默认值 + + if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns: + data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg'] + else: + data['mobility_index'] = 0 + + if 'max_range_km' in data.columns and 'min_range_km' in data.columns: + data['range_ratio'] = data['max_range_km'] / data['min_range_km'] + else: + data['range_ratio'] = 0 + + if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns: + data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg'] + else: + data['power_weight_ratio'] = 0 + + if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']): + data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m']) + else: + data['volume_density'] = 0 + + else: # 巡飞弹 + # 巡飞弹衍生特征计算 + if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns: + data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg'] + else: + data['warhead_ratio'] = 0 + + if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns: + data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6) + else: + data['speed_ratio'] = 0 + + if 'max_range_km' in data.columns and 'flight_time_min' in data.columns: + data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min'] + else: + data['range_time_ratio'] = 0 + + if 'wingspan_m' in data.columns and 'length_m' in data.columns: + data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m'] + else: + data['aspect_ratio'] = 0 + + if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']): + data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m']) + else: + data['volume_density'] = 0 + + return data + + except Exception as e: + logger.error(f"Error calculating derived features: {str(e)}") + raise + + def analyze_features(self, features, target, feature_names): + """ + 分析特征重要性和相关性 + """ + try: + # 转换为numpy数组 + X = np.array(features) + y = np.array(target) + + # 数据标准化 + X_scaled = self.scaler.fit_transform(X) + + # 特征重要性分析 + rf = RandomForestRegressor(n_estimators=100, random_state=42) + rf.fit(X_scaled, y) + importances = rf.feature_importances_ + + # 按重要性排序,使用中文特征名 + importance_indices = np.argsort(importances)[::-1] + important_features = [ + { + 'name': self.feature_names_map.get(feature_names[i], feature_names[i]), + 'importance': float(importances[i]) + } + for i in importance_indices + ] + + # 相关性分析 + df = pd.DataFrame(X_scaled, columns=feature_names) + correlation_matrix = df.corr().values + + # 生成相关性分析数据,保留2位小数 + correlation_data = [] + chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names] + for i in range(len(feature_names)): + for j in range(len(feature_names)): + correlation_data.append([ + i, j, + round(correlation_matrix[i][j], 2) # 修改为保留2位小数 + ]) + + return { + 'important_features': important_features, + 'correlation_analysis': { + 'features': chinese_feature_names, # 使用中文特征名 + 'matrix': correlation_data + } + } + + except Exception as e: + logger.error(f"Error in feature analysis: {str(e)}") + raise + + def preprocess_features(self, equipment_data, equipment_type): + """ + 预处理特征数据 + """ + try: + # 转换为 DataFrame + df = pd.DataFrame(equipment_data) + + # 计算衍生特征 + df = self.calculate_derived_features(df, equipment_type) + + # 处理缺失值 + numeric_columns = df.select_dtypes(include=[np.number]).columns + for col in numeric_columns: + # 转换为数值类型 + df[col] = pd.to_numeric(df[col], errors='coerce') + # 使用新的方式填充缺失值 + mean_value = df[col].mean() + df[col] = df[col].fillna(mean_value) + + logger.info(f"Preprocessed data shape: {df.shape}") + return df + + except Exception as e: + logger.error(f"Error preprocessing features: {str(e)}") + raise Exception(f"Feature preprocessing error: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/import_data.py b/deploy/equipment_cost_system/src/import_data.py new file mode 100644 index 0000000..03dfa36 --- /dev/null +++ b/deploy/equipment_cost_system/src/import_data.py @@ -0,0 +1,255 @@ +import pandas as pd +from .logger import setup_logger +from src.database.db_connection import get_db_connection + +logger = setup_logger(__name__) + +def import_training_data(excel_file): + """ + 从Excel导入训练数据到数据库 + """ + try: + # 读取所有sheet + rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮') + missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹') + special_df = pd.read_excel(excel_file, sheet_name='特殊参数') + + # 记录所有装备名称,用于后续检查 + equipment_names = set() + + with get_db_connection() as conn: + cursor = conn.cursor() + + # 1. 先导入火箭炮数据 + logger.info("开始导入火箭炮数据...") + for _, row in rocket_df.iterrows(): + equipment_names.add(row['名称']) + # 检查是否已存在相同名称的装备 + cursor.execute(""" + SELECT id FROM equipment + WHERE name = %s AND type = '火箭炮' + """, (row['名称'],)) + + existing_equipment = cursor.fetchone() + if existing_equipment: + logger.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入") + continue + + # 插入基本信息 + cursor.execute(""" + INSERT INTO equipment (name, type, manufacturer) + VALUES (%s, %s, %s) + """, (row['名称'], '火箭炮', row['制造商'])) + + equipment_id = cursor.lastrowid + + # 插入通用参数 + cursor.execute(""" + INSERT INTO common_params + (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) + VALUES (%s, %s, %s, %s, %s, %s) + """, ( + equipment_id, + row['总长_m'] if pd.notna(row['总长_m']) else None, + row['宽度_m'] if pd.notna(row['宽度_m']) else None, + row['高度_m'] if pd.notna(row['高度_m']) else None, + row['重量_kg'] if pd.notna(row['重量_kg']) else None, + row['最大射程_km'] if pd.notna(row['最大射程_km']) else None + )) + + # 插入火箭炮特有参数 + cursor.execute(""" + INSERT INTO rocket_artillery_params + (equipment_id, firing_angle_horizontal, firing_angle_vertical, + rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire, + combat_weight_kg, speed_kmh, min_range_km, mobility_type, + structure_layout, engine_model, engine_params, power_hp, + travel_range_km) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, ( + equipment_id, + row['方向射界_度'] if pd.notna(row['方向射界_度']) else None, + row['高低射界_度'] if pd.notna(row['高低射界_度']) else None, + row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None, + row['口径_mm'] if pd.notna(row['口径_mm']) else None, + row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None, + row['射速_发'] if pd.notna(row['射速_发']) else None, + row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None, + row['速度_km/h'] if pd.notna(row['速度_km/h']) else None, + row['最小射程_km'] if pd.notna(row['最小射程_km']) else None, + row['行走方式'] if pd.notna(row['行走方式']) else None, + row['结构布局'] if pd.notna(row['结构布局']) else None, + row['发动机型号'] if pd.notna(row['发动机型号']) else None, + row['发动机参数'] if pd.notna(row['发动机参数']) else None, + row['功率_hp'] if pd.notna(row['功率_hp']) else None, + row['行程_km'] if pd.notna(row['行程_km']) else None + )) + + # 插入成本数据 + if pd.notna(row['成本_元']): + cursor.execute(""" + INSERT INTO cost_data (equipment_id, actual_cost) + VALUES (%s, %s) + """, (equipment_id, row['成本_元'])) + + logger.info("火箭炮数据导入完成") + + # 2. 导入巡飞弹数据 + logger.info("开始导入巡飞弹数据...") + for index, row in missile_df.iterrows(): + # 记录每行数据的空值情况 + null_values = row[row.isna()].index.tolist() + if null_values: + logger.info(f"行 {index + 2} 中的空值字段: {null_values}") + + equipment_names.add(row['名称']) + # 检查是否已存在相同名称的装备 + cursor.execute(""" + SELECT id FROM equipment + WHERE name = %s AND type = '巡飞弹' + """, (row['名称'],)) + + existing_equipment = cursor.fetchone() + if existing_equipment: + logger.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入") + continue + + # 插入基本信息 + cursor.execute(""" + INSERT INTO equipment (name, type, manufacturer) + VALUES (%s, %s, %s) + """, ( + row['名称'], + '巡飞弹', + row['制造商'] if pd.notna(row['制造商']) else None + )) + + equipment_id = cursor.lastrowid + + # 插入通用参数 + cursor.execute(""" + INSERT INTO common_params + (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) + VALUES (%s, %s, %s, %s, %s, %s) + """, ( + equipment_id, + float(row['弹长_m']) if pd.notna(row['弹长_m']) else None, + float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米 + float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米 + float(row['重量_kg']) if pd.notna(row['重量_kg']) else None, + float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None + )) + + # 插入巡飞弹特有参数 + cursor.execute(""" + INSERT INTO loitering_munition_params + (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, + cruise_speed_kmh, flight_time_min, warhead_type, launch_mode, + folded_length_mm, folded_width_mm, folded_height_mm, + power_system, guidance_system) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, ( + equipment_id, + float(row['翼展_m']) if pd.notna(row['翼展_m']) else None, + float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None, + float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None, + float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None, + float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None, + str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None, + str(row['发射方式']) if pd.notna(row['发射方式']) else None, + float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None, + float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None, + float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None, + str(row['动力装置']) if pd.notna(row['动力装置']) else None, + str(row['制导体制']) if pd.notna(row['制导体制']) else None + )) + + # 插入成本数据 + if pd.notna(row['成本_元']): + cursor.execute(""" + INSERT INTO cost_data (equipment_id, actual_cost) + VALUES (%s, %s) + """, (equipment_id, float(row['成本_元']))) + + logger.info("巡飞弹数据导入完成") + + # 提交之前的更改并关闭原有游标 + cursor.close() + conn.commit() + + # 3. 导入特殊参数 + logger.info("开始导入特殊参数...") + for index, row in special_df.iterrows(): + equipment_name = row['装备名称'] + param_name = row['参数名称'] + logger.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'") + + if equipment_name not in equipment_names: + logger.warning(f"未找到装备: {equipment_name},请检查名称是否正确") + continue + + # 获取装备ID - 使用新的游标 + logger.debug(f"查询装备ID: {equipment_name}") + with conn.cursor() as id_cursor: + id_cursor.execute(""" + SELECT id FROM equipment WHERE name = %s + """, (equipment_name,)) + result = id_cursor.fetchone() + + if not result: + logger.warning(f"未找到装备: {equipment_name}") + continue + + equipment_id = result[0] + logger.debug(f"找到装备ID: {equipment_id}") + + # 检查参数是否存在 - 使用新的游标 + logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'") + with conn.cursor() as check_cursor: + check_cursor.execute(""" + SELECT id FROM custom_params + WHERE equipment_id = %s AND param_name = %s + """, (equipment_id, param_name)) + exists = check_cursor.fetchone() + + if exists: + logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入") + continue + + # 插入新的参数 - 使用新的游标 + param_value = str(row['参数值']) if pd.notna(row['参数值']) else None + param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None + param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None + + logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'") + with conn.cursor() as insert_cursor: + insert_cursor.execute(""" + INSERT INTO custom_params + (equipment_id, param_name, param_value, param_unit, description) + VALUES (%s, %s, %s, %s, %s) + """, ( + equipment_id, + param_name, + param_value, + param_unit, + param_desc + )) + logger.debug(f"成功插入参数记录") + + # 最终提交 + conn.commit() + logger.info("特殊参数导入完成") + logger.info("所有数据导入成功") + return True + + except Exception as e: + logger.error(f"Error importing data: {str(e)}") + raise + +if __name__ == "__main__": + try: + excel_file = 'data/equipment_data_20241108.xlsx' + import_training_data(excel_file) + logger.info("All data imported successfully") + except Exception as e: + logger.error(f"Import failed: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/init_data.sql b/deploy/equipment_cost_system/src/init_data.sql new file mode 100644 index 0000000..ee99db3 --- /dev/null +++ b/deploy/equipment_cost_system/src/init_data.sql @@ -0,0 +1,319 @@ +/* +这是用于开发和测试环境的示例数据。 +生产环境请使用系统的数据导入功能添加实际数据。 + +主要用途: +1. 提供开发测试数据 +2. 作为数据格式参考 +3. 用于系统功能验证 +*/ + +-- 插入装备基本信息 +INSERT INTO equipment (name, type, manufacturer, target_type) VALUES +('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), +('胜利-2', '火箭炮', '伊朗', '地面固定目标'); + +-- 插入巡飞弹技术参数 +INSERT INTO technical_params ( + equipment_id, + length_m, + width_m, + height_m, + weight_kg, + max_speed_kmh, + cruise_speed_kmh, + max_range_km, + flight_time_min, + warhead_type, + launch_mode, + folded_length_mm, + folded_width_mm, + folded_height_mm +) VALUES ( + 1, -- 终结者巡飞弹 + 0.56, + 0.15, + 0.20, + 2.72, + 160.93, + 96.56, + 24, + 15, + '破片杀伤战斗部', + '凭自身动力起飞', + 560, + 150, + 200 +); + +-- 插入火箭炮技术参数 +INSERT INTO technical_params ( + equipment_id, + length_m, + width_m, + height_m, + weight_kg, + max_range_km +) VALUES ( + 2, -- 胜利-2火箭炮 + 10, + 2.5, + 3.34, + 15000, + 23 +); + +-- 插入成本数据(示例数据) +INSERT INTO cost_data (equipment_id, actual_cost) VALUES +(1, 1000000), -- 终结者巡飞弹成本 +(2, 5000000); -- 胜利-2火箭炮成本 + +-- 插入更多巡飞弹变体数据用于训练 +INSERT INTO equipment (name, type, manufacturer, target_type) VALUES +('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), +('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), +('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'); + +-- 插入变体技术参数 +INSERT INTO technical_params ( + equipment_id, + length_m, + width_m, + height_m, + weight_kg, + max_speed_kmh, + cruise_speed_kmh, + max_range_km, + flight_time_min, + warhead_type, + launch_mode, + folded_length_mm, + folded_width_mm, + folded_height_mm +) VALUES +-- 终结者-A(稍大型号) +(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210), +-- 终结者-B(稍小型号) +(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190), +-- 终结者-C(标准型号的改进版) +(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200); + +-- 插入变体成本数据 +INSERT INTO cost_data (equipment_id, actual_cost) VALUES +(3, 1100000), -- 终结者-A成本(较高) +(4, 900000), -- 终结者-B成本(较低) +(5, 1050000); -- 终结者-C成本(中等) + +-- 添加更多巡飞弹数据 +INSERT INTO equipment (name, type, manufacturer, target_type) VALUES +('哈比', '巡飞弹', '以色列', '防空系统和雷达站'), +('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'), +('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'), +('弹簧刀', '巡飞弹', '波兰', '装甲目标'), +('彩虹-4', '巡飞弹', '中国', '地面固定目标'); + +-- 添加它们的技术参数 +INSERT INTO technical_params ( + equipment_id, + length_m, + width_m, + height_m, + weight_kg, + max_speed_kmh, + cruise_speed_kmh, + max_range_km, + flight_time_min, + warhead_type, + launch_mode, + folded_length_mm, + folded_width_mm, + folded_height_mm +) VALUES +-- 哈比 +(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600), +-- 游荡者 +(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400), +-- 凤凰 +(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300), +-- 弹簧刀 +(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350), +-- 彩虹-4 +(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800); + +-- 添加成本数据 +INSERT INTO cost_data (equipment_id, actual_cost) VALUES +(6, 800000), -- 哈比 +(7, 500000), -- 游荡者 +(8, 450000), -- 凤凰 +(9, 480000), -- 弹簧刀 +(10, 1500000); -- 彩虹-4 + +-- 火箭炮数据 +INSERT INTO equipment (name, type, manufacturer) VALUES +('BM-21', '火箭炮', '俄罗斯'), +('SR5', '火箭炮', '中国'), +('HIMARS', '火箭炮', '美国'), +('LAR-160', '火箭炮', '以色列'), +('T-122', '火箭炮', '土耳其'), +('RM-70', '火箭炮', '捷克'), +('ASTROS II', '火箭炮', '巴西'); + +-- 火箭炮通用参数 +INSERT INTO common_params ( + equipment_id, + length_m, + width_m, + height_m, + weight_kg, + max_range_km +) VALUES +-- BM-21 +(1, 7.35, 2.4, 3.1, 13700, 20.4), +-- SR5 +(2, 10.2, 2.8, 3.2, 28500, 70), +-- HIMARS +(3, 7.0, 2.4, 3.2, 16250, 70), +-- LAR-160 +(4, 6.7, 2.5, 2.8, 15000, 45), +-- T-122 +(5, 7.2, 2.5, 2.9, 18000, 40), +-- RM-70 +(6, 7.5, 2.5, 3.0, 17200, 20.3), +-- ASTROS II +(7, 8.0, 2.7, 3.1, 24500, 90); + +-- 火箭炮特有参数 +INSERT INTO rocket_artillery_params ( + equipment_id, + firing_angle_horizontal, + firing_angle_vertical, + rocket_length_m, + rocket_diameter_mm, + rocket_weight_kg, + rate_of_fire, + combat_weight_kg, + speed_kmh, + min_range_km, + mobility_type, + structure_layout, + engine_model, + engine_params, + power_hp, + travel_range_km +) VALUES +-- BM-21 +(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500), +-- SR5 +(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650), +-- HIMARS +(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480), +-- LAR-160 +(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550), +-- T-122 +(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600), +-- RM-70 +(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520), +-- ASTROS II +(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700); + +-- 巡飞弹数据 +INSERT INTO equipment (name, type, manufacturer) VALUES +('Hero-120', '巡飞弹', '以色列'), +('Switchblade 600', '巡飞弹', '美国'), +('Warmate', '巡飞弹', '波兰'), +('CH-901', '巡飞弹', '中国'), +('HAROP', '巡飞弹', '以色列'), +('Coyote', '巡飞弹', '美国'), +('WS-43', '巡飞弹', '中国'); + +-- 巡飞弹通用参数 +INSERT INTO common_params ( + equipment_id, + length_m, + width_m, + height_m, + weight_kg, + max_range_km +) VALUES +-- Hero-120 +(8, 1.3, 0.23, 0.23, 12.5, 40), +-- Switchblade 600 +(9, 1.3, 0.22, 0.22, 15.0, 40), +-- Warmate +(10, 1.1, 0.15, 0.15, 5.7, 15), +-- CH-901 +(11, 1.2, 0.18, 0.18, 9.0, 20), +-- HAROP +(12, 2.5, 0.43, 0.43, 135, 1000), +-- Coyote +(13, 0.9, 0.12, 0.12, 5.9, 20), +-- WS-43 +(14, 1.8, 0.35, 0.35, 20, 60); + +-- 巡飞弹特有参数 +INSERT INTO loitering_munition_params ( + equipment_id, + wingspan_m, + warhead_weight_kg, + max_speed_ms, + cruise_speed_kmh, + flight_time_min, + warhead_type, + launch_mode, + folded_length_mm, + folded_width_mm, + folded_height_mm, + power_system, + guidance_system +) VALUES +-- Hero-120 +(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'), +-- Switchblade 600 +(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'), +-- Warmate +(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'), +-- CH-901 +(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'), +-- HAROP +(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'), +-- Coyote +(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'), +-- WS-43 +(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电'); + +-- 插入成本数据(示例成本) +INSERT INTO cost_data (equipment_id, actual_cost) VALUES +-- 火箭炮 +(1, 800000), -- BM-21 +(2, 4500000), -- SR5 +(3, 5500000), -- HIMARS +(4, 3500000), -- LAR-160 +(5, 2800000), -- T-122 +(6, 1500000), -- RM-70 +(7, 4800000), -- ASTROS II +-- 巡飞弹 +(8, 150000), -- Hero-120 +(9, 180000), -- Switchblade 600 +(10, 80000), -- Warmate +(11, 100000), -- CH-901 +(12, 850000), -- HAROP +(13, 75000), -- Coyote +(14, 120000); -- WS-43 + +-- 创建初始数据集 +INSERT INTO datasets (name, description, equipment_type, purpose) VALUES +('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'), +('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'), +('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'), +('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证'); + +-- 关联装备到数据集 +INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES +-- 火箭炮训练集 +(1, 1), (1, 2), (1, 3), (1, 4), +-- 巡飞弹训练集 +(2, 8), (2, 9), (2, 10), (2, 11), (2, 12), +-- 火箭炮验证集 +(3, 5), (3, 6), (3, 7), +-- 巡飞弹验证集 +(4, 13), (4, 14); \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/logger.py b/deploy/equipment_cost_system/src/logger.py new file mode 100644 index 0000000..52bb879 --- /dev/null +++ b/deploy/equipment_cost_system/src/logger.py @@ -0,0 +1,33 @@ +import logging +import os +from datetime import datetime + +def setup_logger(name): + """ + 创建并配置logger + """ + # 创建logger + logger = logging.getLogger(name) + + # 如果logger已经有处理器,直接返回 + if logger.handlers: + return logger + + # 设置日志级别 + logger.setLevel(logging.INFO) + + # 确保日志目录存在 + os.makedirs('logs', exist_ok=True) + + # 创建文件处理器 + file_handler = logging.FileHandler('logs/api.log') + file_handler.setLevel(logging.INFO) + + # 创建格式化器 + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + + # 添加处理器 + logger.addHandler(file_handler) + + return logger \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/model_trainer.py b/deploy/equipment_cost_system/src/model_trainer.py new file mode 100644 index 0000000..660c38e --- /dev/null +++ b/deploy/equipment_cost_system/src/model_trainer.py @@ -0,0 +1,612 @@ +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler +from sklearn.model_selection import cross_val_score +from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor +from sklearn.impute import SimpleImputer +import xgboost as xgb +import lightgbm as lgb +import logging +import joblib +import os +from src.feature_analysis import FeatureAnalysis +from datetime import datetime +import json +from src.database import get_db_connection +from src.data_preparation import DataPreparation +from sklearn.cross_decomposition import PLSRegression +from .logger import setup_logger + +logger = setup_logger(__name__) + +class ModelTrainer: + def __init__(self): + """ + 初始化 ModelTrainer + """ + self.models = { + 'xgboost': self._create_xgboost_model(), + 'lightgbm': self._create_lightgbm_model(), + 'gbm': self._create_gbm_model(), + 'rf': self._create_rf_model(), + 'pls': self._create_pls_model() + } + self.best_model = None + self.imputer = SimpleImputer(strategy='mean') + self.feature_scaler = None + self.target_scaler = None + self.equipment_type = None + self.feature_analyzer = FeatureAnalysis() + + def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None): + """ + 训练模型并返回评估结果 + """ + try: + self.equipment_type = equipment_type + logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}") + logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}") + + results = {} + best_score = -float('inf') + best_model_info = None + + # 首先训练 PLS 模型 + logger.info("Training pls...") + pls_model = self.models['pls'] + pls_model.fit(X_train, y_train) + pls_metrics = self._calculate_metrics( + pls_model, + X_train, y_train, + X_val, y_val + ) + results['pls'] = pls_metrics + + # 训练其他机器学习模型 + for model_name in model_names: + if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了 + continue + + if model_name not in self.models: + logger.warning(f"Unknown model: {model_name}") + continue + + logger.info(f"Training {model_name}...") + model = self.models[model_name] + + # 训练模型 + model.fit(X_train, y_train) + + # 计算评估指标 + metrics = self._calculate_metrics( + model, + X_train, y_train, + X_val, y_val + ) + + results[model_name] = metrics + + # 更新最佳模型(只在机器学习模型中比较) + if metrics['validation']['r2'] > best_score: + best_score = metrics['validation']['r2'] + best_model_info = { + 'type': model_name, + 'r2': metrics['validation']['r2'], + 'mae': metrics['validation']['mae'], + 'rmse': metrics['validation']['rmse'] + } + self.best_model = model + + # 保存最佳模型和 PLS 模型 + if equipment_type and best_model_info: + self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val) + + return { + 'metrics': results, + 'best_model': best_model_info + } + + except Exception as e: + logger.error(f"Error in model training: {str(e)}") + raise + + def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None): + """ + 计算模型评估指标 + """ + from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error + + # 训练集评估 + train_pred = model.predict(X_train) + + # 如果使用了标准化,需要转换回原始范围 + if hasattr(self, 'target_scaler'): + train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel() + y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel() + else: + y_train_orig = y_train + + # 记录预测范围 + logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}") + logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}") + + train_metrics = { + 'r2': r2_score(y_train_orig, train_pred), + 'mae': mean_absolute_error(y_train_orig, train_pred), + 'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred)) + } + + # 验证集评估 + if X_val is not None and y_val is not None: + val_pred = model.predict(X_val) + + # 如果使用了标准化,需要转换回原始范围 + if hasattr(self, 'target_scaler'): + val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel() + y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel() + else: + y_val_orig = y_val + + # 记录预测范围 + logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}") + logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}") + + val_metrics = { + 'r2': r2_score(y_val_orig, val_pred), + 'mae': mean_absolute_error(y_val_orig, val_pred), + 'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred)) + } + else: + # 使用交叉验证 + cv_scores = cross_val_score(model, X_train, y_train, cv=5) + val_metrics = { + 'r2': cv_scores.mean(), + 'mae': None, + 'rmse': None + } + + return { + 'train': train_metrics, + 'validation': val_metrics + } + + def _create_xgboost_model(self): + """ + 创建 XGBoost 模型,增强正则化 + """ + return xgb.XGBRegressor( + n_estimators=50, # 减少树的数量 + learning_rate=0.05, # 学习率 + max_depth=3, # 减小树的深 + min_child_weight=3, # 增加节点权重 + subsample=0.7, # 减小样本采样比例 + colsample_bytree=0.7, # 减小特征采样比例 + reg_alpha=0.1, # L1 正则化 + reg_lambda=1, # L2 正则化 + random_state=42 + ) + + def _create_lightgbm_model(self): + """ + 创建 LightGBM 模型,增强正则化 + """ + return lgb.LGBMRegressor( + n_estimators=50, + learning_rate=0.05, + max_depth=3, + num_leaves=7, + min_data_in_leaf=3, + min_sum_hessian_in_leaf=1e-3, + subsample=0.7, + colsample_bytree=0.7, + reg_alpha=0.1, + reg_lambda=1, + random_state=42, + verbose=-1 + ) + + def _create_gbm_model(self): + """ + 创建 GBM 模型,增强正则化以减轻过拟合 + """ + return GradientBoostingRegressor( + n_estimators=100, + learning_rate=0.1, + max_depth=3, + random_state=42, + subsample=0.8, + min_samples_split=3, + min_samples_leaf=2 + ) + + def _create_rf_model(self): + """ + 创建随机森林模型,针对小样本数据调整参数 + """ + return RandomForestRegressor( + n_estimators=100, + max_depth=3, + random_state=42, + min_samples_split=3, + min_samples_leaf=2 + ) + + def _create_pls_model(self): + """ + 创建 PLS 模型,优化参数配置 + """ + return PLSRegression( + n_components=2, # 减少主成分数量,从5减到2 + scale=True, # 保持数据标准化 + max_iter=500, # 减少最大迭代次数,避免过拟合 + tol=1e-6 # 降低收敛精度,避免过拟合 + ) + + def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None): + """ + 保存最佳模型和 PLS 模型 + """ + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_dir = 'models' + os.makedirs(model_dir, exist_ok=True) + + # 1. 保存最佳机器学习模型 + model_path = f'{model_dir}/{equipment_type}_{timestamp}' + if isinstance(self.best_model, xgb.XGBRegressor): + self.best_model.save_model(f'{model_path}.json') + model_format = 'json' + else: + joblib.dump(self.best_model, f'{model_path}.joblib') + model_format = 'joblib' + + # 2. 保存 PLS 模型 + pls_model = self.models['pls'] + pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib' + joblib.dump(pls_model, pls_path) + + # 3. 保存标准化器 + scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib' + joblib.dump({ + 'feature_scaler': self.feature_scaler, + 'target_scaler': self.target_scaler + }, scaler_path) + + logger.info(f"Saved best model to {model_path}.{model_format}") + logger.info(f"Saved PLS model to {pls_path}") + logger.info(f"Saved scalers to {scaler_path}") + + # 4. 更新数据库中的模型记录 + with get_db_connection() as conn: + cursor = conn.cursor() + + # 将所有模型设置为非激活 + cursor.execute(""" + UPDATE trained_models + SET is_active = FALSE + WHERE equipment_type = %s + """, (equipment_type,)) + + # 获取 PLS 模型的评估指标 + pls_metrics = self._calculate_metrics( + self.models['pls'], + X_train, + y_train, + X_val, + y_val + ) + + # 保存最佳机器学习模型记录 + self.best_model.equipment_type = equipment_type # 设置装备类型 + ml_feature_importance = self._get_feature_importance(self.best_model) + + cursor.execute(""" + INSERT INTO trained_models ( + model_name, model_type, equipment_type, model_path, scaler_path, + r2_score, mae, rmse, feature_importance, training_data_size, + training_date, is_active, created_by + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s) + """, ( + f"{equipment_type}_{timestamp}", # model_name + best_model_info['type'], # model_type + equipment_type, # equipment_type + f"{model_path}.{model_format}", # model_path + scaler_path, # scaler_path + best_model_info['r2'], # r2_score + best_model_info['mae'], # mae + best_model_info['rmse'], # rmse + json.dumps(ml_feature_importance), # feature_importance + len(X_train), # training_data_size + 'system' # created_by + )) + + # 保存 PLS 模型记录 + pls_model.equipment_type = equipment_type # 设置装备类型 + pls_feature_importance = self._get_feature_importance(pls_model) + + cursor.execute(""" + INSERT INTO trained_models ( + model_name, model_type, equipment_type, model_path, scaler_path, + r2_score, mae, rmse, feature_importance, training_data_size, + training_date, is_active, created_by + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s) + """, ( + f"{equipment_type}_{timestamp}_pls", # model_name + 'pls', # model_type + equipment_type, # equipment_type + pls_path, # model_path + scaler_path, # scaler_path + float(pls_metrics['validation']['r2']), # r2_score + float(pls_metrics['validation']['mae']), # mae + float(pls_metrics['validation']['rmse']), # rmse + json.dumps(pls_feature_importance), # feature_importance + len(X_train), # training_data_size + 'system' # created_by + )) + + conn.commit() + + except Exception as e: + logger.error(f"Error saving models: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + raise + + def load_model(self, equipment_type, model_type='ml'): + """ + 加载已训练的模型 + """ + try: + logger.info(f"Loading {model_type} model for {equipment_type}") + + # 从数据库获取激活的模型 + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + + # 构建查询语句 + if model_type == 'pls': + query = """ + SELECT * FROM trained_models + WHERE equipment_type = %s + AND model_type = 'pls' + AND is_active = TRUE + LIMIT 1 + """ + params = (equipment_type,) + else: + query = """ + SELECT * FROM trained_models + WHERE equipment_type = %s + AND model_type != 'pls' + AND is_active = TRUE + LIMIT 1 + """ + params = (equipment_type,) + + # 记录查询信息 + logger.info(f"Executing query: {query}") + logger.info(f"Query params: {params}") + + cursor.execute(query, params) + model_record = cursor.fetchone() + + # 记录查询结果 + if model_record: + logger.info(f"Found model record: {model_record}") + else: + logger.warning(f"No active model found for type {model_type}") + return False + + # 检查文件是否存在 + logger.info(f"Checking model file: {model_record['model_path']}") + logger.info(f"Checking scaler file: {model_record['scaler_path']}") + + if not os.path.exists(model_record['model_path']): + logger.error(f"Model file not found: {model_record['model_path']}") + raise FileNotFoundError(f"Model file not found: {model_record['model_path']}") + + if not os.path.exists(model_record['scaler_path']): + logger.error(f"Scaler file not found: {model_record['scaler_path']}") + raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}") + + # 加载模型文件 + logger.info(f"Loading model from {model_record['model_path']}") + if model_type == 'pls': + self.best_model = joblib.load(model_record['model_path']) + logger.info("Loaded PLS model") + else: + if model_record['model_type'] == 'xgboost': + self.best_model = xgb.XGBRegressor() + self.best_model.load_model(model_record['model_path']) + logger.info("Loaded XGBoost model") + else: + self.best_model = joblib.load(model_record['model_path']) + logger.info(f"Loaded {model_record['model_type']} model") + + # 加载标准化器 + logger.info(f"Loading scalers from {model_record['scaler_path']}") + scalers = joblib.load(model_record['scaler_path']) + self.feature_scaler = scalers['feature_scaler'] + self.target_scaler = scalers['target_scaler'] + logger.info("Loaded scalers successfully") + + return True + + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + logger.error(f"Detailed traceback:", exc_info=True) + return False + + def predict(self, features): + """ + 使用加载的模型进行预测 + """ + try: + if not self.best_model: + raise ValueError("No model loaded") + + if not self.feature_scaler: + raise ValueError("Feature scaler not loaded") + + if not self.target_scaler: + raise ValueError("Target scaler not loaded") + + logger.info("Starting prediction") + logger.info(f"Input features shape: {features.shape}") + logger.info(f"Input features: \n{features}") + + # 处理缺失值 + features_filled = np.array(features, dtype=float) + features_filled[np.isnan(features_filled)] = 0 + features_filled = np.nan_to_num(features_filled, 0) + + logger.info(f"Filled features: \n{features_filled}") + + # 标准化特征 + X = self.feature_scaler.transform(features_filled) + logger.info(f"Transformed features shape: {X.shape}") + logger.info(f"Transformed features: \n{X}") + + # 预测 + y_pred_scaled = self.best_model.predict(X) + logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}") + logger.info(f"Scaled prediction: {y_pred_scaled}") + + # ��标准化 + y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1)) + logger.info(f"Final prediction shape: {y_pred.shape}") + logger.info(f"Final prediction: {y_pred}") + + # 记录标准化器的参数 + logger.info("Target scaler params:") + logger.info(f"Mean: {self.target_scaler.mean_}") + logger.info(f"Scale: {self.target_scaler.scale_}") + + return y_pred.ravel() + + except Exception as e: + logger.error(f"Error in prediction: {str(e)}") + raise + + def _get_feature_importance(self, model): + """ + 获取特征重要性 + """ + try: + if not model: + return {} + + # 获取特征名称 + feature_analyzer = FeatureAnalysis() + feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type) + + # 获取特���重要性 + if hasattr(model, 'feature_importances_'): + importances = model.feature_importances_ + elif hasattr(model, 'coef_'): + if len(model.coef_.shape) > 1: # 如果是二维数组 + importances = np.abs(model.coef_[0]) # 取第一行 + else: + importances = np.abs(model.coef_) + else: + return {} + + # 创建特征重要性字典 + importance_dict = {} + for name, importance in zip(feature_names, importances): + importance_dict[name] = float(importance) # 确保转换为 Python 标量 + + # 按重要性降序排序 + sorted_dict = dict(sorted( + importance_dict.items(), + key=lambda x: x[1], + reverse=True + )) + + # 过滤掉重要性为0的特征 + return {k: v for k, v in sorted_dict.items() if v > 0} + + except Exception as e: + logger.error(f"Error getting feature importance: {str(e)}") + return {} + + def _calculate_confidence_interval(self, prediction, confidence=0.95): + """ + 计算预测值的置信区间 + """ + try: + # 使用预测值的20%作为标准差(增加不确定性) + std = abs(prediction) * 0.2 + + # 计算置信区间 + from scipy import stats + interval = stats.norm.interval(confidence, loc=prediction, scale=std) + + # 确保区间值为正数且合理 + lower = max(1000, interval[0]) # 最小值设为1000元 + upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20% + + logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]") + + return [lower, upper] + + except Exception as e: + logger.error(f"Error calculating confidence interval: {str(e)}") + # 如果计算失败,返回基于20%的简单区间 + lower = max(1000, prediction * 0.8) + upper = prediction * 1.2 + return [lower, upper] + + def get_model_type(self): + """ + 获取当前模型的类型 + """ + if isinstance(self.best_model, xgb.XGBRegressor): + return 'xgboost' + elif isinstance(self.best_model, lgb.LGBMRegressor): + return 'lightgbm' + elif isinstance(self.best_model, GradientBoostingRegressor): + return 'gbm' + elif isinstance(self.best_model, RandomForestRegressor): + return 'rf' + else: + return 'unknown' + + def _get_pls_feature_importance(self): + """ + 获取 PLS 模型的特征重要性 + """ + try: + if not self.models['pls']: + return {} + + # 获取特征名称 + feature_analyzer = FeatureAnalysis() + feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type) + + # 获取 PLS 模型的系数作为特征重要性 + pls_model = self.models['pls'] + if hasattr(pls_model, 'coef_'): + # 使用绝对值作为重要性指标 + importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组 + else: + return {} + + # 创建特征重要性字典 + importance_dict = {} + for name, importance in zip(feature_names, importances): + importance_dict[name] = float(importance) # 确保转换为 Python 标量 + + # 按重要性降序排序 + sorted_dict = dict(sorted( + importance_dict.items(), + key=lambda x: x[1], + reverse=True + )) + + # 过滤掉重要性为0的特征 + return {k: v for k, v in sorted_dict.items() if v > 0} + + except Exception as e: + logger.error(f"Error getting PLS feature importance: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + return {} \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/routes.py b/deploy/equipment_cost_system/src/routes.py new file mode 100644 index 0000000..df4e825 --- /dev/null +++ b/deploy/equipment_cost_system/src/routes.py @@ -0,0 +1,1254 @@ +from flask import Blueprint, request, jsonify, send_file +from .cost_prediction import CostPredictor +from .feature_analysis import FeatureAnalysis +import pandas as pd +from datetime import datetime +import numpy as np +import mysql.connector +from sklearn.metrics import mean_absolute_error +from .create_template import create_excel_template +import json +import os +from .data_preparation import DataPreparation +from .model_trainer import ModelTrainer +from .logger import setup_logger + +# 创建蓝图 +api_bp = Blueprint('api', __name__) + +# 获取logger +logger = setup_logger(__name__) + +@api_bp.route('/', methods=['GET']) +def index(): + """ + API根路由 + 返回API版本信息和可用端点列表 + """ + return jsonify({ + 'name': '装备成本估算系统 API', + 'version': '1.0.0', + 'endpoints': { + 'predict': { + 'url': '/api/predict', + 'method': 'POST', + 'description': '成本预测' + }, + 'analyze-features': { + 'url': '/api/analyze-features', + 'method': 'POST', + 'description': '特征分析' + }, + 'train': { + 'url': '/api/train', + 'method': 'POST', + 'description': '模型训练' + }, + 'evaluate': { + 'url': '/api/evaluate', + 'method': 'POST', + 'description': '模型评估' + } + } + }) + +@api_bp.route('/predict', methods=['POST']) +def predict_cost(): + """ + 成本预测接口 + """ + try: + data = request.get_json() + logger.info(f"Received prediction request for equipment type: {data.get('type')}") + + # 验证装备类型 + if 'type' not in data: + return jsonify({'error': 'Equipment type is required'}), 400 + + # 预测成本 + predictor = CostPredictor() + result = predictor.predict(data) + + # 获取当前使用的模型信息 + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute(""" + SELECT model_type, model_name, r2_score, mae, rmse + FROM trained_models + WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE + LIMIT 1 + """, (data['type'],)) + model_info = cursor.fetchone() + + # 在结果中添加模型信息 + result.update({ + 'model_info': { + 'type': model_info['model_type'], + 'name': model_info['model_name'], + 'r2_score': float(model_info['r2_score']), + 'mae': float(model_info['mae']), + 'rmse': float(model_info['rmse']) + } + }) + + logger.info(f"Prediction completed: {result}") + return jsonify(result) + + except Exception as e: + logger.error(f"Error in prediction: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/analyze-features', methods=['POST']) +def analyze_features(): + """ + 基于数据集进行特征分析 + """ + try: + data = request.get_json() + dataset_id = data.get('dataset_id') + + logger.info(f"Starting feature analysis for dataset {dataset_id}") + + if not dataset_id: + logger.warning("No dataset_id provided") + return jsonify({'error': '请选择数据集'}), 400 + + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + + # 获取数据集信息 + cursor.execute(""" + SELECT d.*, + e.type as equipment_type + FROM datasets d + JOIN dataset_equipment de ON d.id = de.dataset_id + JOIN equipment e ON de.equipment_id = e.id + WHERE d.id = %s + LIMIT 1 + """, (dataset_id,)) + dataset = cursor.fetchone() + + if not dataset: + logger.warning(f"Dataset {dataset_id} not found") + return jsonify({'error': '数据集不存在'}), 404 + + logger.info(f"Dataset info: {dataset}") + + # 创建特征分析实例 + from src.feature_analysis import FeatureAnalysis + analyzer = FeatureAnalysis() + + # 获取特征列表 + feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type']) + logger.info(f"Feature names: {feature_names}") + + # 获取数据集中的装备数据 + if dataset['equipment_type'] == '火箭炮': + cursor.execute(""" + SELECT e.*, cp.*, rap.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + AND cd.actual_cost IS NOT NULL + """, (dataset_id,)) + else: + cursor.execute(""" + SELECT e.*, cp.*, lmp.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + AND cd.actual_cost IS NOT NULL + """, (dataset_id,)) + + equipment_data = cursor.fetchall() + logger.info(f"Found {len(equipment_data)} equipment records") + + if not equipment_data: + logger.warning("No valid equipment data found in dataset") + return jsonify({'error': '数据集没有有效的成本数据'}), 400 + + # 统计每个特征的缺失率 + missing_rates = {} + for name in feature_names: + missing_count = sum(1 for item in equipment_data if item.get(name) is None) + missing_rate = missing_count / len(equipment_data) + missing_rates[name] = missing_rate + logger.info(f"Feature {name} missing rate: {missing_rate:.2%}") + + # 过滤掉缺失率过高的特征 + valid_features = [name for name in feature_names if missing_rates[name] < 0.7] + logger.info(f"Valid features after filtering: {valid_features}") + + if len(valid_features) < 3: # 至少需要3个特征 + return jsonify({'error': '有效特征数量不足'}), 400 + + # 计算每个特征的均值 + feature_means = {} + for name in valid_features: + values = [float(item[name]) for item in equipment_data if item.get(name) is not None] + feature_means[name] = sum(values) / len(values) if values else 0 + logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}") + + # 准备特征和目标值 + features = [] + target = [] + + # 提取特征和目标值,使用均值填充缺失值 + for item in equipment_data: + feature_values = [] + for name in valid_features: + value = item.get(name) + try: + # 确保数值类型转换正确 + feature_values.append(float(value) if value is not None else feature_means[name]) + except (ValueError, TypeError) as e: + logger.error(f"Error converting value for feature {name}: {value}") + logger.error(f"Error details: {str(e)}") + return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400 + features.append(feature_values) + + # 确保成本值是值类型 + try: + target.append(float(item['actual_cost'])) + except (ValueError, TypeError) as e: + logger.error(f"Error converting actual_cost: {item['actual_cost']}") + logger.error(f"Error details: {str(e)}") + return jsonify({'error': '成本值无法换为数值'}), 400 + + logger.info(f"Prepared {len(features)} feature vectors") + logger.info(f"First feature vector: {features[0] if features else None}") + logger.info(f"First target value: {target[0] if target else None}") + + # 调用特征分析方法 + result = analyzer.analyze_features(features, target, valid_features) + logger.info("Analysis completed successfully") + + return jsonify(result) + + except Exception as e: + logger.error(f"Error analyzing features: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/train', methods=['POST']) +def train_model(): + """ + 训练模型 + """ + try: + data = request.get_json() + logger.info(f"Starting model training for {data.get('type')}") + equipment_type = data.get('type') + train_dataset_id = data.get('train_dataset_id') + validation_dataset_id = data.get('validation_dataset_id') + models = data.get('models', []) + + logger.info(f"Training dataset: {train_dataset_id}") + logger.info(f"Validation dataset: {validation_dataset_id}") + logger.info(f"Selected models: {models}") + + # 获取训练数据 + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + + # 获取训练集数据 + if equipment_type == '火箭炮': + cursor.execute(""" + SELECT e.*, cp.*, rap.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + AND cd.actual_cost IS NOT NULL + """, (train_dataset_id,)) + else: + cursor.execute(""" + SELECT e.*, cp.*, lmp.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + AND cd.actual_cost IS NOT NULL + """, (train_dataset_id,)) + + train_data = cursor.fetchall() + + # 获取验证集数据(如果有) + validation_data = None + if validation_dataset_id: + if equipment_type == '火箭炮': + cursor.execute(""" + SELECT e.*, cp.*, rap.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + AND cd.actual_cost IS NOT NULL + """, (validation_dataset_id,)) + else: + cursor.execute(""" + SELECT e.*, cp.*, lmp.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + AND cd.actual_cost IS NOT NULL + """, (validation_dataset_id,)) + validation_data = cursor.fetchall() + + if not train_data: + return jsonify({'error': '训练数据集为空'}), 400 + + # 1. 准备数据 + data_processor = DataPreparation() + + # 准备训练数据 + train_prepared = data_processor.prepare_training_data(train_data, equipment_type) + + # 准备验证数据(如果有) + validation_prepared = None + if validation_data: + validation_prepared = data_processor.prepare_validation_data( + validation_data, + equipment_type, + train_prepared['feature_names'], + { + 'feature_scaler': train_prepared['feature_scaler'], + 'target_scaler': train_prepared['target_scaler'] + } + ) + + # 2. 训练模型 + model_trainer = ModelTrainer() + model_trainer.feature_scaler = train_prepared['feature_scaler'] + model_trainer.target_scaler = train_prepared['target_scaler'] + + # 执行训练,传入 equipment_type 参数 + training_result = model_trainer.fit_model( + train_prepared['X'], + train_prepared['y'], + models, + validation_prepared['X'] if validation_prepared else None, + validation_prepared['y'] if validation_prepared else None, + equipment_type=equipment_type + ) + + return jsonify(training_result) + + except Exception as e: + logger.error(f"Error in model training: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/evaluate', methods=['POST']) +def evaluate_model(): + """ + 模型评估接口 + """ + try: + data = request.get_json() + logger.info("Received model evaluation request") + + if 'test_data' not in data: + return jsonify({'error': 'Test data is required'}), 400 + + predictor = CostPredictor() + evaluation_result = predictor.evaluate( + data['test_data']['actual'], + data['test_data']['predicted'] + ) + + logger.info("Model evaluation completed") + return jsonify(evaluation_result) + + except Exception as e: + logger.error(f"Error in model evaluation: {str(e)}") + return jsonify({'error': str(e)}), 500 + +def get_required_params(equipment_type): + """ + 根据装备类型获取必要参数 + """ + common_params = [ + 'length_m', + 'width_m', + 'height_m', + 'weight_kg', + 'max_range_km' + ] + + if equipment_type == '火箭炮': + return common_params + [ + 'firing_angle_horizontal', + 'firing_angle_vertical', + 'rocket_length_m', + 'rocket_diameter_mm', + 'rocket_weight_kg' + ] + elif equipment_type == '巡飞弹': + return common_params + [ + 'max_speed_kmh', + 'cruise_speed_kmh', + 'flight_time_min', + 'folded_length_mm', + 'folded_width_mm', + 'folded_height_mm' + ] + + return common_params + +@api_bp.errorhandler(404) +def not_found(error): + return jsonify({'error': 'Not found'}), 404 + +@api_bp.errorhandler(500) +def internal_error(error): + logger.error(f"Internal server error: {str(error)}") + return jsonify({'error': 'Internal server error'}), 500 + +@api_bp.route('/data', methods=['GET']) +def get_equipment_data(): + """ + 获取装备数据 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute('SET SESSION group_concat_max_len = 1000000') + + # 先测试特殊参数查询 + cursor.execute(""" + SELECT equipment_id, param_name, param_value, param_unit + FROM custom_params + WHERE param_name IS NOT NULL + AND param_value IS NOT NULL + LIMIT 5 + """) + test_params = cursor.fetchall() + logger.info(f"Test custom params: {test_params}") + + # 获取火箭炮数据 + logger.info("Fetching rocket artillery data...") + cursor.execute(""" + SELECT + e.id, + e.name, + e.type, + e.manufacturer, + e.created_at, + cp.length_m, + cp.width_m, + cp.height_m, + cp.weight_kg, + cp.max_range_km, + rap.firing_angle_horizontal, + rap.firing_angle_vertical, + rap.rocket_length_m, + rap.rocket_diameter_mm, + rap.rocket_weight_kg, + rap.rate_of_fire, + rap.combat_weight_kg, + rap.speed_kmh, + rap.min_range_km, + rap.mobility_type, + rap.structure_layout, + rap.engine_model, + rap.engine_params, + rap.power_hp, + rap.travel_range_km, + cd.actual_cost, + ( + SELECT COALESCE( + JSON_ARRAYAGG( + JSON_OBJECT( + 'id', csp.id, + 'param_name', csp.param_name, + 'param_value', csp.param_value, + 'param_unit', csp.param_unit, + 'description', csp.description + ) + ), + '[]' + ) + FROM custom_params csp + WHERE csp.equipment_id = e.id + AND csp.param_name IS NOT NULL + AND csp.param_value IS NOT NULL + ) as custom_params + FROM equipment e + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE e.type = '火箭炮' + """) + rocket_artillery = cursor.fetchall() + logger.info(f"Found {len(rocket_artillery)} rocket artillery records") + if rocket_artillery: + logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}") + logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}") + + # 获取巡飞弹数据 + logger.info("Fetching missile data...") + cursor.execute(""" + SELECT + e.id, + e.name, + e.type, + e.manufacturer, + e.created_at, + cp.length_m, + cp.width_m, + cp.height_m, + cp.weight_kg, + cp.max_range_km, + lmp.wingspan_m, + lmp.warhead_weight_kg, + lmp.max_speed_ms, + lmp.cruise_speed_kmh, + lmp.flight_time_min, + lmp.warhead_type, + lmp.launch_mode, + lmp.folded_length_mm, + lmp.folded_width_mm, + lmp.folded_height_mm, + lmp.power_system, + lmp.guidance_system, + cd.actual_cost, + ( + SELECT COALESCE( + JSON_ARRAYAGG( + JSON_OBJECT( + 'id', csp.id, + 'param_name', csp.param_name, + 'param_value', csp.param_value, + 'param_unit', csp.param_unit, + 'description', csp.description + ) + ), + '[]' + ) + FROM custom_params csp + WHERE csp.equipment_id = e.id + AND csp.param_name IS NOT NULL + AND csp.param_value IS NOT NULL + ) as custom_params + FROM equipment e + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE e.type = '巡飞弹' + """) + loitering_munition = cursor.fetchall() + logger.info(f"Found {len(loitering_munition)} missile records") + if loitering_munition: + logger.info(f"First missile: {loitering_munition[0]['name']}") + logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}") + + # 处理 custom_params,保为 NULL + for item in rocket_artillery + loitering_munition: + if item['custom_params'] is None: + item['custom_params'] = [] + logger.debug(f"Set empty custom_params for equipment {item['id']}") + else: + logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params") + + logger.info("Data fetching completed") + return jsonify({ + 'rocket_artillery': rocket_artillery, + 'loitering_munition': loitering_munition + }) + + except Exception as e: + logger.error(f"Error getting equipment data: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/data/', methods=['DELETE']) +def delete_equipment(id): + """ + 删除装备数据 + """ + try: + db = get_db_connection() + cursor = db.cursor() + + # 删除相关数据 + cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,)) + cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,)) + cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,)) + cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,)) + cursor.execute("DELETE FROM equipment WHERE id = %s", (id,)) + + db.commit() + cursor.close() + db.close() + + return jsonify({'status': 'success'}) + + except Exception as e: + logger.error(f"Error deleting equipment: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/data/template', methods=['GET']) +def download_template(): + """ + 下载数据模板 + """ + try: + # 创建模板文件 + from .create_template import create_excel_template + template_path = create_excel_template() + + # 检查文件是否存 + if not os.path.exists(template_path): + raise FileNotFoundError("模板文件不存在") + + # 返回文件 + return send_file( + template_path, + as_attachment=True, + download_name='equipment_data_template.xlsx', + mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + ) + + except Exception as e: + logger.error(f"Error creating template: {str(e)}") + return jsonify({'error': str(e)}), 500 + +def get_db_connection(): + """ + 获取数据库连接 + """ + return mysql.connector.connect( + host="localhost", + user="root", + password="123456", + database="equipment_cost_db" + ) + +@api_bp.route('/pls/predict', methods=['POST']) +def pls_predict(): + """ + PLS回归预测接口 + """ + try: + data = request.get_json() + logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}") + + # 验证装备类型 + if 'type' not in data: + return jsonify({'error': 'Equipment type is required'}), 400 + + # 使用 ModelTrainer 中的 PLS 模型进行预测 + trainer = ModelTrainer() + if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型 + return jsonify({'error': '未找到可用的模型'}), 404 + + # 准备特征数据 + feature_analyzer = FeatureAnalysis() + features = feature_analyzer.get_equipment_specific_features(data['type']) + X = np.array([[data.get(feature) for feature in features]]) + + # 预测 + result = trainer.predict(X) + + # 计算置信区间 + confidence_interval = trainer._calculate_confidence_interval(result[0]) + + # 获取模型信息 + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute(""" + SELECT model_type, model_name, r2_score, mae, rmse + FROM trained_models + WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE + LIMIT 1 + """, (data['type'],)) + model_info = cursor.fetchone() + + # 确保返回的数据可以序列化为JSON + response = { + 'predicted_cost': float(result[0]), + 'model_info': { + 'type': model_info['model_type'], + 'name': model_info['model_name'], + 'r2_score': model_info['r2_score'], + 'mae': model_info['mae'], + 'rmse': model_info['rmse'] + }, + 'confidence_interval': { + 'lower': float(confidence_interval[0]), + 'upper': float(confidence_interval[1]) + } + } + + logger.info(f"PLS prediction completed: {response}") + return jsonify(response) + + except Exception as e: + logger.error(f"Error in PLS prediction: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/data/import', methods=['POST']) +def import_data(): + """ + 导入数据接口 + """ + try: + if 'file' not in request.files: + return jsonify({'error': '没有上传文件'}), 400 + + file = request.files['file'] + if not file.filename.endswith(('.xls', '.xlsx')): + return jsonify({'error': '请上传Excel文件'}), 400 + + # 保存上的文件 + upload_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') + os.makedirs(upload_dir, exist_ok=True) + file_path = os.path.join(upload_dir, file.filename) + file.save(file_path) + + # 导入数据 + from .import_data import import_training_data + import_training_data(file_path) + + return jsonify({ + 'success': True, + 'message': '数据导入成功' + }) + + except Exception as e: + logger.error(f"Error importing data: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/data/', methods=['PUT']) +def update_equipment(id): + """ + 更新装备数据 + """ + try: + data = request.get_json() + logger.info(f"Updating equipment ID: {id}") + logger.info(f"Update data: {data}") + + with get_db_connection() as conn: + cursor = conn.cursor() + + # 更新基本信息 + cursor.execute(""" + UPDATE equipment + SET name = %s, manufacturer = %s + WHERE id = %s + """, (data['name'], data['manufacturer'], id)) + logger.info("Basic info updated") + + # 更新通用参数 + cursor.execute(""" + UPDATE common_params + SET length_m = %s, width_m = %s, height_m = %s, + weight_kg = %s, max_range_km = %s + WHERE equipment_id = %s + """, ( + data['length_m'], data['width_m'], data['height_m'], + data['weight_kg'], data['max_range_km'], id + )) + logger.info("Common params updated") + + # 根据备类型更新特有参数 + if data['type'] == '火箭炮': + cursor.execute(""" + UPDATE rocket_artillery_params + SET firing_angle_horizontal = %s, firing_angle_vertical = %s, + rocket_length_m = %s, rocket_diameter_mm = %s, + rocket_weight_kg = %s, rate_of_fire = %s + WHERE equipment_id = %s + """, ( + data['firing_angle_horizontal'], data['firing_angle_vertical'], + data['rocket_length_m'], data['rocket_diameter_mm'], + data['rocket_weight_kg'], data['rate_of_fire'], id + )) + logger.info("Rocket artillery params updated") + else: + cursor.execute(""" + UPDATE loitering_munition_params + SET max_speed_ms = %s, cruise_speed_kmh = %s, + flight_time_min = %s, warhead_type = %s, + launch_mode = %s, folded_length_mm = %s, + folded_width_mm = %s, folded_height_mm = %s + WHERE equipment_id = %s + """, ( + data['max_speed_ms'], data['cruise_speed_kmh'], + data['flight_time_min'], data['warhead_type'], + data['launch_mode'], data['folded_length_mm'], + data['folded_width_mm'], data['folded_height_mm'], id + )) + logger.info("Missile params updated") + + # 更新成本数据 + if 'actual_cost' in data: + cursor.execute(""" + UPDATE cost_data + SET actual_cost = %s + WHERE equipment_id = %s + """, (data['actual_cost'], id)) + logger.info("Cost data updated") + + # 更新特殊参数 + if 'custom_params' in data and data['custom_params']: + logger.info(f"Updating custom params: {data['custom_params']}") + for param in data['custom_params']: + cursor.execute(""" + UPDATE custom_params + SET param_value = %s + WHERE id = %s AND equipment_id = %s + """, (param['param_value'], param['id'], id)) + logger.info("Custom params updated") + + conn.commit() + logger.info("All updates committed successfully") + + return jsonify({'success': True}) + + except Exception as e: + logger.error(f"Error updating equipment: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/data/details/', methods=['GET']) +def get_equipment_details(id): + """ + 获取装备详数据 + """ + try: + logger.info(f"Getting details for equipment ID: {id}") + + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + + # 先获取装备类型 + cursor.execute("SELECT type FROM equipment WHERE id = %s", (id,)) + equipment = cursor.fetchone() + + if not equipment: + logger.warning(f"Equipment not found: {id}") + return jsonify({'error': 'Equipment not found'}), 404 + + equipment_type = equipment['type'] + logger.info(f"Equipment type: {equipment_type}") + + # 根据装备类型选择查询 + if equipment_type == '火箭炮': + query = """ + SELECT + e.*, + cp.*, + rap.*, + cd.actual_cost, + cd.prediction_date as cost_estimate_date, + cd.predicted_cost, + ( + SELECT JSON_ARRAYAGG( + CASE + WHEN csp.id IS NOT NULL THEN + JSON_OBJECT( + 'id', csp.id, + 'param_name', csp.param_name, + 'param_value', csp.param_value, + 'param_unit', csp.param_unit, + 'description', csp.description + ) + END + ) + FROM custom_params csp + WHERE csp.equipment_id = e.id + AND csp.param_name IS NOT NULL + AND csp.param_value IS NOT NULL + ) as custom_params + FROM equipment e + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE e.id = %s + """ + else: + query = """ + SELECT + e.*, + cp.*, + lmp.*, + cd.actual_cost, + cd.prediction_date as cost_estimate_date, + cd.predicted_cost, + ( + SELECT JSON_ARRAYAGG( + CASE + WHEN csp.id IS NOT NULL THEN + JSON_OBJECT( + 'id', csp.id, + 'param_name', csp.param_name, + 'param_value', csp.param_value, + 'param_unit', csp.param_unit, + 'description', csp.description + ) + END + ) + FROM custom_params csp + WHERE csp.equipment_id = e.id + AND csp.param_name IS NOT NULL + AND csp.param_value IS NOT NULL + ) as custom_params + FROM equipment e + LEFT JOIN common_params cp ON e.id = cp.equipment_id + LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE e.id = %s + """ + + cursor.execute(query, (id,)) + result = cursor.fetchone() + + if result: + logger.info(f"Found equipment details: {result['name']}") + logger.info(f"Custom params: {result.get('custom_params')}") + + return jsonify(result) + + except Exception as e: + logger.error(f"Error getting equipment details: {str(e)}") + return jsonify({'error': str(e)}), 500 + +# 添加数据集相关的路由 +@api_bp.route('/datasets', methods=['GET']) +def get_datasets(): + """ + 获取数据集列表 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute(""" + SELECT d.*, + COUNT(de.equipment_id) as equipment_count, + GROUP_CONCAT(e.name) as equipment_names + FROM datasets d + LEFT JOIN dataset_equipment de ON d.id = de.dataset_id + LEFT JOIN equipment e ON de.equipment_id = e.id + GROUP BY d.id + """) + datasets = cursor.fetchall() + + # 理装备名称列表 + for dataset in datasets: + if dataset['equipment_names']: + dataset['equipment_names'] = dataset['equipment_names'].split(',') + else: + dataset['equipment_names'] = [] + + return jsonify(datasets) + except Exception as e: + logger.error(f"Error getting datasets: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/datasets/', methods=['GET']) +def get_dataset(id): + """ + 获取数据集详情 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + # 获取数据集基本信息 + cursor.execute(""" + SELECT d.*, + COUNT(de.equipment_id) as equipment_count + FROM datasets d + LEFT JOIN dataset_equipment de ON d.id = de.dataset_id + WHERE d.id = %s + GROUP BY d.id + """, (id,)) + dataset = cursor.fetchone() + + if not dataset: + return jsonify({'error': 'Dataset not found'}), 404 + + # 获取数据集中的装备 + cursor.execute(""" + SELECT e.*, cd.actual_cost + FROM equipment e + JOIN dataset_equipment de ON e.id = de.equipment_id + LEFT JOIN cost_data cd ON e.id = cd.equipment_id + WHERE de.dataset_id = %s + """, (id,)) + equipment = cursor.fetchall() + + # 计算统计信息 + if equipment: + total_cost = sum(item['actual_cost'] or 0 for item in equipment) + avg_cost = total_cost / len(equipment) + dataset['statistics'] = { + 'equipment_count': len(equipment), + 'total_cost': total_cost, + 'average_cost': avg_cost + } + else: + dataset['statistics'] = { + 'equipment_count': 0, + 'total_cost': 0, + 'average_cost': 0 + } + + dataset['equipment'] = equipment + return jsonify(dataset) + except Exception as e: + logger.error(f"Error getting dataset: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/datasets', methods=['POST']) +def create_dataset(): + """ + 建数据集 + """ + try: + data = request.get_json() + with get_db_connection() as conn: + cursor = conn.cursor() + + # 创建数据集 + cursor.execute(""" + INSERT INTO datasets (name, description, equipment_type, purpose) + VALUES (%s, %s, %s, %s) + """, (data['name'], data['description'], data['equipment_type'], data['purpose'])) + + dataset_id = cursor.lastrowid + + # 添加装备关联 + if 'equipment_ids' in data and data['equipment_ids']: + values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']] + cursor.executemany(""" + INSERT INTO dataset_equipment (dataset_id, equipment_id) + VALUES (%s, %s) + """, values) + + conn.commit() + return jsonify({'id': dataset_id, 'message': '数据集创建成功'}) + except Exception as e: + logger.error(f"Error creating dataset: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/datasets/', methods=['PUT']) +def update_dataset(id): + """ + 更新数据集 + """ + try: + data = request.get_json() + with get_db_connection() as conn: + cursor = conn.cursor() + + # 更新数据集基本信息 + cursor.execute(""" + UPDATE datasets + SET name = %s, description = %s, equipment_type = %s, purpose = %s + WHERE id = %s + """, (data['name'], data['description'], data['equipment_type'], data['purpose'], id)) + + # 删除旧的装备关联 + cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) + + # 加新的装备关联 + if 'equipment_ids' in data: + for equipment_id in data['equipment_ids']: + cursor.execute(""" + INSERT INTO dataset_equipment (dataset_id, equipment_id) + VALUES (%s, %s) + """, (id, equipment_id)) + + conn.commit() + return jsonify({'success': True}) + except Exception as e: + logger.error(f"Error updating dataset: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/datasets/', methods=['DELETE']) +def delete_dataset(id): + """ + 删除数据集 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor() + + # 删除装备关联 + cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) + + # 删除数据集 + cursor.execute("DELETE FROM datasets WHERE id = %s", (id,)) + + conn.commit() + return jsonify({'success': True}) + except Exception as e: + logger.error(f"Error deleting dataset: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/models//latest', methods=['GET']) +def get_latest_model(equipment_type): + """ + 获取最新训练的型信息 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute(""" + SELECT * FROM trained_models + WHERE equipment_type = %s AND is_active = TRUE + ORDER BY training_date DESC LIMIT 1 + """, (equipment_type,)) + + model = cursor.fetchone() + return jsonify(model) + + except Exception as e: + logger.error(f"Error getting latest model: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/models', methods=['GET']) +def get_models(): + """ + 获取模型列表 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute(""" + SELECT * FROM trained_models + ORDER BY training_date DESC + """) + + models = cursor.fetchall() + + # 确保数值类型字段是 float + for model in models: + if model['r2_score'] is not None: + model['r2_score'] = float(model['r2_score']) + if model['mae'] is not None: + model['mae'] = float(model['mae']) + if model['rmse'] is not None: + model['rmse'] = float(model['rmse']) + + # 解析特征重要性 + if model['feature_importance']: + model['feature_importance'] = json.loads(model['feature_importance']) + + return jsonify(models) + + except Exception as e: + logger.error(f"Error getting models: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/models//activate', methods=['POST']) +def activate_model(id): + """ + 激活指定的模型 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor() + + # 获取模型信息 + cursor.execute(""" + SELECT equipment_type FROM trained_models + WHERE id = %s + """, (id,)) + model = cursor.fetchone() + + if not model: + return jsonify({'error': 'Model not found'}), 404 + + # 将同类型的其他模型设置为非激活 + cursor.execute(""" + UPDATE trained_models + SET is_active = FALSE + WHERE equipment_type = %s + """, (model[0],)) + + # 激活指定模型 + cursor.execute(""" + UPDATE trained_models + SET is_active = TRUE + WHERE id = %s + """, (id,)) + + conn.commit() + return jsonify({'success': True}) + + except Exception as e: + logger.error(f"Error activating model: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/models/', methods=['DELETE']) +def delete_model(id): + """ + 删除指定的模型 + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor() + + # 获取模型文件路径 + cursor.execute(""" + SELECT model_path, scaler_path + FROM trained_models + WHERE id = %s + """, (id,)) + model = cursor.fetchone() + + if not model: + return jsonify({'error': 'Model not found'}), 404 + + # 删除模型文件 + if os.path.exists(model[0]): + os.remove(model[0]) + if os.path.exists(model[1]): + os.remove(model[1]) + + # 删除数据库记录 + cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,)) + conn.commit() + + return jsonify({'success': True}) + + except Exception as e: + logger.error(f"Error deleting model: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/predict/all', methods=['POST']) +def predict_all(): + """ + 获取所有机器学习模型的预测结果 + """ + try: + data = request.get_json() + logger.info(f"Received prediction request for all models, equipment type: {data.get('type')}") + + predictor = CostPredictor() + results = predictor.predict_all(data) + + return jsonify(results) + + except Exception as e: + logger.error(f"Error in prediction: {str(e)}") + return jsonify({'error': str(e)}), 500 \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/schema.sql b/deploy/equipment_cost_system/src/schema.sql new file mode 100644 index 0000000..a0abd10 --- /dev/null +++ b/deploy/equipment_cost_system/src/schema.sql @@ -0,0 +1,140 @@ +-- 如果数据库已存在则删除 +DROP DATABASE IF EXISTS equipment_cost_db; + +-- 创建数据库 +CREATE DATABASE equipment_cost_db +DEFAULT CHARACTER SET utf8mb4 +COLLATE utf8mb4_unicode_ci; + +-- 使用数据库 +USE equipment_cost_db; + +-- 装备基本信息表 +CREATE TABLE equipment ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100), -- 名称 + type VARCHAR(50), -- 类型(火箭炮/巡飞弹) + manufacturer VARCHAR(100), -- 制造商 + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 通用参数表 +CREATE TABLE common_params ( + id INT AUTO_INCREMENT PRIMARY KEY, + equipment_id INT, + length_m FLOAT, -- 总长(m) + width_m FLOAT, -- 宽度(m) + height_m FLOAT, -- 高度(m) + weight_kg FLOAT, -- 重量(kg) + max_range_km FLOAT, -- 最大射程(km) + FOREIGN KEY (equipment_id) REFERENCES equipment(id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 火箭炮特有参数表 +CREATE TABLE rocket_artillery_params ( + id INT AUTO_INCREMENT PRIMARY KEY, + equipment_id INT, + firing_angle_horizontal FLOAT, -- 方向射界(度) + firing_angle_vertical FLOAT, -- 高低射界(度) + rocket_length_m FLOAT, -- 火箭弹长度(m) + rocket_diameter_mm FLOAT, -- 弹体直径(mm) + rocket_weight_kg FLOAT, -- 火箭弹重量(kg) + rate_of_fire FLOAT, -- 射速(发/分钟) + combat_weight_kg FLOAT, -- 战斗重量(kg) + speed_kmh FLOAT, -- 速度(km/h) + min_range_km FLOAT, -- 最小射程(km) + mobility_type VARCHAR(50), -- 行走方式 + structure_layout VARCHAR(100), -- 结构布局 + engine_model VARCHAR(100), -- 发动机型号 + engine_params TEXT, -- 发动机参数 + power_hp FLOAT, -- 功率(hp) + travel_range_km FLOAT, -- 行程(km) + FOREIGN KEY (equipment_id) REFERENCES equipment(id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 巡飞弹特有参数表 +CREATE TABLE loitering_munition_params ( + id INT AUTO_INCREMENT PRIMARY KEY, + equipment_id INT, + wingspan_m FLOAT, -- 翼展(m) + warhead_weight_kg FLOAT, -- 战斗部重量(kg) + max_speed_ms FLOAT, -- 最大速度(m/s) + cruise_speed_kmh FLOAT, -- 巡航速度(km/h) + flight_time_min FLOAT, -- 巡飞时间(min) + warhead_type VARCHAR(50), -- 战斗部类型 + launch_mode VARCHAR(50), -- 发射方式 + folded_length_mm FLOAT, -- 折叠长度(mm) + folded_width_mm FLOAT, -- 折叠宽度(mm) + folded_height_mm FLOAT, -- 折叠高度(mm) + power_system VARCHAR(100), -- 动力装置 + guidance_system VARCHAR(100), -- 制导体制 + FOREIGN KEY (equipment_id) REFERENCES equipment(id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 成本数据表 +CREATE TABLE cost_data ( + id INT AUTO_INCREMENT PRIMARY KEY, + equipment_id INT, + actual_cost DECIMAL(15,2), -- 实际成本(元) + predicted_cost DECIMAL(15,2), -- 预测成本(元) + prediction_date TIMESTAMP, -- 预测日期 + FOREIGN KEY (equipment_id) REFERENCES equipment(id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 特殊参数表 +CREATE TABLE custom_params ( + id INT AUTO_INCREMENT PRIMARY KEY, + equipment_id INT, + param_name VARCHAR(100), -- 参数名称 + param_value VARCHAR(255), -- 参数值 + param_unit VARCHAR(50), -- 参数单位 + description TEXT, -- 参数说明 + FOREIGN KEY (equipment_id) REFERENCES equipment(id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 添加索引 +CREATE INDEX idx_equipment_type ON equipment(type); +CREATE INDEX idx_equipment_name ON equipment(name); +CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id); + +-- 数据集表 +CREATE TABLE datasets ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, -- 数据集名称 + description TEXT, -- 数据集描述 + equipment_type VARCHAR(50) NOT NULL, -- 装备类型 + purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证) + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +); + +-- 数据集-装备关联表 +CREATE TABLE dataset_equipment ( + dataset_id INT NOT NULL, + equipment_id INT NOT NULL, + PRIMARY KEY (dataset_id, equipment_id), + FOREIGN KEY (dataset_id) REFERENCES datasets(id), + FOREIGN KEY (equipment_id) REFERENCES equipment(id) +); + +-- 训练模型表 +CREATE TABLE trained_models ( + id INT AUTO_INCREMENT PRIMARY KEY, + model_name VARCHAR(100) NOT NULL, -- 模型名称 + model_type VARCHAR(50) NOT NULL, -- 模型类型 + equipment_type VARCHAR(50) NOT NULL, -- 装备类型 + model_path VARCHAR(255) NOT NULL, -- 模型文件路径 + scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径 + r2_score FLOAT, -- R²分数 + mae FLOAT, -- 平均绝对误差 + rmse FLOAT, -- 均方根误差 + feature_importance JSON, -- 特征重要性 + training_data_size INT, -- 训练数据量 + training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间 + is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型 + created_by VARCHAR(50) -- 创建者 +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 添加索引 +CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type); +CREATE INDEX idx_model_active ON trained_models(is_active); \ No newline at end of file diff --git a/deploy/equipment_cost_system/src/test_api.py b/deploy/equipment_cost_system/src/test_api.py new file mode 100644 index 0000000..f68d88d --- /dev/null +++ b/deploy/equipment_cost_system/src/test_api.py @@ -0,0 +1,191 @@ +import requests +import json +import logging +from datetime import datetime +import os +import sys + +# 添加项目根目录到 Python 路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# 配置基本日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('logs/test_api.log'), + logging.StreamHandler() + ] +) + +logger = logging.getLogger(__name__) + +def test_api_endpoints(): + """ + 测试 API 各个端点 + """ + base_url = 'http://localhost:5001/api' + + try: + # 1. 测试根路由 + logger.info("\n1. 测试 API 根路由") + response = requests.get(f'{base_url}/') + print_response(response, "API 根路由") + + # 2. 测试机器学习预测接口 + logger.info("\n2. 测试机器学习预测接口") + predict_data = { + "type": "巡飞弹", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "folded_length_mm": 1300, + "folded_width_mm": 230, + "folded_height_mm": 230, + "warhead_type": "破片杀伤战斗部", + "launch_mode": "凭自身动力起飞" + } + + response = requests.post( + f'{base_url}/predict', + json=predict_data + ) + print_response(response, "机器学习预测") + + # 3. 测试 PLS 预测接口 + logger.info("\n3. 测试 PLS 预测接口") + response = requests.post( + f'{base_url}/pls/predict', + json=predict_data + ) + print_response(response, "PLS 预测") + + # 4. 测试特征分析接口 + logger.info("\n4. 测试特征分析接口") + analysis_data = { + "dataset_id": 1, # 假设数据集 ID 为 1 + "equipment_type": "巡飞弹" + } + + response = requests.post( + f'{base_url}/analyze-features', + json=analysis_data + ) + print_response(response, "特征分析") + + # 5. 测试机器学习模型训练接口 + logger.info("\n5. 测试机器学习模型训练接口") + training_data = { + "type": "巡飞弹", + "train_dataset_id": 1, # 训练数据集 ID + "validation_dataset_id": 2, # 验证数据集 ID + "models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型 + } + + response = requests.post( + f'{base_url}/train', + json=training_data + ) + print_response(response, "模型训练") + + # 6. 测试数据集相关接口 + logger.info("\n6. 测试数据集相关接口") + + # 6.1 获取数据集列表 + response = requests.get(f'{base_url}/datasets') + print_response(response, "获取数据集列表") + + # 6.2 获取可用的装备列表 + response = requests.get(f'{base_url}/data') + equipment_data = response.json() + + # 获取巡飞弹类型的装备ID + available_equipment_ids = [] + if 'loitering_munition' in equipment_data: + available_equipment_ids = [ + item['id'] + for item in equipment_data['loitering_munition'] + if item['id'] is not None + ][:3] # 取前3个可用的ID + + if not available_equipment_ids: + logger.warning("没有找到可用的装备ID,跳过创建数据集测试") + else: + # 6.3 创建新数据集 + new_dataset = { + "name": "测试数据集", + "description": "用于测试的数据集", + "equipment_type": "巡飞弹", + "purpose": "训练", + "equipment_ids": available_equipment_ids + } + + logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}") + + response = requests.post( + f'{base_url}/datasets', + json=new_dataset + ) + print_response(response, "创建数据集") + + # 7. 测试模型相关接口 + logger.info("\n7. 测试模型相关接口") + + # 7.1 获取模型列表 + response = requests.get(f'{base_url}/models') + print_response(response, "获取模型列表") + + # 7.2 获取最新模型 + response = requests.get(f'{base_url}/models/巡飞弹/latest') + print_response(response, "获取最新模型") + + # 8. 测试多模型预测接口 + logger.info("\n8. 测试多模型预测接口") + response = requests.post( + f'{base_url}/predict/all', + json=predict_data + ) + print_response(response, "多模型预测") + + logger.info("所有测试完成") + + except requests.exceptions.RequestException as e: + logger.error(f"API 请求错误: {str(e)}") + except Exception as e: + logger.error(f"测试过程中出现错误: {str(e)}") + +def print_response(response, endpoint_name): + """ + 打印响应结果 + """ + try: + logger.info(f"\n=== {endpoint_name} 测试结果 ===") + logger.info(f"状态码: {response.status_code}") + + if response.status_code == 200: + result = response.json() + logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}") + else: + logger.error(f"错误响应:\n{response.text}") + + except Exception as e: + logger.error(f"处理响应时出错: {str(e)}") + +if __name__ == "__main__": + try: + # 确保日志目录存在 + os.makedirs('logs', exist_ok=True) + + logger.info(f"=== API 测试开始 - {datetime.now()} ===") + test_api_endpoints() + logger.info(f"=== API 测试结束 - {datetime.now()} ===") + except Exception as e: + logger.error(f"测试执行失败: {str(e)}") \ No newline at end of file diff --git a/deploy/equipment_cost_system/venv b/deploy/equipment_cost_system/venv new file mode 100644 index 0000000..e69de29 diff --git a/deploy/setup.md b/deploy/setup.md new file mode 100644 index 0000000..b905ab0 --- /dev/null +++ b/deploy/setup.md @@ -0,0 +1,34 @@ +# 安装说明 + +## 1. 系统要求 + +- Linux 服务器 (推荐 Ubuntu 20.04+) +- Python 3.8+ +- MySQL 8.0+ + +## 2. 安装部署 + +```bash +# 解压部署包 +tar -xzf equipment_cost_prediction.tar.gz +cd equipment_cost_prediction + +# 运行安装脚本 +bash scripts/install.sh + +# 修改配置文件 +vim config/.env + +# 启动服务 +bash scripts/start.sh +``` + +## 3. 验证部署 + +```bash +# 检查服务状态 +curl http://localhost:5001/api/ + +# 检查日志 +tail -f logs/api.log +``` diff --git a/docs/deploy/api.md b/docs/deploy/api.md new file mode 100644 index 0000000..000bd35 --- /dev/null +++ b/docs/deploy/api.md @@ -0,0 +1,663 @@ +# 装备成本估算系统 API 文档 + +这个 API 文档提供了完整的接口说明,包括: + +- 每个端点的详细描述 +- 请求和响应的具体示例 +- 清晰的参数格式要求 +- 统一的错误处理说明 +- 重要的注意事项 + +文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。 + +## 基本信息 + +- 基础URL: `http://localhost:5001/api` +- 版本: 1.0.0 +- 响应格式: JSON + +## API 端点列表 + +### 1. 获取 API 信息 + +获取 API 版本信息和可用端点列表。 + +- **URL**: `/` +- **方法**: `GET` +- **响应示例**: +json +{ +"name": "装备成本估算系统 API", +"version": "1.0.0", +"endpoints": { +"predict": { +"url": "/api/predict", +"method": "POST", +"description": "成本预测" +}, +"analyze-features": { +"url": "/api/analyze-features", +"method": "POST", +"description": "特征分析" +}, +"train": { +"url": "/api/train", +"method": "POST", +"description": "模型训练" +}, +"evaluate": { +"url": "/api/evaluate", +"method": "POST", +"description": "模型评估" +} +} +} + +### 2. 单模型预测 + +使用当前激活的最优模型进行成本预测。 + +- **URL**: `/predict` +- **方法**: `POST` +- **请求体示例** (巡飞弹): + +```json +{ + "type": "巡飞弹", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "folded_length_mm": 1300, + "folded_width_mm": 230, + "folded_height_mm": 230, + "warhead_type": "破片杀伤战斗部", + "launch_mode": "凭自身动力起飞" +} +``` + +- **响应示例**: + +```json +{ + "predicted_cost": 150000.0, + "model_info": { + "type": "xgboost", + "name": "巡飞弹_20241111_model", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0 + }, + "confidence_interval": { + "lower": 135000.0, + "upper": 165000.0 + } +} +``` + +### 3. PLS 模型预测 + +使用 PLS 回归模型进行预测。 + +- **URL**: `/pls/predict` +- **方法**: `POST` +- **请求体**: 与单模型预测相同 +- **响应示例**: + +```json +{ + "predicted_cost": 148000.0, + "confidence_interval": { + "lower": 133000.0, + "upper": 163000.0 + } +} +``` + +### 4. 多模型预测 + +使用所有激活的模型进行预测并返回综合结果。 + +- **URL**: `/predict/all` +- **方法**: `POST` +- **请求体**: 与单模型预测相同 +- **响应示例**: + +```json +{ + "individual_predictions": { + "xgboost": { + "predicted_cost": 150000.0, + "model_info": { + "name": "巡飞弹_xgboost_model", + "type": "xgboost", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0 + }, + "confidence_interval": { + "lower": 135000.0, + "upper": 165000.0 + } + }, + "pls": { + "predicted_cost": 148000.0, + "model_info": { + "name": "巡飞弹_pls_model", + "type": "pls", + "r2_score": 0.92, + "mae": 5500.0, + "rmse": 8000.0 + }, + "confidence_interval": { + "lower": 133000.0, + "upper": 163000.0 + } + } + }, + "ensemble_prediction": { + "predicted_cost": 149000.0, + "standard_deviation": 1414.21, + "confidence_interval": { + "lower": 146228.15, + "upper": 151771.85 + } + } +} +``` + +### 5. 特征分析 + +分析数据集中特征的重要性和相关性。 + +- **URL**: `/analyze-features` +- **方法**: `POST` +- **请求体示例**: + +```json +{ + "dataset_id": 1, + "equipment_type": "巡飞弹" +} +``` + +- **响应示例**: + +```json +{ + "important_features": [ + { + "name": "最大射程(km)", + "importance": 0.35 + }, + { + "name": "重量(kg)", + "importance": 0.25 + } + ], + "correlation_analysis": { + "features": ["最大射程(km)", "重量(kg)"], + "matrix": [[1.0, 0.8], [0.8, 1.0]] + } +} +``` + +### 6. 模型训练 + +训练新的模型。 + +- **URL**: `/train` +- **方法**: `POST` +- **请求体示例**: + +```json +{ + "type": "巡飞弹", + "train_dataset_id": 1, + "validation_dataset_id": 2, + "models": ["xgboost", "lightgbm", "rf"] +} +``` + +- **响应示例**: + +```json +{ + "metrics": { + "xgboost": { + "train": { + "r2": 0.95, + "mae": 5000.0, + "rmse": 7500.0 + }, + "validation": { + "r2": 0.92, + "mae": 5500.0, + "rmse": 8000.0 + } + } + }, + "best_model": { + "type": "xgboost", + "r2": 0.92, + "mae": 5500.0, + "rmse": 8000.0 + } +} +``` + +### 7. 数据集管理 + +#### 7.1 获取数据集列表 + +- **URL**: `/datasets` +- **方法**: `GET` +- **响应示例**: + +```json +[ + { + "id": 1, + "name": "训练数据集", + "description": "用于训练的数据集", + "equipment_type": "巡飞弹", + "equipment_count": 10, + "equipment_names": ["设备1", "设备2"], + "purpose": "训练", + "created_at": "2024-11-11T10:00:00" + } +] +``` + +#### 7.2 获取数据集详情 + +- **URL**: `/datasets/{id}` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "id": 1, + "name": "训练数据集", + "description": "用于训练的数据集", + "equipment_type": "巡飞弹", + "purpose": "训练", + "created_at": "2024-11-11T10:00:00", + "equipment": [ + { + "id": 1, + "name": "设备1", + "type": "巡飞弹", + "manufacturer": "制造商1", + "actual_cost": 150000 + } + ], + "statistics": { + "equipment_count": 10, + "total_cost": 1500000, + "average_cost": 150000 + } +} +``` + +#### 7.3 创建数据集 + +- **URL**: `/datasets` +- **方法**: `POST` +- **请求体示例**: + +```json +{ + "name": "测试数据集", + "description": "用于测试的数据集", + "equipment_type": "巡飞弹", + "purpose": "训练", + "equipment_ids": [1, 2, 3] +} +``` + +- **响应示例**: + +```json +{ + "id": 2, + "message": "数据集创建成功" +} +``` + +#### 7.4 更新数据集 + +- **URL**: `/datasets/{id}` +- **方法**: `PUT` +- **请求体示例**: + +```json +{ + "name": "更新后的数据集名称", + "description": "更新后的描述", + "equipment_type": "巡飞弹", + "purpose": "验证", + "equipment_ids": [1, 2, 3, 4] +} +``` + +- **响应示例**: + +```json +{ + "success": true, + "message": "数据集更新成功" +} +``` + +#### 7.5 删除数据集 + +- **URL**: `/datasets/{id}` +- **方法**: `DELETE` +- **描述**: 删除指定的数据集及其关联关系 +- **响应示例**: + +```json +{ + "success": true, + "message": "数据集删除成功" +} +``` + +注意事项: + +1. 数据集删除后不会删除关联的装备数据 +2. 不能删除正在被模型使用的数据集 +3. 更新数据集时会重新计算统计信息 +4. 数据集的装备类型一旦创建后不能更改 + +### 8. 模型管理 + +#### 8.1 获取模型列表 + +- **URL**: `/models` +- **方法**: `GET` +- **响应示例**: + +```json +[ + { + "id": 1, + "model_name": "巡飞弹_xgboost_model", + "model_type": "xgboost", + "equipment_type": "巡飞弹", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0, + "is_active": true, + "training_date": "2024-11-11T10:00:00" + } +] +``` + +#### 8.2 获取最新模型 + +- **URL**: `/models/{equipment_type}/latest` +- **方法**: `GET` +- **响应示例**: 与模型列表的单个模型格式相同 + +#### 8.3 获取模型详情 + +- **URL**: `/models/{id}` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "id": 1, + "model_name": "巡飞弹_xgboost_model", + "model_type": "xgboost", + "equipment_type": "巡飞弹", + "r2_score": 0.95, + "mae": 5000.0, + "rmse": 7500.0, + "is_active": true, + "training_date": "2024-11-11T10:00:00", + "feature_importance": { + "max_range_km": 0.35, + "weight_kg": 0.25, + "length_m": 0.20 + }, + "training_data_size": 100, + "created_by": "system" +} +``` + +#### 8.4 激活模型 + +- **URL**: `/models/{id}/activate` +- **方法**: `POST` +- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态 +- **响应示例**: + +```json +{ + "success": true, + "message": "模型已激活" +} +``` + +#### 8.5 删除模型 + +- **URL**: `/models/{id}` +- **方法**: `DELETE` +- **描述**: 删除指定模型,包括模型文件和数据库记录 +- **响应示例**: + +```json +{ + "success": true, + "message": "模型已删除" +} +``` + +注意事项: + +1. 删除模型时会同时删除相关的文件和数据库记录 +2. 不能删除当前正在使用(已激活)的模型 +3. 激活模型时会自动取消同类型其他模型的激活状态 +4. 模型详情包含了更多的训练相关信息,如特征重要性等 + +### 9. 数据管理 + +#### 9.1 获取装备数据列表 + +- **URL**: `/data` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "rocket_artillery": [ + { + "id": 1, + "name": "BM-21", + "type": "火箭炮", + "manufacturer": "俄罗斯", + "length_m": 7.35, + "width_m": 2.4, + "height_m": 3.1, + "weight_kg": 13700, + "max_range_km": 20.4, + "actual_cost": 800000 + } + ], + "loitering_munition": [ + { + "id": 8, + "name": "Hero-120", + "type": "巡飞弹", + "manufacturer": "以色列", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "actual_cost": 150000 + } + ] +} +``` + +#### 9.2 获取装备详情 + +- **URL**: `/data/details/{id}` +- **方法**: `GET` +- **响应示例**: + +```json +{ + "id": 8, + "name": "Hero-120", + "type": "巡飞弹", + "manufacturer": "以色列", + "common_params": { + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40 + }, + "specific_params": { + "wingspan_m": 2.1, + "warhead_weight_kg": 3.5, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "warhead_type": "破片杀伤战斗部", + "launch_mode": "箱式发射", + "power_system": "电动机", + "guidance_system": "GPS/INS" + }, + "cost_data": { + "actual_cost": 150000, + "prediction_date": "2024-11-11T10:00:00", + "predicted_cost": 148000 + }, + "custom_params": [ + { + "id": 1, + "param_name": "续航时间", + "param_value": "2小时", + "param_unit": "小时", + "description": "最大续航时间" + } + ] +} +``` + +#### 9.3 更新装备数据 + +- **URL**: `/data/{id}` +- **方法**: `PUT` +- **请求体示例**: + +```json +{ + "name": "Hero-120", + "manufacturer": "以色列", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40, + "wingspan_m": 2.1, + "warhead_weight_kg": 3.5, + "max_speed_ms": 50, + "cruise_speed_kmh": 100, + "flight_time_min": 60, + "actual_cost": 150000, + "custom_params": [ + { + "id": 1, + "param_value": "2.5小时" + } + ] +} +``` + +- **响应示例**: + +```json +{ + "success": true, + "message": "装备数据更新成功" +} +``` + +#### 9.4 删除装备数据 + +- **URL**: `/data/{id}` +- **方法**: `DELETE` +- **响应示例**: + +```json +{ + "success": true, + "message": "装备数据删除成功" +} +``` + +#### 9.5 下载数据模板 + +- **URL**: `/data/template` +- **方法**: `GET` +- **描述**: 下载Excel格式的数据导入模板 +- **响应**: Excel文件下载 + +#### 9.6 导入数据 + +- **URL**: `/data/import` +- **方法**: `POST` +- **请求体**: + - Content-Type: multipart/form-data + - 参数名: file + - 文件类型: .xlsx 或 .xls +- **响应示例**: + +```json +{ + "success": true, + "message": "数据导入成功", + "imported_count": { + "rocket_artillery": 3, + "loitering_munition": 5 + } +} +``` + +注意事项: + +1. 导入数据时必须使用系统提供的模板 +2. 更新装备数据时会同时更新关联的参数表 +3. 删除装备数据会同时删除相关的参数和成本数据 +4. 导入的Excel文件大小不应超过10MB +5. 所有数值字段必须符合指定的单位和范围要求 +6. 特殊参数的值必须包含单位信息 + +## 错误响应 + +所有接口在发生错误时都会返回以下格式的响应: + +```json +{ + "error": "错误描述信息" +} +``` + +## 注意事项 + +1. 所有数值参数必须大于0 +2. 所有单位必须按照参数名称中指定的单位提供 +3. 预测结果中的成本单位为元 +4. 置信区间表示预测结果的95%置信水平范围 +5. 所有请求和响应的编码均为 UTF-8 diff --git a/docs/deploy/deploy.md b/docs/deploy/deploy.md new file mode 100644 index 0000000..1f52359 --- /dev/null +++ b/docs/deploy/deploy.md @@ -0,0 +1,120 @@ +# 装备成本估算系统部署指南 + +## 一、系统要求 + +### 1. 基础软件 + +- Linux 操作系统 (推荐 Ubuntu 20.04+) +- Python 3.8+ 及相关组件 + + ```bash + sudo apt update + sudo apt install python3 python3-pip python3-venv + sudo apt install python3-dev build-essential + ``` + +- Node.js 14+ 及 npm + + ```bash + # 使用 nvm 安装 Node.js + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash + source ~/.bashrc + nvm install 14 + nvm use 14 + ``` + +### 2. 数据库 + +- MySQL 8.0+ + + ```bash + sudo apt install mysql-server mysql-client + sudo apt install libmysqlclient-dev + ``` + +### 3. Python包依赖 + +```bash +# 科学计算相关 +sudo apt install libatlas-base-dev # numpy依赖 +sudo apt install libopenblas-dev # 线性代数库 +sudo apt install liblapack-dev # 线性代数包 +sudo apt install gfortran # Fortran编译器(scipy依赖) + +# XML处理相关(用于Excel文件处理) +sudo apt install libxml2-dev +sudo apt install libxslt1-dev +``` + +## 二、部署运行 + +### 1. 安装服务 + +```bash +sh scripts/install.sh +``` + +### 2. 启动服务 + +```bash +sh scripts/start.sh +``` + +### 3. 停止服务 + +```bash +sh scripts/stop.sh +``` + +## 三、维护说明 + +### 1. 日志管理 + +```bash +# 后端日志 +tail -f logs/api.log + +# 数据库日志 +tail -f /var/log/mysql/error.log +``` + +## 四、安全建议 + +1. 系统安全 + - 使用防火墙限制端口访问 + - 定期更新系统和依赖包 + +2. 数据安全 + - 定期备份数据库 + - 加密敏感信息 + - 限制数据库远程访问 + +3. 访问控制 + - 使用强密码 + - 配置适当的文件权限 + - 使用非root用户运行服务 + +## 五、监控方案 + +### 1. 系统监控 + +```bash +# 资源使用 +top -b -n 1 +df -h +free -m + +# 服务状态 +ps aux | grep gunicorn +ps aux | grep node +``` + +### 2. 应用监控 + +```bash +# API 响应时间 +curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/" + +# 错误日志 +grep "ERROR" logs/api.log +``` diff --git a/docs/debug.md b/docs/dev/debug.md similarity index 100% rename from docs/debug.md rename to docs/dev/debug.md diff --git a/docs/dev/deployment.md b/docs/dev/deployment.md new file mode 100644 index 0000000..462355e --- /dev/null +++ b/docs/dev/deployment.md @@ -0,0 +1,250 @@ +# 装备成本估算系统部署指南 + +## 一、系统打包 + +### 1. 创建部署包结构 + +```bash +mkdir -p deploy/equipment_cost_system/{backend,frontend,docs} +``` + +### 2. 准备部署文件 + +#### 2.1 后端文件 + +```bash +cd deploy/equipment_cost_system/backend +mkdir -p {src,scripts,config,data,logs,models} +cp -r ../../../src/* src/ +cp ../../../requirements.txt ./ +cp ../../../.env.example config/.env.template +``` + +#### 2.2 前端文件 + +```bash +cd ../frontend +mkdir -p {src,public,dist} +cp -r ../../frontend/src/* src/ +cp -r ../../frontend/public/* public/ +cp ../../../frontend/package.json ./ +cp ../../../frontend/vite.config.js ./ +cp ../../../frontend/.env.production ./ +``` + +#### 2.3 复制文档 + +```bash +cd ../docs +cp -r ../../docs/deploy/* ./ +``` + +#### 2.4 创建部署脚本 + +```bash +touch scripts/{install.sh,start.sh,stop.sh} +``` + +### 3. 部署脚本内容 + +#### 3.1 安装脚本 (install.sh) + +```bash +#!/bin/bash + +# 安装 Python 依赖 +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt + +# 创建必要的目录 +mkdir -p logs +mkdir -p data +mkdir -p models + +# 配置文件 +if [ ! -f config/.env ]; then + cp config/.env.template config/.env + echo "请修改 config/.env 中的配置" +fi + +# 初始化数据库 +read -p "请输入MySQL root密码: " mysqlpass +mysql -u root -p$mysqlpass < src/schema.sql + +# 设置权限 +chmod +x scripts/*.sh +``` + +#### 3.2 启动脚本 (start.sh) + +```bash +#!/bin/bash + +# 激活虚拟环境 +source venv/bin/activate + +# 检查配置文件 +if [ ! -f config/.env ]; then + echo "错误: 配置文件不存在" + exit 1 +fi + +# 启动服务 +export $(cat config/.env | xargs) +gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" --daemon + +echo "服务已启动,访问 http://localhost:5001" +``` + +#### 3.3 停止脚本 (stop.sh) + +```bash +#!/bin/bash + +# 查找并停止 gunicorn 进程 +pkill -f gunicorn + +echo "服务已停止" +``` + +### 4. 创建部署包 + +```bash +cd .. +tar -czf equipment_cost_system.tar.gz equipment_cost_system/ +``` + +## 二、部署步骤 + +### 1. 系统要求 + +- Linux 服务器 (推荐 Ubuntu 20.04+) +- Python 3.8+ +- MySQL 8.0+ + +### 2. 安装部署 + +```bash +# 解压部署包 +tar -xzf equipment_cost_system.tar.gz +cd equipment_cost_system + +# 运行安装脚本 +bash scripts/install.sh + +# 修改配置文件 +vim config/.env + +# 启动服务 +bash scripts/start.sh +``` + +### 3. 验证部署 + +```bash +# 检查服务状态 +curl http://localhost:5001/api/ + +# 检查日志 +tail -f logs/api.log +``` + +## 三、维护说明 + +### 1. 日常维护 + +```bash +# 查看日志 +tail -f logs/api.log + +# 备份数据库 +mysqldump -u root -p equipment_cost_db > backup/$(date +%Y%m%d).sql + +# 清理旧日志 +find logs/ -name "*.log" -mtime +30 -delete +``` + +### 2. 更新部署 + +```bash +# 停止服务 +bash scripts/stop.sh + +# 备份数据 +cp -r data data_backup_$(date +%Y%m%d) + +# 更新代码 +# ... 更新相关文件 ... + +# 重启服务 +bash scripts/start.sh +``` + +### 3. 故障处理 + +```bash +# 检查服务状态 +ps aux | grep gunicorn + +# 检查数据库连接 +mysql -u root -p -e "show databases;" + +# 重启服务 +bash scripts/stop.sh +bash scripts/start.sh +``` + +## 四、安全建议 + +1. 文件权限设置 + +```bash +# 设置适当的文件权限 +chmod 755 scripts/*.sh +chmod 600 config/.env +chmod 700 logs models data +``` + +2. 数据库安全 + +- 使用强密码 +- 限制数据库访问IP +- 定期备份数据 + +3. 服务器配置 + +- 配置防火墙规则 +- 启用 SSL/TLS +- 定期更新系统 + +## 五、监控方案 + +### 1. 系统监控 + +```bash +# 检查CPU和内存使用 +top -b -n 1 + +# 检查磁盘使用 +df -h + +# 检查网络连接 +netstat -tunlp +``` + +### 2. 应用监控 + +```bash +# 检查API响应时间 +curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/" + +# 检查错误日志 +grep "ERROR" logs/api.log +``` + +### 3. 告警设置 + +- 配置日志告警 +- 设置资源使用阈值告警 +- 配置服务可用性监控 diff --git a/docs/design.md b/docs/dev/design.md similarity index 100% rename from docs/design.md rename to docs/dev/design.md diff --git a/docs/requirements.md b/docs/dev/requirements.md similarity index 91% rename from docs/requirements.md rename to docs/dev/requirements.md index a384157..2da3bb4 100644 --- a/docs/requirements.md +++ b/docs/dev/requirements.md @@ -16,11 +16,11 @@ (2)线性相关分析:对于特征和标签皆为连续值的回归问题,要检测二者的相关性,最直接的做法就是求相关系数rxy,本质是建立协方差矩阵,分析数据和成本之间相关关系的类型和程度,筛选出影响特征 (3)互信息 (mutual information): 用于特征选择,可以从两个角度进行解释:(1)、基于 KL 散度和 (2)、基于信息增益。 2. 数据一致性分析:对特征数据分层分组,计算组内一致性,目标是选择比较合适的一组数据,以此产生一个进行成本估算和分析的虚拟量.大部分的研究中报告的三个数据:rwg、ICC(1)、ICC(2),要符合3个条件,rwg>0.7、ICC(1)>0.05、ICC(2)>0.5 -RWG值:打分一致性; -ICC1:组内一致性; -ICC2:组间一致性。 + RWG值:打分一致性; + ICC1:组内一致性; + ICC2:组间一致性。 3. 回归模型:偏最小二乘回归(partial Least Squares,PLS) -4. 神经网络模型:采用 BP 网络 +4. 神经网络模型:采用适用的神经网络模型 ### 数据准备 diff --git a/docs/dev/run.md b/docs/dev/run.md new file mode 100644 index 0000000..046a422 --- /dev/null +++ b/docs/dev/run.md @@ -0,0 +1,164 @@ +# 装备成本估算系统运行说明 + +## 一、开发环境配置 + +### 1. 系统要求 + +- Linux/macOS/Windows +- Python 3.8+ +- MySQL 8.0+ + +### 2. 安装依赖 + +```bash +# 创建并激活虚拟环境 +python3 -m venv venv +source venv/bin/activate # Linux/macOS +# 或 +.\venv\Scripts\activate # Windows + +# 安装依赖包 +pip install -r requirements.txt +``` + +## 二、初始化系统 + +### 1. 创建必要目录 + +```bash +mkdir -p {logs,data,models} +``` + +### 2. 配置数据库 + +```bash +# 执行数据库初始化脚本 +mysql -u root -p < src/schema.sql + +# [可选] 导入测试数据(仅用于开发环境) +mysql -u root -p equipment_cost_db < src/init_data.sql +``` + +### 3. 环境配置 + +创建 `.env` 文件: + +```ini +MYSQL_HOST=localhost +MYSQL_USER=root +MYSQL_PASSWORD=123456 +MYSQL_DATABASE=equipment_cost_db +``` + +## 三、启动服务 + +### 1. 开发模式 + +```bash +# 启动开发服务器 +python run.py +``` + +### 2. 测试 API + +```bash +# 运行 API 测试 +python src/test_api.py +``` + +## 四、开发调试 + +### 1. 日志查看 + +```bash +# API 日志 +tail -f logs/api.log + +# 测试日志 +tail -f logs/test_api.log + +# 训练日志 +tail -f logs/training.log +``` + +### 2. 数据库调试 + +```sql +-- 检查数据表 +SHOW TABLES; + +-- 查看示例数据 +SELECT * FROM equipment LIMIT 5; +``` + +### 3. API 测试 + +```bash +# 测试 API 根路由 +curl http://localhost:5001/api/ + +# 测试预测接口 +curl -X POST http://localhost:5001/api/predict \ + -H "Content-Type: application/json" \ + -d '{ + "type": "巡飞弹", + "length_m": 1.3, + "width_m": 0.23, + "height_m": 0.23, + "weight_kg": 12.5, + "max_range_km": 40 + }' +``` + +## 五、注意事项 + +1. 开发环境配置 + - 使用虚拟环境隔离依赖 + - 保持日志目录可写权限 + - 定期清理日志文件 + +2. 数据库使用 + - 使用 UTF-8 字符集 + - 定期备份数据 + - 避免直接修改生产数据 + +3. 代码调试 + - 查看详细日志输出 + - 使用测试数据验证功能 + - 遵循代码规范 + +## 六、常见问题 + +1. 数据库连接错误 + - 检查 MySQL 服务状态 + - 验证数据库用户名密码 + - 确认数据库字符集设置 + +2. API 访问问题 + - 检查服务是否正常运行 + - 验证请求格式是否正确 + - 查看错误日志信息 + +3. 模型相关问题 + - 确保训练数据完整性 + - 检查模型文件权限 + - 验证预测结果合理性 + +## 七、开发建议 + +1. 代码管理 + - 使用版本控制 + - 遵循项目结构 + - 及时更新文档 + +2. 测试规范 + - 运行完整测试套件 + - 验证各个功能模块 + - 记录测试结果 + +3. 安全注意 + - 使用安全的数据库密码 + - 避免敏感信息提交 + - 保护测试数据安全 + +注:生产环境部署请参考 `deploy.md` diff --git a/docs/nodejs_install.md b/docs/nodejs_install.md deleted file mode 100644 index 4986f32..0000000 --- a/docs/nodejs_install.md +++ /dev/null @@ -1,136 +0,0 @@ -# Node.js 安装指南 - -## Windows 安装方法 - -### 1. 使用安装包 - -1. 访问 Node.js 官网 -2. 下载 14.x LTS 版本安装包 -3. 运行安装包,按提示完成安装 -4. 验证安装: - -```bash -node --version -npm --version -``` - -### 2. 使用 nvm-windows(推荐) - -1. 下载 nvm-windows: -2. 安装 nvm-windows -3. 安装 Node.js: - -```bash -nvm install 14.21.3 -nvm use 14.21.3 -``` - -## Linux 安装方法 - -### 1. 使用 nvm(推荐) - -```bash -# 安装 nvm -curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash - -# 重新加载配置 -source ~/.bashrc - -# 安装 Node.js 14 -nvm install 14 -nvm use 14 -``` - -### 2. 使用包管理器 - -#### Ubuntu/Debian - -```bash -# 添加 NodeSource 仓库 -curl -fsSL https://deb.nodesource.com/setup_14.x | sudo -E bash - - -# 安装 Node.js -sudo apt-get install -y nodejs -``` - -#### CentOS/RHEL - -```bash -# 添加 NodeSource 仓库 -curl -fsSL https://rpm.nodesource.com/setup_14.x | sudo bash - - -# 安装 Node.js -sudo yum install -y nodejs -``` - -## macOS 安装方法 - -### 1. 使用 Homebrew(推荐) - -```bash -# 安装 Homebrew(如果未安装) -/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" - -# 安装 Node.js 14 -brew install node@14 - -# 添加环境变量 -echo 'export PATH="/usr/local/opt/node@14/bin:$PATH"' >> ~/.zshrc -source ~/.zshrc -``` - -### 2. 使用 nvm - -```bash -# 安装 nvm -curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash - -# 重新加载配置 -source ~/.zshrc - -# 安装 Node.js 14 -nvm install 14 -nvm use 14 -``` - -## 验证安装 - -安装完成后,运行以下命令验证: - -```bash -# 检查 Node.js 版本 -node --version # 应显示 v14.x.x - -# 检查 npm 版本 -npm --version # 应显示 6.x.x 或更高 -``` - -## 常见问题 - -### 1. 权限问题 - -如果遇到权限错误,可以: - -```bash -# Linux/macOS -sudo chown -R $USER /usr/local/lib/node_modules -``` - -### 2. 版本切换 - -如果需要在不同版本间切换: - -```bash -# 使用 nvm -nvm list # 查看已安装版本 -nvm use 14 # 切换到 14.x 版本 -``` - -### 3. npm 配置 - -建议配置国内镜像源: - -```bash -# 使用淘宝镜像 -npm config set registry https://registry.npmmirror.com -``` diff --git a/docs/run.md b/docs/run.md deleted file mode 100644 index e92716b..0000000 --- a/docs/run.md +++ /dev/null @@ -1,163 +0,0 @@ -# 系统运行说明 - -## 一、环境准备 - -### 1. 安装必要软件 - -```bash -# 安装 Python 3.8+ -# 安装 MySQL 8.0+ -# 安装 Node.js 14+ -``` - -### 2. 安装 Python 依赖 - -```bash -pip install -r requirements.txt -``` - -### 3. 安装前端依赖 - -```bash -cd frontend -npm install -``` - -## 二、数据库配置 - -### 1. 创建数据库 - -```sql -CREATE DATABASE equipment_cost_db DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; -``` - -### 2. 初始化数据库结构 - -```bash -# 执行数据库结构初始化脚本 -mysql -u username -p equipment_cost_db < src/schema.sql - -# 导入示例数据 -mysql -u username -p equipment_cost_db < src/init_data.sql - -# 导入真实数据 -mysql -u username -p equipment_cost_db < src/real_data.sql -``` - -## 三、配置文件 - -### 1. 后端配置 - -创建 `config.py` 文件: - -```python -# config.py -DATABASE_URI = "mysql+pymysql://username:password@localhost:3306/equipment_cost_db" -SECRET_KEY = "your-secret-key" -DEBUG = True -``` - -### 2. 前端配置 - -修改 `frontend/src/config.js`: - -```javascript -export const API_BASE_URL = 'http://localhost:5001/api'; -``` - -## 四、启动系统 - -### 1. 启动后端服务 - -```bash -# 开发环境 -python run.py # 服务将在 http://localhost:5001 启动 - -# 生产环境 -gunicorn -w 4 -b 0.0.0.0:5001 run:app -``` - -### 2. 启动前端服务 - -```bash -# 开发环境 -cd frontend -npm run serve # 前端将在 http://localhost:8080 启动 - -# 生产环境 -npm run build -``` - -## 五、访问系统 - -- 后端API: -- 前端界面: - -## 六、常见问题 - -### 1. 数据库连接问题 - -- 检查 MySQL 服务是否启动 -- 验证数据库用户名和密码 -- 确认数据库端口是否正确 - -### 2. 模型训练 - -```bash -# 训练模型 -python src/train_model.py - -# 查看训练日志 -tail -f logs/training.log -``` - -### 3. 系统监控 - -```bash -# 查看系统日志 -tail -f logs/app.log - -# 监控API请求 -tail -f logs/access.log -``` - -## 七、开发调试 - -### 1. 后端调试 - -```bash -# 启动调试模式 -python run.py --debug - -# 运行测试 -python -m pytest tests/ -``` - -### 2. 前端调试 - -```bash -# 启动开发服务器 -npm run serve - -# 运行测试 -npm run test -``` - -## 八、部署建议 - -### 1. 使用 Docker 部署 - -```bash -# 构建镜像 -docker-compose build - -# 启动服务 -docker-compose up -d -``` - -### 2. 生产环境配置 - -- 使用 Nginx 作为反向代理 -- 配置 SSL 证书 -- 设置适当的防火墙规则 -- 启用数据库备份 diff --git a/frontend/README.md b/frontend/README.md index 576b980..f08b1b5 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -1,24 +1,29 @@ # frontend ## Project setup -``` + +```bash npm install ``` ### Compiles and hot-reloads for development -``` + +```bash npm run serve ``` ### Compiles and minifies for production -``` + +```bash npm run build ``` ### Lints and fixes files -``` + +```bash npm run lint ``` ### Customize configuration + See [Configuration Reference](https://cli.vuejs.org/config/). diff --git a/frontend/src/views/AnalysisPage.vue b/frontend/src/views/AnalysisPage.vue index f8f8ba1..385b995 100644 --- a/frontend/src/views/AnalysisPage.vue +++ b/frontend/src/views/AnalysisPage.vue @@ -153,31 +153,93 @@ const startAnalysis = async () => { } } +// 窗口大小变化处理函数 +const resizeHandler = ref(null) + +// 创建 resize 处理函数 +const createResizeHandler = () => { + const handler = () => { + try { + if (importanceChart.value && !importanceChart.value.isDisposed()) { + importanceChart.value.resize() + } + if (correlationChart.value && !correlationChart.value.isDisposed()) { + correlationChart.value.resize() + } + } catch (error) { + console.error('Error in resize handler:', error) + } + } + + // 使用防抖包装 + return debounce(handler, 200) +} + +// 组件挂载时 +onMounted(() => { + // 创建并保存 resize 处理函数的引用 + resizeHandler.value = createResizeHandler() + window.addEventListener('resize', resizeHandler.value) +}) + +// 组件卸载时 +onUnmounted(() => { + // 移除事件监听 + if (resizeHandler.value) { + window.removeEventListener('resize', resizeHandler.value) + resizeHandler.value = null + } + + // 销毁图表实例 + try { + if (importanceChart.value) { + importanceChart.value.dispose() + importanceChart.value = null + } + if (correlationChart.value) { + correlationChart.value.dispose() + correlationChart.value = null + } + } catch (error) { + console.error('Error disposing charts:', error) + } +}) + // 渲染图表 const renderCharts = () => { console.log('Starting to render charts') - // 销毁旧的图表实例 - if (importanceChart.value) { - importanceChart.value.dispose() - } - if (correlationChart.value) { - correlationChart.value.dispose() + // 检查分析结果是否存在且包含必要数据 + if (!analysisResult.value || + !analysisResult.value.important_features || + !analysisResult.value.correlation_analysis) { + console.log('Analysis result not ready') + return } - // 确保 DOM 元素存在 + // 检查DOM元素 if (!importanceChartRef.value || !correlationChartRef.value) { console.log('Chart DOM elements not ready') return } try { + // 销毁旧的图表实例 + if (importanceChart.value) { + importanceChart.value.dispose() + importanceChart.value = null + } + if (correlationChart.value) { + correlationChart.value.dispose() + correlationChart.value = null + } + // 创建新的图表实例 importanceChart.value = echarts.init(importanceChartRef.value) correlationChart.value = echarts.init(correlationChartRef.value) // 设置图表选项 - importanceChart.value.setOption({ + const importanceOption = { title: { text: '特征重要性排序' }, tooltip: {}, xAxis: { @@ -189,12 +251,13 @@ const renderCharts = () => { data: analysisResult.value.important_features.map(f => f.name) }, series: [{ + name: '重要性', type: 'bar', data: analysisResult.value.important_features.map(f => f.importance) }] - }) + } - correlationChart.value.setOption({ + const correlationOption = { title: { text: '特征相关性热力图' }, tooltip: { position: 'top', @@ -233,6 +296,7 @@ const renderCharts = () => { color: ['#cc3333', '#eeeeee', '#00007f'] }, series: [{ + name: '相关性', type: 'heatmap', data: analysisResult.value.correlation_analysis.matrix, label: { @@ -242,13 +306,11 @@ const renderCharts = () => { } } }] - }) + } - // 监听窗口大小变化 - window.addEventListener('resize', () => { - importanceChart.value?.resize() - correlationChart.value?.resize() - }) + // 设置图表选项 + importanceChart.value.setOption(importanceOption) + correlationChart.value.setOption(correlationOption) console.log('Charts rendered successfully') } catch (error) { @@ -256,16 +318,16 @@ const renderCharts = () => { } } -// 初始化 -onMounted(() => { - // 可以在这里加载默认数据 -}) - -// 组件卸载时清理图表实例 -onUnmounted(() => { - importanceChart.value?.dispose() - correlationChart.value?.dispose() -}) +// 防抖函数 +function debounce(fn, delay) { + let timer = null + return function (...args) { + if (timer) clearTimeout(timer) + timer = setTimeout(() => { + fn.apply(this, args) + }, delay) + } +}