diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..49204df
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,25 @@
+# 数据库配置
+MYSQL_HOST=localhost
+MYSQL_USER=root
+MYSQL_PASSWORD=your_password_here
+MYSQL_DATABASE=equipment_cost_db
+
+# 服务配置
+PORT=5001
+DEBUG=False
+
+# 日志配置
+LOG_LEVEL=INFO
+LOG_DIR=logs
+
+# 模型配置
+MODEL_DIR=models
+DATA_DIR=data
+
+# 安全配置
+SECRET_KEY=your_secret_key_here
+ALLOWED_HOSTS=localhost,127.0.0.1
+
+# 其他配置
+UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
+ALLOWED_FILE_TYPES=.xlsx,.xls
\ No newline at end of file
diff --git a/config.py b/config.py
index ee008e5..ec92cd1 100644
--- a/config.py
+++ b/config.py
@@ -8,8 +8,8 @@ DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db"
SECRET_KEY = secrets.token_hex(16)
# 环境配置
-DEBUG = True
-ENV = 'development'
+DEBUG = False
+ENV = 'production'
# 文件上传配置
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
diff --git a/data/equipment_data_20241108_training.xlsx b/data/equipment_data_20241108_training.xlsx
deleted file mode 100644
index 30f7dcf..0000000
Binary files a/data/equipment_data_20241108_training.xlsx and /dev/null differ
diff --git a/data/equipment_data_20241108_verify.xlsx b/data/equipment_data_20241108_verify.xlsx
deleted file mode 100644
index 92a527a..0000000
Binary files a/data/equipment_data_20241108_verify.xlsx and /dev/null differ
diff --git a/deploy/equipment_cost_system.tar.gz b/deploy/equipment_cost_system.tar.gz
new file mode 100644
index 0000000..c906e30
Binary files /dev/null and b/deploy/equipment_cost_system.tar.gz differ
diff --git a/deploy/equipment_cost_system/config/.env.template b/deploy/equipment_cost_system/config/.env.template
new file mode 100644
index 0000000..195987f
--- /dev/null
+++ b/deploy/equipment_cost_system/config/.env.template
@@ -0,0 +1,25 @@
+# 数据库配置
+MYSQL_HOST=localhost
+MYSQL_USER=root
+MYSQL_PASSWORD=your_password_here
+MYSQL_DATABASE=equipment_cost_db
+
+# 服务配置
+PORT=5001
+DEBUG=False
+
+# 日志配置
+LOG_LEVEL=INFO
+LOG_DIR=logs
+
+# 模型配置
+MODEL_DIR=models
+DATA_DIR=data
+
+# 安全配置
+SECRET_KEY=your_secret_key_here
+ALLOWED_HOSTS=localhost,127.0.0.1
+
+# 其他配置
+UPLOAD_MAX_SIZE=10485760 # 10MB in bytes
+ALLOWED_FILE_TYPES=.xlsx,.xls
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/docs/api.md b/deploy/equipment_cost_system/docs/api.md
new file mode 100644
index 0000000..000bd35
--- /dev/null
+++ b/deploy/equipment_cost_system/docs/api.md
@@ -0,0 +1,663 @@
+# 装备成本估算系统 API 文档
+
+这个 API 文档提供了完整的接口说明,包括:
+
+- 每个端点的详细描述
+- 请求和响应的具体示例
+- 清晰的参数格式要求
+- 统一的错误处理说明
+- 重要的注意事项
+
+文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。
+
+## 基本信息
+
+- 基础URL: `http://localhost:5001/api`
+- 版本: 1.0.0
+- 响应格式: JSON
+
+## API 端点列表
+
+### 1. 获取 API 信息
+
+获取 API 版本信息和可用端点列表。
+
+- **URL**: `/`
+- **方法**: `GET`
+- **响应示例**:
+json
+{
+"name": "装备成本估算系统 API",
+"version": "1.0.0",
+"endpoints": {
+"predict": {
+"url": "/api/predict",
+"method": "POST",
+"description": "成本预测"
+},
+"analyze-features": {
+"url": "/api/analyze-features",
+"method": "POST",
+"description": "特征分析"
+},
+"train": {
+"url": "/api/train",
+"method": "POST",
+"description": "模型训练"
+},
+"evaluate": {
+"url": "/api/evaluate",
+"method": "POST",
+"description": "模型评估"
+}
+}
+}
+
+### 2. 单模型预测
+
+使用当前激活的最优模型进行成本预测。
+
+- **URL**: `/predict`
+- **方法**: `POST`
+- **请求体示例** (巡飞弹):
+
+```json
+{
+ "type": "巡飞弹",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "folded_length_mm": 1300,
+ "folded_width_mm": 230,
+ "folded_height_mm": 230,
+ "warhead_type": "破片杀伤战斗部",
+ "launch_mode": "凭自身动力起飞"
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "predicted_cost": 150000.0,
+ "model_info": {
+ "type": "xgboost",
+ "name": "巡飞弹_20241111_model",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0
+ },
+ "confidence_interval": {
+ "lower": 135000.0,
+ "upper": 165000.0
+ }
+}
+```
+
+### 3. PLS 模型预测
+
+使用 PLS 回归模型进行预测。
+
+- **URL**: `/pls/predict`
+- **方法**: `POST`
+- **请求体**: 与单模型预测相同
+- **响应示例**:
+
+```json
+{
+ "predicted_cost": 148000.0,
+ "confidence_interval": {
+ "lower": 133000.0,
+ "upper": 163000.0
+ }
+}
+```
+
+### 4. 多模型预测
+
+使用所有激活的模型进行预测并返回综合结果。
+
+- **URL**: `/predict/all`
+- **方法**: `POST`
+- **请求体**: 与单模型预测相同
+- **响应示例**:
+
+```json
+{
+ "individual_predictions": {
+ "xgboost": {
+ "predicted_cost": 150000.0,
+ "model_info": {
+ "name": "巡飞弹_xgboost_model",
+ "type": "xgboost",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0
+ },
+ "confidence_interval": {
+ "lower": 135000.0,
+ "upper": 165000.0
+ }
+ },
+ "pls": {
+ "predicted_cost": 148000.0,
+ "model_info": {
+ "name": "巡飞弹_pls_model",
+ "type": "pls",
+ "r2_score": 0.92,
+ "mae": 5500.0,
+ "rmse": 8000.0
+ },
+ "confidence_interval": {
+ "lower": 133000.0,
+ "upper": 163000.0
+ }
+ }
+ },
+ "ensemble_prediction": {
+ "predicted_cost": 149000.0,
+ "standard_deviation": 1414.21,
+ "confidence_interval": {
+ "lower": 146228.15,
+ "upper": 151771.85
+ }
+ }
+}
+```
+
+### 5. 特征分析
+
+分析数据集中特征的重要性和相关性。
+
+- **URL**: `/analyze-features`
+- **方法**: `POST`
+- **请求体示例**:
+
+```json
+{
+ "dataset_id": 1,
+ "equipment_type": "巡飞弹"
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "important_features": [
+ {
+ "name": "最大射程(km)",
+ "importance": 0.35
+ },
+ {
+ "name": "重量(kg)",
+ "importance": 0.25
+ }
+ ],
+ "correlation_analysis": {
+ "features": ["最大射程(km)", "重量(kg)"],
+ "matrix": [[1.0, 0.8], [0.8, 1.0]]
+ }
+}
+```
+
+### 6. 模型训练
+
+训练新的模型。
+
+- **URL**: `/train`
+- **方法**: `POST`
+- **请求体示例**:
+
+```json
+{
+ "type": "巡飞弹",
+ "train_dataset_id": 1,
+ "validation_dataset_id": 2,
+ "models": ["xgboost", "lightgbm", "rf"]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "metrics": {
+ "xgboost": {
+ "train": {
+ "r2": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0
+ },
+ "validation": {
+ "r2": 0.92,
+ "mae": 5500.0,
+ "rmse": 8000.0
+ }
+ }
+ },
+ "best_model": {
+ "type": "xgboost",
+ "r2": 0.92,
+ "mae": 5500.0,
+ "rmse": 8000.0
+ }
+}
+```
+
+### 7. 数据集管理
+
+#### 7.1 获取数据集列表
+
+- **URL**: `/datasets`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+[
+ {
+ "id": 1,
+ "name": "训练数据集",
+ "description": "用于训练的数据集",
+ "equipment_type": "巡飞弹",
+ "equipment_count": 10,
+ "equipment_names": ["设备1", "设备2"],
+ "purpose": "训练",
+ "created_at": "2024-11-11T10:00:00"
+ }
+]
+```
+
+#### 7.2 获取数据集详情
+
+- **URL**: `/datasets/{id}`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "id": 1,
+ "name": "训练数据集",
+ "description": "用于训练的数据集",
+ "equipment_type": "巡飞弹",
+ "purpose": "训练",
+ "created_at": "2024-11-11T10:00:00",
+ "equipment": [
+ {
+ "id": 1,
+ "name": "设备1",
+ "type": "巡飞弹",
+ "manufacturer": "制造商1",
+ "actual_cost": 150000
+ }
+ ],
+ "statistics": {
+ "equipment_count": 10,
+ "total_cost": 1500000,
+ "average_cost": 150000
+ }
+}
+```
+
+#### 7.3 创建数据集
+
+- **URL**: `/datasets`
+- **方法**: `POST`
+- **请求体示例**:
+
+```json
+{
+ "name": "测试数据集",
+ "description": "用于测试的数据集",
+ "equipment_type": "巡飞弹",
+ "purpose": "训练",
+ "equipment_ids": [1, 2, 3]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "id": 2,
+ "message": "数据集创建成功"
+}
+```
+
+#### 7.4 更新数据集
+
+- **URL**: `/datasets/{id}`
+- **方法**: `PUT`
+- **请求体示例**:
+
+```json
+{
+ "name": "更新后的数据集名称",
+ "description": "更新后的描述",
+ "equipment_type": "巡飞弹",
+ "purpose": "验证",
+ "equipment_ids": [1, 2, 3, 4]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "数据集更新成功"
+}
+```
+
+#### 7.5 删除数据集
+
+- **URL**: `/datasets/{id}`
+- **方法**: `DELETE`
+- **描述**: 删除指定的数据集及其关联关系
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "数据集删除成功"
+}
+```
+
+注意事项:
+
+1. 数据集删除后不会删除关联的装备数据
+2. 不能删除正在被模型使用的数据集
+3. 更新数据集时会重新计算统计信息
+4. 数据集的装备类型一旦创建后不能更改
+
+### 8. 模型管理
+
+#### 8.1 获取模型列表
+
+- **URL**: `/models`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+[
+ {
+ "id": 1,
+ "model_name": "巡飞弹_xgboost_model",
+ "model_type": "xgboost",
+ "equipment_type": "巡飞弹",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0,
+ "is_active": true,
+ "training_date": "2024-11-11T10:00:00"
+ }
+]
+```
+
+#### 8.2 获取最新模型
+
+- **URL**: `/models/{equipment_type}/latest`
+- **方法**: `GET`
+- **响应示例**: 与模型列表的单个模型格式相同
+
+#### 8.3 获取模型详情
+
+- **URL**: `/models/{id}`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "id": 1,
+ "model_name": "巡飞弹_xgboost_model",
+ "model_type": "xgboost",
+ "equipment_type": "巡飞弹",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0,
+ "is_active": true,
+ "training_date": "2024-11-11T10:00:00",
+ "feature_importance": {
+ "max_range_km": 0.35,
+ "weight_kg": 0.25,
+ "length_m": 0.20
+ },
+ "training_data_size": 100,
+ "created_by": "system"
+}
+```
+
+#### 8.4 激活模型
+
+- **URL**: `/models/{id}/activate`
+- **方法**: `POST`
+- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "模型已激活"
+}
+```
+
+#### 8.5 删除模型
+
+- **URL**: `/models/{id}`
+- **方法**: `DELETE`
+- **描述**: 删除指定模型,包括模型文件和数据库记录
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "模型已删除"
+}
+```
+
+注意事项:
+
+1. 删除模型时会同时删除相关的文件和数据库记录
+2. 不能删除当前正在使用(已激活)的模型
+3. 激活模型时会自动取消同类型其他模型的激活状态
+4. 模型详情包含了更多的训练相关信息,如特征重要性等
+
+### 9. 数据管理
+
+#### 9.1 获取装备数据列表
+
+- **URL**: `/data`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "rocket_artillery": [
+ {
+ "id": 1,
+ "name": "BM-21",
+ "type": "火箭炮",
+ "manufacturer": "俄罗斯",
+ "length_m": 7.35,
+ "width_m": 2.4,
+ "height_m": 3.1,
+ "weight_kg": 13700,
+ "max_range_km": 20.4,
+ "actual_cost": 800000
+ }
+ ],
+ "loitering_munition": [
+ {
+ "id": 8,
+ "name": "Hero-120",
+ "type": "巡飞弹",
+ "manufacturer": "以色列",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "actual_cost": 150000
+ }
+ ]
+}
+```
+
+#### 9.2 获取装备详情
+
+- **URL**: `/data/details/{id}`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "id": 8,
+ "name": "Hero-120",
+ "type": "巡飞弹",
+ "manufacturer": "以色列",
+ "common_params": {
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40
+ },
+ "specific_params": {
+ "wingspan_m": 2.1,
+ "warhead_weight_kg": 3.5,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "warhead_type": "破片杀伤战斗部",
+ "launch_mode": "箱式发射",
+ "power_system": "电动机",
+ "guidance_system": "GPS/INS"
+ },
+ "cost_data": {
+ "actual_cost": 150000,
+ "prediction_date": "2024-11-11T10:00:00",
+ "predicted_cost": 148000
+ },
+ "custom_params": [
+ {
+ "id": 1,
+ "param_name": "续航时间",
+ "param_value": "2小时",
+ "param_unit": "小时",
+ "description": "最大续航时间"
+ }
+ ]
+}
+```
+
+#### 9.3 更新装备数据
+
+- **URL**: `/data/{id}`
+- **方法**: `PUT`
+- **请求体示例**:
+
+```json
+{
+ "name": "Hero-120",
+ "manufacturer": "以色列",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "wingspan_m": 2.1,
+ "warhead_weight_kg": 3.5,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "actual_cost": 150000,
+ "custom_params": [
+ {
+ "id": 1,
+ "param_value": "2.5小时"
+ }
+ ]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "装备数据更新成功"
+}
+```
+
+#### 9.4 删除装备数据
+
+- **URL**: `/data/{id}`
+- **方法**: `DELETE`
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "装备数据删除成功"
+}
+```
+
+#### 9.5 下载数据模板
+
+- **URL**: `/data/template`
+- **方法**: `GET`
+- **描述**: 下载Excel格式的数据导入模板
+- **响应**: Excel文件下载
+
+#### 9.6 导入数据
+
+- **URL**: `/data/import`
+- **方法**: `POST`
+- **请求体**:
+ - Content-Type: multipart/form-data
+ - 参数名: file
+ - 文件类型: .xlsx 或 .xls
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "数据导入成功",
+ "imported_count": {
+ "rocket_artillery": 3,
+ "loitering_munition": 5
+ }
+}
+```
+
+注意事项:
+
+1. 导入数据时必须使用系统提供的模板
+2. 更新装备数据时会同时更新关联的参数表
+3. 删除装备数据会同时删除相关的参数和成本数据
+4. 导入的Excel文件大小不应超过10MB
+5. 所有数值字段必须符合指定的单位和范围要求
+6. 特殊参数的值必须包含单位信息
+
+## 错误响应
+
+所有接口在发生错误时都会返回以下格式的响应:
+
+```json
+{
+ "error": "错误描述信息"
+}
+```
+
+## 注意事项
+
+1. 所有数值参数必须大于0
+2. 所有单位必须按照参数名称中指定的单位提供
+3. 预测结果中的成本单位为元
+4. 置信区间表示预测结果的95%置信水平范围
+5. 所有请求和响应的编码均为 UTF-8
diff --git a/deploy/equipment_cost_system/docs/deploy.md b/deploy/equipment_cost_system/docs/deploy.md
new file mode 100644
index 0000000..1f52359
--- /dev/null
+++ b/deploy/equipment_cost_system/docs/deploy.md
@@ -0,0 +1,120 @@
+# 装备成本估算系统部署指南
+
+## 一、系统要求
+
+### 1. 基础软件
+
+- Linux 操作系统 (推荐 Ubuntu 20.04+)
+- Python 3.8+ 及相关组件
+
+ ```bash
+ sudo apt update
+ sudo apt install python3 python3-pip python3-venv
+ sudo apt install python3-dev build-essential
+ ```
+
+- Node.js 14+ 及 npm
+
+ ```bash
+ # 使用 nvm 安装 Node.js
+ curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
+ source ~/.bashrc
+ nvm install 14
+ nvm use 14
+ ```
+
+### 2. 数据库
+
+- MySQL 8.0+
+
+ ```bash
+ sudo apt install mysql-server mysql-client
+ sudo apt install libmysqlclient-dev
+ ```
+
+### 3. Python包依赖
+
+```bash
+# 科学计算相关
+sudo apt install libatlas-base-dev # numpy依赖
+sudo apt install libopenblas-dev # 线性代数库
+sudo apt install liblapack-dev # 线性代数包
+sudo apt install gfortran # Fortran编译器(scipy依赖)
+
+# XML处理相关(用于Excel文件处理)
+sudo apt install libxml2-dev
+sudo apt install libxslt1-dev
+```
+
+## 二、部署运行
+
+### 1. 安装服务
+
+```bash
+sh scripts/install.sh
+```
+
+### 2. 启动服务
+
+```bash
+sh scripts/start.sh
+```
+
+### 3. 停止服务
+
+```bash
+sh scripts/stop.sh
+```
+
+## 三、维护说明
+
+### 1. 日志管理
+
+```bash
+# 后端日志
+tail -f logs/api.log
+
+# 数据库日志
+tail -f /var/log/mysql/error.log
+```
+
+## 四、安全建议
+
+1. 系统安全
+ - 使用防火墙限制端口访问
+ - 定期更新系统和依赖包
+
+2. 数据安全
+ - 定期备份数据库
+ - 加密敏感信息
+ - 限制数据库远程访问
+
+3. 访问控制
+ - 使用强密码
+ - 配置适当的文件权限
+ - 使用非root用户运行服务
+
+## 五、监控方案
+
+### 1. 系统监控
+
+```bash
+# 资源使用
+top -b -n 1
+df -h
+free -m
+
+# 服务状态
+ps aux | grep gunicorn
+ps aux | grep node
+```
+
+### 2. 应用监控
+
+```bash
+# API 响应时间
+curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
+
+# 错误日志
+grep "ERROR" logs/api.log
+```
diff --git a/deploy/equipment_cost_system/frontend/babel.config.js b/deploy/equipment_cost_system/frontend/babel.config.js
new file mode 100644
index 0000000..e955840
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/babel.config.js
@@ -0,0 +1,5 @@
+module.exports = {
+ presets: [
+ '@vue/cli-plugin-babel/preset'
+ ]
+}
diff --git a/deploy/equipment_cost_system/frontend/package.json b/deploy/equipment_cost_system/frontend/package.json
new file mode 100644
index 0000000..5e6409e
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/package.json
@@ -0,0 +1,61 @@
+{
+ "name": "frontend",
+ "version": "0.1.0",
+ "private": true,
+ "engines": {
+ "node": ">=16",
+ "npm": ">=8"
+ },
+ "scripts": {
+ "serve": "vue-cli-service serve",
+ "build": "vue-cli-service build",
+ "lint": "vue-cli-service lint"
+ },
+ "dependencies": {
+ "axios": "^1.6.0",
+ "core-js": "^3.8.3",
+ "echarts": "^5.4.3",
+ "element-plus": "^2.4.2",
+ "vue": "^3.2.13",
+ "vue-router": "^4.0.3",
+ "vuex": "^4.0.0"
+ },
+ "devDependencies": {
+ "@babel/core": "^7.12.16",
+ "@babel/eslint-parser": "^7.12.16",
+ "@element-plus/icons-vue": "^2.3.1",
+ "@vue/cli-plugin-babel": "~5.0.0",
+ "@vue/cli-plugin-eslint": "~5.0.0",
+ "@vue/cli-plugin-router": "~5.0.0",
+ "@vue/cli-plugin-vuex": "~5.0.0",
+ "@vue/cli-service": "~5.0.0",
+ "@vue/compiler-sfc": "^3.2.13",
+ "eslint": "^7.32.0",
+ "eslint-plugin-vue": "^8.0.3",
+ "sass": "^1.32.7",
+ "sass-loader": "^12.0.0"
+ },
+ "eslintConfig": {
+ "root": true,
+ "env": {
+ "node": true
+ },
+ "extends": [
+ "plugin:vue/vue3-essential",
+ "eslint:recommended"
+ ],
+ "parserOptions": {
+ "parser": "@babel/eslint-parser"
+ },
+ "rules": {
+ "vue/multi-word-component-names": "off",
+ "no-unused-vars": "warn"
+ }
+ },
+ "browserslist": [
+ "> 1%",
+ "last 2 versions",
+ "not dead",
+ "not ie 11"
+ ]
+}
diff --git a/deploy/equipment_cost_system/frontend/public/favicon.ico b/deploy/equipment_cost_system/frontend/public/favicon.ico
new file mode 100644
index 0000000..df36fcf
Binary files /dev/null and b/deploy/equipment_cost_system/frontend/public/favicon.ico differ
diff --git a/deploy/equipment_cost_system/frontend/public/index.html b/deploy/equipment_cost_system/frontend/public/index.html
new file mode 100644
index 0000000..3e5a139
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/public/index.html
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+ <%= htmlWebpackPlugin.options.title %>
+
+
+
+
+
+
+
diff --git a/deploy/equipment_cost_system/frontend/src/App.vue b/deploy/equipment_cost_system/frontend/src/App.vue
new file mode 100644
index 0000000..c4278c5
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/App.vue
@@ -0,0 +1,42 @@
+
+
+
+
+ 首页
+ 成本预测
+ 特征分析
+ 模型训练
+ 模型管理
+ 数据集管理
+ 数据管理
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/deploy/equipment_cost_system/frontend/src/api/index.js b/deploy/equipment_cost_system/frontend/src/api/index.js
new file mode 100644
index 0000000..1090d1a
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/api/index.js
@@ -0,0 +1,43 @@
+import axios from 'axios'
+import { API_BASE_URL } from '@/config'
+
+const api = axios.create({
+ baseURL: API_BASE_URL,
+ timeout: 10000
+})
+
+export const predict = (data) => {
+ return api.post('/predict', data)
+}
+
+export const analyzeFeatures = (data) => {
+ return api.post('/analyze-features', data)
+}
+
+export const trainModel = (data) => {
+ return api.post('/train', data)
+}
+
+export const evaluateModel = (data) => {
+ return api.post('/evaluate', data)
+}
+
+export const importData = (formData) => {
+ return api.post('/data/import', formData, {
+ headers: {
+ 'Content-Type': 'multipart/form-data'
+ }
+ })
+}
+
+export const getEquipmentData = () => {
+ return api.get('/data')
+}
+
+export const updateEquipment = (id, data) => {
+ return api.put(`/data/${id}`, data)
+}
+
+export const deleteEquipment = (id) => {
+ return api.delete(`/data/${id}`)
+}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/assets/logo.png b/deploy/equipment_cost_system/frontend/src/assets/logo.png
new file mode 100644
index 0000000..f3d2503
Binary files /dev/null and b/deploy/equipment_cost_system/frontend/src/assets/logo.png differ
diff --git a/deploy/equipment_cost_system/frontend/src/assets/styles/global.css b/deploy/equipment_cost_system/frontend/src/assets/styles/global.css
new file mode 100644
index 0000000..a9e81e1
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/assets/styles/global.css
@@ -0,0 +1,39 @@
+/* 禁用 ResizeObserver 警告 */
+iframe[style*="position: fixed; top: 0px; left: 0px; width: 100%; height: 100%; border: none; z-index: 2147483647;"] {
+ display: none !important;
+}
+
+.el-overlay {
+ overflow: hidden !important;
+}
+
+/* 添加全局样式 */
+body {
+ margin: 0;
+ padding: 0;
+ overflow-x: hidden;
+}
+
+/* 修复 Element Plus 的一些已知问题 */
+.el-dialog__wrapper {
+ overflow: hidden !important;
+}
+
+.el-select-dropdown {
+ overflow: hidden !important;
+}
+
+/* 禁用 ResizeObserver 相关的警告样式 */
+.resize-observer-warning {
+ display: none !important;
+}
+
+/* 优化滚动行为 */
+* {
+ scroll-behavior: smooth;
+}
+
+/* 防止页面抖动 */
+.el-main {
+ overflow-x: hidden;
+}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/config.js b/deploy/equipment_cost_system/frontend/src/config.js
new file mode 100644
index 0000000..0a697ee
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/config.js
@@ -0,0 +1,8 @@
+export const API_BASE_URL = 'http://localhost:5001/api';
+
+export const DB_CONFIG = {
+ host: 'localhost',
+ user: 'root',
+ password: '123456',
+ database: 'equipment_cost_db'
+};
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/main.js b/deploy/equipment_cost_system/frontend/src/main.js
new file mode 100644
index 0000000..360a4a4
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/main.js
@@ -0,0 +1,55 @@
+import { createApp } from 'vue'
+import App from './App.vue'
+import router from './router'
+import store from './store'
+import ElementPlus from 'element-plus'
+import 'element-plus/dist/index.css'
+import './assets/styles/global.css'
+import * as ElementPlusIconsVue from '@element-plus/icons-vue'
+
+// 创建应用实例
+const app = createApp(App)
+
+// 注册插件
+app.use(ElementPlus, {
+ size: 'default'
+})
+app.use(router)
+app.use(store)
+
+// 注册图标
+for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
+ app.component(key, component)
+}
+
+// 全局错误处理
+app.config.errorHandler = (err) => {
+ if (err.message && err.message.includes('ResizeObserver')) {
+ return
+ }
+ console.error(err)
+}
+
+// 全局警告处理
+app.config.warnHandler = (msg, trace) => {
+ if (msg.includes('ResizeObserver')) {
+ return
+ }
+ console.warn(msg, trace)
+}
+
+// 挂载应用
+app.mount('#app')
+
+// 处理 ResizeObserver 错误
+const _ResizeObserver = window.ResizeObserver
+window.ResizeObserver = class ResizeObserver extends _ResizeObserver {
+ constructor(callback) {
+ super((entries, observer) => {
+ requestAnimationFrame(() => {
+ if (!Array.isArray(entries)) return
+ callback(entries, observer)
+ })
+ })
+ }
+}
diff --git a/deploy/equipment_cost_system/frontend/src/router/index.js b/deploy/equipment_cost_system/frontend/src/router/index.js
new file mode 100644
index 0000000..3dfcdee
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/router/index.js
@@ -0,0 +1,52 @@
+import { createRouter, createWebHistory } from 'vue-router'
+import HomePage from '@/views/HomePage.vue'
+import DataPage from '@/views/DataPage.vue'
+import DatasetPage from '@/views/DatasetPage.vue'
+import PredictPage from '@/views/PredictPage.vue'
+import AnalysisPage from '@/views/AnalysisPage.vue'
+import TrainingPage from '@/views/TrainingPage.vue'
+
+const routes = [
+ {
+ path: '/',
+ name: 'Home',
+ component: HomePage
+ },
+ {
+ path: '/data',
+ name: 'Data',
+ component: DataPage
+ },
+ {
+ path: '/datasets',
+ name: 'Datasets',
+ component: DatasetPage
+ },
+ {
+ path: '/predict',
+ name: 'Predict',
+ component: PredictPage
+ },
+ {
+ path: '/analysis',
+ name: 'Analysis',
+ component: AnalysisPage
+ },
+ {
+ path: '/training',
+ name: 'Training',
+ component: TrainingPage
+ },
+ {
+ path: '/models',
+ name: 'Models',
+ component: () => import('../views/ModelPage.vue')
+ }
+]
+
+const router = createRouter({
+ history: createWebHistory(),
+ routes
+})
+
+export default router
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/store/index.js b/deploy/equipment_cost_system/frontend/src/store/index.js
new file mode 100644
index 0000000..0faaec7
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/store/index.js
@@ -0,0 +1,14 @@
+import { createStore } from 'vuex'
+
+export default createStore({
+ state: {
+ },
+ getters: {
+ },
+ mutations: {
+ },
+ actions: {
+ },
+ modules: {
+ }
+})
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js b/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js
new file mode 100644
index 0000000..de54055
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/utils/errorHandler.js
@@ -0,0 +1,73 @@
+// 处理 ResizeObserver 错误
+const resizeHandler = () => {
+ const resizeObserverErrors = [
+ "ResizeObserver loop completed with undelivered notifications.",
+ "ResizeObserver loop limit exceeded",
+ "ResizeObserver loop completed"
+ ]
+
+ // 添加全局错误处理
+ const handler = (event) => {
+ if (event && event.message && resizeObserverErrors.includes(event.message)) {
+ event.stopPropagation()
+ event.preventDefault()
+ event.stopImmediatePropagation()
+ return false
+ }
+ }
+
+ // 添加多个事件监听器
+ window.addEventListener('error', handler, true)
+ window.addEventListener('unhandledrejection', handler, true)
+
+ // 添加 ResizeObserver 错误处理
+ if (window.ResizeObserver) {
+ const resizeObserverPrototype = ResizeObserver.prototype
+ const originalObserve = resizeObserverPrototype.observe
+
+ resizeObserverPrototype.observe = function (...args) {
+ try {
+ return originalObserve.apply(this, args)
+ } catch (e) {
+ if (resizeObserverErrors.includes(e.message)) {
+ return null
+ }
+ throw e
+ }
+ }
+ }
+}
+
+export default {
+ install: (app) => {
+ resizeHandler()
+
+ // 添加全局错误处理器
+ app.config.errorHandler = (err, vm, info) => {
+ const resizeObserverErrors = [
+ "ResizeObserver loop completed with undelivered notifications.",
+ "ResizeObserver loop limit exceeded",
+ "ResizeObserver loop completed"
+ ]
+
+ if (err && err.message && resizeObserverErrors.includes(err.message)) {
+ return
+ }
+ console.error('Vue Error:', err, info)
+ }
+
+ // 添加全局警告处理器
+ app.config.warnHandler = (msg, vm, trace) => {
+ const resizeObserverErrors = [
+ "ResizeObserver loop completed with undelivered notifications.",
+ "ResizeObserver loop limit exceeded",
+ "ResizeObserver loop completed"
+ ]
+
+ if (resizeObserverErrors.includes(msg)) {
+ return
+ }
+ console.warn('Vue Warning:', msg, trace)
+ }
+ }
+}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue b/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue
new file mode 100644
index 0000000..f8f8ba1
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/AnalysisPage.vue
@@ -0,0 +1,304 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ selectedDataset.name }}
+ {{ selectedDataset.equipment_count }}
+ {{ selectedDataset.description }}
+
+
+
+
+
+
+ {{ analyzing ? '分析中...' : '开始分析' }}
+
+
+
+
+
+
分析结果
+
+
+
特征重要性
+
+
+
+
相关性分析
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/DataPage.vue b/deploy/equipment_cost_system/frontend/src/views/DataPage.vue
new file mode 100644
index 0000000..4689f70
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/DataPage.vue
@@ -0,0 +1,725 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 详情
+
+
+ 编辑
+
+
+ 删除
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 详情
+
+
+ 编辑
+
+
+ 删除
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 基本信息
+
+
+
+ {{ selectedData?.name }}
+ {{ selectedData?.manufacturer }}
+
+
+ {{ formatNumber(selectedData?.length_m) }}
+ {{ formatNumber(selectedData?.width_m) }}
+ {{ formatNumber(selectedData?.height_m) }}
+ {{ formatNumber(selectedData?.weight_kg) }}
+ {{ formatNumber(selectedData?.max_range_km) }}
+
+
+
+
+
+
+ 火箭炮参数
+
+
+
+ {{ formatNumber(selectedData?.rocket_diameter_mm) }}
+ {{ formatNumber(selectedData?.rate_of_fire) }}
+ {{ formatNumber(selectedData?.rocket_length_m) }}
+ {{ formatNumber(selectedData?.rocket_weight_kg) }}
+ {{ formatNumber(selectedData?.firing_angle_horizontal) }}
+ {{ formatNumber(selectedData?.firing_angle_vertical) }}
+ {{ selectedData?.mobility_type }}
+ {{ selectedData?.structure_layout }}
+ {{ selectedData?.engine_model }}
+ {{ selectedData?.engine_params }}
+ {{ formatNumber(selectedData?.power_hp) }}
+ {{ formatNumber(selectedData?.travel_range_km) }}
+
+
+
+
+
+
+ 巡飞弹参数
+
+
+
+ {{ formatNumber(selectedData?.max_speed_ms) }}
+ {{ formatNumber(selectedData?.cruise_speed_kmh) }}
+ {{ formatNumber(selectedData?.flight_time_min) }}
+ {{ selectedData?.warhead_type }}
+ {{ selectedData?.launch_mode }}
+ {{ selectedData?.power_system }}
+ {{ selectedData?.guidance_system }}
+
+
+
+
+ {{ console.log('Rendering custom params:', selectedData.custom_params) }}
+
+
+ 特殊参数
+
+
+
+
+ {{ formatCustomParamValue(param) }}
+
+
+
+
+
+
+
+ 成本信息
+
+
+
+
+ {{ formatMoney(selectedData?.actual_cost) }}
+
+
+ {{ formatMoney(selectedData?.predicted_cost) }}
+
+
+ {{ selectedData?.cost_estimate_date || '-' }}
+
+
+
+
+
+
+
+
+
+
+ 基本信息
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 火箭炮参数
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 巡飞弹参数
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 特殊参数
+
+
+
+ {{ param.param_unit }}
+
+
+
+
+ {{ param.param_unit }}
+
+
+
+
+
+
+ 成本信息
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue b/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue
new file mode 100644
index 0000000..1079611
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/DatasetPage.vue
@@ -0,0 +1,322 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ formatDateTime(scope.row.created_at) }}
+
+
+
+
+ {{ formatDateTime(scope.row.updated_at) }}
+
+
+
+
+ 查看
+ 编辑
+ 删除
+
+
+
+
+
+
+
+
+ {{ selectedDataset?.equipment_type }}
+ {{ selectedDataset?.purpose }}
+ {{ selectedDataset?.description }}
+
+
+
+
+ 统计信息
+
+ {{ selectedDataset?.statistics?.equipment_count || 0 }}
+ {{ formatMoney(selectedDataset?.statistics?.average_cost) }}
+ {{ formatMoney(selectedDataset?.statistics?.total_cost) }}
+
+
+
+
+
+ 包含装备
+
+
+
+
+
+ {{ formatMoney(scope.row.actual_cost) }}
+
+
+
+
+ 详情
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ formatMoney(scope.row.actual_cost) }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/HomePage.vue b/deploy/equipment_cost_system/frontend/src/views/HomePage.vue
new file mode 100644
index 0000000..60c0676
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/HomePage.vue
@@ -0,0 +1,101 @@
+
+
+
+
+ 装备成本估算系统
+
+
+
+
+
+ 成本预测
+ 基于机器学习和 PLS 回归模型的成本预测
+
+
+
+
+
+ 特征分析
+ 分析参数对成本的影响
+
+
+
+
+
+ 模型训练
+ 训练和优化预测模型
+
+
+
+
+
+ 模型管理
+ 管理训练好的模型
+
+
+
+
+
+ 数据集管理
+ 管理训练和验证数据集
+
+
+
+
+
+ 数据管理
+ 管理装备数据和成本数据
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue b/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue
new file mode 100644
index 0000000..e9632f9
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/ModelPage.vue
@@ -0,0 +1,279 @@
+
+
+
+
+
+
+
+
+
+
+
+ {{ getModelName(scope.row.model_type) }}
+
+
+
+
+
+
+ {{ scope.row.r2_score.toFixed(4) }}
+
+
+
+
+ {{ scope.row.mae.toFixed(2) }}
+
+
+
+
+ {{ scope.row.rmse.toFixed(2) }}
+
+
+
+
+ {{ formatDateTime(scope.row.training_date) }}
+
+
+
+
+
+ {{ scope.row.is_active ? '使用中' : '未使用' }}
+
+
+
+
+
+
+ 激活
+
+
+ 详情
+
+
+ 删除
+
+
+
+
+
+
+
+
+
+ {{ selectedModel?.model_name }}
+ {{ formatModelType(selectedModel?.model_type) }}
+ {{ selectedModel?.equipment_type }}
+ {{ selectedModel?.training_data_size }}
+ {{ formatDateTime(selectedModel?.training_date) }}
+
+
+ {{ selectedModel?.is_active ? '使用中' : '未使用' }}
+
+
+
+
+
+
+ 评估指标
+
+ {{ selectedModel?.r2_score.toFixed(4) }}
+ {{ selectedModel?.mae.toFixed(2) }} 元
+ {{ selectedModel?.rmse.toFixed(2) }} 元
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue b/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue
new file mode 100644
index 0000000..a42fa9f
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/PredictPage.vue
@@ -0,0 +1,321 @@
+
+
+
+
+ 成本预测
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 预测成本
+ 重置
+
+
+
+
+
+
预测结果
+
+
+
+
机器学习模型预测
+
+
+ {{ getModelName(mlPrediction.model_info.type) }}
+
+
+ {{ mlPrediction.model_info.name }}
+
+
+ {{ formatMoney(mlPrediction.predicted_cost) }}
+
+
+ {{ formatMoney(mlPrediction.confidence_interval.lower) }} ~
+ {{ formatMoney(mlPrediction.confidence_interval.upper) }}
+
+
+
+
+
+
+
PLS回归预测
+
+
+ {{ getModelName(plsPrediction.model_info.type) }}
+
+
+ {{ plsPrediction.model_info.name }}
+
+
+ {{ formatMoney(plsPrediction.predicted_cost) }}
+
+
+ {{ formatMoney(plsPrediction.confidence_interval.lower) }} ~
+ {{ formatMoney(plsPrediction.confidence_interval.upper) }}
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue b/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue
new file mode 100644
index 0000000..edd20c1
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/src/views/TrainingPage.vue
@@ -0,0 +1,370 @@
+
+
+
+
+ 模型训练
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PLS回归
+ XGBoost
+ LightGBM
+ GBM
+ Random Forest
+
+
+
+
+
+ 开始训练
+
+
+
+
+
+
+
训练结果
+
+
+
+
最佳模型: {{ getModelName(trainingResult.best_model.type) }}
+
R²分数: {{ formatNumber(trainingResult.best_model.r2) }}
+
MAE: {{ formatNumber(trainingResult.best_model.mae) }} 元
+
RMSE: {{ formatNumber(trainingResult.best_model.rmse) }} 元
+
+
+
+
+
+
+ {{ getModelName(scope.row.model) }}
+
+
+
+
+
+
+
+ {{ formatNumber(scope.row.train.r2) }}
+
+
+
+
+ {{ formatNumber(scope.row.train.mae) }}
+
+
+
+
+ {{ formatNumber(scope.row.train.rmse) }}
+
+
+
+
+
+
+
+
+ {{ formatNumber(scope.row.validation.r2) }}
+
+
+
+
+ {{ formatNumber(scope.row.validation.mae) }}
+
+
+
+
+ {{ formatNumber(scope.row.validation.rmse) }}
+
+
+
+
+
+
+
+
特征重要性
+
+
+
+
+ {{ formatNumber(scope.row.importance) }}
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/frontend/vue.config.js b/deploy/equipment_cost_system/frontend/vue.config.js
new file mode 100644
index 0000000..b2a2970
--- /dev/null
+++ b/deploy/equipment_cost_system/frontend/vue.config.js
@@ -0,0 +1,5 @@
+module.exports = {
+ devServer: {
+ port: 8080
+ }
+}
diff --git a/deploy/equipment_cost_system/requirements.txt b/deploy/equipment_cost_system/requirements.txt
new file mode 100644
index 0000000..6eec783
--- /dev/null
+++ b/deploy/equipment_cost_system/requirements.txt
@@ -0,0 +1,12 @@
+flask==2.0.1
+flask-cors==3.0.10
+sqlalchemy==1.4.23
+pymysql==1.0.2
+cryptography==3.4.7 # MySQL 8.0+ 认证需要
+numpy==1.21.2
+pandas==1.3.3
+scikit-learn==0.24.2
+tensorflow==2.6.0
+urllib3<2.0.0 # 添加这一行,限制 urllib3 版本
+openpyxl==3.1.2 # 用于读取 .xlsx 文件
+xlrd==2.0.1 # 用于读取 .xls 文件
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/scripts/install.sh b/deploy/equipment_cost_system/scripts/install.sh
new file mode 100644
index 0000000..a7a80d6
--- /dev/null
+++ b/deploy/equipment_cost_system/scripts/install.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+echo "开始安装装备成本估算系统..."
+
+# 检查Python版本
+python3 -V || {
+ echo "错误: 需要 Python 3.8+"
+ exit 1
+}
+
+# 检查Node.js版本
+node -v || {
+ echo "错误: 需要 Node.js 14+"
+ exit 1
+}
+
+# 创建必要的目录
+echo "创建系统目录..."
+mkdir -p {logs,data,models}
+
+# 安装后端依赖
+echo "安装后端依赖..."
+python3 -m venv venv
+source venv/bin/activate
+pip install -r requirements.txt
+
+# 安装前端依赖
+echo "安装前端依赖..."
+cd frontend
+npm install
+npm run build
+cd ..
+
+# 配置文件
+if [ ! -f config/.env ]; then
+ echo "创建配置文件..."
+ cp config/.env.template config/.env
+ echo "请修改 config/.env 中的配置"
+fi
+
+# 初始化数据库
+echo "初始化数据库..."
+read -p "请输入MySQL root密码: " mysqlpass
+mysql -u root -p$mysqlpass < src/schema.sql
+
+# 导入测试数据(可选)
+read -p "是否导入测试数据?(y/n) " import_test_data
+if [ "$import_test_data" = "y" ]; then
+ mysql -u root -p$mysqlpass equipment_cost_db < src/init_data.sql
+fi
+
+# 设置权限
+echo "设置文件权限..."
+chmod +x scripts/*.sh
+chmod 755 logs models data
+chmod 600 config/.env
+
+echo "安装完成!"
+echo "请检查并修改 config/.env 中的配置"
+echo "使用 ./scripts/start.sh 启动服务"
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/scripts/start.sh b/deploy/equipment_cost_system/scripts/start.sh
new file mode 100644
index 0000000..5b28ac4
--- /dev/null
+++ b/deploy/equipment_cost_system/scripts/start.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+
+echo "启动装备成本估算系统..."
+
+# 检查配置文件
+if [ ! -f config/.env ]; then
+ echo "错误: 配置文件不存在"
+ echo "请先运行 install.sh"
+ exit 1
+fi
+
+# 检查日志目录
+if [ ! -d logs ]; then
+ mkdir -p logs
+fi
+
+# 激活虚拟环境
+source venv/bin/activate
+
+# 导出环境变量
+export $(cat config/.env | xargs)
+
+# 启动后端服务
+echo "启动后端服务..."
+gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" \
+ --daemon \
+ --pid gunicorn.pid \
+ --access-logfile logs/access.log \
+ --error-logfile logs/error.log
+
+# 等待后端服务启动
+sleep 2
+
+# 检查后端服务是否成功启动
+if ! curl -s http://localhost:5001/api/ > /dev/null; then
+ echo "错误: 后端服务启动失败"
+ exit 1
+fi
+
+# 启动前端服务
+echo "启动前端服务..."
+cd frontend
+npm run serve -- --port 8080 --host 0.0.0.0 &
+echo $! > ../frontend.pid
+cd ..
+
+echo "服务已启动!"
+echo "后端API: http://localhost:5001"
+echo "前端界面: http://localhost:8080"
+echo "查看后端日志: tail -f logs/access.log"
+echo "查看前端日志: tail -f logs/frontend.log"
diff --git a/deploy/equipment_cost_system/scripts/stop.sh b/deploy/equipment_cost_system/scripts/stop.sh
new file mode 100644
index 0000000..085d432
--- /dev/null
+++ b/deploy/equipment_cost_system/scripts/stop.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+echo "停止装备成本估算系统..."
+
+# 停止后端服务
+if [ -f gunicorn.pid ]; then
+ echo "停止后端服务..."
+ pid=$(cat gunicorn.pid)
+ kill -TERM $pid
+ rm gunicorn.pid
+ echo "后端服务已停止 (PID: $pid)"
+else
+ echo "未找到后端服务PID文件"
+ # 尝试查找并停止所有gunicorn进程
+ pkill -f gunicorn
+ echo "已尝试停止所有gunicorn进程"
+fi
+
+# 停止前端服务
+if [ -f frontend.pid ]; then
+ echo "停止前端服务..."
+ pid=$(cat frontend.pid)
+ kill -TERM $pid
+ rm frontend.pid
+ echo "前端服务已停止 (PID: $pid)"
+else
+ echo "未找到前端服务PID文件"
+ # 尝试查找并停止前端服务进程
+ pkill -f "vite preview"
+ echo "已尝试停止所有前端服务进程"
+fi
+
+# 检查是否还有相关进程在运行
+if pgrep -f gunicorn > /dev/null; then
+ echo "警告: 仍有gunicorn进程在运行"
+ ps aux | grep gunicorn | grep -v grep
+fi
+
+if pgrep -f "vite preview" > /dev/null; then
+ echo "警告: 仍有前端服务进程在运行"
+ ps aux | grep "vite preview" | grep -v grep
+fi
+
+echo "所有服务已停止"
diff --git a/deploy/equipment_cost_system/src/__init__.py b/deploy/equipment_cost_system/src/__init__.py
new file mode 100644
index 0000000..497b4a4
--- /dev/null
+++ b/deploy/equipment_cost_system/src/__init__.py
@@ -0,0 +1 @@
+# 这个文件可以为空,但必须存在
diff --git a/deploy/equipment_cost_system/src/app.py b/deploy/equipment_cost_system/src/app.py
new file mode 100644
index 0000000..037fb79
--- /dev/null
+++ b/deploy/equipment_cost_system/src/app.py
@@ -0,0 +1,50 @@
+from flask import Flask
+from flask_cors import CORS
+from .routes import api_bp
+from .logger import setup_logger
+import os
+
+# 获取logger
+logger = setup_logger(__name__)
+
+def create_app():
+ """
+ 创建并配置Flask应用
+ """
+ try:
+ # 创建必要的目录
+ os.makedirs('logs', exist_ok=True)
+ os.makedirs('data', exist_ok=True)
+ os.makedirs('models', exist_ok=True)
+
+ logger.info("=== Server Starting ===")
+ logger.info("Initializing directories...")
+
+ # 创建Flask应用
+ app = Flask(__name__)
+
+ # 配置CORS
+ CORS(app)
+ logger.info("CORS enabled")
+
+ # 注册API蓝图
+ app.register_blueprint(api_bp, url_prefix='/api')
+ logger.info("API blueprint registered")
+
+ # 配置数据库连接
+ app.config['MYSQL_HOST'] = 'localhost'
+ app.config['MYSQL_USER'] = 'root'
+ app.config['MYSQL_PASSWORD'] = '123456'
+ app.config['MYSQL_DB'] = 'equipment_cost_db'
+
+ logger.info("Starting server...")
+
+ return app
+
+ except Exception as e:
+ logger.error(f"Error creating app: {str(e)}")
+ raise
+
+if __name__ == '__main__':
+ app = create_app()
+ app.run(host='localhost', port=5001)
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/cost_prediction.py b/deploy/equipment_cost_system/src/cost_prediction.py
new file mode 100644
index 0000000..23ccce5
--- /dev/null
+++ b/deploy/equipment_cost_system/src/cost_prediction.py
@@ -0,0 +1,342 @@
+import numpy as np
+from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
+from sklearn.preprocessing import StandardScaler
+import tensorflow as tf
+from scipy import stats
+import joblib
+import os
+import pandas as pd
+from .feature_analysis import FeatureAnalysis
+import logging
+from src.model_trainer import ModelTrainer
+from src.database import get_db_connection
+from .logger import setup_logger
+
+logger = setup_logger(__name__)
+
+class CostPredictor:
+ def __init__(self):
+ self.scaler_X = StandardScaler()
+ self.scaler_y = StandardScaler()
+ self.model = None
+ self.feature_analyzer = FeatureAnalysis()
+ self.equipment_type = None
+
+ # 添加 TensorFlow 配置
+ tf.config.run_functions_eagerly(False) # 启用图执行模式
+
+ # 创建预测函数
+ @tf.function(reduce_retracing=True, jit_compile=True)
+ def predict_fn(x):
+ return self.model(x, training=False)
+
+ self._predict_fn = predict_fn
+
+ self.load_model()
+
+ def load_model(self):
+ """
+ 加载预训练型和标准化器
+ """
+ try:
+ model_dir = 'models'
+ os.makedirs(model_dir, exist_ok=True)
+
+ # 创建默认模型
+ self._create_default_model()
+
+ # 创建预测函数
+ @tf.function(reduce_retracing=True, jit_compile=True)
+ def predict_fn(x):
+ return self.model(x, training=False)
+
+ self._predict_fn = predict_fn
+
+ except Exception as e:
+ logging.error(f"Error loading model: {str(e)}")
+ self._create_default_model()
+
+ def _create_default_model(self):
+ """
+ 创建默认模型并进行初始化训练
+ """
+ # 创建输入层
+ inputs = tf.keras.Input(shape=(11,))
+
+ # 创建隐藏层
+ x = tf.keras.layers.Dense(64, activation='relu')(inputs)
+ x = tf.keras.layers.Dense(32, activation='relu')(x)
+
+ # 创建输出层
+ outputs = tf.keras.layers.Dense(1)(x)
+
+ # 创建模型
+ self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
+
+ # 编译模型
+ self.model.compile(
+ optimizer='adam',
+ loss=tf.keras.losses.mean_squared_error,
+ metrics=[tf.keras.metrics.mean_absolute_error]
+ )
+
+ # 创建示例数据
+ example_data = pd.DataFrame({
+ 'length_m': [7.35, 10.2],
+ 'width_m': [2.4, 2.8],
+ 'height_m': [3.1, 3.2],
+ 'weight_kg': [13700, 28500],
+ 'max_range_km': [20.4, 70],
+ 'firing_angle_horizontal': [102, 110],
+ 'firing_angle_vertical': [55, 60],
+ 'rocket_length_m': [2.87, 4.1],
+ 'rocket_diameter_mm': [122, 220],
+ 'rocket_weight_kg': [66.6, 150],
+ 'rate_of_fire': [40, 60]
+ })
+
+ # 训练标准化器
+ self.scaler_X.fit(example_data)
+ self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围
+
+ # 设置默认装备类型
+ self.equipment_type = '火箭炮'
+
+ def _create_example_data(self):
+ """
+ 创建示例数据来训练标准化器
+ """
+ # 火箭炮示例数据
+ rocket_data = pd.DataFrame({
+ 'length_m': [7.35, 10.2],
+ 'width_m': [2.4, 2.8],
+ 'height_m': [3.1, 3.2],
+ 'weight_kg': [13700, 28500],
+ 'max_range_km': [20.4, 70],
+ 'firing_angle_horizontal': [102, 110],
+ 'firing_angle_vertical': [55, 60],
+ 'rocket_length_m': [2.87, 4.1],
+ 'rocket_diameter_mm': [122, 220],
+ 'rocket_weight_kg': [66.6, 150],
+ 'rate_of_fire': [40, 60]
+ })
+
+ # 巡飞弹示例数据
+ missile_data = pd.DataFrame({
+ 'length_m': [1.3, 2.5],
+ 'width_m': [0.23, 0.6],
+ 'height_m': [0.23, 0.6],
+ 'weight_kg': [12.5, 135],
+ 'max_range_km': [40, 250],
+ 'max_speed_kmh': [180, 185],
+ 'cruise_speed_kmh': [100, 110],
+ 'flight_time_min': [60, 120],
+ 'folded_length_mm': [1300, 2500],
+ 'folded_width_mm': [230, 600],
+ 'folded_height_mm': [230, 600]
+ })
+
+ # 训练标准化器
+ self.scaler_X.fit(rocket_data) # 使用火箭炮数据
+ self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据
+
+ # 设置默认装备类型
+ self.equipment_type = '火箭炮'
+
+ def predict(self, data):
+ """
+ 使用训练好的最优模型进行预测
+ """
+ try:
+ logger.info(f"Starting prediction for {data.get('type')}")
+ equipment_type = data.get('type')
+
+ # 加载已训练的最优模型
+ trainer = ModelTrainer()
+ if not trainer.load_model(equipment_type):
+ raise ValueError(f"No trained model found for {equipment_type}")
+
+ # 准备特征数据
+ features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
+ X = np.array([[data.get(feature) for feature in features]])
+
+ # 预测
+ y_pred = trainer.predict(X)
+
+ # 计算置信区间
+ confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
+
+ # 获取模型类型
+ model_type = trainer.get_model_type()
+
+ return {
+ 'predicted_cost': float(y_pred[0]),
+ 'model_type': model_type, # 返回使用的模型类型
+ 'confidence_interval': {
+ 'lower': float(confidence_interval[0]),
+ 'upper': float(confidence_interval[1])
+ }
+ }
+
+ except Exception as e:
+ logger.error(f"Prediction error: {str(e)}")
+ raise
+
+ def _calculate_confidence_interval(self, prediction, confidence=0.95):
+ """
+ 计算预测值的置信区间
+ """
+ try:
+ # 使用预测值的20%作为标准差(增加不确定性)
+ std = abs(prediction) * 0.2
+
+ # 计算置信区间
+ from scipy import stats
+ interval = stats.norm.interval(confidence, loc=prediction, scale=std)
+
+ # 确保区间值为正数且合理
+ lower = max(1000, interval[0]) # 最小值设为1000元
+ upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
+
+ logging.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
+
+ return [lower, upper]
+
+ except Exception as e:
+ logging.error(f"Error calculating confidence interval: {str(e)}")
+ # 如果计算失败,返回基于20%的简单区间
+ lower = max(1000, prediction * 0.8)
+ upper = prediction * 1.2
+ return [lower, upper]
+
+ def evaluate(self, y_true, y_pred):
+ """
+ 模型评估
+ """
+ return {
+ 'mae': float(mean_absolute_error(y_true, y_pred)),
+ 'mse': float(mean_squared_error(y_true, y_pred)),
+ 'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))),
+ 'r2': float(r2_score(y_true, y_pred))
+ }
+
+ def predict_pls(self, data):
+ """
+ 使用 PLS ���型预测成本
+ """
+ try:
+ logger.info(f"Starting PLS prediction for {data.get('type')}")
+ equipment_type = data.get('type')
+
+ # 加载 PLS 模型
+ trainer = ModelTrainer()
+ if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型
+ raise ValueError(f"No trained PLS model found for {equipment_type}")
+
+ # 准备特征数据
+ features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
+ X = np.array([[data.get(feature) for feature in features]])
+
+ # 预测
+ y_pred = trainer.predict(X)
+
+ # 计算置信区间
+ confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
+
+ return {
+ 'predicted_cost': float(y_pred[0]),
+ 'confidence_interval': {
+ 'lower': float(confidence_interval[0]),
+ 'upper': float(confidence_interval[1])
+ }
+ }
+
+ except Exception as e:
+ logger.error(f"PLS prediction error: {str(e)}")
+ raise
+
+ def predict_all(self, data):
+ """
+ 使用所有可用模型进行预测
+ """
+ try:
+ logger.info(f"Starting multi-model prediction for {data.get('type')}")
+ equipment_type = data.get('type')
+ results = {}
+
+ # 1. 获取所有激活的模型
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute("""
+ SELECT id, model_type, model_name, r2_score, mae, rmse
+ FROM trained_models
+ WHERE equipment_type = %s AND is_active = TRUE
+ """, (equipment_type,))
+ active_models = cursor.fetchall()
+
+ if not active_models:
+ raise ValueError(f"No active models found for {equipment_type}")
+
+ # 2. 使用每个模型进行预测
+ trainer = ModelTrainer()
+ for model_info in active_models:
+ try:
+ # 加载特定模型
+ if not trainer.load_model(equipment_type, model_type=model_info['model_type']):
+ logger.warning(f"Failed to load model: {model_info['model_name']}")
+ continue
+
+ # 准备特征数据
+ features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
+ X = np.array([[data.get(feature) for feature in features]])
+
+ # 预测
+ y_pred = trainer.predict(X)
+
+ # 计算置信区间
+ confidence_interval = trainer._calculate_confidence_interval(y_pred[0])
+
+ # 保存结果
+ results[model_info['model_type']] = {
+ 'predicted_cost': float(y_pred[0]),
+ 'model_info': {
+ 'name': model_info['model_name'],
+ 'type': model_info['model_type'],
+ 'r2_score': float(model_info['r2_score']),
+ 'mae': float(model_info['mae']),
+ 'rmse': float(model_info['rmse'])
+ },
+ 'confidence_interval': {
+ 'lower': float(confidence_interval[0]),
+ 'upper': float(confidence_interval[1])
+ }
+ }
+
+ except Exception as e:
+ logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}")
+ continue
+
+ if not results:
+ raise ValueError("No successful predictions from any model")
+
+ # 3. 计算综合预测结果
+ all_predictions = [result['predicted_cost'] for result in results.values()]
+ ensemble_prediction = float(np.mean(all_predictions))
+ prediction_std = float(np.std(all_predictions))
+
+ # 4. 返回所有结果
+ return {
+ 'individual_predictions': results,
+ 'ensemble_prediction': {
+ 'predicted_cost': ensemble_prediction,
+ 'standard_deviation': prediction_std,
+ 'confidence_interval': {
+ 'lower': float(ensemble_prediction - 1.96 * prediction_std),
+ 'upper': float(ensemble_prediction + 1.96 * prediction_std)
+ }
+ }
+ }
+
+ except Exception as e:
+ logger.error(f"Error in multi-model prediction: {str(e)}")
+ raise
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/create_template.py b/deploy/equipment_cost_system/src/create_template.py
new file mode 100644
index 0000000..0f7d545
--- /dev/null
+++ b/deploy/equipment_cost_system/src/create_template.py
@@ -0,0 +1,155 @@
+import pandas as pd
+import openpyxl
+from openpyxl.styles import PatternFill, Font, Alignment
+from openpyxl.worksheet.datavalidation import DataValidation
+import os
+from .logger import setup_logger
+
+logger = setup_logger(__name__)
+
+def create_excel_template():
+ """
+ 创建数据模板
+ """
+ try:
+ # 确保data目录存在
+ data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
+ os.makedirs(data_dir, exist_ok=True)
+
+ # 创建完整的文件路径
+ template_path = os.path.join(data_dir, 'equipment_data_template.xlsx')
+
+ # 创建 Excel 写入器
+ writer = pd.ExcelWriter(template_path, engine='openpyxl')
+
+ # 火箭炮基本参数表
+ rocket_artillery_columns = [
+ '名称', '类型', '制造商', '口径_mm',
+ '反射管数量', '乘员数', '总长_m',
+ '宽度_m', '高度_m', '重量_kg',
+ '战斗重_kg', '速度_km/h', '最大射程_km',
+ '最小射程_km', '方向射界_度', '高低射界_度',
+ '火箭弹长度_m', '火箭弹重量_kg',
+ '火箭弹最大速度_m/s', '射速_发',
+ '战斗部重量_kg', '行走方式',
+ '结构布局', '发动机型号', '发动机参数',
+ '功率_hp', '行程_km', '成本_元'
+ ]
+
+ # 巡飞弹基本参数表
+ loitering_munition_columns = [
+ '名称', '类型', '制造商', '目标类型',
+ '弹长_m', '弹径_mm', '翼展_m',
+ '重量_kg', '战斗部重量_kg',
+ '最大射程_km', '最大速度_m/s',
+ '巡航速度_kmh', '巡飞时间_min',
+ '战斗部类型', '发射方式',
+ '折叠长度_mm', '折叠宽度_mm',
+ '折叠高度_mm', '动力装置',
+ '制导体制', '成本_元'
+ ]
+
+ # 特殊参数表
+ special_params_columns = [
+ '装备名称', # 关联字段
+ '参数名称',
+ '参数值',
+ '参数单位',
+ '参数说明'
+ ]
+
+ # 创建工作表
+ pd.DataFrame(columns=rocket_artillery_columns).to_excel(
+ writer, sheet_name='火箭炮', index=False
+ )
+ pd.DataFrame(columns=loitering_munition_columns).to_excel(
+ writer, sheet_name='巡飞弹', index=False
+ )
+ pd.DataFrame(columns=special_params_columns).to_excel(
+ writer, sheet_name='特殊参数', index=False
+ )
+
+ # 获取工作簿
+ workbook = writer.book
+
+ # 设置火箭炮工作表格式
+ rocket_sheet = workbook['火箭炮']
+ for col in range(1, len(rocket_artillery_columns) + 1):
+ cell = rocket_sheet.cell(row=1, column=col)
+ cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
+ cell.font = Font(bold=True)
+ cell.alignment = Alignment(horizontal='center')
+
+ # 设置巡飞弹工作表格式
+ missile_sheet = workbook['巡飞弹']
+ for col in range(1, len(loitering_munition_columns) + 1):
+ cell = missile_sheet.cell(row=1, column=col)
+ cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
+ cell.font = Font(bold=True)
+ cell.alignment = Alignment(horizontal='center')
+
+ # 设置特殊参数工作表格式
+ special_sheet = workbook['特殊参数']
+ for col in range(1, len(special_params_columns) + 1):
+ cell = special_sheet.cell(row=1, column=col)
+ cell.fill = PatternFill(start_color='CCCCCC', end_color='CCCCCC', fill_type='solid')
+ cell.font = Font(bold=True)
+ cell.alignment = Alignment(horizontal='center')
+
+ # 添加数据验证
+ for sheet in [rocket_sheet, missile_sheet]:
+ # 数值验证
+ number_validation = DataValidation(type="decimal", operator="greaterThan", formula1="0")
+ number_validation.error = "请输入大于0的数值"
+ number_validation.errorTitle = "输入错误"
+
+ # 应用到相应列
+ for col in ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
+ number_validation.add(f"{col}2:{col}1000")
+
+ sheet.add_data_validation(number_validation)
+
+ # 添加说明
+ rocket_sheet['AD1'] = "填写说明:"
+ rocket_sheet['AD2'] = "1. 所有数值必须大于0"
+ rocket_sheet['AD3'] = "2. 单位必须按照表头要求填写"
+ rocket_sheet['AD4'] = "3. 成本单位为元"
+
+ missile_sheet['V1'] = "填写说明:"
+ missile_sheet['V2'] = "1. 所有数值必须大于0"
+ missile_sheet['V3'] = "2. 单位必须按照表头要求填写"
+ missile_sheet['V4'] = "3. 成本单位为元"
+
+ special_sheet['G1'] = "填写说明:"
+ special_sheet['G2'] = "1. 装备名称必须与基本参数表中的名称一致"
+ special_sheet['G3'] = "2. 参数值必须包含单位"
+ special_sheet['G4'] = "3. 参数说明应简明扼要"
+
+ # 调整列宽
+ for sheet in [rocket_sheet, missile_sheet, special_sheet]:
+ for col in sheet.columns:
+ max_length = 0
+ column = col[0].column_letter
+ for cell in col:
+ try:
+ if len(str(cell.value)) > max_length:
+ max_length = len(str(cell.value))
+ except:
+ pass
+ adjusted_width = (max_length + 2)
+ sheet.column_dimensions[column].width = adjusted_width
+
+ # 保存文件
+ writer.close()
+
+ return template_path
+
+ except Exception as e:
+ raise Exception(f"创建模板文件失败: {str(e)}")
+
+if __name__ == "__main__":
+ try:
+ template_path = create_excel_template()
+ print(f"模板文件已创建: {template_path}")
+ except Exception as e:
+ print(f"错误: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/data_preparation.py b/deploy/equipment_cost_system/src/data_preparation.py
new file mode 100644
index 0000000..6380647
--- /dev/null
+++ b/deploy/equipment_cost_system/src/data_preparation.py
@@ -0,0 +1,233 @@
+from sklearn.preprocessing import StandardScaler
+from datetime import datetime
+import os
+import joblib
+import pandas as pd
+import numpy as np
+from src.feature_analysis import FeatureAnalysis
+from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
+from xgboost import XGBRegressor
+from lightgbm import LGBMRegressor
+from sklearn.model_selection import cross_val_score, LeaveOneOut
+import json
+import logging
+from src.database.db_connection import get_db_connection
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+from .logger import setup_logger
+
+logger = setup_logger(__name__)
+
+class DataPreparation:
+ def __init__(self):
+ self.feature_analyzer = FeatureAnalysis()
+ self.feature_scaler = StandardScaler()
+ self.target_scaler = StandardScaler() # 添加目标值标准化器
+
+ def prepare_training_data(self, equipment_data, equipment_type):
+ """
+ 准备训练数据
+ """
+ try:
+ logger.info(f"Preparing training data for {equipment_type}")
+ logger.info(f"Raw data size: {len(equipment_data)}")
+
+ # 如果输入已经是 numpy 数组,直接返回
+ if isinstance(equipment_data, np.ndarray):
+ X = equipment_data
+ logger.info(f"Input is already numpy array with shape: {X.shape}")
+
+ # 处理无效值
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
+
+ return {
+ 'X': X,
+ 'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
+ 'feature_scaler': self.feature_scaler,
+ 'target_scaler': self.target_scaler
+ }
+
+ # 从原始数据中提取特征和目标值
+ feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
+ features = []
+ targets = []
+
+ for item in equipment_data:
+ # 提取特征值
+ feature_values = []
+ for name in feature_names:
+ value = item.get(name)
+ try:
+ feature_values.append(float(value) if value is not None else 0.0)
+ except (ValueError, TypeError):
+ feature_values.append(0.0)
+ features.append(feature_values)
+
+ # 提取目标值(成本)
+ try:
+ cost = float(item['actual_cost'])
+ if cost > 0: # 只使用正数成本值
+ targets.append(cost)
+ else:
+ logger.warning(f"Skipping non-positive cost value: {cost}")
+ except (ValueError, TypeError, KeyError):
+ logger.error(f"Invalid cost value: {item.get('actual_cost')}")
+ continue
+
+ # 转换为numpy数组
+ X = np.array(features, dtype=float)
+ y = np.array(targets, dtype=float)
+
+ # 记录原始数据范围
+ logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
+ logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
+
+ # 标准化特征和目标值
+ X_scaled = self.feature_scaler.fit_transform(X)
+ y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
+
+ # 记录标准化后的数据范围
+ logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}")
+ logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}")
+
+ # 记录标准化器参数
+ logger.info("Feature scaler params:")
+ logger.info(f"Mean: {self.feature_scaler.mean_}")
+ logger.info(f"Scale: {self.feature_scaler.scale_}")
+
+ logger.info("Target scaler params:")
+ logger.info(f"Mean: {self.target_scaler.mean_}")
+ logger.info(f"Scale: {self.target_scaler.scale_}")
+
+ return {
+ 'X': X_scaled,
+ 'y': y_scaled,
+ 'feature_names': feature_names,
+ 'feature_scaler': self.feature_scaler,
+ 'target_scaler': self.target_scaler
+ }
+
+ except Exception as e:
+ logger.error(f"Error in data preparation: {str(e)}")
+ raise Exception(f"Training error: {str(e)}")
+
+ def prepare_validation_data(self, validation_data, equipment_type, feature_names=None, scalers=None):
+ """
+ 准备验证数据
+ """
+ try:
+ logger.info(f"Preparing validation data for {equipment_type}")
+ logger.info(f"Raw validation data size: {len(validation_data)}")
+
+ # 如果输入已经是 numpy 数组,直接使用
+ if isinstance(validation_data, np.ndarray):
+ X = validation_data
+ logger.info(f"Input is already numpy array with shape: {X.shape}")
+
+ # 处理无效值
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # 使用训练数据的标准化器
+ if scalers and 'feature_scaler' in scalers:
+ X_scaled = scalers['feature_scaler'].transform(X)
+ else:
+ X_scaled = X
+
+ logger.info(f"Preprocessed data shape: {X_scaled.shape}")
+ logger.info(f"Validation features shape: {X_scaled.shape}")
+ logger.info(f"Validation features type: {X_scaled.dtype}")
+
+ return {
+ 'X': X_scaled,
+ 'y': None # 验证数据可能没有标签
+ }
+
+ # 从原始数据中提取特征和目标值
+ if not feature_names:
+ feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
+
+ features = []
+ targets = []
+
+ for item in validation_data:
+ # 提取特征值
+ feature_values = []
+ for name in feature_names:
+ value = item.get(name)
+ try:
+ feature_values.append(float(value) if value is not None else 0.0)
+ except (ValueError, TypeError):
+ feature_values.append(0.0)
+ features.append(feature_values)
+
+ # 提取目标值(成本)并验证范围
+ try:
+ cost = float(item['actual_cost'])
+ logger.info(f"Raw cost value: {cost}")
+ if cost > 0: # 只使用正数成本值
+ targets.append(cost)
+ else:
+ logger.warning(f"Skipping non-positive cost value: {cost}")
+ except (ValueError, TypeError):
+ logger.error(f"Invalid cost value: {item.get('actual_cost')}")
+ continue
+
+ # 转换为numpy数组
+ X = np.array(features, dtype=float)
+ y = np.array(targets, dtype=float)
+
+ # 记录数据范围
+ logger.info(f"Features range: min={X.min()}, max={X.max()}")
+ logger.info(f"Targets range: min={y.min()}, max={y.max()}")
+
+ # 处理无效值
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # 使用训练数据的标准化器
+ if scalers and 'feature_scaler' in scalers:
+ X_scaled = scalers['feature_scaler'].transform(X)
+ if 'target_scaler' in scalers:
+ y_scaled = scalers['target_scaler'].transform(y.reshape(-1, 1)).ravel()
+ else:
+ y_scaled = y
+ else:
+ X_scaled = X
+ y_scaled = y
+
+ logger.info(f"Preprocessed data shape: {X_scaled.shape}")
+ logger.info(f"Validation features shape: {X_scaled.shape}")
+ logger.info(f"Validation features type: {X_scaled.dtype}")
+
+ # 记录标准化后的数据范围
+ logger.info(f"Scaled validation X range: min={X_scaled.min()}, max={X_scaled.max()}")
+ logger.info(f"Scaled validation y range: min={y_scaled.min()}, max={y_scaled.max()}")
+
+ # 确保特征维度一致
+ if not feature_names:
+ feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
+
+ logger.info(f"Expected features: {len(feature_names)}")
+ logger.info(f"Actual features: {X_scaled.shape[1]}")
+
+ if X_scaled.shape[1] != len(feature_names):
+ raise ValueError(f"Feature dimension mismatch: expected {len(feature_names)}, got {X_scaled.shape[1]}")
+
+ return {
+ 'X': X_scaled,
+ 'y': y_scaled # 返回标准化后的成本值
+ }
+
+ except Exception as e:
+ logger.error(f"Error in validation data preparation: {str(e)}")
+ logger.error(f"Feature names: {feature_names}")
+ logger.error(f"Equipment type: {equipment_type}")
+ raise Exception(f"Validation error: {str(e)}")
+
+ def calculate_derived_features(self, data, equipment_type):
+ """
+ 计算衍生特征
+ """
+ try:
+ return self.feature_analyzer.calculate_derived_features(data, equipment_type)
+ except Exception as e:
+ logger.error(f"Error calculating derived features: {str(e)}")
+ raise Exception(f"Feature calculation error: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/database/__init__.py b/deploy/equipment_cost_system/src/database/__init__.py
new file mode 100644
index 0000000..6df49c9
--- /dev/null
+++ b/deploy/equipment_cost_system/src/database/__init__.py
@@ -0,0 +1 @@
+from .db_connection import get_db_connection
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/database/db_connection.py b/deploy/equipment_cost_system/src/database/db_connection.py
new file mode 100644
index 0000000..361a4d2
--- /dev/null
+++ b/deploy/equipment_cost_system/src/database/db_connection.py
@@ -0,0 +1,37 @@
+import mysql.connector
+from mysql.connector import Error
+from contextlib import contextmanager
+import os
+from dotenv import load_dotenv
+from ..logger import setup_logger
+
+# 获取logger
+logger = setup_logger(__name__)
+
+# 加载环境变量
+load_dotenv()
+
+@contextmanager
+def get_db_connection():
+ """
+ 数据库连接上下文管理器
+ """
+ connection = None
+ try:
+ connection = mysql.connector.connect(
+ host=os.getenv('MYSQL_HOST', 'localhost'),
+ user=os.getenv('MYSQL_USER', 'root'),
+ password=os.getenv('MYSQL_PASSWORD', '123456'),
+ database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
+ )
+ logger.info("Database connection established")
+ yield connection
+
+ except Error as e:
+ logger.error(f"Error connecting to MySQL: {str(e)}")
+ raise
+
+ finally:
+ if connection and connection.is_connected():
+ connection.close()
+ logger.info("Database connection closed")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/feature_analysis.py b/deploy/equipment_cost_system/src/feature_analysis.py
new file mode 100644
index 0000000..5b87972
--- /dev/null
+++ b/deploy/equipment_cost_system/src/feature_analysis.py
@@ -0,0 +1,269 @@
+import numpy as np
+import pandas as pd
+from scipy import stats
+from sklearn.preprocessing import StandardScaler
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.metrics import r2_score
+import logging
+from .logger import setup_logger
+
+logger = setup_logger(__name__)
+
+class FeatureAnalysis:
+ def __init__(self):
+ self.scaler = StandardScaler()
+ self.important_features = []
+
+ # 添加特征名称映射
+ self.feature_names_map = {
+ # 通用参数
+ 'length_m': '总长(m)',
+ 'width_m': '宽度(m)',
+ 'height_m': '高度(m)',
+ 'weight_kg': '重量(kg)',
+ 'max_range_km': '最大射程(km)',
+
+ # 火箭炮特有参数
+ 'firing_angle_horizontal': '方向射界(度)',
+ 'firing_angle_vertical': '高低射界(度)',
+ 'rocket_length_m': '火箭弹长度(m)',
+ 'rocket_diameter_mm': '口径(mm)',
+ 'rocket_weight_kg': '火箭弹重量(kg)',
+ 'rate_of_fire': '射速(发/分)',
+ 'combat_weight_kg': '战斗重量(kg)',
+ 'speed_kmh': '速度(km/h)',
+ 'min_range_km': '最小射程(km)',
+ 'power_hp': '功率(hp)',
+
+ # 火箭炮衍生特征
+ 'fire_density': '火力密度',
+ 'mobility_index': '机动性指标',
+ 'range_ratio': '射程比',
+ 'power_weight_ratio': '功重比',
+ 'volume_density': '体积密度',
+
+ # 巡飞弹特有参数
+ 'wingspan_m': '翼展(m)',
+ 'warhead_weight_kg': '战斗部重量(kg)',
+ 'max_speed_ms': '最大速度(m/s)',
+ 'cruise_speed_kmh': '巡航速度(km/h)',
+ 'flight_time_min': '巡飞时间(min)',
+ 'folded_length_mm': '折叠长度(mm)',
+ 'folded_width_mm': '折叠宽度(mm)',
+ 'folded_height_mm': '折叠高度(mm)',
+
+ # 巡飞弹衍生特征
+ 'warhead_ratio': '战斗部比重',
+ 'speed_ratio': '速度比',
+ 'range_time_ratio': '射程时间比',
+ 'aspect_ratio': '展弦比',
+ 'volume_density': '体积密度'
+ }
+
+ def get_equipment_specific_features(self, equipment_type):
+ """
+ 获取特定装备类型的特征列表
+ """
+ # 通用参数
+ common_features = [
+ 'length_m', # 总长(m)
+ 'width_m', # 宽度(m)
+ 'height_m', # 高度(m)
+ 'weight_kg', # 重量(kg)
+ 'max_range_km' # 最大射程(km)
+ ]
+
+ if equipment_type == '火箭炮':
+ # 火箭炮特有参数
+ specific_features = [
+ 'firing_angle_horizontal', # 方向射界(度)
+ 'firing_angle_vertical', # 高低射界(度)
+ 'rocket_length_m', # 火箭弹长度(m)
+ 'rocket_diameter_mm', # 口径(mm)
+ 'rocket_weight_kg', # 火箭弹重量(kg)
+ 'rate_of_fire', # 射速(发/分)
+ 'combat_weight_kg', # 战斗重量(kg)
+ 'speed_kmh', # 速度(km/h)
+ 'min_range_km', # 最小射程(km)
+ 'power_hp' # 功率(hp)
+ ]
+
+ # 火箭炮衍生特征
+ derived_features = [
+ 'fire_density', # 火力密度 = 射速 * 火箭弹重量
+ 'mobility_index', # 机动性指标 = 速度 / 战斗重量
+ 'range_ratio', # 射程比 = 最大射程 / 最小射程
+ 'power_weight_ratio', # 功重比 = 功率 / 战斗重量
+ 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
+ ]
+
+ return common_features + specific_features + derived_features
+
+ else: # 巡飞弹
+ # 巡飞弹特有参数
+ specific_features = [
+ 'wingspan_m', # 翼展(m)
+ 'warhead_weight_kg', # 战斗部重量(kg)
+ 'max_speed_ms', # 最大速度(m/s)
+ 'cruise_speed_kmh', # 巡航速度(km/h)
+ 'flight_time_min', # 巡飞时间(min)
+ 'folded_length_mm', # 折叠长度(mm)
+ 'folded_width_mm', # 折叠宽度(mm)
+ 'folded_height_mm' # 折叠高度(mm)
+ ]
+
+ # 巡飞弹衍生特征
+ derived_features = [
+ 'warhead_ratio', # 战斗部比重 = 战斗部重量 / 总重量
+ 'speed_ratio', # 速度比 = 巡航速度 / 最大速度
+ 'range_time_ratio', # 射程时间比 = 最大射程 / 巡飞时间
+ 'aspect_ratio', # 展弦比 = 翼展^2 / 参考面积
+ 'volume_density' # 体积密度 = 重量 / (长 * 宽 * 高)
+ ]
+
+ return common_features + specific_features + derived_features
+
+ def calculate_derived_features(self, data, equipment_type):
+ """
+ 计算衍生特征
+ """
+ try:
+ if equipment_type == '火箭炮':
+ # 火箭炮衍生特征计算
+ if 'rate_of_fire' in data.columns and 'rocket_weight_kg' in data.columns:
+ data['fire_density'] = data['rate_of_fire'] * data['rocket_weight_kg']
+ else:
+ data['fire_density'] = 0 # 或者其他默认值
+
+ if 'speed_kmh' in data.columns and 'combat_weight_kg' in data.columns:
+ data['mobility_index'] = data['speed_kmh'] / data['combat_weight_kg']
+ else:
+ data['mobility_index'] = 0
+
+ if 'max_range_km' in data.columns and 'min_range_km' in data.columns:
+ data['range_ratio'] = data['max_range_km'] / data['min_range_km']
+ else:
+ data['range_ratio'] = 0
+
+ if 'power_hp' in data.columns and 'combat_weight_kg' in data.columns:
+ data['power_weight_ratio'] = data['power_hp'] / data['combat_weight_kg']
+ else:
+ data['power_weight_ratio'] = 0
+
+ if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
+ data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
+ else:
+ data['volume_density'] = 0
+
+ else: # 巡飞弹
+ # 巡飞弹衍生特征计算
+ if 'warhead_weight_kg' in data.columns and 'weight_kg' in data.columns:
+ data['warhead_ratio'] = data['warhead_weight_kg'] / data['weight_kg']
+ else:
+ data['warhead_ratio'] = 0
+
+ if 'cruise_speed_kmh' in data.columns and 'max_speed_ms' in data.columns:
+ data['speed_ratio'] = data['cruise_speed_kmh'] / (data['max_speed_ms'] * 3.6)
+ else:
+ data['speed_ratio'] = 0
+
+ if 'max_range_km' in data.columns and 'flight_time_min' in data.columns:
+ data['range_time_ratio'] = data['max_range_km'] / data['flight_time_min']
+ else:
+ data['range_time_ratio'] = 0
+
+ if 'wingspan_m' in data.columns and 'length_m' in data.columns:
+ data['aspect_ratio'] = (data['wingspan_m'] ** 2) / data['length_m']
+ else:
+ data['aspect_ratio'] = 0
+
+ if all(col in data.columns for col in ['weight_kg', 'length_m', 'width_m', 'height_m']):
+ data['volume_density'] = data['weight_kg'] / (data['length_m'] * data['width_m'] * data['height_m'])
+ else:
+ data['volume_density'] = 0
+
+ return data
+
+ except Exception as e:
+ logger.error(f"Error calculating derived features: {str(e)}")
+ raise
+
+ def analyze_features(self, features, target, feature_names):
+ """
+ 分析特征重要性和相关性
+ """
+ try:
+ # 转换为numpy数组
+ X = np.array(features)
+ y = np.array(target)
+
+ # 数据标准化
+ X_scaled = self.scaler.fit_transform(X)
+
+ # 特征重要性分析
+ rf = RandomForestRegressor(n_estimators=100, random_state=42)
+ rf.fit(X_scaled, y)
+ importances = rf.feature_importances_
+
+ # 按重要性排序,使用中文特征名
+ importance_indices = np.argsort(importances)[::-1]
+ important_features = [
+ {
+ 'name': self.feature_names_map.get(feature_names[i], feature_names[i]),
+ 'importance': float(importances[i])
+ }
+ for i in importance_indices
+ ]
+
+ # 相关性分析
+ df = pd.DataFrame(X_scaled, columns=feature_names)
+ correlation_matrix = df.corr().values
+
+ # 生成相关性分析数据,保留2位小数
+ correlation_data = []
+ chinese_feature_names = [self.feature_names_map.get(name, name) for name in feature_names]
+ for i in range(len(feature_names)):
+ for j in range(len(feature_names)):
+ correlation_data.append([
+ i, j,
+ round(correlation_matrix[i][j], 2) # 修改为保留2位小数
+ ])
+
+ return {
+ 'important_features': important_features,
+ 'correlation_analysis': {
+ 'features': chinese_feature_names, # 使用中文特征名
+ 'matrix': correlation_data
+ }
+ }
+
+ except Exception as e:
+ logger.error(f"Error in feature analysis: {str(e)}")
+ raise
+
+ def preprocess_features(self, equipment_data, equipment_type):
+ """
+ 预处理特征数据
+ """
+ try:
+ # 转换为 DataFrame
+ df = pd.DataFrame(equipment_data)
+
+ # 计算衍生特征
+ df = self.calculate_derived_features(df, equipment_type)
+
+ # 处理缺失值
+ numeric_columns = df.select_dtypes(include=[np.number]).columns
+ for col in numeric_columns:
+ # 转换为数值类型
+ df[col] = pd.to_numeric(df[col], errors='coerce')
+ # 使用新的方式填充缺失值
+ mean_value = df[col].mean()
+ df[col] = df[col].fillna(mean_value)
+
+ logger.info(f"Preprocessed data shape: {df.shape}")
+ return df
+
+ except Exception as e:
+ logger.error(f"Error preprocessing features: {str(e)}")
+ raise Exception(f"Feature preprocessing error: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/import_data.py b/deploy/equipment_cost_system/src/import_data.py
new file mode 100644
index 0000000..03dfa36
--- /dev/null
+++ b/deploy/equipment_cost_system/src/import_data.py
@@ -0,0 +1,255 @@
+import pandas as pd
+from .logger import setup_logger
+from src.database.db_connection import get_db_connection
+
+logger = setup_logger(__name__)
+
+def import_training_data(excel_file):
+ """
+ 从Excel导入训练数据到数据库
+ """
+ try:
+ # 读取所有sheet
+ rocket_df = pd.read_excel(excel_file, sheet_name='火箭炮')
+ missile_df = pd.read_excel(excel_file, sheet_name='巡飞弹')
+ special_df = pd.read_excel(excel_file, sheet_name='特殊参数')
+
+ # 记录所有装备名称,用于后续检查
+ equipment_names = set()
+
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 1. 先导入火箭炮数据
+ logger.info("开始导入火箭炮数据...")
+ for _, row in rocket_df.iterrows():
+ equipment_names.add(row['名称'])
+ # 检查是否已存在相同名称的装备
+ cursor.execute("""
+ SELECT id FROM equipment
+ WHERE name = %s AND type = '火箭炮'
+ """, (row['名称'],))
+
+ existing_equipment = cursor.fetchone()
+ if existing_equipment:
+ logger.warning(f"火箭炮 '{row['名称']}' 已存在,跳过导入")
+ continue
+
+ # 插入基本信息
+ cursor.execute("""
+ INSERT INTO equipment (name, type, manufacturer)
+ VALUES (%s, %s, %s)
+ """, (row['名称'], '火箭炮', row['制造商']))
+
+ equipment_id = cursor.lastrowid
+
+ # 插入通用参数
+ cursor.execute("""
+ INSERT INTO common_params
+ (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
+ VALUES (%s, %s, %s, %s, %s, %s)
+ """, (
+ equipment_id,
+ row['总长_m'] if pd.notna(row['总长_m']) else None,
+ row['宽度_m'] if pd.notna(row['宽度_m']) else None,
+ row['高度_m'] if pd.notna(row['高度_m']) else None,
+ row['重量_kg'] if pd.notna(row['重量_kg']) else None,
+ row['最大射程_km'] if pd.notna(row['最大射程_km']) else None
+ ))
+
+ # 插入火箭炮特有参数
+ cursor.execute("""
+ INSERT INTO rocket_artillery_params
+ (equipment_id, firing_angle_horizontal, firing_angle_vertical,
+ rocket_length_m, rocket_diameter_mm, rocket_weight_kg, rate_of_fire,
+ combat_weight_kg, speed_kmh, min_range_km, mobility_type,
+ structure_layout, engine_model, engine_params, power_hp,
+ travel_range_km)
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
+ """, (
+ equipment_id,
+ row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
+ row['高低射界_度'] if pd.notna(row['高低射界_度']) else None,
+ row['火箭弹长度_m'] if pd.notna(row['火箭弹长度_m']) else None,
+ row['口径_mm'] if pd.notna(row['口径_mm']) else None,
+ row['火箭弹重量_kg'] if pd.notna(row['火箭弹重量_kg']) else None,
+ row['射速_发'] if pd.notna(row['射速_发']) else None,
+ row['战斗重_kg'] if pd.notna(row['战斗重_kg']) else None,
+ row['速度_km/h'] if pd.notna(row['速度_km/h']) else None,
+ row['最小射程_km'] if pd.notna(row['最小射程_km']) else None,
+ row['行走方式'] if pd.notna(row['行走方式']) else None,
+ row['结构布局'] if pd.notna(row['结构布局']) else None,
+ row['发动机型号'] if pd.notna(row['发动机型号']) else None,
+ row['发动机参数'] if pd.notna(row['发动机参数']) else None,
+ row['功率_hp'] if pd.notna(row['功率_hp']) else None,
+ row['行程_km'] if pd.notna(row['行程_km']) else None
+ ))
+
+ # 插入成本数据
+ if pd.notna(row['成本_元']):
+ cursor.execute("""
+ INSERT INTO cost_data (equipment_id, actual_cost)
+ VALUES (%s, %s)
+ """, (equipment_id, row['成本_元']))
+
+ logger.info("火箭炮数据导入完成")
+
+ # 2. 导入巡飞弹数据
+ logger.info("开始导入巡飞弹数据...")
+ for index, row in missile_df.iterrows():
+ # 记录每行数据的空值情况
+ null_values = row[row.isna()].index.tolist()
+ if null_values:
+ logger.info(f"行 {index + 2} 中的空值字段: {null_values}")
+
+ equipment_names.add(row['名称'])
+ # 检查是否已存在相同名称的装备
+ cursor.execute("""
+ SELECT id FROM equipment
+ WHERE name = %s AND type = '巡飞弹'
+ """, (row['名称'],))
+
+ existing_equipment = cursor.fetchone()
+ if existing_equipment:
+ logger.warning(f"巡飞弹 '{row['名称']}' 已存在,跳过导入")
+ continue
+
+ # 插入基本信息
+ cursor.execute("""
+ INSERT INTO equipment (name, type, manufacturer)
+ VALUES (%s, %s, %s)
+ """, (
+ row['名称'],
+ '巡飞弹',
+ row['制造商'] if pd.notna(row['制造商']) else None
+ ))
+
+ equipment_id = cursor.lastrowid
+
+ # 插入通用参数
+ cursor.execute("""
+ INSERT INTO common_params
+ (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
+ VALUES (%s, %s, %s, %s, %s, %s)
+ """, (
+ equipment_id,
+ float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
+ float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
+ float(row['弹径_mm'])/1000 if pd.notna(row['弹径_mm']) else None, # 转换为米
+ float(row['重量_kg']) if pd.notna(row['重量_kg']) else None,
+ float(row['最大射程_km']) if pd.notna(row['最大射程_km']) else None
+ ))
+
+ # 插入巡飞弹特有参数
+ cursor.execute("""
+ INSERT INTO loitering_munition_params
+ (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms,
+ cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
+ folded_length_mm, folded_width_mm, folded_height_mm,
+ power_system, guidance_system)
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
+ """, (
+ equipment_id,
+ float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
+ float(row['战斗部重量_kg']) if pd.notna(row['战斗部重量_kg']) else None,
+ float(row['最大速度_m/s']) if pd.notna(row['最大速度_m/s']) else None,
+ float(row['巡航速度_km/h']) if pd.notna(row['巡航速度_km/h']) else None,
+ float(row['巡飞时间_min']) if pd.notna(row['巡飞时间_min']) else None,
+ str(row['战斗部类型']) if pd.notna(row['战斗部类型']) else None,
+ str(row['发射方式']) if pd.notna(row['发射方式']) else None,
+ float(row['折叠长度_mm']) if pd.notna(row['折叠长度_mm']) else None,
+ float(row['折叠宽度_mm']) if pd.notna(row['折叠宽度_mm']) else None,
+ float(row['折叠高度_mm']) if pd.notna(row['折叠高度_mm']) else None,
+ str(row['动力装置']) if pd.notna(row['动力装置']) else None,
+ str(row['制导体制']) if pd.notna(row['制导体制']) else None
+ ))
+
+ # 插入成本数据
+ if pd.notna(row['成本_元']):
+ cursor.execute("""
+ INSERT INTO cost_data (equipment_id, actual_cost)
+ VALUES (%s, %s)
+ """, (equipment_id, float(row['成本_元'])))
+
+ logger.info("巡飞弹数据导入完成")
+
+ # 提交之前的更改并关闭原有游标
+ cursor.close()
+ conn.commit()
+
+ # 3. 导入特殊参数
+ logger.info("开始导入特殊参数...")
+ for index, row in special_df.iterrows():
+ equipment_name = row['装备名称']
+ param_name = row['参数名称']
+ logger.info(f"处理第 {index + 1} 条记录: 装备='{equipment_name}', 参数='{param_name}'")
+
+ if equipment_name not in equipment_names:
+ logger.warning(f"未找到装备: {equipment_name},请检查名称是否正确")
+ continue
+
+ # 获取装备ID - 使用新的游标
+ logger.debug(f"查询装备ID: {equipment_name}")
+ with conn.cursor() as id_cursor:
+ id_cursor.execute("""
+ SELECT id FROM equipment WHERE name = %s
+ """, (equipment_name,))
+ result = id_cursor.fetchone()
+
+ if not result:
+ logger.warning(f"未找到装备: {equipment_name}")
+ continue
+
+ equipment_id = result[0]
+ logger.debug(f"找到装备ID: {equipment_id}")
+
+ # 检查参数是否存在 - 使用新的游标
+ logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
+ with conn.cursor() as check_cursor:
+ check_cursor.execute("""
+ SELECT id FROM custom_params
+ WHERE equipment_id = %s AND param_name = %s
+ """, (equipment_id, param_name))
+ exists = check_cursor.fetchone()
+
+ if exists:
+ logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
+ continue
+
+ # 插入新的参数 - 使用新的游标
+ param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
+ param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
+ param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
+
+ logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
+ with conn.cursor() as insert_cursor:
+ insert_cursor.execute("""
+ INSERT INTO custom_params
+ (equipment_id, param_name, param_value, param_unit, description)
+ VALUES (%s, %s, %s, %s, %s)
+ """, (
+ equipment_id,
+ param_name,
+ param_value,
+ param_unit,
+ param_desc
+ ))
+ logger.debug(f"成功插入参数记录")
+
+ # 最终提交
+ conn.commit()
+ logger.info("特殊参数导入完成")
+ logger.info("所有数据导入成功")
+ return True
+
+ except Exception as e:
+ logger.error(f"Error importing data: {str(e)}")
+ raise
+
+if __name__ == "__main__":
+ try:
+ excel_file = 'data/equipment_data_20241108.xlsx'
+ import_training_data(excel_file)
+ logger.info("All data imported successfully")
+ except Exception as e:
+ logger.error(f"Import failed: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/init_data.sql b/deploy/equipment_cost_system/src/init_data.sql
new file mode 100644
index 0000000..ee99db3
--- /dev/null
+++ b/deploy/equipment_cost_system/src/init_data.sql
@@ -0,0 +1,319 @@
+/*
+这是用于开发和测试环境的示例数据。
+生产环境请使用系统的数据导入功能添加实际数据。
+
+主要用途:
+1. 提供开发测试数据
+2. 作为数据格式参考
+3. 用于系统功能验证
+*/
+
+-- 插入装备基本信息
+INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
+('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
+('胜利-2', '火箭炮', '伊朗', '地面固定目标');
+
+-- 插入巡飞弹技术参数
+INSERT INTO technical_params (
+ equipment_id,
+ length_m,
+ width_m,
+ height_m,
+ weight_kg,
+ max_speed_kmh,
+ cruise_speed_kmh,
+ max_range_km,
+ flight_time_min,
+ warhead_type,
+ launch_mode,
+ folded_length_mm,
+ folded_width_mm,
+ folded_height_mm
+) VALUES (
+ 1, -- 终结者巡飞弹
+ 0.56,
+ 0.15,
+ 0.20,
+ 2.72,
+ 160.93,
+ 96.56,
+ 24,
+ 15,
+ '破片杀伤战斗部',
+ '凭自身动力起飞',
+ 560,
+ 150,
+ 200
+);
+
+-- 插入火箭炮技术参数
+INSERT INTO technical_params (
+ equipment_id,
+ length_m,
+ width_m,
+ height_m,
+ weight_kg,
+ max_range_km
+) VALUES (
+ 2, -- 胜利-2火箭炮
+ 10,
+ 2.5,
+ 3.34,
+ 15000,
+ 23
+);
+
+-- 插入成本数据(示例数据)
+INSERT INTO cost_data (equipment_id, actual_cost) VALUES
+(1, 1000000), -- 终结者巡飞弹成本
+(2, 5000000); -- 胜利-2火箭炮成本
+
+-- 插入更多巡飞弹变体数据用于训练
+INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
+('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
+('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'),
+('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆');
+
+-- 插入变体技术参数
+INSERT INTO technical_params (
+ equipment_id,
+ length_m,
+ width_m,
+ height_m,
+ weight_kg,
+ max_speed_kmh,
+ cruise_speed_kmh,
+ max_range_km,
+ flight_time_min,
+ warhead_type,
+ launch_mode,
+ folded_length_mm,
+ folded_width_mm,
+ folded_height_mm
+) VALUES
+-- 终结者-A(稍大型号)
+(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210),
+-- 终结者-B(稍小型号)
+(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190),
+-- 终结者-C(标准型号的改进版)
+(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200);
+
+-- 插入变体成本数据
+INSERT INTO cost_data (equipment_id, actual_cost) VALUES
+(3, 1100000), -- 终结者-A成本(较高)
+(4, 900000), -- 终结者-B成本(较低)
+(5, 1050000); -- 终结者-C成本(中等)
+
+-- 添加更多巡飞弹数据
+INSERT INTO equipment (name, type, manufacturer, target_type) VALUES
+('哈比', '巡飞弹', '以色列', '防空系统和雷达站'),
+('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'),
+('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'),
+('弹簧刀', '巡飞弹', '波兰', '装甲目标'),
+('彩虹-4', '巡飞弹', '中国', '地面固定目标');
+
+-- 添加它们的技术参数
+INSERT INTO technical_params (
+ equipment_id,
+ length_m,
+ width_m,
+ height_m,
+ weight_kg,
+ max_speed_kmh,
+ cruise_speed_kmh,
+ max_range_km,
+ flight_time_min,
+ warhead_type,
+ launch_mode,
+ folded_length_mm,
+ folded_width_mm,
+ folded_height_mm
+) VALUES
+-- 哈比
+(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600),
+-- 游荡者
+(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400),
+-- 凤凰
+(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300),
+-- 弹簧刀
+(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350),
+-- 彩虹-4
+(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800);
+
+-- 添加成本数据
+INSERT INTO cost_data (equipment_id, actual_cost) VALUES
+(6, 800000), -- 哈比
+(7, 500000), -- 游荡者
+(8, 450000), -- 凤凰
+(9, 480000), -- 弹簧刀
+(10, 1500000); -- 彩虹-4
+
+-- 火箭炮数据
+INSERT INTO equipment (name, type, manufacturer) VALUES
+('BM-21', '火箭炮', '俄罗斯'),
+('SR5', '火箭炮', '中国'),
+('HIMARS', '火箭炮', '美国'),
+('LAR-160', '火箭炮', '以色列'),
+('T-122', '火箭炮', '土耳其'),
+('RM-70', '火箭炮', '捷克'),
+('ASTROS II', '火箭炮', '巴西');
+
+-- 火箭炮通用参数
+INSERT INTO common_params (
+ equipment_id,
+ length_m,
+ width_m,
+ height_m,
+ weight_kg,
+ max_range_km
+) VALUES
+-- BM-21
+(1, 7.35, 2.4, 3.1, 13700, 20.4),
+-- SR5
+(2, 10.2, 2.8, 3.2, 28500, 70),
+-- HIMARS
+(3, 7.0, 2.4, 3.2, 16250, 70),
+-- LAR-160
+(4, 6.7, 2.5, 2.8, 15000, 45),
+-- T-122
+(5, 7.2, 2.5, 2.9, 18000, 40),
+-- RM-70
+(6, 7.5, 2.5, 3.0, 17200, 20.3),
+-- ASTROS II
+(7, 8.0, 2.7, 3.1, 24500, 90);
+
+-- 火箭炮特有参数
+INSERT INTO rocket_artillery_params (
+ equipment_id,
+ firing_angle_horizontal,
+ firing_angle_vertical,
+ rocket_length_m,
+ rocket_diameter_mm,
+ rocket_weight_kg,
+ rate_of_fire,
+ combat_weight_kg,
+ speed_kmh,
+ min_range_km,
+ mobility_type,
+ structure_layout,
+ engine_model,
+ engine_params,
+ power_hp,
+ travel_range_km
+) VALUES
+-- BM-21
+(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500),
+-- SR5
+(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650),
+-- HIMARS
+(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480),
+-- LAR-160
+(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550),
+-- T-122
+(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600),
+-- RM-70
+(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520),
+-- ASTROS II
+(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700);
+
+-- 巡飞弹数据
+INSERT INTO equipment (name, type, manufacturer) VALUES
+('Hero-120', '巡飞弹', '以色列'),
+('Switchblade 600', '巡飞弹', '美国'),
+('Warmate', '巡飞弹', '波兰'),
+('CH-901', '巡飞弹', '中国'),
+('HAROP', '巡飞弹', '以色列'),
+('Coyote', '巡飞弹', '美国'),
+('WS-43', '巡飞弹', '中国');
+
+-- 巡飞弹通用参数
+INSERT INTO common_params (
+ equipment_id,
+ length_m,
+ width_m,
+ height_m,
+ weight_kg,
+ max_range_km
+) VALUES
+-- Hero-120
+(8, 1.3, 0.23, 0.23, 12.5, 40),
+-- Switchblade 600
+(9, 1.3, 0.22, 0.22, 15.0, 40),
+-- Warmate
+(10, 1.1, 0.15, 0.15, 5.7, 15),
+-- CH-901
+(11, 1.2, 0.18, 0.18, 9.0, 20),
+-- HAROP
+(12, 2.5, 0.43, 0.43, 135, 1000),
+-- Coyote
+(13, 0.9, 0.12, 0.12, 5.9, 20),
+-- WS-43
+(14, 1.8, 0.35, 0.35, 20, 60);
+
+-- 巡飞弹特有参数
+INSERT INTO loitering_munition_params (
+ equipment_id,
+ wingspan_m,
+ warhead_weight_kg,
+ max_speed_ms,
+ cruise_speed_kmh,
+ flight_time_min,
+ warhead_type,
+ launch_mode,
+ folded_length_mm,
+ folded_width_mm,
+ folded_height_mm,
+ power_system,
+ guidance_system
+) VALUES
+-- Hero-120
+(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'),
+-- Switchblade 600
+(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'),
+-- Warmate
+(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'),
+-- CH-901
+(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'),
+-- HAROP
+(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'),
+-- Coyote
+(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'),
+-- WS-43
+(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电');
+
+-- 插入成本数据(示例成本)
+INSERT INTO cost_data (equipment_id, actual_cost) VALUES
+-- 火箭炮
+(1, 800000), -- BM-21
+(2, 4500000), -- SR5
+(3, 5500000), -- HIMARS
+(4, 3500000), -- LAR-160
+(5, 2800000), -- T-122
+(6, 1500000), -- RM-70
+(7, 4800000), -- ASTROS II
+-- 巡飞弹
+(8, 150000), -- Hero-120
+(9, 180000), -- Switchblade 600
+(10, 80000), -- Warmate
+(11, 100000), -- CH-901
+(12, 850000), -- HAROP
+(13, 75000), -- Coyote
+(14, 120000); -- WS-43
+
+-- 创建初始数据集
+INSERT INTO datasets (name, description, equipment_type, purpose) VALUES
+('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'),
+('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'),
+('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'),
+('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证');
+
+-- 关联装备到数据集
+INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES
+-- 火箭炮训练集
+(1, 1), (1, 2), (1, 3), (1, 4),
+-- 巡飞弹训练集
+(2, 8), (2, 9), (2, 10), (2, 11), (2, 12),
+-- 火箭炮验证集
+(3, 5), (3, 6), (3, 7),
+-- 巡飞弹验证集
+(4, 13), (4, 14);
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/logger.py b/deploy/equipment_cost_system/src/logger.py
new file mode 100644
index 0000000..52bb879
--- /dev/null
+++ b/deploy/equipment_cost_system/src/logger.py
@@ -0,0 +1,33 @@
+import logging
+import os
+from datetime import datetime
+
+def setup_logger(name):
+ """
+ 创建并配置logger
+ """
+ # 创建logger
+ logger = logging.getLogger(name)
+
+ # 如果logger已经有处理器,直接返回
+ if logger.handlers:
+ return logger
+
+ # 设置日志级别
+ logger.setLevel(logging.INFO)
+
+ # 确保日志目录存在
+ os.makedirs('logs', exist_ok=True)
+
+ # 创建文件处理器
+ file_handler = logging.FileHandler('logs/api.log')
+ file_handler.setLevel(logging.INFO)
+
+ # 创建格式化器
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ file_handler.setFormatter(formatter)
+
+ # 添加处理器
+ logger.addHandler(file_handler)
+
+ return logger
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/model_trainer.py b/deploy/equipment_cost_system/src/model_trainer.py
new file mode 100644
index 0000000..660c38e
--- /dev/null
+++ b/deploy/equipment_cost_system/src/model_trainer.py
@@ -0,0 +1,612 @@
+import numpy as np
+import pandas as pd
+from sklearn.preprocessing import StandardScaler
+from sklearn.model_selection import cross_val_score
+from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
+from sklearn.impute import SimpleImputer
+import xgboost as xgb
+import lightgbm as lgb
+import logging
+import joblib
+import os
+from src.feature_analysis import FeatureAnalysis
+from datetime import datetime
+import json
+from src.database import get_db_connection
+from src.data_preparation import DataPreparation
+from sklearn.cross_decomposition import PLSRegression
+from .logger import setup_logger
+
+logger = setup_logger(__name__)
+
+class ModelTrainer:
+ def __init__(self):
+ """
+ 初始化 ModelTrainer
+ """
+ self.models = {
+ 'xgboost': self._create_xgboost_model(),
+ 'lightgbm': self._create_lightgbm_model(),
+ 'gbm': self._create_gbm_model(),
+ 'rf': self._create_rf_model(),
+ 'pls': self._create_pls_model()
+ }
+ self.best_model = None
+ self.imputer = SimpleImputer(strategy='mean')
+ self.feature_scaler = None
+ self.target_scaler = None
+ self.equipment_type = None
+ self.feature_analyzer = FeatureAnalysis()
+
+ def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None):
+ """
+ 训练模型并返回评估结果
+ """
+ try:
+ self.equipment_type = equipment_type
+ logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}")
+ logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}")
+
+ results = {}
+ best_score = -float('inf')
+ best_model_info = None
+
+ # 首先训练 PLS 模型
+ logger.info("Training pls...")
+ pls_model = self.models['pls']
+ pls_model.fit(X_train, y_train)
+ pls_metrics = self._calculate_metrics(
+ pls_model,
+ X_train, y_train,
+ X_val, y_val
+ )
+ results['pls'] = pls_metrics
+
+ # 训练其他机器学习模型
+ for model_name in model_names:
+ if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了
+ continue
+
+ if model_name not in self.models:
+ logger.warning(f"Unknown model: {model_name}")
+ continue
+
+ logger.info(f"Training {model_name}...")
+ model = self.models[model_name]
+
+ # 训练模型
+ model.fit(X_train, y_train)
+
+ # 计算评估指标
+ metrics = self._calculate_metrics(
+ model,
+ X_train, y_train,
+ X_val, y_val
+ )
+
+ results[model_name] = metrics
+
+ # 更新最佳模型(只在机器学习模型中比较)
+ if metrics['validation']['r2'] > best_score:
+ best_score = metrics['validation']['r2']
+ best_model_info = {
+ 'type': model_name,
+ 'r2': metrics['validation']['r2'],
+ 'mae': metrics['validation']['mae'],
+ 'rmse': metrics['validation']['rmse']
+ }
+ self.best_model = model
+
+ # 保存最佳模型和 PLS 模型
+ if equipment_type and best_model_info:
+ self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val)
+
+ return {
+ 'metrics': results,
+ 'best_model': best_model_info
+ }
+
+ except Exception as e:
+ logger.error(f"Error in model training: {str(e)}")
+ raise
+
+ def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None):
+ """
+ 计算模型评估指标
+ """
+ from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
+
+ # 训练集评估
+ train_pred = model.predict(X_train)
+
+ # 如果使用了标准化,需要转换回原始范围
+ if hasattr(self, 'target_scaler'):
+ train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
+ y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
+ else:
+ y_train_orig = y_train
+
+ # 记录预测范围
+ logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}")
+ logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}")
+
+ train_metrics = {
+ 'r2': r2_score(y_train_orig, train_pred),
+ 'mae': mean_absolute_error(y_train_orig, train_pred),
+ 'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred))
+ }
+
+ # 验证集评估
+ if X_val is not None and y_val is not None:
+ val_pred = model.predict(X_val)
+
+ # 如果使用了标准化,需要转换回原始范围
+ if hasattr(self, 'target_scaler'):
+ val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel()
+ y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
+ else:
+ y_val_orig = y_val
+
+ # 记录预测范围
+ logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}")
+ logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}")
+
+ val_metrics = {
+ 'r2': r2_score(y_val_orig, val_pred),
+ 'mae': mean_absolute_error(y_val_orig, val_pred),
+ 'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred))
+ }
+ else:
+ # 使用交叉验证
+ cv_scores = cross_val_score(model, X_train, y_train, cv=5)
+ val_metrics = {
+ 'r2': cv_scores.mean(),
+ 'mae': None,
+ 'rmse': None
+ }
+
+ return {
+ 'train': train_metrics,
+ 'validation': val_metrics
+ }
+
+ def _create_xgboost_model(self):
+ """
+ 创建 XGBoost 模型,增强正则化
+ """
+ return xgb.XGBRegressor(
+ n_estimators=50, # 减少树的数量
+ learning_rate=0.05, # 学习率
+ max_depth=3, # 减小树的深
+ min_child_weight=3, # 增加节点权重
+ subsample=0.7, # 减小样本采样比例
+ colsample_bytree=0.7, # 减小特征采样比例
+ reg_alpha=0.1, # L1 正则化
+ reg_lambda=1, # L2 正则化
+ random_state=42
+ )
+
+ def _create_lightgbm_model(self):
+ """
+ 创建 LightGBM 模型,增强正则化
+ """
+ return lgb.LGBMRegressor(
+ n_estimators=50,
+ learning_rate=0.05,
+ max_depth=3,
+ num_leaves=7,
+ min_data_in_leaf=3,
+ min_sum_hessian_in_leaf=1e-3,
+ subsample=0.7,
+ colsample_bytree=0.7,
+ reg_alpha=0.1,
+ reg_lambda=1,
+ random_state=42,
+ verbose=-1
+ )
+
+ def _create_gbm_model(self):
+ """
+ 创建 GBM 模型,增强正则化以减轻过拟合
+ """
+ return GradientBoostingRegressor(
+ n_estimators=100,
+ learning_rate=0.1,
+ max_depth=3,
+ random_state=42,
+ subsample=0.8,
+ min_samples_split=3,
+ min_samples_leaf=2
+ )
+
+ def _create_rf_model(self):
+ """
+ 创建随机森林模型,针对小样本数据调整参数
+ """
+ return RandomForestRegressor(
+ n_estimators=100,
+ max_depth=3,
+ random_state=42,
+ min_samples_split=3,
+ min_samples_leaf=2
+ )
+
+ def _create_pls_model(self):
+ """
+ 创建 PLS 模型,优化参数配置
+ """
+ return PLSRegression(
+ n_components=2, # 减少主成分数量,从5减到2
+ scale=True, # 保持数据标准化
+ max_iter=500, # 减少最大迭代次数,避免过拟合
+ tol=1e-6 # 降低收敛精度,避免过拟合
+ )
+
+ def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None):
+ """
+ 保存最佳模型和 PLS 模型
+ """
+ try:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ model_dir = 'models'
+ os.makedirs(model_dir, exist_ok=True)
+
+ # 1. 保存最佳机器学习模型
+ model_path = f'{model_dir}/{equipment_type}_{timestamp}'
+ if isinstance(self.best_model, xgb.XGBRegressor):
+ self.best_model.save_model(f'{model_path}.json')
+ model_format = 'json'
+ else:
+ joblib.dump(self.best_model, f'{model_path}.joblib')
+ model_format = 'joblib'
+
+ # 2. 保存 PLS 模型
+ pls_model = self.models['pls']
+ pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib'
+ joblib.dump(pls_model, pls_path)
+
+ # 3. 保存标准化器
+ scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib'
+ joblib.dump({
+ 'feature_scaler': self.feature_scaler,
+ 'target_scaler': self.target_scaler
+ }, scaler_path)
+
+ logger.info(f"Saved best model to {model_path}.{model_format}")
+ logger.info(f"Saved PLS model to {pls_path}")
+ logger.info(f"Saved scalers to {scaler_path}")
+
+ # 4. 更新数据库中的模型记录
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 将所有模型设置为非激活
+ cursor.execute("""
+ UPDATE trained_models
+ SET is_active = FALSE
+ WHERE equipment_type = %s
+ """, (equipment_type,))
+
+ # 获取 PLS 模型的评估指标
+ pls_metrics = self._calculate_metrics(
+ self.models['pls'],
+ X_train,
+ y_train,
+ X_val,
+ y_val
+ )
+
+ # 保存最佳机器学习模型记录
+ self.best_model.equipment_type = equipment_type # 设置装备类型
+ ml_feature_importance = self._get_feature_importance(self.best_model)
+
+ cursor.execute("""
+ INSERT INTO trained_models (
+ model_name, model_type, equipment_type, model_path, scaler_path,
+ r2_score, mae, rmse, feature_importance, training_data_size,
+ training_date, is_active, created_by
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
+ """, (
+ f"{equipment_type}_{timestamp}", # model_name
+ best_model_info['type'], # model_type
+ equipment_type, # equipment_type
+ f"{model_path}.{model_format}", # model_path
+ scaler_path, # scaler_path
+ best_model_info['r2'], # r2_score
+ best_model_info['mae'], # mae
+ best_model_info['rmse'], # rmse
+ json.dumps(ml_feature_importance), # feature_importance
+ len(X_train), # training_data_size
+ 'system' # created_by
+ ))
+
+ # 保存 PLS 模型记录
+ pls_model.equipment_type = equipment_type # 设置装备类型
+ pls_feature_importance = self._get_feature_importance(pls_model)
+
+ cursor.execute("""
+ INSERT INTO trained_models (
+ model_name, model_type, equipment_type, model_path, scaler_path,
+ r2_score, mae, rmse, feature_importance, training_data_size,
+ training_date, is_active, created_by
+ ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s)
+ """, (
+ f"{equipment_type}_{timestamp}_pls", # model_name
+ 'pls', # model_type
+ equipment_type, # equipment_type
+ pls_path, # model_path
+ scaler_path, # scaler_path
+ float(pls_metrics['validation']['r2']), # r2_score
+ float(pls_metrics['validation']['mae']), # mae
+ float(pls_metrics['validation']['rmse']), # rmse
+ json.dumps(pls_feature_importance), # feature_importance
+ len(X_train), # training_data_size
+ 'system' # created_by
+ ))
+
+ conn.commit()
+
+ except Exception as e:
+ logger.error(f"Error saving models: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ raise
+
+ def load_model(self, equipment_type, model_type='ml'):
+ """
+ 加载已训练的模型
+ """
+ try:
+ logger.info(f"Loading {model_type} model for {equipment_type}")
+
+ # 从数据库获取激活的模型
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+
+ # 构建查询语句
+ if model_type == 'pls':
+ query = """
+ SELECT * FROM trained_models
+ WHERE equipment_type = %s
+ AND model_type = 'pls'
+ AND is_active = TRUE
+ LIMIT 1
+ """
+ params = (equipment_type,)
+ else:
+ query = """
+ SELECT * FROM trained_models
+ WHERE equipment_type = %s
+ AND model_type != 'pls'
+ AND is_active = TRUE
+ LIMIT 1
+ """
+ params = (equipment_type,)
+
+ # 记录查询信息
+ logger.info(f"Executing query: {query}")
+ logger.info(f"Query params: {params}")
+
+ cursor.execute(query, params)
+ model_record = cursor.fetchone()
+
+ # 记录查询结果
+ if model_record:
+ logger.info(f"Found model record: {model_record}")
+ else:
+ logger.warning(f"No active model found for type {model_type}")
+ return False
+
+ # 检查文件是否存在
+ logger.info(f"Checking model file: {model_record['model_path']}")
+ logger.info(f"Checking scaler file: {model_record['scaler_path']}")
+
+ if not os.path.exists(model_record['model_path']):
+ logger.error(f"Model file not found: {model_record['model_path']}")
+ raise FileNotFoundError(f"Model file not found: {model_record['model_path']}")
+
+ if not os.path.exists(model_record['scaler_path']):
+ logger.error(f"Scaler file not found: {model_record['scaler_path']}")
+ raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}")
+
+ # 加载模型文件
+ logger.info(f"Loading model from {model_record['model_path']}")
+ if model_type == 'pls':
+ self.best_model = joblib.load(model_record['model_path'])
+ logger.info("Loaded PLS model")
+ else:
+ if model_record['model_type'] == 'xgboost':
+ self.best_model = xgb.XGBRegressor()
+ self.best_model.load_model(model_record['model_path'])
+ logger.info("Loaded XGBoost model")
+ else:
+ self.best_model = joblib.load(model_record['model_path'])
+ logger.info(f"Loaded {model_record['model_type']} model")
+
+ # 加载标准化器
+ logger.info(f"Loading scalers from {model_record['scaler_path']}")
+ scalers = joblib.load(model_record['scaler_path'])
+ self.feature_scaler = scalers['feature_scaler']
+ self.target_scaler = scalers['target_scaler']
+ logger.info("Loaded scalers successfully")
+
+ return True
+
+ except Exception as e:
+ logger.error(f"Error loading model: {str(e)}")
+ logger.error(f"Detailed traceback:", exc_info=True)
+ return False
+
+ def predict(self, features):
+ """
+ 使用加载的模型进行预测
+ """
+ try:
+ if not self.best_model:
+ raise ValueError("No model loaded")
+
+ if not self.feature_scaler:
+ raise ValueError("Feature scaler not loaded")
+
+ if not self.target_scaler:
+ raise ValueError("Target scaler not loaded")
+
+ logger.info("Starting prediction")
+ logger.info(f"Input features shape: {features.shape}")
+ logger.info(f"Input features: \n{features}")
+
+ # 处理缺失值
+ features_filled = np.array(features, dtype=float)
+ features_filled[np.isnan(features_filled)] = 0
+ features_filled = np.nan_to_num(features_filled, 0)
+
+ logger.info(f"Filled features: \n{features_filled}")
+
+ # 标准化特征
+ X = self.feature_scaler.transform(features_filled)
+ logger.info(f"Transformed features shape: {X.shape}")
+ logger.info(f"Transformed features: \n{X}")
+
+ # 预测
+ y_pred_scaled = self.best_model.predict(X)
+ logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}")
+ logger.info(f"Scaled prediction: {y_pred_scaled}")
+
+ # ��标准化
+ y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1))
+ logger.info(f"Final prediction shape: {y_pred.shape}")
+ logger.info(f"Final prediction: {y_pred}")
+
+ # 记录标准化器的参数
+ logger.info("Target scaler params:")
+ logger.info(f"Mean: {self.target_scaler.mean_}")
+ logger.info(f"Scale: {self.target_scaler.scale_}")
+
+ return y_pred.ravel()
+
+ except Exception as e:
+ logger.error(f"Error in prediction: {str(e)}")
+ raise
+
+ def _get_feature_importance(self, model):
+ """
+ 获取特征重要性
+ """
+ try:
+ if not model:
+ return {}
+
+ # 获取特征名称
+ feature_analyzer = FeatureAnalysis()
+ feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
+
+ # 获取特���重要性
+ if hasattr(model, 'feature_importances_'):
+ importances = model.feature_importances_
+ elif hasattr(model, 'coef_'):
+ if len(model.coef_.shape) > 1: # 如果是二维数组
+ importances = np.abs(model.coef_[0]) # 取第一行
+ else:
+ importances = np.abs(model.coef_)
+ else:
+ return {}
+
+ # 创建特征重要性字典
+ importance_dict = {}
+ for name, importance in zip(feature_names, importances):
+ importance_dict[name] = float(importance) # 确保转换为 Python 标量
+
+ # 按重要性降序排序
+ sorted_dict = dict(sorted(
+ importance_dict.items(),
+ key=lambda x: x[1],
+ reverse=True
+ ))
+
+ # 过滤掉重要性为0的特征
+ return {k: v for k, v in sorted_dict.items() if v > 0}
+
+ except Exception as e:
+ logger.error(f"Error getting feature importance: {str(e)}")
+ return {}
+
+ def _calculate_confidence_interval(self, prediction, confidence=0.95):
+ """
+ 计算预测值的置信区间
+ """
+ try:
+ # 使用预测值的20%作为标准差(增加不确定性)
+ std = abs(prediction) * 0.2
+
+ # 计算置信区间
+ from scipy import stats
+ interval = stats.norm.interval(confidence, loc=prediction, scale=std)
+
+ # 确保区间值为正数且合理
+ lower = max(1000, interval[0]) # 最小值设为1000元
+ upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20%
+
+ logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]")
+
+ return [lower, upper]
+
+ except Exception as e:
+ logger.error(f"Error calculating confidence interval: {str(e)}")
+ # 如果计算失败,返回基于20%的简单区间
+ lower = max(1000, prediction * 0.8)
+ upper = prediction * 1.2
+ return [lower, upper]
+
+ def get_model_type(self):
+ """
+ 获取当前模型的类型
+ """
+ if isinstance(self.best_model, xgb.XGBRegressor):
+ return 'xgboost'
+ elif isinstance(self.best_model, lgb.LGBMRegressor):
+ return 'lightgbm'
+ elif isinstance(self.best_model, GradientBoostingRegressor):
+ return 'gbm'
+ elif isinstance(self.best_model, RandomForestRegressor):
+ return 'rf'
+ else:
+ return 'unknown'
+
+ def _get_pls_feature_importance(self):
+ """
+ 获取 PLS 模型的特征重要性
+ """
+ try:
+ if not self.models['pls']:
+ return {}
+
+ # 获取特征名称
+ feature_analyzer = FeatureAnalysis()
+ feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type)
+
+ # 获取 PLS 模型的系数作为特征重要性
+ pls_model = self.models['pls']
+ if hasattr(pls_model, 'coef_'):
+ # 使用绝对值作为重要性指标
+ importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组
+ else:
+ return {}
+
+ # 创建特征重要性字典
+ importance_dict = {}
+ for name, importance in zip(feature_names, importances):
+ importance_dict[name] = float(importance) # 确保转换为 Python 标量
+
+ # 按重要性降序排序
+ sorted_dict = dict(sorted(
+ importance_dict.items(),
+ key=lambda x: x[1],
+ reverse=True
+ ))
+
+ # 过滤掉重要性为0的特征
+ return {k: v for k, v in sorted_dict.items() if v > 0}
+
+ except Exception as e:
+ logger.error(f"Error getting PLS feature importance: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ return {}
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/routes.py b/deploy/equipment_cost_system/src/routes.py
new file mode 100644
index 0000000..df4e825
--- /dev/null
+++ b/deploy/equipment_cost_system/src/routes.py
@@ -0,0 +1,1254 @@
+from flask import Blueprint, request, jsonify, send_file
+from .cost_prediction import CostPredictor
+from .feature_analysis import FeatureAnalysis
+import pandas as pd
+from datetime import datetime
+import numpy as np
+import mysql.connector
+from sklearn.metrics import mean_absolute_error
+from .create_template import create_excel_template
+import json
+import os
+from .data_preparation import DataPreparation
+from .model_trainer import ModelTrainer
+from .logger import setup_logger
+
+# 创建蓝图
+api_bp = Blueprint('api', __name__)
+
+# 获取logger
+logger = setup_logger(__name__)
+
+@api_bp.route('/', methods=['GET'])
+def index():
+ """
+ API根路由
+ 返回API版本信息和可用端点列表
+ """
+ return jsonify({
+ 'name': '装备成本估算系统 API',
+ 'version': '1.0.0',
+ 'endpoints': {
+ 'predict': {
+ 'url': '/api/predict',
+ 'method': 'POST',
+ 'description': '成本预测'
+ },
+ 'analyze-features': {
+ 'url': '/api/analyze-features',
+ 'method': 'POST',
+ 'description': '特征分析'
+ },
+ 'train': {
+ 'url': '/api/train',
+ 'method': 'POST',
+ 'description': '模型训练'
+ },
+ 'evaluate': {
+ 'url': '/api/evaluate',
+ 'method': 'POST',
+ 'description': '模型评估'
+ }
+ }
+ })
+
+@api_bp.route('/predict', methods=['POST'])
+def predict_cost():
+ """
+ 成本预测接口
+ """
+ try:
+ data = request.get_json()
+ logger.info(f"Received prediction request for equipment type: {data.get('type')}")
+
+ # 验证装备类型
+ if 'type' not in data:
+ return jsonify({'error': 'Equipment type is required'}), 400
+
+ # 预测成本
+ predictor = CostPredictor()
+ result = predictor.predict(data)
+
+ # 获取当前使用的模型信息
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute("""
+ SELECT model_type, model_name, r2_score, mae, rmse
+ FROM trained_models
+ WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE
+ LIMIT 1
+ """, (data['type'],))
+ model_info = cursor.fetchone()
+
+ # 在结果中添加模型信息
+ result.update({
+ 'model_info': {
+ 'type': model_info['model_type'],
+ 'name': model_info['model_name'],
+ 'r2_score': float(model_info['r2_score']),
+ 'mae': float(model_info['mae']),
+ 'rmse': float(model_info['rmse'])
+ }
+ })
+
+ logger.info(f"Prediction completed: {result}")
+ return jsonify(result)
+
+ except Exception as e:
+ logger.error(f"Error in prediction: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/analyze-features', methods=['POST'])
+def analyze_features():
+ """
+ 基于数据集进行特征分析
+ """
+ try:
+ data = request.get_json()
+ dataset_id = data.get('dataset_id')
+
+ logger.info(f"Starting feature analysis for dataset {dataset_id}")
+
+ if not dataset_id:
+ logger.warning("No dataset_id provided")
+ return jsonify({'error': '请选择数据集'}), 400
+
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+
+ # 获取数据集信息
+ cursor.execute("""
+ SELECT d.*,
+ e.type as equipment_type
+ FROM datasets d
+ JOIN dataset_equipment de ON d.id = de.dataset_id
+ JOIN equipment e ON de.equipment_id = e.id
+ WHERE d.id = %s
+ LIMIT 1
+ """, (dataset_id,))
+ dataset = cursor.fetchone()
+
+ if not dataset:
+ logger.warning(f"Dataset {dataset_id} not found")
+ return jsonify({'error': '数据集不存在'}), 404
+
+ logger.info(f"Dataset info: {dataset}")
+
+ # 创建特征分析实例
+ from src.feature_analysis import FeatureAnalysis
+ analyzer = FeatureAnalysis()
+
+ # 获取特征列表
+ feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
+ logger.info(f"Feature names: {feature_names}")
+
+ # 获取数据集中的装备数据
+ if dataset['equipment_type'] == '火箭炮':
+ cursor.execute("""
+ SELECT e.*, cp.*, rap.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ AND cd.actual_cost IS NOT NULL
+ """, (dataset_id,))
+ else:
+ cursor.execute("""
+ SELECT e.*, cp.*, lmp.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ AND cd.actual_cost IS NOT NULL
+ """, (dataset_id,))
+
+ equipment_data = cursor.fetchall()
+ logger.info(f"Found {len(equipment_data)} equipment records")
+
+ if not equipment_data:
+ logger.warning("No valid equipment data found in dataset")
+ return jsonify({'error': '数据集没有有效的成本数据'}), 400
+
+ # 统计每个特征的缺失率
+ missing_rates = {}
+ for name in feature_names:
+ missing_count = sum(1 for item in equipment_data if item.get(name) is None)
+ missing_rate = missing_count / len(equipment_data)
+ missing_rates[name] = missing_rate
+ logger.info(f"Feature {name} missing rate: {missing_rate:.2%}")
+
+ # 过滤掉缺失率过高的特征
+ valid_features = [name for name in feature_names if missing_rates[name] < 0.7]
+ logger.info(f"Valid features after filtering: {valid_features}")
+
+ if len(valid_features) < 3: # 至少需要3个特征
+ return jsonify({'error': '有效特征数量不足'}), 400
+
+ # 计算每个特征的均值
+ feature_means = {}
+ for name in valid_features:
+ values = [float(item[name]) for item in equipment_data if item.get(name) is not None]
+ feature_means[name] = sum(values) / len(values) if values else 0
+ logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}")
+
+ # 准备特征和目标值
+ features = []
+ target = []
+
+ # 提取特征和目标值,使用均值填充缺失值
+ for item in equipment_data:
+ feature_values = []
+ for name in valid_features:
+ value = item.get(name)
+ try:
+ # 确保数值类型转换正确
+ feature_values.append(float(value) if value is not None else feature_means[name])
+ except (ValueError, TypeError) as e:
+ logger.error(f"Error converting value for feature {name}: {value}")
+ logger.error(f"Error details: {str(e)}")
+ return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400
+ features.append(feature_values)
+
+ # 确保成本值是值类型
+ try:
+ target.append(float(item['actual_cost']))
+ except (ValueError, TypeError) as e:
+ logger.error(f"Error converting actual_cost: {item['actual_cost']}")
+ logger.error(f"Error details: {str(e)}")
+ return jsonify({'error': '成本值无法换为数值'}), 400
+
+ logger.info(f"Prepared {len(features)} feature vectors")
+ logger.info(f"First feature vector: {features[0] if features else None}")
+ logger.info(f"First target value: {target[0] if target else None}")
+
+ # 调用特征分析方法
+ result = analyzer.analyze_features(features, target, valid_features)
+ logger.info("Analysis completed successfully")
+
+ return jsonify(result)
+
+ except Exception as e:
+ logger.error(f"Error analyzing features: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/train', methods=['POST'])
+def train_model():
+ """
+ 训练模型
+ """
+ try:
+ data = request.get_json()
+ logger.info(f"Starting model training for {data.get('type')}")
+ equipment_type = data.get('type')
+ train_dataset_id = data.get('train_dataset_id')
+ validation_dataset_id = data.get('validation_dataset_id')
+ models = data.get('models', [])
+
+ logger.info(f"Training dataset: {train_dataset_id}")
+ logger.info(f"Validation dataset: {validation_dataset_id}")
+ logger.info(f"Selected models: {models}")
+
+ # 获取训练数据
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+
+ # 获取训练集数据
+ if equipment_type == '火箭炮':
+ cursor.execute("""
+ SELECT e.*, cp.*, rap.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ AND cd.actual_cost IS NOT NULL
+ """, (train_dataset_id,))
+ else:
+ cursor.execute("""
+ SELECT e.*, cp.*, lmp.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ AND cd.actual_cost IS NOT NULL
+ """, (train_dataset_id,))
+
+ train_data = cursor.fetchall()
+
+ # 获取验证集数据(如果有)
+ validation_data = None
+ if validation_dataset_id:
+ if equipment_type == '火箭炮':
+ cursor.execute("""
+ SELECT e.*, cp.*, rap.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ AND cd.actual_cost IS NOT NULL
+ """, (validation_dataset_id,))
+ else:
+ cursor.execute("""
+ SELECT e.*, cp.*, lmp.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ AND cd.actual_cost IS NOT NULL
+ """, (validation_dataset_id,))
+ validation_data = cursor.fetchall()
+
+ if not train_data:
+ return jsonify({'error': '训练数据集为空'}), 400
+
+ # 1. 准备数据
+ data_processor = DataPreparation()
+
+ # 准备训练数据
+ train_prepared = data_processor.prepare_training_data(train_data, equipment_type)
+
+ # 准备验证数据(如果有)
+ validation_prepared = None
+ if validation_data:
+ validation_prepared = data_processor.prepare_validation_data(
+ validation_data,
+ equipment_type,
+ train_prepared['feature_names'],
+ {
+ 'feature_scaler': train_prepared['feature_scaler'],
+ 'target_scaler': train_prepared['target_scaler']
+ }
+ )
+
+ # 2. 训练模型
+ model_trainer = ModelTrainer()
+ model_trainer.feature_scaler = train_prepared['feature_scaler']
+ model_trainer.target_scaler = train_prepared['target_scaler']
+
+ # 执行训练,传入 equipment_type 参数
+ training_result = model_trainer.fit_model(
+ train_prepared['X'],
+ train_prepared['y'],
+ models,
+ validation_prepared['X'] if validation_prepared else None,
+ validation_prepared['y'] if validation_prepared else None,
+ equipment_type=equipment_type
+ )
+
+ return jsonify(training_result)
+
+ except Exception as e:
+ logger.error(f"Error in model training: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/evaluate', methods=['POST'])
+def evaluate_model():
+ """
+ 模型评估接口
+ """
+ try:
+ data = request.get_json()
+ logger.info("Received model evaluation request")
+
+ if 'test_data' not in data:
+ return jsonify({'error': 'Test data is required'}), 400
+
+ predictor = CostPredictor()
+ evaluation_result = predictor.evaluate(
+ data['test_data']['actual'],
+ data['test_data']['predicted']
+ )
+
+ logger.info("Model evaluation completed")
+ return jsonify(evaluation_result)
+
+ except Exception as e:
+ logger.error(f"Error in model evaluation: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+def get_required_params(equipment_type):
+ """
+ 根据装备类型获取必要参数
+ """
+ common_params = [
+ 'length_m',
+ 'width_m',
+ 'height_m',
+ 'weight_kg',
+ 'max_range_km'
+ ]
+
+ if equipment_type == '火箭炮':
+ return common_params + [
+ 'firing_angle_horizontal',
+ 'firing_angle_vertical',
+ 'rocket_length_m',
+ 'rocket_diameter_mm',
+ 'rocket_weight_kg'
+ ]
+ elif equipment_type == '巡飞弹':
+ return common_params + [
+ 'max_speed_kmh',
+ 'cruise_speed_kmh',
+ 'flight_time_min',
+ 'folded_length_mm',
+ 'folded_width_mm',
+ 'folded_height_mm'
+ ]
+
+ return common_params
+
+@api_bp.errorhandler(404)
+def not_found(error):
+ return jsonify({'error': 'Not found'}), 404
+
+@api_bp.errorhandler(500)
+def internal_error(error):
+ logger.error(f"Internal server error: {str(error)}")
+ return jsonify({'error': 'Internal server error'}), 500
+
+@api_bp.route('/data', methods=['GET'])
+def get_equipment_data():
+ """
+ 获取装备数据
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute('SET SESSION group_concat_max_len = 1000000')
+
+ # 先测试特殊参数查询
+ cursor.execute("""
+ SELECT equipment_id, param_name, param_value, param_unit
+ FROM custom_params
+ WHERE param_name IS NOT NULL
+ AND param_value IS NOT NULL
+ LIMIT 5
+ """)
+ test_params = cursor.fetchall()
+ logger.info(f"Test custom params: {test_params}")
+
+ # 获取火箭炮数据
+ logger.info("Fetching rocket artillery data...")
+ cursor.execute("""
+ SELECT
+ e.id,
+ e.name,
+ e.type,
+ e.manufacturer,
+ e.created_at,
+ cp.length_m,
+ cp.width_m,
+ cp.height_m,
+ cp.weight_kg,
+ cp.max_range_km,
+ rap.firing_angle_horizontal,
+ rap.firing_angle_vertical,
+ rap.rocket_length_m,
+ rap.rocket_diameter_mm,
+ rap.rocket_weight_kg,
+ rap.rate_of_fire,
+ rap.combat_weight_kg,
+ rap.speed_kmh,
+ rap.min_range_km,
+ rap.mobility_type,
+ rap.structure_layout,
+ rap.engine_model,
+ rap.engine_params,
+ rap.power_hp,
+ rap.travel_range_km,
+ cd.actual_cost,
+ (
+ SELECT COALESCE(
+ JSON_ARRAYAGG(
+ JSON_OBJECT(
+ 'id', csp.id,
+ 'param_name', csp.param_name,
+ 'param_value', csp.param_value,
+ 'param_unit', csp.param_unit,
+ 'description', csp.description
+ )
+ ),
+ '[]'
+ )
+ FROM custom_params csp
+ WHERE csp.equipment_id = e.id
+ AND csp.param_name IS NOT NULL
+ AND csp.param_value IS NOT NULL
+ ) as custom_params
+ FROM equipment e
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE e.type = '火箭炮'
+ """)
+ rocket_artillery = cursor.fetchall()
+ logger.info(f"Found {len(rocket_artillery)} rocket artillery records")
+ if rocket_artillery:
+ logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}")
+ logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}")
+
+ # 获取巡飞弹数据
+ logger.info("Fetching missile data...")
+ cursor.execute("""
+ SELECT
+ e.id,
+ e.name,
+ e.type,
+ e.manufacturer,
+ e.created_at,
+ cp.length_m,
+ cp.width_m,
+ cp.height_m,
+ cp.weight_kg,
+ cp.max_range_km,
+ lmp.wingspan_m,
+ lmp.warhead_weight_kg,
+ lmp.max_speed_ms,
+ lmp.cruise_speed_kmh,
+ lmp.flight_time_min,
+ lmp.warhead_type,
+ lmp.launch_mode,
+ lmp.folded_length_mm,
+ lmp.folded_width_mm,
+ lmp.folded_height_mm,
+ lmp.power_system,
+ lmp.guidance_system,
+ cd.actual_cost,
+ (
+ SELECT COALESCE(
+ JSON_ARRAYAGG(
+ JSON_OBJECT(
+ 'id', csp.id,
+ 'param_name', csp.param_name,
+ 'param_value', csp.param_value,
+ 'param_unit', csp.param_unit,
+ 'description', csp.description
+ )
+ ),
+ '[]'
+ )
+ FROM custom_params csp
+ WHERE csp.equipment_id = e.id
+ AND csp.param_name IS NOT NULL
+ AND csp.param_value IS NOT NULL
+ ) as custom_params
+ FROM equipment e
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE e.type = '巡飞弹'
+ """)
+ loitering_munition = cursor.fetchall()
+ logger.info(f"Found {len(loitering_munition)} missile records")
+ if loitering_munition:
+ logger.info(f"First missile: {loitering_munition[0]['name']}")
+ logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}")
+
+ # 处理 custom_params,保为 NULL
+ for item in rocket_artillery + loitering_munition:
+ if item['custom_params'] is None:
+ item['custom_params'] = []
+ logger.debug(f"Set empty custom_params for equipment {item['id']}")
+ else:
+ logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params")
+
+ logger.info("Data fetching completed")
+ return jsonify({
+ 'rocket_artillery': rocket_artillery,
+ 'loitering_munition': loitering_munition
+ })
+
+ except Exception as e:
+ logger.error(f"Error getting equipment data: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/data/', methods=['DELETE'])
+def delete_equipment(id):
+ """
+ 删除装备数据
+ """
+ try:
+ db = get_db_connection()
+ cursor = db.cursor()
+
+ # 删除相关数据
+ cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
+ cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
+ cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
+ cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
+ cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
+
+ db.commit()
+ cursor.close()
+ db.close()
+
+ return jsonify({'status': 'success'})
+
+ except Exception as e:
+ logger.error(f"Error deleting equipment: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/data/template', methods=['GET'])
+def download_template():
+ """
+ 下载数据模板
+ """
+ try:
+ # 创建模板文件
+ from .create_template import create_excel_template
+ template_path = create_excel_template()
+
+ # 检查文件是否存
+ if not os.path.exists(template_path):
+ raise FileNotFoundError("模板文件不存在")
+
+ # 返回文件
+ return send_file(
+ template_path,
+ as_attachment=True,
+ download_name='equipment_data_template.xlsx',
+ mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
+ )
+
+ except Exception as e:
+ logger.error(f"Error creating template: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+def get_db_connection():
+ """
+ 获取数据库连接
+ """
+ return mysql.connector.connect(
+ host="localhost",
+ user="root",
+ password="123456",
+ database="equipment_cost_db"
+ )
+
+@api_bp.route('/pls/predict', methods=['POST'])
+def pls_predict():
+ """
+ PLS回归预测接口
+ """
+ try:
+ data = request.get_json()
+ logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}")
+
+ # 验证装备类型
+ if 'type' not in data:
+ return jsonify({'error': 'Equipment type is required'}), 400
+
+ # 使用 ModelTrainer 中的 PLS 模型进行预测
+ trainer = ModelTrainer()
+ if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型
+ return jsonify({'error': '未找到可用的模型'}), 404
+
+ # 准备特征数据
+ feature_analyzer = FeatureAnalysis()
+ features = feature_analyzer.get_equipment_specific_features(data['type'])
+ X = np.array([[data.get(feature) for feature in features]])
+
+ # 预测
+ result = trainer.predict(X)
+
+ # 计算置信区间
+ confidence_interval = trainer._calculate_confidence_interval(result[0])
+
+ # 获取模型信息
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute("""
+ SELECT model_type, model_name, r2_score, mae, rmse
+ FROM trained_models
+ WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE
+ LIMIT 1
+ """, (data['type'],))
+ model_info = cursor.fetchone()
+
+ # 确保返回的数据可以序列化为JSON
+ response = {
+ 'predicted_cost': float(result[0]),
+ 'model_info': {
+ 'type': model_info['model_type'],
+ 'name': model_info['model_name'],
+ 'r2_score': model_info['r2_score'],
+ 'mae': model_info['mae'],
+ 'rmse': model_info['rmse']
+ },
+ 'confidence_interval': {
+ 'lower': float(confidence_interval[0]),
+ 'upper': float(confidence_interval[1])
+ }
+ }
+
+ logger.info(f"PLS prediction completed: {response}")
+ return jsonify(response)
+
+ except Exception as e:
+ logger.error(f"Error in PLS prediction: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/data/import', methods=['POST'])
+def import_data():
+ """
+ 导入数据接口
+ """
+ try:
+ if 'file' not in request.files:
+ return jsonify({'error': '没有上传文件'}), 400
+
+ file = request.files['file']
+ if not file.filename.endswith(('.xls', '.xlsx')):
+ return jsonify({'error': '请上传Excel文件'}), 400
+
+ # 保存上的文件
+ upload_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
+ os.makedirs(upload_dir, exist_ok=True)
+ file_path = os.path.join(upload_dir, file.filename)
+ file.save(file_path)
+
+ # 导入数据
+ from .import_data import import_training_data
+ import_training_data(file_path)
+
+ return jsonify({
+ 'success': True,
+ 'message': '数据导入成功'
+ })
+
+ except Exception as e:
+ logger.error(f"Error importing data: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/data/', methods=['PUT'])
+def update_equipment(id):
+ """
+ 更新装备数据
+ """
+ try:
+ data = request.get_json()
+ logger.info(f"Updating equipment ID: {id}")
+ logger.info(f"Update data: {data}")
+
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 更新基本信息
+ cursor.execute("""
+ UPDATE equipment
+ SET name = %s, manufacturer = %s
+ WHERE id = %s
+ """, (data['name'], data['manufacturer'], id))
+ logger.info("Basic info updated")
+
+ # 更新通用参数
+ cursor.execute("""
+ UPDATE common_params
+ SET length_m = %s, width_m = %s, height_m = %s,
+ weight_kg = %s, max_range_km = %s
+ WHERE equipment_id = %s
+ """, (
+ data['length_m'], data['width_m'], data['height_m'],
+ data['weight_kg'], data['max_range_km'], id
+ ))
+ logger.info("Common params updated")
+
+ # 根据备类型更新特有参数
+ if data['type'] == '火箭炮':
+ cursor.execute("""
+ UPDATE rocket_artillery_params
+ SET firing_angle_horizontal = %s, firing_angle_vertical = %s,
+ rocket_length_m = %s, rocket_diameter_mm = %s,
+ rocket_weight_kg = %s, rate_of_fire = %s
+ WHERE equipment_id = %s
+ """, (
+ data['firing_angle_horizontal'], data['firing_angle_vertical'],
+ data['rocket_length_m'], data['rocket_diameter_mm'],
+ data['rocket_weight_kg'], data['rate_of_fire'], id
+ ))
+ logger.info("Rocket artillery params updated")
+ else:
+ cursor.execute("""
+ UPDATE loitering_munition_params
+ SET max_speed_ms = %s, cruise_speed_kmh = %s,
+ flight_time_min = %s, warhead_type = %s,
+ launch_mode = %s, folded_length_mm = %s,
+ folded_width_mm = %s, folded_height_mm = %s
+ WHERE equipment_id = %s
+ """, (
+ data['max_speed_ms'], data['cruise_speed_kmh'],
+ data['flight_time_min'], data['warhead_type'],
+ data['launch_mode'], data['folded_length_mm'],
+ data['folded_width_mm'], data['folded_height_mm'], id
+ ))
+ logger.info("Missile params updated")
+
+ # 更新成本数据
+ if 'actual_cost' in data:
+ cursor.execute("""
+ UPDATE cost_data
+ SET actual_cost = %s
+ WHERE equipment_id = %s
+ """, (data['actual_cost'], id))
+ logger.info("Cost data updated")
+
+ # 更新特殊参数
+ if 'custom_params' in data and data['custom_params']:
+ logger.info(f"Updating custom params: {data['custom_params']}")
+ for param in data['custom_params']:
+ cursor.execute("""
+ UPDATE custom_params
+ SET param_value = %s
+ WHERE id = %s AND equipment_id = %s
+ """, (param['param_value'], param['id'], id))
+ logger.info("Custom params updated")
+
+ conn.commit()
+ logger.info("All updates committed successfully")
+
+ return jsonify({'success': True})
+
+ except Exception as e:
+ logger.error(f"Error updating equipment: {str(e)}")
+ logger.error("Detailed traceback:", exc_info=True)
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/data/details/', methods=['GET'])
+def get_equipment_details(id):
+ """
+ 获取装备详数据
+ """
+ try:
+ logger.info(f"Getting details for equipment ID: {id}")
+
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+
+ # 先获取装备类型
+ cursor.execute("SELECT type FROM equipment WHERE id = %s", (id,))
+ equipment = cursor.fetchone()
+
+ if not equipment:
+ logger.warning(f"Equipment not found: {id}")
+ return jsonify({'error': 'Equipment not found'}), 404
+
+ equipment_type = equipment['type']
+ logger.info(f"Equipment type: {equipment_type}")
+
+ # 根据装备类型选择查询
+ if equipment_type == '火箭炮':
+ query = """
+ SELECT
+ e.*,
+ cp.*,
+ rap.*,
+ cd.actual_cost,
+ cd.prediction_date as cost_estimate_date,
+ cd.predicted_cost,
+ (
+ SELECT JSON_ARRAYAGG(
+ CASE
+ WHEN csp.id IS NOT NULL THEN
+ JSON_OBJECT(
+ 'id', csp.id,
+ 'param_name', csp.param_name,
+ 'param_value', csp.param_value,
+ 'param_unit', csp.param_unit,
+ 'description', csp.description
+ )
+ END
+ )
+ FROM custom_params csp
+ WHERE csp.equipment_id = e.id
+ AND csp.param_name IS NOT NULL
+ AND csp.param_value IS NOT NULL
+ ) as custom_params
+ FROM equipment e
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE e.id = %s
+ """
+ else:
+ query = """
+ SELECT
+ e.*,
+ cp.*,
+ lmp.*,
+ cd.actual_cost,
+ cd.prediction_date as cost_estimate_date,
+ cd.predicted_cost,
+ (
+ SELECT JSON_ARRAYAGG(
+ CASE
+ WHEN csp.id IS NOT NULL THEN
+ JSON_OBJECT(
+ 'id', csp.id,
+ 'param_name', csp.param_name,
+ 'param_value', csp.param_value,
+ 'param_unit', csp.param_unit,
+ 'description', csp.description
+ )
+ END
+ )
+ FROM custom_params csp
+ WHERE csp.equipment_id = e.id
+ AND csp.param_name IS NOT NULL
+ AND csp.param_value IS NOT NULL
+ ) as custom_params
+ FROM equipment e
+ LEFT JOIN common_params cp ON e.id = cp.equipment_id
+ LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE e.id = %s
+ """
+
+ cursor.execute(query, (id,))
+ result = cursor.fetchone()
+
+ if result:
+ logger.info(f"Found equipment details: {result['name']}")
+ logger.info(f"Custom params: {result.get('custom_params')}")
+
+ return jsonify(result)
+
+ except Exception as e:
+ logger.error(f"Error getting equipment details: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+# 添加数据集相关的路由
+@api_bp.route('/datasets', methods=['GET'])
+def get_datasets():
+ """
+ 获取数据集列表
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute("""
+ SELECT d.*,
+ COUNT(de.equipment_id) as equipment_count,
+ GROUP_CONCAT(e.name) as equipment_names
+ FROM datasets d
+ LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
+ LEFT JOIN equipment e ON de.equipment_id = e.id
+ GROUP BY d.id
+ """)
+ datasets = cursor.fetchall()
+
+ # 理装备名称列表
+ for dataset in datasets:
+ if dataset['equipment_names']:
+ dataset['equipment_names'] = dataset['equipment_names'].split(',')
+ else:
+ dataset['equipment_names'] = []
+
+ return jsonify(datasets)
+ except Exception as e:
+ logger.error(f"Error getting datasets: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/datasets/', methods=['GET'])
+def get_dataset(id):
+ """
+ 获取数据集详情
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ # 获取数据集基本信息
+ cursor.execute("""
+ SELECT d.*,
+ COUNT(de.equipment_id) as equipment_count
+ FROM datasets d
+ LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
+ WHERE d.id = %s
+ GROUP BY d.id
+ """, (id,))
+ dataset = cursor.fetchone()
+
+ if not dataset:
+ return jsonify({'error': 'Dataset not found'}), 404
+
+ # 获取数据集中的装备
+ cursor.execute("""
+ SELECT e.*, cd.actual_cost
+ FROM equipment e
+ JOIN dataset_equipment de ON e.id = de.equipment_id
+ LEFT JOIN cost_data cd ON e.id = cd.equipment_id
+ WHERE de.dataset_id = %s
+ """, (id,))
+ equipment = cursor.fetchall()
+
+ # 计算统计信息
+ if equipment:
+ total_cost = sum(item['actual_cost'] or 0 for item in equipment)
+ avg_cost = total_cost / len(equipment)
+ dataset['statistics'] = {
+ 'equipment_count': len(equipment),
+ 'total_cost': total_cost,
+ 'average_cost': avg_cost
+ }
+ else:
+ dataset['statistics'] = {
+ 'equipment_count': 0,
+ 'total_cost': 0,
+ 'average_cost': 0
+ }
+
+ dataset['equipment'] = equipment
+ return jsonify(dataset)
+ except Exception as e:
+ logger.error(f"Error getting dataset: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/datasets', methods=['POST'])
+def create_dataset():
+ """
+ 建数据集
+ """
+ try:
+ data = request.get_json()
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 创建数据集
+ cursor.execute("""
+ INSERT INTO datasets (name, description, equipment_type, purpose)
+ VALUES (%s, %s, %s, %s)
+ """, (data['name'], data['description'], data['equipment_type'], data['purpose']))
+
+ dataset_id = cursor.lastrowid
+
+ # 添加装备关联
+ if 'equipment_ids' in data and data['equipment_ids']:
+ values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']]
+ cursor.executemany("""
+ INSERT INTO dataset_equipment (dataset_id, equipment_id)
+ VALUES (%s, %s)
+ """, values)
+
+ conn.commit()
+ return jsonify({'id': dataset_id, 'message': '数据集创建成功'})
+ except Exception as e:
+ logger.error(f"Error creating dataset: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/datasets/', methods=['PUT'])
+def update_dataset(id):
+ """
+ 更新数据集
+ """
+ try:
+ data = request.get_json()
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 更新数据集基本信息
+ cursor.execute("""
+ UPDATE datasets
+ SET name = %s, description = %s, equipment_type = %s, purpose = %s
+ WHERE id = %s
+ """, (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
+
+ # 删除旧的装备关联
+ cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
+
+ # 加新的装备关联
+ if 'equipment_ids' in data:
+ for equipment_id in data['equipment_ids']:
+ cursor.execute("""
+ INSERT INTO dataset_equipment (dataset_id, equipment_id)
+ VALUES (%s, %s)
+ """, (id, equipment_id))
+
+ conn.commit()
+ return jsonify({'success': True})
+ except Exception as e:
+ logger.error(f"Error updating dataset: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/datasets/', methods=['DELETE'])
+def delete_dataset(id):
+ """
+ 删除数据集
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 删除装备关联
+ cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
+
+ # 删除数据集
+ cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
+
+ conn.commit()
+ return jsonify({'success': True})
+ except Exception as e:
+ logger.error(f"Error deleting dataset: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/models//latest', methods=['GET'])
+def get_latest_model(equipment_type):
+ """
+ 获取最新训练的型信息
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute("""
+ SELECT * FROM trained_models
+ WHERE equipment_type = %s AND is_active = TRUE
+ ORDER BY training_date DESC LIMIT 1
+ """, (equipment_type,))
+
+ model = cursor.fetchone()
+ return jsonify(model)
+
+ except Exception as e:
+ logger.error(f"Error getting latest model: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/models', methods=['GET'])
+def get_models():
+ """
+ 获取模型列表
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor(dictionary=True)
+ cursor.execute("""
+ SELECT * FROM trained_models
+ ORDER BY training_date DESC
+ """)
+
+ models = cursor.fetchall()
+
+ # 确保数值类型字段是 float
+ for model in models:
+ if model['r2_score'] is not None:
+ model['r2_score'] = float(model['r2_score'])
+ if model['mae'] is not None:
+ model['mae'] = float(model['mae'])
+ if model['rmse'] is not None:
+ model['rmse'] = float(model['rmse'])
+
+ # 解析特征重要性
+ if model['feature_importance']:
+ model['feature_importance'] = json.loads(model['feature_importance'])
+
+ return jsonify(models)
+
+ except Exception as e:
+ logger.error(f"Error getting models: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/models//activate', methods=['POST'])
+def activate_model(id):
+ """
+ 激活指定的模型
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 获取模型信息
+ cursor.execute("""
+ SELECT equipment_type FROM trained_models
+ WHERE id = %s
+ """, (id,))
+ model = cursor.fetchone()
+
+ if not model:
+ return jsonify({'error': 'Model not found'}), 404
+
+ # 将同类型的其他模型设置为非激活
+ cursor.execute("""
+ UPDATE trained_models
+ SET is_active = FALSE
+ WHERE equipment_type = %s
+ """, (model[0],))
+
+ # 激活指定模型
+ cursor.execute("""
+ UPDATE trained_models
+ SET is_active = TRUE
+ WHERE id = %s
+ """, (id,))
+
+ conn.commit()
+ return jsonify({'success': True})
+
+ except Exception as e:
+ logger.error(f"Error activating model: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/models/', methods=['DELETE'])
+def delete_model(id):
+ """
+ 删除指定的模型
+ """
+ try:
+ with get_db_connection() as conn:
+ cursor = conn.cursor()
+
+ # 获取模型文件路径
+ cursor.execute("""
+ SELECT model_path, scaler_path
+ FROM trained_models
+ WHERE id = %s
+ """, (id,))
+ model = cursor.fetchone()
+
+ if not model:
+ return jsonify({'error': 'Model not found'}), 404
+
+ # 删除模型文件
+ if os.path.exists(model[0]):
+ os.remove(model[0])
+ if os.path.exists(model[1]):
+ os.remove(model[1])
+
+ # 删除数据库记录
+ cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
+ conn.commit()
+
+ return jsonify({'success': True})
+
+ except Exception as e:
+ logger.error(f"Error deleting model: {str(e)}")
+ return jsonify({'error': str(e)}), 500
+
+@api_bp.route('/predict/all', methods=['POST'])
+def predict_all():
+ """
+ 获取所有机器学习模型的预测结果
+ """
+ try:
+ data = request.get_json()
+ logger.info(f"Received prediction request for all models, equipment type: {data.get('type')}")
+
+ predictor = CostPredictor()
+ results = predictor.predict_all(data)
+
+ return jsonify(results)
+
+ except Exception as e:
+ logger.error(f"Error in prediction: {str(e)}")
+ return jsonify({'error': str(e)}), 500
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/schema.sql b/deploy/equipment_cost_system/src/schema.sql
new file mode 100644
index 0000000..a0abd10
--- /dev/null
+++ b/deploy/equipment_cost_system/src/schema.sql
@@ -0,0 +1,140 @@
+-- 如果数据库已存在则删除
+DROP DATABASE IF EXISTS equipment_cost_db;
+
+-- 创建数据库
+CREATE DATABASE equipment_cost_db
+DEFAULT CHARACTER SET utf8mb4
+COLLATE utf8mb4_unicode_ci;
+
+-- 使用数据库
+USE equipment_cost_db;
+
+-- 装备基本信息表
+CREATE TABLE equipment (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ name VARCHAR(100), -- 名称
+ type VARCHAR(50), -- 类型(火箭炮/巡飞弹)
+ manufacturer VARCHAR(100), -- 制造商
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 通用参数表
+CREATE TABLE common_params (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ equipment_id INT,
+ length_m FLOAT, -- 总长(m)
+ width_m FLOAT, -- 宽度(m)
+ height_m FLOAT, -- 高度(m)
+ weight_kg FLOAT, -- 重量(kg)
+ max_range_km FLOAT, -- 最大射程(km)
+ FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 火箭炮特有参数表
+CREATE TABLE rocket_artillery_params (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ equipment_id INT,
+ firing_angle_horizontal FLOAT, -- 方向射界(度)
+ firing_angle_vertical FLOAT, -- 高低射界(度)
+ rocket_length_m FLOAT, -- 火箭弹长度(m)
+ rocket_diameter_mm FLOAT, -- 弹体直径(mm)
+ rocket_weight_kg FLOAT, -- 火箭弹重量(kg)
+ rate_of_fire FLOAT, -- 射速(发/分钟)
+ combat_weight_kg FLOAT, -- 战斗重量(kg)
+ speed_kmh FLOAT, -- 速度(km/h)
+ min_range_km FLOAT, -- 最小射程(km)
+ mobility_type VARCHAR(50), -- 行走方式
+ structure_layout VARCHAR(100), -- 结构布局
+ engine_model VARCHAR(100), -- 发动机型号
+ engine_params TEXT, -- 发动机参数
+ power_hp FLOAT, -- 功率(hp)
+ travel_range_km FLOAT, -- 行程(km)
+ FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 巡飞弹特有参数表
+CREATE TABLE loitering_munition_params (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ equipment_id INT,
+ wingspan_m FLOAT, -- 翼展(m)
+ warhead_weight_kg FLOAT, -- 战斗部重量(kg)
+ max_speed_ms FLOAT, -- 最大速度(m/s)
+ cruise_speed_kmh FLOAT, -- 巡航速度(km/h)
+ flight_time_min FLOAT, -- 巡飞时间(min)
+ warhead_type VARCHAR(50), -- 战斗部类型
+ launch_mode VARCHAR(50), -- 发射方式
+ folded_length_mm FLOAT, -- 折叠长度(mm)
+ folded_width_mm FLOAT, -- 折叠宽度(mm)
+ folded_height_mm FLOAT, -- 折叠高度(mm)
+ power_system VARCHAR(100), -- 动力装置
+ guidance_system VARCHAR(100), -- 制导体制
+ FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 成本数据表
+CREATE TABLE cost_data (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ equipment_id INT,
+ actual_cost DECIMAL(15,2), -- 实际成本(元)
+ predicted_cost DECIMAL(15,2), -- 预测成本(元)
+ prediction_date TIMESTAMP, -- 预测日期
+ FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 特殊参数表
+CREATE TABLE custom_params (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ equipment_id INT,
+ param_name VARCHAR(100), -- 参数名称
+ param_value VARCHAR(255), -- 参数值
+ param_unit VARCHAR(50), -- 参数单位
+ description TEXT, -- 参数说明
+ FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 添加索引
+CREATE INDEX idx_equipment_type ON equipment(type);
+CREATE INDEX idx_equipment_name ON equipment(name);
+CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id);
+
+-- 数据集表
+CREATE TABLE datasets (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ name VARCHAR(100) NOT NULL, -- 数据集名称
+ description TEXT, -- 数据集描述
+ equipment_type VARCHAR(50) NOT NULL, -- 装备类型
+ purpose VARCHAR(50) NOT NULL, -- 用途(训练/验证)
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
+);
+
+-- 数据集-装备关联表
+CREATE TABLE dataset_equipment (
+ dataset_id INT NOT NULL,
+ equipment_id INT NOT NULL,
+ PRIMARY KEY (dataset_id, equipment_id),
+ FOREIGN KEY (dataset_id) REFERENCES datasets(id),
+ FOREIGN KEY (equipment_id) REFERENCES equipment(id)
+);
+
+-- 训练模型表
+CREATE TABLE trained_models (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ model_name VARCHAR(100) NOT NULL, -- 模型名称
+ model_type VARCHAR(50) NOT NULL, -- 模型类型
+ equipment_type VARCHAR(50) NOT NULL, -- 装备类型
+ model_path VARCHAR(255) NOT NULL, -- 模型文件路径
+ scaler_path VARCHAR(255) NOT NULL, -- 标准化器路径
+ r2_score FLOAT, -- R²分数
+ mae FLOAT, -- 平均绝对误差
+ rmse FLOAT, -- 均方根误差
+ feature_importance JSON, -- 特征重要性
+ training_data_size INT, -- 训练数据量
+ training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间
+ is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型
+ created_by VARCHAR(50) -- 创建者
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+
+-- 添加索引
+CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type);
+CREATE INDEX idx_model_active ON trained_models(is_active);
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/src/test_api.py b/deploy/equipment_cost_system/src/test_api.py
new file mode 100644
index 0000000..f68d88d
--- /dev/null
+++ b/deploy/equipment_cost_system/src/test_api.py
@@ -0,0 +1,191 @@
+import requests
+import json
+import logging
+from datetime import datetime
+import os
+import sys
+
+# 添加项目根目录到 Python 路径
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.dirname(current_dir)
+if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+# 配置基本日志
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ handlers=[
+ logging.FileHandler('logs/test_api.log'),
+ logging.StreamHandler()
+ ]
+)
+
+logger = logging.getLogger(__name__)
+
+def test_api_endpoints():
+ """
+ 测试 API 各个端点
+ """
+ base_url = 'http://localhost:5001/api'
+
+ try:
+ # 1. 测试根路由
+ logger.info("\n1. 测试 API 根路由")
+ response = requests.get(f'{base_url}/')
+ print_response(response, "API 根路由")
+
+ # 2. 测试机器学习预测接口
+ logger.info("\n2. 测试机器学习预测接口")
+ predict_data = {
+ "type": "巡飞弹",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "folded_length_mm": 1300,
+ "folded_width_mm": 230,
+ "folded_height_mm": 230,
+ "warhead_type": "破片杀伤战斗部",
+ "launch_mode": "凭自身动力起飞"
+ }
+
+ response = requests.post(
+ f'{base_url}/predict',
+ json=predict_data
+ )
+ print_response(response, "机器学习预测")
+
+ # 3. 测试 PLS 预测接口
+ logger.info("\n3. 测试 PLS 预测接口")
+ response = requests.post(
+ f'{base_url}/pls/predict',
+ json=predict_data
+ )
+ print_response(response, "PLS 预测")
+
+ # 4. 测试特征分析接口
+ logger.info("\n4. 测试特征分析接口")
+ analysis_data = {
+ "dataset_id": 1, # 假设数据集 ID 为 1
+ "equipment_type": "巡飞弹"
+ }
+
+ response = requests.post(
+ f'{base_url}/analyze-features',
+ json=analysis_data
+ )
+ print_response(response, "特征分析")
+
+ # 5. 测试机器学习模型训练接口
+ logger.info("\n5. 测试机器学习模型训练接口")
+ training_data = {
+ "type": "巡飞弹",
+ "train_dataset_id": 1, # 训练数据集 ID
+ "validation_dataset_id": 2, # 验证数据集 ID
+ "models": ["xgboost", "lightgbm", "rf"] # 要训练的模型类型
+ }
+
+ response = requests.post(
+ f'{base_url}/train',
+ json=training_data
+ )
+ print_response(response, "模型训练")
+
+ # 6. 测试数据集相关接口
+ logger.info("\n6. 测试数据集相关接口")
+
+ # 6.1 获取数据集列表
+ response = requests.get(f'{base_url}/datasets')
+ print_response(response, "获取数据集列表")
+
+ # 6.2 获取可用的装备列表
+ response = requests.get(f'{base_url}/data')
+ equipment_data = response.json()
+
+ # 获取巡飞弹类型的装备ID
+ available_equipment_ids = []
+ if 'loitering_munition' in equipment_data:
+ available_equipment_ids = [
+ item['id']
+ for item in equipment_data['loitering_munition']
+ if item['id'] is not None
+ ][:3] # 取前3个可用的ID
+
+ if not available_equipment_ids:
+ logger.warning("没有找到可用的装备ID,跳过创建数据集测试")
+ else:
+ # 6.3 创建新数据集
+ new_dataset = {
+ "name": "测试数据集",
+ "description": "用于测试的数据集",
+ "equipment_type": "巡飞弹",
+ "purpose": "训练",
+ "equipment_ids": available_equipment_ids
+ }
+
+ logger.info(f"创建数据集使用的装备IDs: {available_equipment_ids}")
+
+ response = requests.post(
+ f'{base_url}/datasets',
+ json=new_dataset
+ )
+ print_response(response, "创建数据集")
+
+ # 7. 测试模型相关接口
+ logger.info("\n7. 测试模型相关接口")
+
+ # 7.1 获取模型列表
+ response = requests.get(f'{base_url}/models')
+ print_response(response, "获取模型列表")
+
+ # 7.2 获取最新模型
+ response = requests.get(f'{base_url}/models/巡飞弹/latest')
+ print_response(response, "获取最新模型")
+
+ # 8. 测试多模型预测接口
+ logger.info("\n8. 测试多模型预测接口")
+ response = requests.post(
+ f'{base_url}/predict/all',
+ json=predict_data
+ )
+ print_response(response, "多模型预测")
+
+ logger.info("所有测试完成")
+
+ except requests.exceptions.RequestException as e:
+ logger.error(f"API 请求错误: {str(e)}")
+ except Exception as e:
+ logger.error(f"测试过程中出现错误: {str(e)}")
+
+def print_response(response, endpoint_name):
+ """
+ 打印响应结果
+ """
+ try:
+ logger.info(f"\n=== {endpoint_name} 测试结果 ===")
+ logger.info(f"状态码: {response.status_code}")
+
+ if response.status_code == 200:
+ result = response.json()
+ logger.info(f"响应数据:\n{json.dumps(result, indent=2, ensure_ascii=False)}")
+ else:
+ logger.error(f"错误响应:\n{response.text}")
+
+ except Exception as e:
+ logger.error(f"处理响应时出错: {str(e)}")
+
+if __name__ == "__main__":
+ try:
+ # 确保日志目录存在
+ os.makedirs('logs', exist_ok=True)
+
+ logger.info(f"=== API 测试开始 - {datetime.now()} ===")
+ test_api_endpoints()
+ logger.info(f"=== API 测试结束 - {datetime.now()} ===")
+ except Exception as e:
+ logger.error(f"测试执行失败: {str(e)}")
\ No newline at end of file
diff --git a/deploy/equipment_cost_system/venv b/deploy/equipment_cost_system/venv
new file mode 100644
index 0000000..e69de29
diff --git a/deploy/setup.md b/deploy/setup.md
new file mode 100644
index 0000000..b905ab0
--- /dev/null
+++ b/deploy/setup.md
@@ -0,0 +1,34 @@
+# 安装说明
+
+## 1. 系统要求
+
+- Linux 服务器 (推荐 Ubuntu 20.04+)
+- Python 3.8+
+- MySQL 8.0+
+
+## 2. 安装部署
+
+```bash
+# 解压部署包
+tar -xzf equipment_cost_prediction.tar.gz
+cd equipment_cost_prediction
+
+# 运行安装脚本
+bash scripts/install.sh
+
+# 修改配置文件
+vim config/.env
+
+# 启动服务
+bash scripts/start.sh
+```
+
+## 3. 验证部署
+
+```bash
+# 检查服务状态
+curl http://localhost:5001/api/
+
+# 检查日志
+tail -f logs/api.log
+```
diff --git a/docs/deploy/api.md b/docs/deploy/api.md
new file mode 100644
index 0000000..000bd35
--- /dev/null
+++ b/docs/deploy/api.md
@@ -0,0 +1,663 @@
+# 装备成本估算系统 API 文档
+
+这个 API 文档提供了完整的接口说明,包括:
+
+- 每个端点的详细描述
+- 请求和响应的具体示例
+- 清晰的参数格式要求
+- 统一的错误处理说明
+- 重要的注意事项
+
+文档使用 Markdown 格式编写,请使用支持 Markdown 的工具查看。
+
+## 基本信息
+
+- 基础URL: `http://localhost:5001/api`
+- 版本: 1.0.0
+- 响应格式: JSON
+
+## API 端点列表
+
+### 1. 获取 API 信息
+
+获取 API 版本信息和可用端点列表。
+
+- **URL**: `/`
+- **方法**: `GET`
+- **响应示例**:
+json
+{
+"name": "装备成本估算系统 API",
+"version": "1.0.0",
+"endpoints": {
+"predict": {
+"url": "/api/predict",
+"method": "POST",
+"description": "成本预测"
+},
+"analyze-features": {
+"url": "/api/analyze-features",
+"method": "POST",
+"description": "特征分析"
+},
+"train": {
+"url": "/api/train",
+"method": "POST",
+"description": "模型训练"
+},
+"evaluate": {
+"url": "/api/evaluate",
+"method": "POST",
+"description": "模型评估"
+}
+}
+}
+
+### 2. 单模型预测
+
+使用当前激活的最优模型进行成本预测。
+
+- **URL**: `/predict`
+- **方法**: `POST`
+- **请求体示例** (巡飞弹):
+
+```json
+{
+ "type": "巡飞弹",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "folded_length_mm": 1300,
+ "folded_width_mm": 230,
+ "folded_height_mm": 230,
+ "warhead_type": "破片杀伤战斗部",
+ "launch_mode": "凭自身动力起飞"
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "predicted_cost": 150000.0,
+ "model_info": {
+ "type": "xgboost",
+ "name": "巡飞弹_20241111_model",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0
+ },
+ "confidence_interval": {
+ "lower": 135000.0,
+ "upper": 165000.0
+ }
+}
+```
+
+### 3. PLS 模型预测
+
+使用 PLS 回归模型进行预测。
+
+- **URL**: `/pls/predict`
+- **方法**: `POST`
+- **请求体**: 与单模型预测相同
+- **响应示例**:
+
+```json
+{
+ "predicted_cost": 148000.0,
+ "confidence_interval": {
+ "lower": 133000.0,
+ "upper": 163000.0
+ }
+}
+```
+
+### 4. 多模型预测
+
+使用所有激活的模型进行预测并返回综合结果。
+
+- **URL**: `/predict/all`
+- **方法**: `POST`
+- **请求体**: 与单模型预测相同
+- **响应示例**:
+
+```json
+{
+ "individual_predictions": {
+ "xgboost": {
+ "predicted_cost": 150000.0,
+ "model_info": {
+ "name": "巡飞弹_xgboost_model",
+ "type": "xgboost",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0
+ },
+ "confidence_interval": {
+ "lower": 135000.0,
+ "upper": 165000.0
+ }
+ },
+ "pls": {
+ "predicted_cost": 148000.0,
+ "model_info": {
+ "name": "巡飞弹_pls_model",
+ "type": "pls",
+ "r2_score": 0.92,
+ "mae": 5500.0,
+ "rmse": 8000.0
+ },
+ "confidence_interval": {
+ "lower": 133000.0,
+ "upper": 163000.0
+ }
+ }
+ },
+ "ensemble_prediction": {
+ "predicted_cost": 149000.0,
+ "standard_deviation": 1414.21,
+ "confidence_interval": {
+ "lower": 146228.15,
+ "upper": 151771.85
+ }
+ }
+}
+```
+
+### 5. 特征分析
+
+分析数据集中特征的重要性和相关性。
+
+- **URL**: `/analyze-features`
+- **方法**: `POST`
+- **请求体示例**:
+
+```json
+{
+ "dataset_id": 1,
+ "equipment_type": "巡飞弹"
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "important_features": [
+ {
+ "name": "最大射程(km)",
+ "importance": 0.35
+ },
+ {
+ "name": "重量(kg)",
+ "importance": 0.25
+ }
+ ],
+ "correlation_analysis": {
+ "features": ["最大射程(km)", "重量(kg)"],
+ "matrix": [[1.0, 0.8], [0.8, 1.0]]
+ }
+}
+```
+
+### 6. 模型训练
+
+训练新的模型。
+
+- **URL**: `/train`
+- **方法**: `POST`
+- **请求体示例**:
+
+```json
+{
+ "type": "巡飞弹",
+ "train_dataset_id": 1,
+ "validation_dataset_id": 2,
+ "models": ["xgboost", "lightgbm", "rf"]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "metrics": {
+ "xgboost": {
+ "train": {
+ "r2": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0
+ },
+ "validation": {
+ "r2": 0.92,
+ "mae": 5500.0,
+ "rmse": 8000.0
+ }
+ }
+ },
+ "best_model": {
+ "type": "xgboost",
+ "r2": 0.92,
+ "mae": 5500.0,
+ "rmse": 8000.0
+ }
+}
+```
+
+### 7. 数据集管理
+
+#### 7.1 获取数据集列表
+
+- **URL**: `/datasets`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+[
+ {
+ "id": 1,
+ "name": "训练数据集",
+ "description": "用于训练的数据集",
+ "equipment_type": "巡飞弹",
+ "equipment_count": 10,
+ "equipment_names": ["设备1", "设备2"],
+ "purpose": "训练",
+ "created_at": "2024-11-11T10:00:00"
+ }
+]
+```
+
+#### 7.2 获取数据集详情
+
+- **URL**: `/datasets/{id}`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "id": 1,
+ "name": "训练数据集",
+ "description": "用于训练的数据集",
+ "equipment_type": "巡飞弹",
+ "purpose": "训练",
+ "created_at": "2024-11-11T10:00:00",
+ "equipment": [
+ {
+ "id": 1,
+ "name": "设备1",
+ "type": "巡飞弹",
+ "manufacturer": "制造商1",
+ "actual_cost": 150000
+ }
+ ],
+ "statistics": {
+ "equipment_count": 10,
+ "total_cost": 1500000,
+ "average_cost": 150000
+ }
+}
+```
+
+#### 7.3 创建数据集
+
+- **URL**: `/datasets`
+- **方法**: `POST`
+- **请求体示例**:
+
+```json
+{
+ "name": "测试数据集",
+ "description": "用于测试的数据集",
+ "equipment_type": "巡飞弹",
+ "purpose": "训练",
+ "equipment_ids": [1, 2, 3]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "id": 2,
+ "message": "数据集创建成功"
+}
+```
+
+#### 7.4 更新数据集
+
+- **URL**: `/datasets/{id}`
+- **方法**: `PUT`
+- **请求体示例**:
+
+```json
+{
+ "name": "更新后的数据集名称",
+ "description": "更新后的描述",
+ "equipment_type": "巡飞弹",
+ "purpose": "验证",
+ "equipment_ids": [1, 2, 3, 4]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "数据集更新成功"
+}
+```
+
+#### 7.5 删除数据集
+
+- **URL**: `/datasets/{id}`
+- **方法**: `DELETE`
+- **描述**: 删除指定的数据集及其关联关系
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "数据集删除成功"
+}
+```
+
+注意事项:
+
+1. 数据集删除后不会删除关联的装备数据
+2. 不能删除正在被模型使用的数据集
+3. 更新数据集时会重新计算统计信息
+4. 数据集的装备类型一旦创建后不能更改
+
+### 8. 模型管理
+
+#### 8.1 获取模型列表
+
+- **URL**: `/models`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+[
+ {
+ "id": 1,
+ "model_name": "巡飞弹_xgboost_model",
+ "model_type": "xgboost",
+ "equipment_type": "巡飞弹",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0,
+ "is_active": true,
+ "training_date": "2024-11-11T10:00:00"
+ }
+]
+```
+
+#### 8.2 获取最新模型
+
+- **URL**: `/models/{equipment_type}/latest`
+- **方法**: `GET`
+- **响应示例**: 与模型列表的单个模型格式相同
+
+#### 8.3 获取模型详情
+
+- **URL**: `/models/{id}`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "id": 1,
+ "model_name": "巡飞弹_xgboost_model",
+ "model_type": "xgboost",
+ "equipment_type": "巡飞弹",
+ "r2_score": 0.95,
+ "mae": 5000.0,
+ "rmse": 7500.0,
+ "is_active": true,
+ "training_date": "2024-11-11T10:00:00",
+ "feature_importance": {
+ "max_range_km": 0.35,
+ "weight_kg": 0.25,
+ "length_m": 0.20
+ },
+ "training_data_size": 100,
+ "created_by": "system"
+}
+```
+
+#### 8.4 激活模型
+
+- **URL**: `/models/{id}/activate`
+- **方法**: `POST`
+- **描述**: 激活指定模型,同时会将同类型的其他模型设置为非激活状态
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "模型已激活"
+}
+```
+
+#### 8.5 删除模型
+
+- **URL**: `/models/{id}`
+- **方法**: `DELETE`
+- **描述**: 删除指定模型,包括模型文件和数据库记录
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "模型已删除"
+}
+```
+
+注意事项:
+
+1. 删除模型时会同时删除相关的文件和数据库记录
+2. 不能删除当前正在使用(已激活)的模型
+3. 激活模型时会自动取消同类型其他模型的激活状态
+4. 模型详情包含了更多的训练相关信息,如特征重要性等
+
+### 9. 数据管理
+
+#### 9.1 获取装备数据列表
+
+- **URL**: `/data`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "rocket_artillery": [
+ {
+ "id": 1,
+ "name": "BM-21",
+ "type": "火箭炮",
+ "manufacturer": "俄罗斯",
+ "length_m": 7.35,
+ "width_m": 2.4,
+ "height_m": 3.1,
+ "weight_kg": 13700,
+ "max_range_km": 20.4,
+ "actual_cost": 800000
+ }
+ ],
+ "loitering_munition": [
+ {
+ "id": 8,
+ "name": "Hero-120",
+ "type": "巡飞弹",
+ "manufacturer": "以色列",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "actual_cost": 150000
+ }
+ ]
+}
+```
+
+#### 9.2 获取装备详情
+
+- **URL**: `/data/details/{id}`
+- **方法**: `GET`
+- **响应示例**:
+
+```json
+{
+ "id": 8,
+ "name": "Hero-120",
+ "type": "巡飞弹",
+ "manufacturer": "以色列",
+ "common_params": {
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40
+ },
+ "specific_params": {
+ "wingspan_m": 2.1,
+ "warhead_weight_kg": 3.5,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "warhead_type": "破片杀伤战斗部",
+ "launch_mode": "箱式发射",
+ "power_system": "电动机",
+ "guidance_system": "GPS/INS"
+ },
+ "cost_data": {
+ "actual_cost": 150000,
+ "prediction_date": "2024-11-11T10:00:00",
+ "predicted_cost": 148000
+ },
+ "custom_params": [
+ {
+ "id": 1,
+ "param_name": "续航时间",
+ "param_value": "2小时",
+ "param_unit": "小时",
+ "description": "最大续航时间"
+ }
+ ]
+}
+```
+
+#### 9.3 更新装备数据
+
+- **URL**: `/data/{id}`
+- **方法**: `PUT`
+- **请求体示例**:
+
+```json
+{
+ "name": "Hero-120",
+ "manufacturer": "以色列",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40,
+ "wingspan_m": 2.1,
+ "warhead_weight_kg": 3.5,
+ "max_speed_ms": 50,
+ "cruise_speed_kmh": 100,
+ "flight_time_min": 60,
+ "actual_cost": 150000,
+ "custom_params": [
+ {
+ "id": 1,
+ "param_value": "2.5小时"
+ }
+ ]
+}
+```
+
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "装备数据更新成功"
+}
+```
+
+#### 9.4 删除装备数据
+
+- **URL**: `/data/{id}`
+- **方法**: `DELETE`
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "装备数据删除成功"
+}
+```
+
+#### 9.5 下载数据模板
+
+- **URL**: `/data/template`
+- **方法**: `GET`
+- **描述**: 下载Excel格式的数据导入模板
+- **响应**: Excel文件下载
+
+#### 9.6 导入数据
+
+- **URL**: `/data/import`
+- **方法**: `POST`
+- **请求体**:
+ - Content-Type: multipart/form-data
+ - 参数名: file
+ - 文件类型: .xlsx 或 .xls
+- **响应示例**:
+
+```json
+{
+ "success": true,
+ "message": "数据导入成功",
+ "imported_count": {
+ "rocket_artillery": 3,
+ "loitering_munition": 5
+ }
+}
+```
+
+注意事项:
+
+1. 导入数据时必须使用系统提供的模板
+2. 更新装备数据时会同时更新关联的参数表
+3. 删除装备数据会同时删除相关的参数和成本数据
+4. 导入的Excel文件大小不应超过10MB
+5. 所有数值字段必须符合指定的单位和范围要求
+6. 特殊参数的值必须包含单位信息
+
+## 错误响应
+
+所有接口在发生错误时都会返回以下格式的响应:
+
+```json
+{
+ "error": "错误描述信息"
+}
+```
+
+## 注意事项
+
+1. 所有数值参数必须大于0
+2. 所有单位必须按照参数名称中指定的单位提供
+3. 预测结果中的成本单位为元
+4. 置信区间表示预测结果的95%置信水平范围
+5. 所有请求和响应的编码均为 UTF-8
diff --git a/docs/deploy/deploy.md b/docs/deploy/deploy.md
new file mode 100644
index 0000000..1f52359
--- /dev/null
+++ b/docs/deploy/deploy.md
@@ -0,0 +1,120 @@
+# 装备成本估算系统部署指南
+
+## 一、系统要求
+
+### 1. 基础软件
+
+- Linux 操作系统 (推荐 Ubuntu 20.04+)
+- Python 3.8+ 及相关组件
+
+ ```bash
+ sudo apt update
+ sudo apt install python3 python3-pip python3-venv
+ sudo apt install python3-dev build-essential
+ ```
+
+- Node.js 14+ 及 npm
+
+ ```bash
+ # 使用 nvm 安装 Node.js
+ curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
+ source ~/.bashrc
+ nvm install 14
+ nvm use 14
+ ```
+
+### 2. 数据库
+
+- MySQL 8.0+
+
+ ```bash
+ sudo apt install mysql-server mysql-client
+ sudo apt install libmysqlclient-dev
+ ```
+
+### 3. Python包依赖
+
+```bash
+# 科学计算相关
+sudo apt install libatlas-base-dev # numpy依赖
+sudo apt install libopenblas-dev # 线性代数库
+sudo apt install liblapack-dev # 线性代数包
+sudo apt install gfortran # Fortran编译器(scipy依赖)
+
+# XML处理相关(用于Excel文件处理)
+sudo apt install libxml2-dev
+sudo apt install libxslt1-dev
+```
+
+## 二、部署运行
+
+### 1. 安装服务
+
+```bash
+sh scripts/install.sh
+```
+
+### 2. 启动服务
+
+```bash
+sh scripts/start.sh
+```
+
+### 3. 停止服务
+
+```bash
+sh scripts/stop.sh
+```
+
+## 三、维护说明
+
+### 1. 日志管理
+
+```bash
+# 后端日志
+tail -f logs/api.log
+
+# 数据库日志
+tail -f /var/log/mysql/error.log
+```
+
+## 四、安全建议
+
+1. 系统安全
+ - 使用防火墙限制端口访问
+ - 定期更新系统和依赖包
+
+2. 数据安全
+ - 定期备份数据库
+ - 加密敏感信息
+ - 限制数据库远程访问
+
+3. 访问控制
+ - 使用强密码
+ - 配置适当的文件权限
+ - 使用非root用户运行服务
+
+## 五、监控方案
+
+### 1. 系统监控
+
+```bash
+# 资源使用
+top -b -n 1
+df -h
+free -m
+
+# 服务状态
+ps aux | grep gunicorn
+ps aux | grep node
+```
+
+### 2. 应用监控
+
+```bash
+# API 响应时间
+curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
+
+# 错误日志
+grep "ERROR" logs/api.log
+```
diff --git a/docs/debug.md b/docs/dev/debug.md
similarity index 100%
rename from docs/debug.md
rename to docs/dev/debug.md
diff --git a/docs/dev/deployment.md b/docs/dev/deployment.md
new file mode 100644
index 0000000..462355e
--- /dev/null
+++ b/docs/dev/deployment.md
@@ -0,0 +1,250 @@
+# 装备成本估算系统部署指南
+
+## 一、系统打包
+
+### 1. 创建部署包结构
+
+```bash
+mkdir -p deploy/equipment_cost_system/{backend,frontend,docs}
+```
+
+### 2. 准备部署文件
+
+#### 2.1 后端文件
+
+```bash
+cd deploy/equipment_cost_system/backend
+mkdir -p {src,scripts,config,data,logs,models}
+cp -r ../../../src/* src/
+cp ../../../requirements.txt ./
+cp ../../../.env.example config/.env.template
+```
+
+#### 2.2 前端文件
+
+```bash
+cd ../frontend
+mkdir -p {src,public,dist}
+cp -r ../../frontend/src/* src/
+cp -r ../../frontend/public/* public/
+cp ../../../frontend/package.json ./
+cp ../../../frontend/vite.config.js ./
+cp ../../../frontend/.env.production ./
+```
+
+#### 2.3 复制文档
+
+```bash
+cd ../docs
+cp -r ../../docs/deploy/* ./
+```
+
+#### 2.4 创建部署脚本
+
+```bash
+touch scripts/{install.sh,start.sh,stop.sh}
+```
+
+### 3. 部署脚本内容
+
+#### 3.1 安装脚本 (install.sh)
+
+```bash
+#!/bin/bash
+
+# 安装 Python 依赖
+python3 -m venv venv
+source venv/bin/activate
+pip install -r requirements.txt
+
+# 创建必要的目录
+mkdir -p logs
+mkdir -p data
+mkdir -p models
+
+# 配置文件
+if [ ! -f config/.env ]; then
+ cp config/.env.template config/.env
+ echo "请修改 config/.env 中的配置"
+fi
+
+# 初始化数据库
+read -p "请输入MySQL root密码: " mysqlpass
+mysql -u root -p$mysqlpass < src/schema.sql
+
+# 设置权限
+chmod +x scripts/*.sh
+```
+
+#### 3.2 启动脚本 (start.sh)
+
+```bash
+#!/bin/bash
+
+# 激活虚拟环境
+source venv/bin/activate
+
+# 检查配置文件
+if [ ! -f config/.env ]; then
+ echo "错误: 配置文件不存在"
+ exit 1
+fi
+
+# 启动服务
+export $(cat config/.env | xargs)
+gunicorn -w 4 -b 0.0.0.0:5001 "src.app:create_app()" --daemon
+
+echo "服务已启动,访问 http://localhost:5001"
+```
+
+#### 3.3 停止脚本 (stop.sh)
+
+```bash
+#!/bin/bash
+
+# 查找并停止 gunicorn 进程
+pkill -f gunicorn
+
+echo "服务已停止"
+```
+
+### 4. 创建部署包
+
+```bash
+cd ..
+tar -czf equipment_cost_system.tar.gz equipment_cost_system/
+```
+
+## 二、部署步骤
+
+### 1. 系统要求
+
+- Linux 服务器 (推荐 Ubuntu 20.04+)
+- Python 3.8+
+- MySQL 8.0+
+
+### 2. 安装部署
+
+```bash
+# 解压部署包
+tar -xzf equipment_cost_system.tar.gz
+cd equipment_cost_system
+
+# 运行安装脚本
+bash scripts/install.sh
+
+# 修改配置文件
+vim config/.env
+
+# 启动服务
+bash scripts/start.sh
+```
+
+### 3. 验证部署
+
+```bash
+# 检查服务状态
+curl http://localhost:5001/api/
+
+# 检查日志
+tail -f logs/api.log
+```
+
+## 三、维护说明
+
+### 1. 日常维护
+
+```bash
+# 查看日志
+tail -f logs/api.log
+
+# 备份数据库
+mysqldump -u root -p equipment_cost_db > backup/$(date +%Y%m%d).sql
+
+# 清理旧日志
+find logs/ -name "*.log" -mtime +30 -delete
+```
+
+### 2. 更新部署
+
+```bash
+# 停止服务
+bash scripts/stop.sh
+
+# 备份数据
+cp -r data data_backup_$(date +%Y%m%d)
+
+# 更新代码
+# ... 更新相关文件 ...
+
+# 重启服务
+bash scripts/start.sh
+```
+
+### 3. 故障处理
+
+```bash
+# 检查服务状态
+ps aux | grep gunicorn
+
+# 检查数据库连接
+mysql -u root -p -e "show databases;"
+
+# 重启服务
+bash scripts/stop.sh
+bash scripts/start.sh
+```
+
+## 四、安全建议
+
+1. 文件权限设置
+
+```bash
+# 设置适当的文件权限
+chmod 755 scripts/*.sh
+chmod 600 config/.env
+chmod 700 logs models data
+```
+
+2. 数据库安全
+
+- 使用强密码
+- 限制数据库访问IP
+- 定期备份数据
+
+3. 服务器配置
+
+- 配置防火墙规则
+- 启用 SSL/TLS
+- 定期更新系统
+
+## 五、监控方案
+
+### 1. 系统监控
+
+```bash
+# 检查CPU和内存使用
+top -b -n 1
+
+# 检查磁盘使用
+df -h
+
+# 检查网络连接
+netstat -tunlp
+```
+
+### 2. 应用监控
+
+```bash
+# 检查API响应时间
+curl -w "@curl-format.txt" -o /dev/null -s "http://localhost:5001/api/"
+
+# 检查错误日志
+grep "ERROR" logs/api.log
+```
+
+### 3. 告警设置
+
+- 配置日志告警
+- 设置资源使用阈值告警
+- 配置服务可用性监控
diff --git a/docs/design.md b/docs/dev/design.md
similarity index 100%
rename from docs/design.md
rename to docs/dev/design.md
diff --git a/docs/requirements.md b/docs/dev/requirements.md
similarity index 91%
rename from docs/requirements.md
rename to docs/dev/requirements.md
index a384157..2da3bb4 100644
--- a/docs/requirements.md
+++ b/docs/dev/requirements.md
@@ -16,11 +16,11 @@
(2)线性相关分析:对于特征和标签皆为连续值的回归问题,要检测二者的相关性,最直接的做法就是求相关系数rxy,本质是建立协方差矩阵,分析数据和成本之间相关关系的类型和程度,筛选出影响特征
(3)互信息 (mutual information): 用于特征选择,可以从两个角度进行解释:(1)、基于 KL 散度和 (2)、基于信息增益。
2. 数据一致性分析:对特征数据分层分组,计算组内一致性,目标是选择比较合适的一组数据,以此产生一个进行成本估算和分析的虚拟量.大部分的研究中报告的三个数据:rwg、ICC(1)、ICC(2),要符合3个条件,rwg>0.7、ICC(1)>0.05、ICC(2)>0.5
-RWG值:打分一致性;
-ICC1:组内一致性;
-ICC2:组间一致性。
+ RWG值:打分一致性;
+ ICC1:组内一致性;
+ ICC2:组间一致性。
3. 回归模型:偏最小二乘回归(partial Least Squares,PLS)
-4. 神经网络模型:采用 BP 网络
+4. 神经网络模型:采用适用的神经网络模型
### 数据准备
diff --git a/docs/dev/run.md b/docs/dev/run.md
new file mode 100644
index 0000000..046a422
--- /dev/null
+++ b/docs/dev/run.md
@@ -0,0 +1,164 @@
+# 装备成本估算系统运行说明
+
+## 一、开发环境配置
+
+### 1. 系统要求
+
+- Linux/macOS/Windows
+- Python 3.8+
+- MySQL 8.0+
+
+### 2. 安装依赖
+
+```bash
+# 创建并激活虚拟环境
+python3 -m venv venv
+source venv/bin/activate # Linux/macOS
+# 或
+.\venv\Scripts\activate # Windows
+
+# 安装依赖包
+pip install -r requirements.txt
+```
+
+## 二、初始化系统
+
+### 1. 创建必要目录
+
+```bash
+mkdir -p {logs,data,models}
+```
+
+### 2. 配置数据库
+
+```bash
+# 执行数据库初始化脚本
+mysql -u root -p < src/schema.sql
+
+# [可选] 导入测试数据(仅用于开发环境)
+mysql -u root -p equipment_cost_db < src/init_data.sql
+```
+
+### 3. 环境配置
+
+创建 `.env` 文件:
+
+```ini
+MYSQL_HOST=localhost
+MYSQL_USER=root
+MYSQL_PASSWORD=123456
+MYSQL_DATABASE=equipment_cost_db
+```
+
+## 三、启动服务
+
+### 1. 开发模式
+
+```bash
+# 启动开发服务器
+python run.py
+```
+
+### 2. 测试 API
+
+```bash
+# 运行 API 测试
+python src/test_api.py
+```
+
+## 四、开发调试
+
+### 1. 日志查看
+
+```bash
+# API 日志
+tail -f logs/api.log
+
+# 测试日志
+tail -f logs/test_api.log
+
+# 训练日志
+tail -f logs/training.log
+```
+
+### 2. 数据库调试
+
+```sql
+-- 检查数据表
+SHOW TABLES;
+
+-- 查看示例数据
+SELECT * FROM equipment LIMIT 5;
+```
+
+### 3. API 测试
+
+```bash
+# 测试 API 根路由
+curl http://localhost:5001/api/
+
+# 测试预测接口
+curl -X POST http://localhost:5001/api/predict \
+ -H "Content-Type: application/json" \
+ -d '{
+ "type": "巡飞弹",
+ "length_m": 1.3,
+ "width_m": 0.23,
+ "height_m": 0.23,
+ "weight_kg": 12.5,
+ "max_range_km": 40
+ }'
+```
+
+## 五、注意事项
+
+1. 开发环境配置
+ - 使用虚拟环境隔离依赖
+ - 保持日志目录可写权限
+ - 定期清理日志文件
+
+2. 数据库使用
+ - 使用 UTF-8 字符集
+ - 定期备份数据
+ - 避免直接修改生产数据
+
+3. 代码调试
+ - 查看详细日志输出
+ - 使用测试数据验证功能
+ - 遵循代码规范
+
+## 六、常见问题
+
+1. 数据库连接错误
+ - 检查 MySQL 服务状态
+ - 验证数据库用户名密码
+ - 确认数据库字符集设置
+
+2. API 访问问题
+ - 检查服务是否正常运行
+ - 验证请求格式是否正确
+ - 查看错误日志信息
+
+3. 模型相关问题
+ - 确保训练数据完整性
+ - 检查模型文件权限
+ - 验证预测结果合理性
+
+## 七、开发建议
+
+1. 代码管理
+ - 使用版本控制
+ - 遵循项目结构
+ - 及时更新文档
+
+2. 测试规范
+ - 运行完整测试套件
+ - 验证各个功能模块
+ - 记录测试结果
+
+3. 安全注意
+ - 使用安全的数据库密码
+ - 避免敏感信息提交
+ - 保护测试数据安全
+
+注:生产环境部署请参考 `deploy.md`
diff --git a/docs/nodejs_install.md b/docs/nodejs_install.md
deleted file mode 100644
index 4986f32..0000000
--- a/docs/nodejs_install.md
+++ /dev/null
@@ -1,136 +0,0 @@
-# Node.js 安装指南
-
-## Windows 安装方法
-
-### 1. 使用安装包
-
-1. 访问 Node.js 官网
-2. 下载 14.x LTS 版本安装包
-3. 运行安装包,按提示完成安装
-4. 验证安装:
-
-```bash
-node --version
-npm --version
-```
-
-### 2. 使用 nvm-windows(推荐)
-
-1. 下载 nvm-windows:
-2. 安装 nvm-windows
-3. 安装 Node.js:
-
-```bash
-nvm install 14.21.3
-nvm use 14.21.3
-```
-
-## Linux 安装方法
-
-### 1. 使用 nvm(推荐)
-
-```bash
-# 安装 nvm
-curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
-
-# 重新加载配置
-source ~/.bashrc
-
-# 安装 Node.js 14
-nvm install 14
-nvm use 14
-```
-
-### 2. 使用包管理器
-
-#### Ubuntu/Debian
-
-```bash
-# 添加 NodeSource 仓库
-curl -fsSL https://deb.nodesource.com/setup_14.x | sudo -E bash -
-
-# 安装 Node.js
-sudo apt-get install -y nodejs
-```
-
-#### CentOS/RHEL
-
-```bash
-# 添加 NodeSource 仓库
-curl -fsSL https://rpm.nodesource.com/setup_14.x | sudo bash -
-
-# 安装 Node.js
-sudo yum install -y nodejs
-```
-
-## macOS 安装方法
-
-### 1. 使用 Homebrew(推荐)
-
-```bash
-# 安装 Homebrew(如果未安装)
-/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
-
-# 安装 Node.js 14
-brew install node@14
-
-# 添加环境变量
-echo 'export PATH="/usr/local/opt/node@14/bin:$PATH"' >> ~/.zshrc
-source ~/.zshrc
-```
-
-### 2. 使用 nvm
-
-```bash
-# 安装 nvm
-curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash
-
-# 重新加载配置
-source ~/.zshrc
-
-# 安装 Node.js 14
-nvm install 14
-nvm use 14
-```
-
-## 验证安装
-
-安装完成后,运行以下命令验证:
-
-```bash
-# 检查 Node.js 版本
-node --version # 应显示 v14.x.x
-
-# 检查 npm 版本
-npm --version # 应显示 6.x.x 或更高
-```
-
-## 常见问题
-
-### 1. 权限问题
-
-如果遇到权限错误,可以:
-
-```bash
-# Linux/macOS
-sudo chown -R $USER /usr/local/lib/node_modules
-```
-
-### 2. 版本切换
-
-如果需要在不同版本间切换:
-
-```bash
-# 使用 nvm
-nvm list # 查看已安装版本
-nvm use 14 # 切换到 14.x 版本
-```
-
-### 3. npm 配置
-
-建议配置国内镜像源:
-
-```bash
-# 使用淘宝镜像
-npm config set registry https://registry.npmmirror.com
-```
diff --git a/docs/run.md b/docs/run.md
deleted file mode 100644
index e92716b..0000000
--- a/docs/run.md
+++ /dev/null
@@ -1,163 +0,0 @@
-# 系统运行说明
-
-## 一、环境准备
-
-### 1. 安装必要软件
-
-```bash
-# 安装 Python 3.8+
-# 安装 MySQL 8.0+
-# 安装 Node.js 14+
-```
-
-### 2. 安装 Python 依赖
-
-```bash
-pip install -r requirements.txt
-```
-
-### 3. 安装前端依赖
-
-```bash
-cd frontend
-npm install
-```
-
-## 二、数据库配置
-
-### 1. 创建数据库
-
-```sql
-CREATE DATABASE equipment_cost_db DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
-```
-
-### 2. 初始化数据库结构
-
-```bash
-# 执行数据库结构初始化脚本
-mysql -u username -p equipment_cost_db < src/schema.sql
-
-# 导入示例数据
-mysql -u username -p equipment_cost_db < src/init_data.sql
-
-# 导入真实数据
-mysql -u username -p equipment_cost_db < src/real_data.sql
-```
-
-## 三、配置文件
-
-### 1. 后端配置
-
-创建 `config.py` 文件:
-
-```python
-# config.py
-DATABASE_URI = "mysql+pymysql://username:password@localhost:3306/equipment_cost_db"
-SECRET_KEY = "your-secret-key"
-DEBUG = True
-```
-
-### 2. 前端配置
-
-修改 `frontend/src/config.js`:
-
-```javascript
-export const API_BASE_URL = 'http://localhost:5001/api';
-```
-
-## 四、启动系统
-
-### 1. 启动后端服务
-
-```bash
-# 开发环境
-python run.py # 服务将在 http://localhost:5001 启动
-
-# 生产环境
-gunicorn -w 4 -b 0.0.0.0:5001 run:app
-```
-
-### 2. 启动前端服务
-
-```bash
-# 开发环境
-cd frontend
-npm run serve # 前端将在 http://localhost:8080 启动
-
-# 生产环境
-npm run build
-```
-
-## 五、访问系统
-
-- 后端API:
-- 前端界面:
-
-## 六、常见问题
-
-### 1. 数据库连接问题
-
-- 检查 MySQL 服务是否启动
-- 验证数据库用户名和密码
-- 确认数据库端口是否正确
-
-### 2. 模型训练
-
-```bash
-# 训练模型
-python src/train_model.py
-
-# 查看训练日志
-tail -f logs/training.log
-```
-
-### 3. 系统监控
-
-```bash
-# 查看系统日志
-tail -f logs/app.log
-
-# 监控API请求
-tail -f logs/access.log
-```
-
-## 七、开发调试
-
-### 1. 后端调试
-
-```bash
-# 启动调试模式
-python run.py --debug
-
-# 运行测试
-python -m pytest tests/
-```
-
-### 2. 前端调试
-
-```bash
-# 启动开发服务器
-npm run serve
-
-# 运行测试
-npm run test
-```
-
-## 八、部署建议
-
-### 1. 使用 Docker 部署
-
-```bash
-# 构建镜像
-docker-compose build
-
-# 启动服务
-docker-compose up -d
-```
-
-### 2. 生产环境配置
-
-- 使用 Nginx 作为反向代理
-- 配置 SSL 证书
-- 设置适当的防火墙规则
-- 启用数据库备份
diff --git a/frontend/README.md b/frontend/README.md
index 576b980..f08b1b5 100644
--- a/frontend/README.md
+++ b/frontend/README.md
@@ -1,24 +1,29 @@
# frontend
## Project setup
-```
+
+```bash
npm install
```
### Compiles and hot-reloads for development
-```
+
+```bash
npm run serve
```
### Compiles and minifies for production
-```
+
+```bash
npm run build
```
### Lints and fixes files
-```
+
+```bash
npm run lint
```
### Customize configuration
+
See [Configuration Reference](https://cli.vuejs.org/config/).
diff --git a/frontend/src/views/AnalysisPage.vue b/frontend/src/views/AnalysisPage.vue
index f8f8ba1..385b995 100644
--- a/frontend/src/views/AnalysisPage.vue
+++ b/frontend/src/views/AnalysisPage.vue
@@ -153,31 +153,93 @@ const startAnalysis = async () => {
}
}
+// 窗口大小变化处理函数
+const resizeHandler = ref(null)
+
+// 创建 resize 处理函数
+const createResizeHandler = () => {
+ const handler = () => {
+ try {
+ if (importanceChart.value && !importanceChart.value.isDisposed()) {
+ importanceChart.value.resize()
+ }
+ if (correlationChart.value && !correlationChart.value.isDisposed()) {
+ correlationChart.value.resize()
+ }
+ } catch (error) {
+ console.error('Error in resize handler:', error)
+ }
+ }
+
+ // 使用防抖包装
+ return debounce(handler, 200)
+}
+
+// 组件挂载时
+onMounted(() => {
+ // 创建并保存 resize 处理函数的引用
+ resizeHandler.value = createResizeHandler()
+ window.addEventListener('resize', resizeHandler.value)
+})
+
+// 组件卸载时
+onUnmounted(() => {
+ // 移除事件监听
+ if (resizeHandler.value) {
+ window.removeEventListener('resize', resizeHandler.value)
+ resizeHandler.value = null
+ }
+
+ // 销毁图表实例
+ try {
+ if (importanceChart.value) {
+ importanceChart.value.dispose()
+ importanceChart.value = null
+ }
+ if (correlationChart.value) {
+ correlationChart.value.dispose()
+ correlationChart.value = null
+ }
+ } catch (error) {
+ console.error('Error disposing charts:', error)
+ }
+})
+
// 渲染图表
const renderCharts = () => {
console.log('Starting to render charts')
- // 销毁旧的图表实例
- if (importanceChart.value) {
- importanceChart.value.dispose()
- }
- if (correlationChart.value) {
- correlationChart.value.dispose()
+ // 检查分析结果是否存在且包含必要数据
+ if (!analysisResult.value ||
+ !analysisResult.value.important_features ||
+ !analysisResult.value.correlation_analysis) {
+ console.log('Analysis result not ready')
+ return
}
- // 确保 DOM 元素存在
+ // 检查DOM元素
if (!importanceChartRef.value || !correlationChartRef.value) {
console.log('Chart DOM elements not ready')
return
}
try {
+ // 销毁旧的图表实例
+ if (importanceChart.value) {
+ importanceChart.value.dispose()
+ importanceChart.value = null
+ }
+ if (correlationChart.value) {
+ correlationChart.value.dispose()
+ correlationChart.value = null
+ }
+
// 创建新的图表实例
importanceChart.value = echarts.init(importanceChartRef.value)
correlationChart.value = echarts.init(correlationChartRef.value)
// 设置图表选项
- importanceChart.value.setOption({
+ const importanceOption = {
title: { text: '特征重要性排序' },
tooltip: {},
xAxis: {
@@ -189,12 +251,13 @@ const renderCharts = () => {
data: analysisResult.value.important_features.map(f => f.name)
},
series: [{
+ name: '重要性',
type: 'bar',
data: analysisResult.value.important_features.map(f => f.importance)
}]
- })
+ }
- correlationChart.value.setOption({
+ const correlationOption = {
title: { text: '特征相关性热力图' },
tooltip: {
position: 'top',
@@ -233,6 +296,7 @@ const renderCharts = () => {
color: ['#cc3333', '#eeeeee', '#00007f']
},
series: [{
+ name: '相关性',
type: 'heatmap',
data: analysisResult.value.correlation_analysis.matrix,
label: {
@@ -242,13 +306,11 @@ const renderCharts = () => {
}
}
}]
- })
+ }
- // 监听窗口大小变化
- window.addEventListener('resize', () => {
- importanceChart.value?.resize()
- correlationChart.value?.resize()
- })
+ // 设置图表选项
+ importanceChart.value.setOption(importanceOption)
+ correlationChart.value.setOption(correlationOption)
console.log('Charts rendered successfully')
} catch (error) {
@@ -256,16 +318,16 @@ const renderCharts = () => {
}
}
-// 初始化
-onMounted(() => {
- // 可以在这里加载默认数据
-})
-
-// 组件卸载时清理图表实例
-onUnmounted(() => {
- importanceChart.value?.dispose()
- correlationChart.value?.dispose()
-})
+// 防抖函数
+function debounce(fn, delay) {
+ let timer = null
+ return function (...args) {
+ if (timer) clearTimeout(timer)
+ timer = setTimeout(() => {
+ fn.apply(this, args)
+ }, delay)
+ }
+}