将mysql改成sqlite,减少前端依赖
This commit is contained in:
parent
137451ba7a
commit
48ba547c36
29
.claude/settings.local.json
Normal file
29
.claude/settings.local.json
Normal file
File diff suppressed because one or more lines are too long
108
CLAUDE.md
Normal file
108
CLAUDE.md
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
本文件为 Claude Code(claude.ai/code)在此仓库中工作提供指引。
|
||||||
|
|
||||||
|
## 常用命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装后端依赖(核心:无需 MySQL、无需 PyTorch)
|
||||||
|
pip install -e .
|
||||||
|
pip install -e ".[dev]" # 含开发工具(pytest, black, mypy)
|
||||||
|
pip install -e ".[torch]" # 安装可选 PyTorch(神经网络训练需要)
|
||||||
|
|
||||||
|
# 启动后端(Flask,0.0.0.0:5001)
|
||||||
|
python run.py
|
||||||
|
|
||||||
|
# 启动前端开发服务器(另开终端,localhost:3000)
|
||||||
|
cd frontend && npm install && npm run serve
|
||||||
|
|
||||||
|
# 构建前端生产版本
|
||||||
|
cd frontend && npm run build
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
python -m pytest tests/
|
||||||
|
python -m pytest tests/test_demo_service.py -q # 单个测试文件
|
||||||
|
python -m pytest tests/test_demo_routes.py -q
|
||||||
|
|
||||||
|
# 代码格式化 / 类型检查
|
||||||
|
black src/ tests/ # 自动格式化(line-length=88)
|
||||||
|
mypy src/ # 类型检查
|
||||||
|
|
||||||
|
# 运行独立演示模式(仅需 5 个依赖)
|
||||||
|
cd demo_standalone && pip install -r requirements.txt && python server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目概述
|
||||||
|
|
||||||
|
基于机器学习的装备成本预测系统。支持两种装备类型:**火箭炮**和**巡飞弹**。
|
||||||
|
|
||||||
|
### 技术栈
|
||||||
|
|
||||||
|
- **后端**:Python 3.9-3.11, Flask 3.1+
|
||||||
|
- **数据库**:SQLite(Python 内置,零外部依赖),首次启动自动建表,无需手动安装配置
|
||||||
|
- **机器学习**:scikit-learn(核心)、XGBoost、LightGBM
|
||||||
|
- **可选依赖**:PyTorch(仅神经网络训练需要,约 800MB)
|
||||||
|
- **前端**:Vue 3 (Composition API) + Element Plus + ECharts, 使用 Vite 构建
|
||||||
|
|
||||||
|
### 关键文件
|
||||||
|
|
||||||
|
| 文件 | 用途 |
|
||||||
|
|---|---|
|
||||||
|
| `run.py` | 入口点,启动 Flask 服务器 |
|
||||||
|
| `src/app.py` | Flask 应用工厂 |
|
||||||
|
| `src/routes.py` | 所有 API 路由(约 1300 行) |
|
||||||
|
| `src/model_trainer.py` | ModelTrainer 类 + CostPredictionModel(PyTorch NN) |
|
||||||
|
| `src/data_preparation.py` | DataPreparation + EquipmentDataset |
|
||||||
|
| `src/cost_prediction.py` | CostPredictor(预测编排) |
|
||||||
|
| `src/demo_service.py` | DemoModelService(基于 CSV,无需数据库) |
|
||||||
|
| `src/database/db_connection.py` | SQLite 数据库连接 + 建表 DDL(内置) |
|
||||||
|
| `config.py` | 运行时配置 |
|
||||||
|
| `frontend/src/router/index.js` | Vue Router,共 8 个路由 |
|
||||||
|
| `frontend/src/api/index.js` | Axios API 客户端 |
|
||||||
|
|
||||||
|
## 架构
|
||||||
|
|
||||||
|
### 与旧架构的核心变化
|
||||||
|
|
||||||
|
| 项目 | 旧架构(MySQL) | 新架构(SQLite) |
|
||||||
|
|------|----------------|-----------------|
|
||||||
|
| 数据库 | MySQL 8.0+,需单独安装运行 | SQLite,Python 内置,零依赖 |
|
||||||
|
| 数据库依赖 | sqlalchemy, pymysql, cryptography, mysql-connector-python | 无(仅用 Python 标准库 sqlite3) |
|
||||||
|
| PyTorch | 硬依赖(顶层 import,无则崩溃) | 可选依赖(try/except 保护,无 PyTorch 可启动) |
|
||||||
|
| 前端构建 | Vue CLI + Babel + SCSS | Vite(无需 Babel/SCSS) |
|
||||||
|
| Vuex | 存在但完全空 | 已移除 |
|
||||||
|
| 依赖总数(核心)| 15+ | 5(flask, numpy, pandas, scikit-learn, openpyxl) |
|
||||||
|
|
||||||
|
### 训练数据流
|
||||||
|
```
|
||||||
|
用户界面 → POST /api/train → 查询 SQLite → DataPreparation(特征提取 + 标准化)
|
||||||
|
→ ModelTrainer.fit_model()(训练 XGBoost, LightGBM, RF, GBM, PyTorch NN, PLS)
|
||||||
|
→ 保存最优模型 → 写入 SQLite trained_models 表
|
||||||
|
```
|
||||||
|
|
||||||
|
### 预测数据流
|
||||||
|
```
|
||||||
|
用户界面 → POST /api/predict → 从 SQLite 加载最优模型 + 标准化器 → 提取特征
|
||||||
|
→ 特征标准化 → 预测 → ±20% 置信区间 → JSON 返回
|
||||||
|
```
|
||||||
|
|
||||||
|
### 机器学习模型
|
||||||
|
|
||||||
|
| 模型 | 标识 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| XGBoost | `xgboost` | 针对小数据量使用保守参数 |
|
||||||
|
| LightGBM | `lightgbm` | 针对小数据量使用保守参数 |
|
||||||
|
| Random Forest | `rf` | sklearn.ensemble |
|
||||||
|
| Gradient Boosting | `gbm` | sklearn.ensemble |
|
||||||
|
| PyTorch NN | `pytorch` | 可选(需安装 torch),针对不同装备类型定制网络结构 |
|
||||||
|
| PLS 回归 | `pls` | 不参与最优模型评选 |
|
||||||
|
| Linear/Ridge/SVR/KNN | 仅演示 | 用于算法对比演示 |
|
||||||
|
|
||||||
|
## 重要注意事项
|
||||||
|
|
||||||
|
- **小数据量机器学习**:所有模型超参数均为保守设置(强正则化、浅树、低学习率、早停)
|
||||||
|
- **两种装备类型**:火箭炮(27+ 特征)和巡飞弹(24+ 特征),各自有独立数据库表和神经网络结构
|
||||||
|
- **生产商议价能力特征**:技术等级、规模、供应链、地区等信息被纳入特征,地区成本乘数(如美国 1.2 倍、中国 0.8 倍)
|
||||||
|
- **SQLite 首次使用**:`data/equipment_cost.db` 文件在首次数据库操作时自动创建,无需手动初始化
|
||||||
|
- **编码规范**:UTF-8 无 BOM,LF 换行符,文件末尾保留换行符
|
||||||
|
- **中文内容**:绝不修改中文注释或文本。编辑中文附近代码时,完整保留原有中文内容
|
||||||
@ -2,11 +2,8 @@ import os
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""配置类"""
|
"""配置类"""
|
||||||
# 数据库配置
|
# 数据库配置(使用 SQLite)
|
||||||
MYSQL_HOST = os.getenv('MYSQL_HOST', 'localhost')
|
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
|
||||||
MYSQL_USER = os.getenv('MYSQL_USER', 'root')
|
|
||||||
MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '123456')
|
|
||||||
MYSQL_DB = os.getenv('MYSQL_DB', 'equipment_cost_db')
|
|
||||||
|
|
||||||
# Flask配置
|
# Flask配置
|
||||||
FLASK_HOST = '0.0.0.0'
|
FLASK_HOST = '0.0.0.0'
|
||||||
|
|||||||
@ -1,5 +0,0 @@
|
|||||||
module.exports = {
|
|
||||||
presets: [
|
|
||||||
'@vue/cli-plugin-babel/preset'
|
|
||||||
]
|
|
||||||
}
|
|
||||||
17
frontend/index.html
Normal file
17
frontend/index.html
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||||
|
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||||
|
<link rel="icon" href="/favicon.ico">
|
||||||
|
<title>装备成本估算系统</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<noscript>
|
||||||
|
<strong>装备成本估算系统需要启用 JavaScript 才能运行。</strong>
|
||||||
|
</noscript>
|
||||||
|
<div id="app"></div>
|
||||||
|
<script type="module" src="/src/main.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
11130
frontend/package-lock.json
generated
11130
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -7,33 +7,24 @@
|
|||||||
"npm": ">=8"
|
"npm": ">=8"
|
||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"serve": "vue-cli-service serve",
|
"serve": "vite",
|
||||||
"build": "vue-cli-service build",
|
"build": "vite build",
|
||||||
"lint": "vue-cli-service lint"
|
"lint": "eslint --ext .js,.vue src/"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"axios": "^1.6.0",
|
"axios": "^1.6.0",
|
||||||
"core-js": "^3.8.3",
|
|
||||||
"echarts": "^5.4.3",
|
"echarts": "^5.4.3",
|
||||||
"element-plus": "^2.4.2",
|
"element-plus": "^2.4.2",
|
||||||
"vue": "^3.2.13",
|
"vue": "^3.2.13",
|
||||||
"vue-router": "^4.0.3",
|
"vue-router": "^4.0.3"
|
||||||
"vuex": "^4.0.0"
|
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@babel/core": "^7.12.16",
|
|
||||||
"@babel/eslint-parser": "^7.12.16",
|
|
||||||
"@element-plus/icons-vue": "^2.3.1",
|
"@element-plus/icons-vue": "^2.3.1",
|
||||||
"@vue/cli-plugin-babel": "~5.0.0",
|
"@vitejs/plugin-vue": "^5.0.0",
|
||||||
"@vue/cli-plugin-eslint": "~5.0.0",
|
|
||||||
"@vue/cli-plugin-router": "~5.0.0",
|
|
||||||
"@vue/cli-plugin-vuex": "~5.0.0",
|
|
||||||
"@vue/cli-service": "~5.0.0",
|
|
||||||
"@vue/compiler-sfc": "^3.2.13",
|
|
||||||
"eslint": "^7.32.0",
|
"eslint": "^7.32.0",
|
||||||
"eslint-plugin-vue": "^8.0.3",
|
"eslint-plugin-vue": "^8.0.3",
|
||||||
"sass": "^1.32.7",
|
"sass-embedded": "^1.99.0",
|
||||||
"sass-loader": "^12.0.0"
|
"vite": "^5.0.0"
|
||||||
},
|
},
|
||||||
"eslintConfig": {
|
"eslintConfig": {
|
||||||
"root": true,
|
"root": true,
|
||||||
@ -45,17 +36,12 @@
|
|||||||
"eslint:recommended"
|
"eslint:recommended"
|
||||||
],
|
],
|
||||||
"parserOptions": {
|
"parserOptions": {
|
||||||
"parser": "@babel/eslint-parser"
|
"ecmaVersion": "latest",
|
||||||
|
"sourceType": "module"
|
||||||
},
|
},
|
||||||
"rules": {
|
"rules": {
|
||||||
"vue/multi-word-component-names": "off",
|
"vue/multi-word-component-names": "off",
|
||||||
"no-unused-vars": "warn"
|
"no-unused-vars": "warn"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"browserslist": [
|
|
||||||
"> 1%",
|
|
||||||
"last 2 versions",
|
|
||||||
"not dead",
|
|
||||||
"not ie 11"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,17 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="">
|
|
||||||
<head>
|
|
||||||
<meta charset="utf-8">
|
|
||||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
|
||||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
|
||||||
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
|
|
||||||
<title><%= htmlWebpackPlugin.options.title %></title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<noscript>
|
|
||||||
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
|
|
||||||
</noscript>
|
|
||||||
<div id="app"></div>
|
|
||||||
<!-- built files will be auto injected -->
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@ -1,7 +1,6 @@
|
|||||||
import { createApp } from 'vue'
|
import { createApp } from 'vue'
|
||||||
import App from './App.vue'
|
import App from './App.vue'
|
||||||
import router from './router'
|
import router from './router'
|
||||||
import store from './store'
|
|
||||||
import ElementPlus from 'element-plus'
|
import ElementPlus from 'element-plus'
|
||||||
import 'element-plus/dist/index.css'
|
import 'element-plus/dist/index.css'
|
||||||
import './assets/styles/global.css'
|
import './assets/styles/global.css'
|
||||||
@ -15,7 +14,6 @@ app.use(ElementPlus, {
|
|||||||
size: 'default'
|
size: 'default'
|
||||||
})
|
})
|
||||||
app.use(router)
|
app.use(router)
|
||||||
app.use(store)
|
|
||||||
|
|
||||||
// 注册图标
|
// 注册图标
|
||||||
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
|
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
|
||||||
|
|||||||
@ -1,14 +0,0 @@
|
|||||||
import { createStore } from 'vuex'
|
|
||||||
|
|
||||||
export default createStore({
|
|
||||||
state: {
|
|
||||||
},
|
|
||||||
getters: {
|
|
||||||
},
|
|
||||||
mutations: {
|
|
||||||
},
|
|
||||||
actions: {
|
|
||||||
},
|
|
||||||
modules: {
|
|
||||||
}
|
|
||||||
})
|
|
||||||
@ -19,7 +19,7 @@ export default defineConfig({
|
|||||||
port: 3000,
|
port: 3000,
|
||||||
proxy: {
|
proxy: {
|
||||||
'/api': {
|
'/api': {
|
||||||
target: 'http://localhost:5000',
|
target: 'http://localhost:5001',
|
||||||
changeOrigin: true
|
changeOrigin: true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,12 +11,6 @@ dependencies = [
|
|||||||
"flask>=3.1.0",
|
"flask>=3.1.0",
|
||||||
"flask-cors>=5.0.0",
|
"flask-cors>=5.0.0",
|
||||||
|
|
||||||
# 数据库
|
|
||||||
"sqlalchemy>=2.0.36",
|
|
||||||
"pymysql>=1.1.1",
|
|
||||||
"cryptography>=43.0.0",
|
|
||||||
"mysql-connector-python>=8.0.0",
|
|
||||||
|
|
||||||
# 数据处理
|
# 数据处理
|
||||||
"numpy>=1.26.0,<2.0.0",
|
"numpy>=1.26.0,<2.0.0",
|
||||||
"pandas>=2.2.0",
|
"pandas>=2.2.0",
|
||||||
@ -25,7 +19,6 @@ dependencies = [
|
|||||||
"scikit-learn>=1.5.2",
|
"scikit-learn>=1.5.2",
|
||||||
"xgboost>=2.1.0",
|
"xgboost>=2.1.0",
|
||||||
"lightgbm>=4.5.0",
|
"lightgbm>=4.5.0",
|
||||||
"torch==2.5.1",
|
|
||||||
|
|
||||||
# 工具
|
# 工具
|
||||||
"openpyxl>=3.1.5",
|
"openpyxl>=3.1.5",
|
||||||
@ -34,6 +27,10 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
# PyTorch 为可选依赖(安装约 800MB,仅训练神经网络时需要)
|
||||||
|
torch = [
|
||||||
|
"torch==2.5.1",
|
||||||
|
]
|
||||||
dev = [
|
dev = [
|
||||||
# 测试工具
|
# 测试工具
|
||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
|
|||||||
@ -1,15 +1,12 @@
|
|||||||
flask>=3.1.0
|
flask>=3.1.0
|
||||||
flask-cors>=5.0.0
|
flask-cors>=5.0.0
|
||||||
sqlalchemy>=2.0.36
|
|
||||||
pymysql>=1.1.1
|
|
||||||
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
|
|
||||||
mysql-connector-python>=8.0.0 # 添加这行
|
|
||||||
numpy>=1.26.0,<2.0.0
|
numpy>=1.26.0,<2.0.0
|
||||||
pandas>=2.2.0
|
pandas>=2.2.0
|
||||||
|
|
||||||
xgboost>=2.1.0
|
xgboost>=2.1.0
|
||||||
lightgbm>=4.5.0
|
lightgbm>=4.5.0
|
||||||
|
|
||||||
scikit-learn>=1.5.2
|
scikit-learn>=1.5.2
|
||||||
|
|
||||||
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
|
openpyxl>=3.1.5
|
||||||
python-dotenv>=1.0.0 # 环境变量
|
python-dotenv>=1.0.0
|
||||||
|
|||||||
@ -34,7 +34,6 @@ pyinstaller --clean `
|
|||||||
--add-data "src/loitering_munition_data.sql;data" `
|
--add-data "src/loitering_munition_data.sql;data" `
|
||||||
--add-data "src/rocket_artillery_data.sql;data" `
|
--add-data "src/rocket_artillery_data.sql;data" `
|
||||||
--add-data "src/manufacturer_data.sql;data" `
|
--add-data "src/manufacturer_data.sql;data" `
|
||||||
--add-data "src/schema.sql;data" `
|
|
||||||
--add-data "config.py;." `
|
--add-data "config.py;." `
|
||||||
--add-data "src;src" `
|
--add-data "src;src" `
|
||||||
--add-data "frontend;frontend" `
|
--add-data "frontend;frontend" `
|
||||||
@ -43,18 +42,11 @@ pyinstaller --clean `
|
|||||||
--add-data "models;models" `
|
--add-data "models;models" `
|
||||||
--collect-all "xgboost" `
|
--collect-all "xgboost" `
|
||||||
--collect-all "lightgbm" `
|
--collect-all "lightgbm" `
|
||||||
--collect-all "torch" `
|
|
||||||
--collect-all "sklearn" `
|
--collect-all "sklearn" `
|
||||||
--collect-all "numpy" `
|
--collect-all "numpy" `
|
||||||
--collect-all "pandas" `
|
--collect-all "pandas" `
|
||||||
--collect-all "sqlalchemy" `
|
|
||||||
--collect-all "pymysql" `
|
|
||||||
--collect-all "cryptography" `
|
|
||||||
--collect-all "flask" `
|
--collect-all "flask" `
|
||||||
--collect-all "flask_cors" `
|
--collect-all "flask_cors" `
|
||||||
--hidden-import "xgboost.testing" `
|
|
||||||
--hidden-import "torch.utils.tensorboard" `
|
|
||||||
--hidden-import "pytest" `
|
|
||||||
run.py
|
run.py
|
||||||
|
|
||||||
# Copy necessary files
|
# Copy necessary files
|
||||||
|
|||||||
@ -15,12 +15,6 @@ def create_app():
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app)
|
CORS(app)
|
||||||
|
|
||||||
# 配置数据库连接
|
|
||||||
app.config['MYSQL_HOST'] = config.MYSQL_HOST
|
|
||||||
app.config['MYSQL_USER'] = config.MYSQL_USER
|
|
||||||
app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
|
|
||||||
app.config['MYSQL_DB'] = config.MYSQL_DB
|
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
app.register_blueprint(api_bp, url_prefix='/api')
|
app.register_blueprint(api_bp, url_prefix='/api')
|
||||||
logger.info("API blueprint registered")
|
logger.info("API blueprint registered")
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
@ -9,6 +8,16 @@ from src.database import get_db_connection
|
|||||||
from src.feature_analysis import FeatureAnalysis
|
from src.feature_analysis import FeatureAnalysis
|
||||||
from .logger import setup_logger
|
from .logger import setup_logger
|
||||||
|
|
||||||
|
# PyTorch 为可选依赖
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
_HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
nn = None
|
||||||
|
_HAS_TORCH = False
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
class CostPredictor:
|
class CostPredictor:
|
||||||
@ -18,7 +27,11 @@ class CostPredictor:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.feature_analyzer = FeatureAnalysis()
|
self.feature_analyzer = FeatureAnalysis()
|
||||||
self.equipment_type = None
|
self.equipment_type = None
|
||||||
|
|
||||||
|
if _HAS_TORCH:
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
else:
|
||||||
|
self.device = None
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
@ -26,10 +39,9 @@ class CostPredictor:
|
|||||||
"""
|
"""
|
||||||
加载预训练模型和标准化器
|
加载预训练模型和标准化器
|
||||||
"""
|
"""
|
||||||
|
if _HAS_TORCH:
|
||||||
try:
|
try:
|
||||||
# 创建默认模型
|
|
||||||
self._create_default_model()
|
self._create_default_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error loading model: {str(e)}")
|
logging.error(f"Error loading model: {str(e)}")
|
||||||
self._create_default_model()
|
self._create_default_model()
|
||||||
@ -38,46 +50,36 @@ class CostPredictor:
|
|||||||
"""
|
"""
|
||||||
创建默认模型并进行初始化训练
|
创建默认模型并进行初始化训练
|
||||||
"""
|
"""
|
||||||
import torch.nn as nn
|
if not _HAS_TORCH:
|
||||||
|
raise ImportError("PyTorch is not installed.")
|
||||||
|
|
||||||
class DefaultModel(nn.Module):
|
class DefaultModel(nn.Module):
|
||||||
def __init__(self, input_size):
|
def __init__(self, input_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.Sequential(
|
self.layers = nn.Sequential(
|
||||||
nn.Linear(input_size, 64),
|
nn.Linear(input_size, 64), nn.ReLU(),
|
||||||
nn.ReLU(),
|
nn.Linear(64, 32), nn.ReLU(),
|
||||||
nn.Linear(64, 32),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(32, 1)
|
nn.Linear(32, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
# 创建示例数据
|
|
||||||
example_features = {
|
example_features = {
|
||||||
'length_m': [7.35, 10.2],
|
'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8],
|
||||||
'width_m': [2.4, 2.8],
|
'height_m': [3.1, 3.2], 'weight_kg': [13700, 28500],
|
||||||
'height_m': [3.1, 3.2],
|
|
||||||
'weight_kg': [13700, 28500],
|
|
||||||
'max_range_km': [20.4, 70],
|
'max_range_km': [20.4, 70],
|
||||||
'firing_angle_horizontal': [102, 110],
|
'firing_angle_horizontal': [102, 110], 'firing_angle_vertical': [55, 60],
|
||||||
'firing_angle_vertical': [55, 60],
|
'rocket_length_m': [2.87, 4.1], 'rocket_diameter_mm': [122, 220],
|
||||||
'rocket_length_m': [2.87, 4.1],
|
'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60]
|
||||||
'rocket_diameter_mm': [122, 220],
|
|
||||||
'rocket_weight_kg': [66.6, 150],
|
|
||||||
'rate_of_fire': [40, 60]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 转换为 tensor
|
|
||||||
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
|
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
|
||||||
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
|
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
|
||||||
|
|
||||||
# 训练标准化器
|
|
||||||
self.scaler_X.fit(X.numpy())
|
self.scaler_X.fit(X.numpy())
|
||||||
self.scaler_y.fit(y.numpy())
|
self.scaler_y.fit(y.numpy())
|
||||||
|
|
||||||
# 创建模型
|
|
||||||
self.model = DefaultModel(X.shape[1]).to(self.device)
|
self.model = DefaultModel(X.shape[1]).to(self.device)
|
||||||
self.equipment_type = '火箭炮'
|
self.equipment_type = '火箭炮'
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import logging
|
import logging
|
||||||
from src.feature_analysis import FeatureAnalysis
|
from src.feature_analysis import FeatureAnalysis
|
||||||
from src.database import get_db_connection
|
from src.database import get_db_connection
|
||||||
@ -9,9 +7,22 @@ from .logger import setup_logger
|
|||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
class EquipmentDataset(Dataset):
|
# PyTorch 为可选依赖
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
_HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
Dataset = object
|
||||||
|
DataLoader = None
|
||||||
|
_HAS_TORCH = False
|
||||||
|
|
||||||
|
class EquipmentDataset(Dataset if _HAS_TORCH else object):
|
||||||
"""装备数据集类"""
|
"""装备数据集类"""
|
||||||
def __init__(self, features, targets=None):
|
def __init__(self, features, targets=None):
|
||||||
|
if not _HAS_TORCH:
|
||||||
|
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||||
self.features = torch.FloatTensor(features)
|
self.features = torch.FloatTensor(features)
|
||||||
self.targets = torch.FloatTensor(targets) if targets is not None else None
|
self.targets = torch.FloatTensor(targets) if targets is not None else None
|
||||||
|
|
||||||
@ -41,7 +52,7 @@ class DataPreparation:
|
|||||||
|
|
||||||
# 获取数据库连接
|
# 获取数据库连接
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 获取所有生产商数据,用于计算特征
|
# 获取所有生产商数据,用于计算特征
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
|
|||||||
@ -1,37 +1,227 @@
|
|||||||
import mysql.connector
|
import sqlite3
|
||||||
from mysql.connector import Error
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
|
||||||
from ..logger import setup_logger
|
from ..logger import setup_logger
|
||||||
|
|
||||||
# 获取logger
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
# 加载环境变量
|
# SQLite 数据库文件路径(相对于项目根目录)
|
||||||
load_dotenv()
|
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data')
|
||||||
|
DB_PATH = os.path.join(DB_DIR, 'equipment_cost.db')
|
||||||
|
|
||||||
|
# 建表 SQL
|
||||||
|
SCHEMA_SQL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS equipments (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT,
|
||||||
|
type TEXT,
|
||||||
|
manufacturer TEXT,
|
||||||
|
manufacturer_id INTEGER,
|
||||||
|
created_at TEXT DEFAULT (datetime('now','localtime'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS common_params (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
equipment_id INTEGER,
|
||||||
|
length_m REAL,
|
||||||
|
width_m REAL,
|
||||||
|
height_m REAL,
|
||||||
|
weight_kg REAL,
|
||||||
|
max_range_km REAL,
|
||||||
|
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS rocket_artillery_params (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
equipment_id INTEGER,
|
||||||
|
firing_angle_horizontal REAL,
|
||||||
|
firing_angle_vertical REAL,
|
||||||
|
rocket_length_m REAL,
|
||||||
|
rocket_diameter_mm REAL,
|
||||||
|
rocket_weight_kg REAL,
|
||||||
|
rate_of_fire REAL,
|
||||||
|
combat_weight_kg REAL,
|
||||||
|
speed_kmh REAL,
|
||||||
|
min_range_km REAL,
|
||||||
|
max_range_km REAL,
|
||||||
|
mobility_type TEXT,
|
||||||
|
structure_layout TEXT,
|
||||||
|
engine_model TEXT,
|
||||||
|
engine_params TEXT,
|
||||||
|
power_hp REAL,
|
||||||
|
travel_range_km REAL,
|
||||||
|
fire_density REAL,
|
||||||
|
range_ratio REAL,
|
||||||
|
mobility_score INTEGER,
|
||||||
|
combat_readiness_score INTEGER,
|
||||||
|
deployment_score INTEGER,
|
||||||
|
terrain_adaptability_score INTEGER,
|
||||||
|
rocket_power_ratio REAL,
|
||||||
|
platform_efficiency REAL,
|
||||||
|
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS loitering_munition_params (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
equipment_id INTEGER,
|
||||||
|
wingspan_m REAL,
|
||||||
|
warhead_weight_kg REAL,
|
||||||
|
max_speed_ms REAL,
|
||||||
|
cruise_speed_kmh REAL,
|
||||||
|
endurance_min REAL,
|
||||||
|
flight_time_min REAL,
|
||||||
|
max_range_km REAL,
|
||||||
|
max_payload_kg REAL,
|
||||||
|
ceiling_altitude_m REAL,
|
||||||
|
combat_radius_km REAL,
|
||||||
|
folded_length_mm REAL,
|
||||||
|
folded_width_mm REAL,
|
||||||
|
folded_height_mm REAL,
|
||||||
|
warhead_type TEXT,
|
||||||
|
launch_mode TEXT,
|
||||||
|
power_system TEXT,
|
||||||
|
guidance_system TEXT,
|
||||||
|
engine_power_kw REAL,
|
||||||
|
engine_thrust_n REAL,
|
||||||
|
datalink_range_km REAL,
|
||||||
|
guidance_accuracy_m REAL,
|
||||||
|
min_altitude_m REAL,
|
||||||
|
max_altitude_m REAL,
|
||||||
|
length_width_ratio REAL,
|
||||||
|
weight_range_ratio REAL,
|
||||||
|
speed_weight_ratio REAL,
|
||||||
|
guidance_system_score INTEGER,
|
||||||
|
warhead_power_score INTEGER,
|
||||||
|
warhead_type_code INTEGER,
|
||||||
|
launch_mode_code INTEGER,
|
||||||
|
power_system_code INTEGER,
|
||||||
|
guidance_system_code INTEGER,
|
||||||
|
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS feature_encoding (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
feature_type TEXT,
|
||||||
|
feature_value TEXT,
|
||||||
|
code INTEGER,
|
||||||
|
UNIQUE(feature_type, feature_value)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS cost_data (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
equipment_id INTEGER,
|
||||||
|
actual_cost REAL,
|
||||||
|
predicted_cost REAL,
|
||||||
|
prediction_date TEXT,
|
||||||
|
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS custom_params (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
equipment_id INTEGER,
|
||||||
|
param_name TEXT,
|
||||||
|
param_value TEXT,
|
||||||
|
param_unit TEXT,
|
||||||
|
description TEXT,
|
||||||
|
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS datasets (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
equipment_type TEXT NOT NULL,
|
||||||
|
purpose TEXT NOT NULL,
|
||||||
|
created_at TEXT DEFAULT (datetime('now','localtime')),
|
||||||
|
updated_at TEXT DEFAULT (datetime('now','localtime'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS dataset_equipments (
|
||||||
|
dataset_id INTEGER NOT NULL,
|
||||||
|
equipment_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (dataset_id, equipment_id),
|
||||||
|
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
|
||||||
|
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS trained_models (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
model_name TEXT NOT NULL,
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
equipment_type TEXT NOT NULL,
|
||||||
|
model_path TEXT NOT NULL,
|
||||||
|
scaler_path TEXT NOT NULL,
|
||||||
|
r2_score REAL,
|
||||||
|
mae REAL,
|
||||||
|
rmse REAL,
|
||||||
|
feature_importance TEXT,
|
||||||
|
training_data_size INTEGER,
|
||||||
|
training_date TEXT DEFAULT (datetime('now','localtime')),
|
||||||
|
is_active INTEGER DEFAULT 0,
|
||||||
|
created_by TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS manufacturers (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
country TEXT NOT NULL,
|
||||||
|
tech_level INTEGER NOT NULL,
|
||||||
|
scale_level INTEGER NOT NULL,
|
||||||
|
supply_chain_level INTEGER NOT NULL,
|
||||||
|
created_at TEXT DEFAULT (datetime('now','localtime')),
|
||||||
|
updated_at TEXT DEFAULT (datetime('now','localtime')),
|
||||||
|
UNIQUE(name)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 索引
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_equipment_type ON equipments(type);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_equipment_name ON equipments(name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cost_data_equipment ON cost_data(equipment_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_equipment_type ON trained_models(equipment_type);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_active ON trained_models(is_active);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_manufacturer_country ON manufacturers(country);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_manufacturer_tech_level ON manufacturers(tech_level);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_manufacturer_scale_level ON manufacturers(scale_level);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_equipment_manufacturer ON equipments(manufacturer_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""初始化数据库:确保数据库文件和表存在"""
|
||||||
|
os.makedirs(DB_DIR, exist_ok=True)
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.executescript(SCHEMA_SQL)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logger.info(f"Database initialized at {DB_PATH}")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_connection():
|
def get_db_connection():
|
||||||
"""
|
"""
|
||||||
数据库连接上下文管理器
|
数据库连接上下文管理器
|
||||||
|
返回的 connection 已设置 dict row_factory,
|
||||||
|
以便按列名访问。
|
||||||
"""
|
"""
|
||||||
connection = None
|
conn = None
|
||||||
try:
|
try:
|
||||||
connection = mysql.connector.connect(
|
# 确保数据库已初始化
|
||||||
host=os.getenv('MYSQL_HOST', 'localhost'),
|
if not os.path.exists(DB_PATH):
|
||||||
user=os.getenv('MYSQL_USER', 'root'),
|
logger.info("Database file not found, initializing...")
|
||||||
password=os.getenv('MYSQL_PASSWORD', '123456'),
|
init_db()
|
||||||
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
|
|
||||||
)
|
|
||||||
logger.info("Database connection established")
|
|
||||||
yield connection
|
|
||||||
|
|
||||||
except Error as e:
|
conn = sqlite3.connect(DB_PATH)
|
||||||
logger.error(f"Error connecting to MySQL: {str(e)}")
|
conn.row_factory = lambda c, r: {col[0]: r[idx] for idx, col in enumerate(c.description)}
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
logger.debug("Database connection established")
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
logger.error(f"Database error: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if connection and connection.is_connected():
|
if conn:
|
||||||
connection.close()
|
conn.close()
|
||||||
logger.info("Database connection closed")
|
logger.debug("Database connection closed")
|
||||||
|
|||||||
@ -267,11 +267,15 @@ class FeatureAnalysis:
|
|||||||
def calculate_manufacturer_features(self, manufacturer_data):
|
def calculate_manufacturer_features(self, manufacturer_data):
|
||||||
"""计算生产商相关的特征"""
|
"""计算生产商相关的特征"""
|
||||||
try:
|
try:
|
||||||
# 确保所有必要的字段都存在,使用默认值处理缺失数据
|
# 处理 None 值(数据库 NULL),使用默认值
|
||||||
tech_level = float(manufacturer_data.get('tech_level', 0))
|
raw_tech = manufacturer_data.get('tech_level')
|
||||||
scale_level = float(manufacturer_data.get('scale_level', 0))
|
raw_scale = manufacturer_data.get('scale_level')
|
||||||
supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0))
|
raw_supply = manufacturer_data.get('supply_chain_level')
|
||||||
country = manufacturer_data.get('country', '未知')
|
|
||||||
|
tech_level = float(raw_tech) if raw_tech is not None else 0
|
||||||
|
scale_level = float(raw_scale) if raw_scale is not None else 0
|
||||||
|
supply_chain_level = float(raw_supply) if raw_supply is not None else 0
|
||||||
|
country = manufacturer_data.get('country', '未知') or '未知'
|
||||||
|
|
||||||
# 计算综合得分
|
# 计算综合得分
|
||||||
composite_score = (
|
composite_score = (
|
||||||
|
|||||||
@ -27,7 +27,7 @@ def import_training_data(excel_file):
|
|||||||
# 检查是否已存在相同名称的装备
|
# 检查是否已存在相同名称的装备
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT id FROM equipments
|
SELECT id FROM equipments
|
||||||
WHERE name = %s AND type = '火箭炮'
|
WHERE name = ? AND type = '火箭炮'
|
||||||
""", (row['名称'],))
|
""", (row['名称'],))
|
||||||
|
|
||||||
existing_equipment = cursor.fetchone()
|
existing_equipment = cursor.fetchone()
|
||||||
@ -38,7 +38,7 @@ def import_training_data(excel_file):
|
|||||||
# 插入基本信息
|
# 插入基本信息
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO equipments (name, type, manufacturer)
|
INSERT INTO equipments (name, type, manufacturer)
|
||||||
VALUES (%s, %s, %s)
|
VALUES (?, ?, ?)
|
||||||
""", (row['名称'], '火箭炮', row['制造商']))
|
""", (row['名称'], '火箭炮', row['制造商']))
|
||||||
|
|
||||||
equipment_id = cursor.lastrowid
|
equipment_id = cursor.lastrowid
|
||||||
@ -47,7 +47,7 @@ def import_training_data(excel_file):
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO common_params
|
INSERT INTO common_params
|
||||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
equipment_id,
|
equipment_id,
|
||||||
row['总长_m'] if pd.notna(row['总长_m']) else None,
|
row['总长_m'] if pd.notna(row['总长_m']) else None,
|
||||||
@ -65,7 +65,7 @@ def import_training_data(excel_file):
|
|||||||
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
|
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
|
||||||
structure_layout, engine_model, engine_params, power_hp,
|
structure_layout, engine_model, engine_params, power_hp,
|
||||||
travel_range_km)
|
travel_range_km)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
equipment_id,
|
equipment_id,
|
||||||
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
|
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
|
||||||
@ -89,7 +89,7 @@ def import_training_data(excel_file):
|
|||||||
if pd.notna(row['成本_元']):
|
if pd.notna(row['成本_元']):
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||||
VALUES (%s, %s)
|
VALUES (?, ?)
|
||||||
""", (equipment_id, row['成本_元']))
|
""", (equipment_id, row['成本_元']))
|
||||||
|
|
||||||
logger.info("火箭炮数据导入完成")
|
logger.info("火箭炮数据导入完成")
|
||||||
@ -105,8 +105,8 @@ def import_training_data(excel_file):
|
|||||||
equipment_names.add(row['名称'])
|
equipment_names.add(row['名称'])
|
||||||
# 检查是否已存在相同名称的装备
|
# 检查是否已存在相同名称的装备
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT id FROM equipment
|
SELECT id FROM equipments
|
||||||
WHERE name = %s AND type = '巡飞弹'
|
WHERE name = ? AND type = '巡飞弹'
|
||||||
""", (row['名称'],))
|
""", (row['名称'],))
|
||||||
|
|
||||||
existing_equipment = cursor.fetchone()
|
existing_equipment = cursor.fetchone()
|
||||||
@ -117,7 +117,7 @@ def import_training_data(excel_file):
|
|||||||
# 插入基本信息
|
# 插入基本信息
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO equipments (name, type, manufacturer)
|
INSERT INTO equipments (name, type, manufacturer)
|
||||||
VALUES (%s, %s, %s)
|
VALUES (?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
row['名称'],
|
row['名称'],
|
||||||
'巡飞弹',
|
'巡飞弹',
|
||||||
@ -130,7 +130,7 @@ def import_training_data(excel_file):
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO common_params
|
INSERT INTO common_params
|
||||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
equipment_id,
|
equipment_id,
|
||||||
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
|
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
|
||||||
@ -147,7 +147,7 @@ def import_training_data(excel_file):
|
|||||||
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
|
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
|
||||||
folded_length_mm, folded_width_mm, folded_height_mm,
|
folded_length_mm, folded_width_mm, folded_height_mm,
|
||||||
power_system, guidance_system)
|
power_system, guidance_system)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
equipment_id,
|
equipment_id,
|
||||||
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
|
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
|
||||||
@ -168,7 +168,7 @@ def import_training_data(excel_file):
|
|||||||
if pd.notna(row['成本_元']):
|
if pd.notna(row['成本_元']):
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||||
VALUES (%s, %s)
|
VALUES (?, ?)
|
||||||
""", (equipment_id, float(row['成本_元'])))
|
""", (equipment_id, float(row['成本_元'])))
|
||||||
|
|
||||||
logger.info("巡飞弹数据导入完成")
|
logger.info("巡飞弹数据导入完成")
|
||||||
@ -190,9 +190,9 @@ def import_training_data(excel_file):
|
|||||||
|
|
||||||
# 获取装备ID - 使用新的游标
|
# 获取装备ID - 使用新的游标
|
||||||
logger.debug(f"查询装备ID: {equipment_name}")
|
logger.debug(f"查询装备ID: {equipment_name}")
|
||||||
with conn.cursor() as id_cursor:
|
id_cursor = conn.cursor()
|
||||||
id_cursor.execute("""
|
id_cursor.execute("""
|
||||||
SELECT id FROM equipments WHERE name = %s
|
SELECT id FROM equipments WHERE name = ?
|
||||||
""", (equipment_name,))
|
""", (equipment_name,))
|
||||||
result = id_cursor.fetchone()
|
result = id_cursor.fetchone()
|
||||||
|
|
||||||
@ -200,15 +200,15 @@ def import_training_data(excel_file):
|
|||||||
logger.warning(f"未找到装备: {equipment_name}")
|
logger.warning(f"未找到装备: {equipment_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
equipment_id = result[0]
|
equipment_id = result['id']
|
||||||
logger.debug(f"找到装备ID: {equipment_id}")
|
logger.debug(f"找到装备ID: {equipment_id}")
|
||||||
|
|
||||||
# 检查参数是否存在 - 使用新的游标
|
# 检查参数是否存在 - 使用新的游标
|
||||||
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
|
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
|
||||||
with conn.cursor() as check_cursor:
|
check_cursor = conn.cursor()
|
||||||
check_cursor.execute("""
|
check_cursor.execute("""
|
||||||
SELECT id FROM custom_params
|
SELECT id FROM custom_params
|
||||||
WHERE equipment_id = %s AND param_name = %s
|
WHERE equipment_id = ? AND param_name = ?
|
||||||
""", (equipment_id, param_name))
|
""", (equipment_id, param_name))
|
||||||
exists = check_cursor.fetchone()
|
exists = check_cursor.fetchone()
|
||||||
|
|
||||||
@ -222,11 +222,11 @@ def import_training_data(excel_file):
|
|||||||
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
|
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
|
||||||
|
|
||||||
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
|
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
|
||||||
with conn.cursor() as insert_cursor:
|
insert_cursor = conn.cursor()
|
||||||
insert_cursor.execute("""
|
insert_cursor.execute("""
|
||||||
INSERT INTO custom_params
|
INSERT INTO custom_params
|
||||||
(equipment_id, param_name, param_value, param_unit, description)
|
(equipment_id, param_name, param_value, param_unit, description)
|
||||||
VALUES (%s, %s, %s, %s, %s)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
equipment_id,
|
equipment_id,
|
||||||
param_name,
|
param_name,
|
||||||
|
|||||||
237
src/import_sql_data.py
Normal file
237
src/import_sql_data.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
"""
|
||||||
|
导入 SQL 数据文件到 SQLite 数据库
|
||||||
|
处理 src/ 目录下的 MySQL 格式 SQL 数据文件:
|
||||||
|
- rocket_artillery_data.sql (96 条火箭炮数据)
|
||||||
|
- loitering_munition_data.sql (100 条巡飞弹数据)
|
||||||
|
- manufacturer_data.sql (生产商数据)
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from src.database.db_connection import get_db_connection, DB_PATH
|
||||||
|
from src.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_mysql_insert(sql_text):
|
||||||
|
"""
|
||||||
|
解析 MySQL INSERT 语句,返回 (table_name, columns, values_list)
|
||||||
|
columns 是列名列表, values_list 是每行的值列表
|
||||||
|
"""
|
||||||
|
# 去掉块注释 /* ... */
|
||||||
|
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
|
||||||
|
# 去掉行注释 --
|
||||||
|
sql_text = re.sub(r'--.*', '', sql_text)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 匹配 INSERT INTO table (columns) VALUES (values), (values), ...;
|
||||||
|
pattern = re.compile(
|
||||||
|
r"INSERT\s+INTO\s+(\w+)\s*\(([^)]+)\)\s*VALUES\s*(.+?);",
|
||||||
|
re.DOTALL | re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
for match in pattern.finditer(sql_text):
|
||||||
|
table_name = match.group(1).strip()
|
||||||
|
columns_str = match.group(2).strip()
|
||||||
|
values_str = match.group(3).strip()
|
||||||
|
|
||||||
|
# 提取列名(去掉注释部分)
|
||||||
|
columns = []
|
||||||
|
for col in columns_str.split(','):
|
||||||
|
col = col.strip()
|
||||||
|
# 去掉行内注释
|
||||||
|
col = re.sub(r'\s*--.*', '', col).strip()
|
||||||
|
columns.append(col)
|
||||||
|
|
||||||
|
# 解析 values - 处理多个用逗号分隔的 ( ... ) 元组
|
||||||
|
values_list = []
|
||||||
|
current_tuple = []
|
||||||
|
depth = 0
|
||||||
|
current_val = ''
|
||||||
|
in_string = False
|
||||||
|
string_char = None
|
||||||
|
|
||||||
|
for ch in values_str:
|
||||||
|
if in_string:
|
||||||
|
if ch == string_char:
|
||||||
|
in_string = False
|
||||||
|
current_val += ch
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch in ("'", '"'):
|
||||||
|
in_string = True
|
||||||
|
string_char = ch
|
||||||
|
current_val += ch
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == '(':
|
||||||
|
if depth > 0:
|
||||||
|
current_val += ch
|
||||||
|
depth += 1
|
||||||
|
if depth == 1:
|
||||||
|
current_val = ''
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == ')':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
current_tuple.append(current_val.strip())
|
||||||
|
values_list.append(current_tuple)
|
||||||
|
current_tuple = []
|
||||||
|
else:
|
||||||
|
current_val += ch
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == ',' and depth == 1:
|
||||||
|
current_tuple.append(current_val.strip())
|
||||||
|
current_val = ''
|
||||||
|
continue
|
||||||
|
|
||||||
|
if depth >= 1:
|
||||||
|
current_val += ch
|
||||||
|
|
||||||
|
results.append((table_name, columns, values_list))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def convert_value(val):
|
||||||
|
"""将 MySQL 值字符串转为 Python 对象"""
|
||||||
|
val = val.strip()
|
||||||
|
if val.upper() == 'NULL' or val == '':
|
||||||
|
return None
|
||||||
|
if val.upper() == 'TRUE':
|
||||||
|
return 1
|
||||||
|
if val.upper() == 'FALSE':
|
||||||
|
return 0
|
||||||
|
# 字符串
|
||||||
|
if (val.startswith("'") and val.endswith("'")) or \
|
||||||
|
(val.startswith('"') and val.endswith('"')):
|
||||||
|
return val[1:-1]
|
||||||
|
# 数字
|
||||||
|
try:
|
||||||
|
if '.' in val:
|
||||||
|
return float(val)
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def import_sql_file(filepath):
|
||||||
|
"""导入单个 SQL 文件"""
|
||||||
|
logger.info(f"Reading {filepath}...")
|
||||||
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
sql_text = f.read()
|
||||||
|
|
||||||
|
parsed = parse_mysql_insert(sql_text)
|
||||||
|
total_rows = 0
|
||||||
|
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
for table_name, columns, values_list in parsed:
|
||||||
|
if not values_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建参数化 INSERT
|
||||||
|
placeholders = ','.join(['?'] * len(columns))
|
||||||
|
col_names = ','.join(columns)
|
||||||
|
sql = f"INSERT OR IGNORE INTO {table_name} ({col_names}) VALUES ({placeholders})"
|
||||||
|
|
||||||
|
row_count = 0
|
||||||
|
for values in values_list:
|
||||||
|
if len(values) != len(columns):
|
||||||
|
logger.warning(f"Column mismatch in {table_name}: "
|
||||||
|
f"expected {len(columns)} values, got {len(values)}: {values}")
|
||||||
|
continue
|
||||||
|
converted = [convert_value(v) for v in values]
|
||||||
|
try:
|
||||||
|
cursor.execute(sql, converted)
|
||||||
|
if cursor.rowcount > 0:
|
||||||
|
row_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error inserting into {table_name}: {e}")
|
||||||
|
logger.warning(f" Values: {converted}")
|
||||||
|
|
||||||
|
if row_count > 0:
|
||||||
|
logger.info(f" {table_name}: inserted {row_count} rows")
|
||||||
|
total_rows += row_count
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return total_rows
|
||||||
|
|
||||||
|
|
||||||
|
def run_manufacturer_update():
|
||||||
|
"""执行 manufacturer_id 更新(把 equipments.manufacturer 名称映射为 manufacturers.id)"""
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
UPDATE equipments
|
||||||
|
SET manufacturer_id = (
|
||||||
|
SELECT id FROM manufacturers WHERE name = equipments.manufacturer
|
||||||
|
)
|
||||||
|
WHERE manufacturer_id IS NULL
|
||||||
|
AND manufacturer IS NOT NULL
|
||||||
|
AND EXISTS (SELECT 1 FROM manufacturers WHERE name = equipments.manufacturer)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
updated = cursor.rowcount
|
||||||
|
if updated > 0:
|
||||||
|
logger.info(f"Updated manufacturer_id for {updated} equipment(s)")
|
||||||
|
else:
|
||||||
|
logger.info("No manufacturer_id updates needed")
|
||||||
|
|
||||||
|
|
||||||
|
def import_all():
|
||||||
|
"""导入所有 SQL 数据文件"""
|
||||||
|
base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
# 清除现有数据库,重新开始
|
||||||
|
if os.path.exists(DB_PATH):
|
||||||
|
os.remove(DB_PATH)
|
||||||
|
logger.info(f"Removed existing database: {DB_PATH}")
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(os.path.join(base_dir, 'src', 'manufacturer_data.sql'), '生产商数据'),
|
||||||
|
(os.path.join(base_dir, 'src', 'rocket_artillery_data.sql'), '火箭炮数据'),
|
||||||
|
(os.path.join(base_dir, 'src', 'loitering_munition_data.sql'), '巡飞弹数据'),
|
||||||
|
]
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for filepath, label in files:
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
logger.warning(f"File not found: {filepath}")
|
||||||
|
continue
|
||||||
|
logger.info(f"正在导入 {label}...")
|
||||||
|
rows = import_sql_file(filepath)
|
||||||
|
logger.info(f" {label}: 共导入 {rows} 行")
|
||||||
|
total += rows
|
||||||
|
|
||||||
|
# 更新 manufacturer_id
|
||||||
|
run_manufacturer_update()
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT COUNT(*) as cnt FROM equipments")
|
||||||
|
eq_count = cursor.fetchone()['cnt']
|
||||||
|
cursor.execute("SELECT type, COUNT(*) as cnt FROM equipments GROUP BY type")
|
||||||
|
by_type = {r['type']: r['cnt'] for r in cursor.fetchall()}
|
||||||
|
cursor.execute("SELECT COUNT(*) as cnt FROM manufacturers")
|
||||||
|
mf_count = cursor.fetchone()['cnt']
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info(f"导入完成!")
|
||||||
|
logger.info(f" 装备总计: {eq_count}")
|
||||||
|
for t, c in by_type.items():
|
||||||
|
logger.info(f" {t}: {c}")
|
||||||
|
logger.info(f" 生产商: {mf_count}")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
return eq_count
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import_all()
|
||||||
@ -1,7 +1,4 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
||||||
@ -11,40 +8,43 @@ from datetime import datetime
|
|||||||
import json
|
import json
|
||||||
from src.feature_analysis import FeatureAnalysis
|
from src.feature_analysis import FeatureAnalysis
|
||||||
from src.database import get_db_connection
|
from src.database import get_db_connection
|
||||||
from src.data_preparation import DataPreparation, EquipmentDataset
|
|
||||||
from .logger import setup_logger
|
from .logger import setup_logger
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
# PyTorch 为可选依赖
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from src.data_preparation import EquipmentDataset
|
||||||
|
_HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
nn = None
|
||||||
|
DataLoader = None
|
||||||
|
EquipmentDataset = None
|
||||||
|
_HAS_TORCH = False
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
class CostPredictionModel(nn.Module):
|
# 条件基类:有 PyTorch 时继承 nn.Module,否则继承 object
|
||||||
|
_PYTORCH_BASE = nn.Module if _HAS_TORCH else object
|
||||||
|
|
||||||
|
class CostPredictionModel(_PYTORCH_BASE):
|
||||||
def __init__(self, input_size, equipment_type):
|
def __init__(self, input_size, equipment_type):
|
||||||
|
if not _HAS_TORCH:
|
||||||
|
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.equipment_type = equipment_type
|
self.equipment_type = equipment_type
|
||||||
|
|
||||||
if equipment_type == '火箭炮':
|
if equipment_type == '火箭炮':
|
||||||
# 火箭炮使用更简单和稳定的网络结构
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
# 第一层:特征映射
|
nn.Linear(input_size, 32), nn.ReLU(), nn.BatchNorm1d(32),
|
||||||
nn.Linear(input_size, 32),
|
nn.Linear(32, 16), nn.ReLU(), nn.BatchNorm1d(16),
|
||||||
nn.ReLU(),
|
nn.Linear(16, 8), nn.ReLU(), nn.BatchNorm1d(8),
|
||||||
nn.BatchNorm1d(32),
|
|
||||||
|
|
||||||
# 第二层:特征提取
|
|
||||||
nn.Linear(32, 16),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.BatchNorm1d(16),
|
|
||||||
|
|
||||||
# 第三层:特征整合
|
|
||||||
nn.Linear(16, 8),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.BatchNorm1d(8),
|
|
||||||
|
|
||||||
# 输出层
|
|
||||||
nn.Linear(8, 1)
|
nn.Linear(8, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用正交初始化
|
|
||||||
def init_weights(m):
|
def init_weights(m):
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
torch.nn.init.orthogonal_(m.weight, gain=0.5)
|
torch.nn.init.orthogonal_(m.weight, gain=0.5)
|
||||||
@ -54,45 +54,19 @@ class CostPredictionModel(nn.Module):
|
|||||||
torch.nn.init.constant_(m.bias, 0.0)
|
torch.nn.init.constant_(m.bias, 0.0)
|
||||||
|
|
||||||
self.net.apply(init_weights)
|
self.net.apply(init_weights)
|
||||||
|
else:
|
||||||
else: # 巡飞弹保持原有结构
|
|
||||||
# 生产商特征网络 - 更简单的结构
|
|
||||||
self.manufacturer_net = nn.Sequential(
|
self.manufacturer_net = nn.Sequential(
|
||||||
nn.Linear(5, 4),
|
nn.Linear(5, 4), nn.ReLU(), nn.BatchNorm1d(4), nn.Dropout(0.2)
|
||||||
nn.ReLU(),
|
|
||||||
nn.BatchNorm1d(4),
|
|
||||||
nn.Dropout(0.2)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 巡飞弹特征网络 - 较深的结构
|
|
||||||
self.equipment_net = nn.Sequential(
|
self.equipment_net = nn.Sequential(
|
||||||
nn.Linear(input_size - 5, 64),
|
nn.Linear(input_size - 5, 64), nn.LeakyReLU(0.1), nn.BatchNorm1d(64), nn.Dropout(0.2),
|
||||||
nn.LeakyReLU(0.1),
|
nn.Linear(64, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
|
||||||
nn.BatchNorm1d(64),
|
nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2)
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.Linear(64, 32),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.BatchNorm1d(32),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.Linear(32, 16),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.BatchNorm1d(16),
|
|
||||||
nn.Dropout(0.2)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 合并网络 - 较复杂的结构
|
|
||||||
self.combined_net = nn.Sequential(
|
self.combined_net = nn.Sequential(
|
||||||
nn.Linear(20, 32), # 4 + 16 = 20
|
nn.Linear(20, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
|
||||||
nn.LeakyReLU(0.1),
|
nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2),
|
||||||
nn.BatchNorm1d(32),
|
nn.Linear(16, 8), nn.LeakyReLU(0.1), nn.BatchNorm1d(8),
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.Linear(32, 16),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.BatchNorm1d(16),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.Linear(16, 8),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.BatchNorm1d(8),
|
|
||||||
nn.Linear(8, 1)
|
nn.Linear(8, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,20 +74,17 @@ class CostPredictionModel(nn.Module):
|
|||||||
if self.equipment_type == '火箭炮':
|
if self.equipment_type == '火箭炮':
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
else:
|
else:
|
||||||
# 分离特征
|
|
||||||
manufacturer_features = x[:, -5:]
|
manufacturer_features = x[:, -5:]
|
||||||
equipment_features = x[:, :-5]
|
equipment_features = x[:, :-5]
|
||||||
|
|
||||||
# 特征处理
|
|
||||||
manu_out = self.manufacturer_net(manufacturer_features)
|
manu_out = self.manufacturer_net(manufacturer_features)
|
||||||
equip_out = self.equipment_net(equipment_features)
|
equip_out = self.equipment_net(equipment_features)
|
||||||
|
|
||||||
# 特征融合
|
|
||||||
combined = torch.cat([equip_out, manu_out], dim=1)
|
combined = torch.cat([equip_out, manu_out], dim=1)
|
||||||
return self.combined_net(combined)
|
return self.combined_net(combined)
|
||||||
|
|
||||||
class ModelTrainer:
|
class ModelTrainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
if not _HAS_TORCH:
|
||||||
|
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.model = None
|
self.model = None
|
||||||
self.feature_scaler = None
|
self.feature_scaler = None
|
||||||
@ -307,7 +278,7 @@ class ModelTrainer:
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE trained_models
|
UPDATE trained_models
|
||||||
SET is_active = FALSE
|
SET is_active = FALSE
|
||||||
WHERE equipment_type = %s AND model_type != %s
|
WHERE equipment_type = ? AND model_type != ?
|
||||||
""", (equipment_type, 'pls'))
|
""", (equipment_type, 'pls'))
|
||||||
|
|
||||||
# 保存新模型记录
|
# 保存新模型记录
|
||||||
@ -316,7 +287,7 @@ class ModelTrainer:
|
|||||||
model_name, model_type, equipment_type, model_path,
|
model_name, model_type, equipment_type, model_path,
|
||||||
scaler_path, training_date, is_active, created_by,
|
scaler_path, training_date, is_active, created_by,
|
||||||
r2_score, mae, rmse
|
r2_score, mae, rmse
|
||||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
|
) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
f"{equipment_type}_{timestamp}",
|
f"{equipment_type}_{timestamp}",
|
||||||
'pytorch',
|
'pytorch',
|
||||||
@ -343,10 +314,10 @@ class ModelTrainer:
|
|||||||
try:
|
try:
|
||||||
# 从数据库获取最新的激活模型
|
# 从数据库获取最新的激活模型
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT * FROM trained_models
|
SELECT * FROM trained_models
|
||||||
WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE
|
WHERE equipment_type = ? AND model_type = ? AND is_active = TRUE
|
||||||
ORDER BY training_date DESC LIMIT 1
|
ORDER BY training_date DESC LIMIT 1
|
||||||
""", (equipment_type, model_type))
|
""", (equipment_type, model_type))
|
||||||
model_record = cursor.fetchone()
|
model_record = cursor.fetchone()
|
||||||
@ -707,13 +678,13 @@ class ModelTrainer:
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE trained_models
|
UPDATE trained_models
|
||||||
SET is_active = FALSE
|
SET is_active = FALSE
|
||||||
WHERE equipment_type = %s AND model_type != %s
|
WHERE equipment_type = ? AND model_type != ?
|
||||||
""", (equipment_type, 'pls'))
|
""", (equipment_type, 'pls'))
|
||||||
else:
|
else:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE trained_models
|
UPDATE trained_models
|
||||||
SET is_active = FALSE
|
SET is_active = FALSE
|
||||||
WHERE equipment_type = %s AND model_type = %s
|
WHERE equipment_type = ? AND model_type = ?
|
||||||
""", (equipment_type, 'pls'))
|
""", (equipment_type, 'pls'))
|
||||||
|
|
||||||
# 保存新模型记录
|
# 保存新模型记录
|
||||||
@ -722,7 +693,7 @@ class ModelTrainer:
|
|||||||
model_name, model_type, equipment_type, model_path,
|
model_name, model_type, equipment_type, model_path,
|
||||||
scaler_path, training_date, is_active, created_by,
|
scaler_path, training_date, is_active, created_by,
|
||||||
r2_score, mae, rmse
|
r2_score, mae, rmse
|
||||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
|
) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
|
||||||
""", (
|
""", (
|
||||||
f"{equipment_type}_{model_type}_{timestamp}",
|
f"{equipment_type}_{model_type}_{timestamp}",
|
||||||
model_type,
|
model_type,
|
||||||
|
|||||||
338
src/routes.py
338
src/routes.py
@ -4,16 +4,23 @@ from .feature_analysis import FeatureAnalysis
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mysql.connector
|
|
||||||
from sklearn.metrics import mean_absolute_error
|
from sklearn.metrics import mean_absolute_error
|
||||||
from .create_template import create_excel_template
|
from .create_template import create_excel_template
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from .data_preparation import DataPreparation
|
from .data_preparation import DataPreparation
|
||||||
from .model_trainer import ModelTrainer
|
from .model_trainer import ModelTrainer
|
||||||
|
from .database import get_db_connection
|
||||||
from .demo_service import DemoModelService
|
from .demo_service import DemoModelService
|
||||||
from .logger import setup_logger
|
from .logger import setup_logger
|
||||||
|
|
||||||
|
# PyTorch 为可选导入
|
||||||
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
_HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_TORCH = False
|
||||||
|
|
||||||
# 创建蓝图
|
# 创建蓝图
|
||||||
api_bp = Blueprint('api', __name__)
|
api_bp = Blueprint('api', __name__)
|
||||||
@ -71,7 +78,7 @@ def demo_algorithms():
|
|||||||
|
|
||||||
@api_bp.route('/demo/dataset', methods=['GET'])
|
@api_bp.route('/demo/dataset', methods=['GET'])
|
||||||
def demo_dataset():
|
def demo_dataset():
|
||||||
"""Return the local demo dataset summary without using MySQL."""
|
"""Return the local demo dataset summary without using a database."""
|
||||||
try:
|
try:
|
||||||
service = DemoModelService()
|
service = DemoModelService()
|
||||||
return jsonify(service.get_dataset_summary())
|
return jsonify(service.get_dataset_summary())
|
||||||
@ -101,11 +108,11 @@ def predict():
|
|||||||
|
|
||||||
# 获取最新的激活模型(非PLS模型)
|
# 获取最新的激活模型(非PLS模型)
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT * FROM trained_models
|
SELECT * FROM trained_models
|
||||||
WHERE equipment_type = %s
|
WHERE equipment_type = ?
|
||||||
AND model_type != 'pls' # 明确排除PLS模型
|
AND model_type != 'pls'
|
||||||
AND is_active = TRUE
|
AND is_active = TRUE
|
||||||
ORDER BY training_date DESC LIMIT 1
|
ORDER BY training_date DESC LIMIT 1
|
||||||
""", (equipment_type,))
|
""", (equipment_type,))
|
||||||
@ -147,14 +154,14 @@ def analyze_features():
|
|||||||
return jsonify({'error': '请选择数据集'}), 400
|
return jsonify({'error': '请选择数据集'}), 400
|
||||||
|
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 首先获取数据集的装备类型
|
# 首先获取数据集的装备类型
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT DISTINCT e.type
|
SELECT DISTINCT e.type
|
||||||
FROM equipments e
|
FROM equipments e
|
||||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""", (dataset_id,))
|
""", (dataset_id,))
|
||||||
|
|
||||||
@ -181,7 +188,7 @@ def analyze_features():
|
|||||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
AND cd.actual_cost IS NOT NULL
|
AND cd.actual_cost IS NOT NULL
|
||||||
""", (dataset_id,))
|
""", (dataset_id,))
|
||||||
else:
|
else:
|
||||||
@ -205,7 +212,7 @@ def analyze_features():
|
|||||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
AND cd.actual_cost IS NOT NULL
|
AND cd.actual_cost IS NOT NULL
|
||||||
""", (dataset_id,))
|
""", (dataset_id,))
|
||||||
|
|
||||||
@ -247,27 +254,27 @@ def analyze_features():
|
|||||||
# 添加装备特有的分析数据
|
# 添加装备特有的分析数据
|
||||||
if equipment_data[0]['type'] == '火箭炮':
|
if equipment_data[0]['type'] == '火箭炮':
|
||||||
rocket_data = {
|
rocket_data = {
|
||||||
'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data],
|
'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
|
||||||
'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data],
|
'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
|
||||||
'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data],
|
'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
|
||||||
'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data],
|
'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
|
||||||
'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data],
|
'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
|
||||||
'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data]
|
'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data]
|
||||||
}
|
}
|
||||||
analysis_result.update(rocket_data)
|
analysis_result.update(rocket_data)
|
||||||
else:
|
else:
|
||||||
missile_data = {
|
missile_data = {
|
||||||
'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data],
|
'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data],
|
||||||
'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data],
|
'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data],
|
||||||
'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data],
|
'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data],
|
||||||
'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data],
|
'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data],
|
||||||
'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data],
|
'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data],
|
||||||
'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data],
|
'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item['guidance_accuracy_m'] is not None else 0 for item in equipment_data],
|
||||||
'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data],
|
'datalink_range_km': [float(item['datalink_range_km']) if item['datalink_range_km'] is not None else 0 for item in equipment_data],
|
||||||
'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data],
|
'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data],
|
||||||
'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data],
|
'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data],
|
||||||
'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data],
|
'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data],
|
||||||
'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data]
|
'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data]
|
||||||
}
|
}
|
||||||
analysis_result.update(missile_data)
|
analysis_result.update(missile_data)
|
||||||
|
|
||||||
@ -295,7 +302,7 @@ def train_model():
|
|||||||
|
|
||||||
# 获取训练数据
|
# 获取训练数据
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 获取训练集数据(包含生产商信息)
|
# 获取训练集数据(包含生产商信息)
|
||||||
if equipment_type == '火箭炮':
|
if equipment_type == '火箭炮':
|
||||||
@ -309,7 +316,7 @@ def train_model():
|
|||||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
AND cd.actual_cost IS NOT NULL
|
AND cd.actual_cost IS NOT NULL
|
||||||
""", (train_dataset_id,))
|
""", (train_dataset_id,))
|
||||||
else:
|
else:
|
||||||
@ -323,7 +330,7 @@ def train_model():
|
|||||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
AND cd.actual_cost IS NOT NULL
|
AND cd.actual_cost IS NOT NULL
|
||||||
""", (train_dataset_id,))
|
""", (train_dataset_id,))
|
||||||
|
|
||||||
@ -342,7 +349,7 @@ def train_model():
|
|||||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
AND cd.actual_cost IS NOT NULL
|
AND cd.actual_cost IS NOT NULL
|
||||||
""", (validation_dataset_id,))
|
""", (validation_dataset_id,))
|
||||||
else:
|
else:
|
||||||
@ -356,7 +363,7 @@ def train_model():
|
|||||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
AND cd.actual_cost IS NOT NULL
|
AND cd.actual_cost IS NOT NULL
|
||||||
""", (validation_dataset_id,))
|
""", (validation_dataset_id,))
|
||||||
validation_data = cursor.fetchall()
|
validation_data = cursor.fetchall()
|
||||||
@ -476,7 +483,7 @@ def get_equipment_data():
|
|||||||
"""获取装备数据列表"""
|
"""获取装备数据列表"""
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 获取所有装备数据(使用equipment_id替代id)
|
# 获取所有装备数据(使用equipment_id替代id)
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
@ -485,43 +492,39 @@ def get_equipment_data():
|
|||||||
cd.actual_cost, cd.predicted_cost,
|
cd.actual_cost, cd.predicted_cost,
|
||||||
CASE
|
CASE
|
||||||
WHEN e.type = '火箭炮' THEN (
|
WHEN e.type = '火箭炮' THEN (
|
||||||
SELECT CONCAT(
|
SELECT firing_angle_horizontal || ',' ||
|
||||||
firing_angle_horizontal, ',',
|
firing_angle_vertical || ',' ||
|
||||||
firing_angle_vertical, ',',
|
rocket_length_m || ',' ||
|
||||||
rocket_length_m, ',',
|
rocket_diameter_mm || ',' ||
|
||||||
rocket_diameter_mm, ',',
|
rocket_weight_kg || ',' ||
|
||||||
rocket_weight_kg, ',',
|
rate_of_fire || ',' ||
|
||||||
rate_of_fire, ',',
|
combat_weight_kg || ',' ||
|
||||||
combat_weight_kg, ',',
|
speed_kmh || ',' ||
|
||||||
speed_kmh, ',',
|
min_range_km || ',' ||
|
||||||
min_range_km, ',',
|
max_range_km || ',' ||
|
||||||
max_range_km, ',',
|
mobility_type || ',' ||
|
||||||
mobility_type, ',',
|
structure_layout || ',' ||
|
||||||
structure_layout, ',',
|
engine_model || ',' ||
|
||||||
engine_model, ',',
|
engine_params || ',' ||
|
||||||
engine_params, ',',
|
power_hp || ',' ||
|
||||||
power_hp, ',',
|
|
||||||
travel_range_km
|
travel_range_km
|
||||||
)
|
|
||||||
FROM rocket_artillery_params
|
FROM rocket_artillery_params
|
||||||
WHERE equipment_id = e.id
|
WHERE equipment_id = e.id
|
||||||
)
|
)
|
||||||
WHEN e.type = '巡飞弹' THEN (
|
WHEN e.type = '巡飞弹' THEN (
|
||||||
SELECT CONCAT(
|
SELECT wingspan_m || ',' ||
|
||||||
wingspan_m, ',',
|
warhead_weight_kg || ',' ||
|
||||||
warhead_weight_kg, ',',
|
max_speed_ms || ',' ||
|
||||||
max_speed_ms, ',',
|
cruise_speed_kmh || ',' ||
|
||||||
cruise_speed_kmh, ',',
|
endurance_min || ',' ||
|
||||||
endurance_min, ',',
|
max_range_km || ',' ||
|
||||||
max_range_km, ',',
|
max_payload_kg || ',' ||
|
||||||
max_payload_kg, ',',
|
ceiling_altitude_m || ',' ||
|
||||||
ceiling_altitude_m, ',',
|
combat_radius_km || ',' ||
|
||||||
combat_radius_km, ',',
|
warhead_type || ',' ||
|
||||||
warhead_type, ',',
|
launch_mode || ',' ||
|
||||||
launch_mode, ',',
|
power_system || ',' ||
|
||||||
power_system, ',',
|
|
||||||
guidance_system
|
guidance_system
|
||||||
)
|
|
||||||
FROM loitering_munition_params
|
FROM loitering_munition_params
|
||||||
WHERE equipment_id = e.id
|
WHERE equipment_id = e.id
|
||||||
)
|
)
|
||||||
@ -548,19 +551,17 @@ def delete_equipment(id):
|
|||||||
删除装备数据
|
删除装备数据
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = get_db_connection()
|
with get_db_connection() as conn:
|
||||||
cursor = db.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 删除相关数据
|
# 删除相关数据
|
||||||
cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
|
cursor.execute("DELETE FROM cost_data WHERE equipment_id = ?", (id,))
|
||||||
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
|
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = ?", (id,))
|
||||||
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
|
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = ?", (id,))
|
||||||
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
|
cursor.execute("DELETE FROM common_params WHERE equipment_id = ?", (id,))
|
||||||
cursor.execute("DELETE FROM equipments WHERE id = %s", (id,))
|
cursor.execute("DELETE FROM equipments WHERE id = ?", (id,))
|
||||||
|
|
||||||
db.commit()
|
conn.commit()
|
||||||
cursor.close()
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
return jsonify({'status': 'success'})
|
return jsonify({'status': 'success'})
|
||||||
|
|
||||||
@ -594,16 +595,6 @@ def download_template():
|
|||||||
logger.error(f"Error creating template: {str(e)}")
|
logger.error(f"Error creating template: {str(e)}")
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
def get_db_connection():
|
|
||||||
"""
|
|
||||||
获取数据库连接
|
|
||||||
"""
|
|
||||||
return mysql.connector.connect(
|
|
||||||
host="localhost",
|
|
||||||
user="root",
|
|
||||||
password="123456",
|
|
||||||
database="equipment_cost_db"
|
|
||||||
)
|
|
||||||
|
|
||||||
@api_bp.route('/pls/predict', methods=['POST'])
|
@api_bp.route('/pls/predict', methods=['POST'])
|
||||||
def pls_predict():
|
def pls_predict():
|
||||||
@ -614,11 +605,11 @@ def pls_predict():
|
|||||||
|
|
||||||
# 获取最新的PLS模型
|
# 获取最新的PLS模型
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT * FROM trained_models
|
SELECT * FROM trained_models
|
||||||
WHERE equipment_type = %s
|
WHERE equipment_type = ?
|
||||||
AND model_type = 'pls' # 只选择PLS模型
|
AND model_type = 'pls'
|
||||||
AND is_active = TRUE
|
AND is_active = TRUE
|
||||||
ORDER BY training_date DESC LIMIT 1
|
ORDER BY training_date DESC LIMIT 1
|
||||||
""", (equipment_type,))
|
""", (equipment_type,))
|
||||||
@ -696,38 +687,38 @@ def update_equipment(id):
|
|||||||
# 更新装备基本信息
|
# 更新装备基本信息
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE equipments
|
UPDATE equipments
|
||||||
SET name = %s, manufacturer = %s
|
SET name = ?, manufacturer = ?
|
||||||
WHERE id = %s
|
WHERE id = ?
|
||||||
""", (data['name'], data['manufacturer'], equipment_id))
|
""", (data['name'], data['manufacturer'], equipment_id))
|
||||||
|
|
||||||
# 更新通用参数
|
# 更新通用参数
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE common_params
|
UPDATE common_params
|
||||||
SET length_m = %s, width_m = %s, height_m = %s, weight_kg = %s
|
SET length_m = ?, width_m = ?, height_m = ?, weight_kg = ?
|
||||||
WHERE equipment_id = %s
|
WHERE equipment_id = ?
|
||||||
""", (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id))
|
""", (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id))
|
||||||
|
|
||||||
# 根据装备类型更新特有参数
|
# 根据装备类型更新特有参数
|
||||||
if data['type'] == '火箭炮':
|
if data['type'] == '火箭炮':
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE rocket_artillery_params
|
UPDATE rocket_artillery_params
|
||||||
SET firing_angle_horizontal = %s,
|
SET firing_angle_horizontal = ?,
|
||||||
firing_angle_vertical = %s,
|
firing_angle_vertical = ?,
|
||||||
rocket_length_m = %s,
|
rocket_length_m = ?,
|
||||||
rocket_diameter_mm = %s,
|
rocket_diameter_mm = ?,
|
||||||
rocket_weight_kg = %s,
|
rocket_weight_kg = ?,
|
||||||
rate_of_fire = %s,
|
rate_of_fire = ?,
|
||||||
combat_weight_kg = %s,
|
combat_weight_kg = ?,
|
||||||
speed_kmh = %s,
|
speed_kmh = ?,
|
||||||
min_range_km = %s,
|
min_range_km = ?,
|
||||||
max_range_km = %s,
|
max_range_km = ?,
|
||||||
mobility_type = %s,
|
mobility_type = ?,
|
||||||
structure_layout = %s,
|
structure_layout = ?,
|
||||||
engine_model = %s,
|
engine_model = ?,
|
||||||
engine_params = %s,
|
engine_params = ?,
|
||||||
power_hp = %s,
|
power_hp = ?,
|
||||||
travel_range_km = %s
|
travel_range_km = ?
|
||||||
WHERE equipment_id = %s
|
WHERE equipment_id = ?
|
||||||
""", (
|
""", (
|
||||||
data['firing_angle_horizontal'],
|
data['firing_angle_horizontal'],
|
||||||
data['firing_angle_vertical'],
|
data['firing_angle_vertical'],
|
||||||
@ -750,17 +741,17 @@ def update_equipment(id):
|
|||||||
else:
|
else:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE loitering_munition_params
|
UPDATE loitering_munition_params
|
||||||
SET wingspan_m = %s,
|
SET wingspan_m = ?,
|
||||||
warhead_weight_kg = %s,
|
warhead_weight_kg = ?,
|
||||||
max_speed_ms = %s,
|
max_speed_ms = ?,
|
||||||
cruise_speed_kmh = %s,
|
cruise_speed_kmh = ?,
|
||||||
endurance_min = %s,
|
endurance_min = ?,
|
||||||
max_range_km = %s,
|
max_range_km = ?,
|
||||||
warhead_type = %s,
|
warhead_type = ?,
|
||||||
launch_mode = %s,
|
launch_mode = ?,
|
||||||
power_system = %s,
|
power_system = ?,
|
||||||
guidance_system = %s
|
guidance_system = ?
|
||||||
WHERE equipment_id = %s
|
WHERE equipment_id = ?
|
||||||
""", (
|
""", (
|
||||||
data['wingspan_m'],
|
data['wingspan_m'],
|
||||||
data['warhead_weight_kg'],
|
data['warhead_weight_kg'],
|
||||||
@ -779,8 +770,8 @@ def update_equipment(id):
|
|||||||
if 'actual_cost' in data:
|
if 'actual_cost' in data:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE cost_data
|
UPDATE cost_data
|
||||||
SET actual_cost = %s
|
SET actual_cost = ?
|
||||||
WHERE equipment_id = %s
|
WHERE equipment_id = ?
|
||||||
""", (data['actual_cost'], equipment_id))
|
""", (data['actual_cost'], equipment_id))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@ -798,7 +789,7 @@ def get_equipment_details(id):
|
|||||||
logger.info(f"Getting details for equipment ID: {id}")
|
logger.info(f"Getting details for equipment ID: {id}")
|
||||||
|
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 获取装备基本信息类型
|
# 获取装备基本信息类型
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
@ -806,7 +797,7 @@ def get_equipment_details(id):
|
|||||||
FROM equipments e
|
FROM equipments e
|
||||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
WHERE e.id = %s
|
WHERE e.id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
|
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
@ -822,14 +813,14 @@ def get_equipment_details(id):
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM rocket_artillery_params
|
FROM rocket_artillery_params
|
||||||
WHERE equipment_id = %s
|
WHERE equipment_id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
custom_params = cursor.fetchone()
|
custom_params = cursor.fetchone()
|
||||||
else:
|
else:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM loitering_munition_params
|
FROM loitering_munition_params
|
||||||
WHERE equipment_id = %s
|
WHERE equipment_id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
custom_params = cursor.fetchone()
|
custom_params = cursor.fetchone()
|
||||||
|
|
||||||
@ -853,7 +844,7 @@ def get_datasets():
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT d.*,
|
SELECT d.*,
|
||||||
COUNT(de.equipment_id) as equipment_count,
|
COUNT(de.equipment_id) as equipment_count,
|
||||||
@ -872,12 +863,7 @@ def get_datasets():
|
|||||||
else:
|
else:
|
||||||
dataset['equipment_names'] = []
|
dataset['equipment_names'] = []
|
||||||
|
|
||||||
# 格式化时间和数值字段
|
# SQLite 已存储日期为文本格式,无需转换
|
||||||
for dataset in datasets:
|
|
||||||
if dataset['created_at']:
|
|
||||||
dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
if dataset['updated_at']:
|
|
||||||
dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
|
|
||||||
return jsonify(datasets)
|
return jsonify(datasets)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -891,14 +877,14 @@ def get_dataset(id):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
# 获取数据集基本信息
|
# 获取数据集基本信息
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT d.*,
|
SELECT d.*,
|
||||||
COUNT(de.equipment_id) as equipment_count
|
COUNT(de.equipment_id) as equipment_count
|
||||||
FROM datasets d
|
FROM datasets d
|
||||||
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
|
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
|
||||||
WHERE d.id = %s
|
WHERE d.id = ?
|
||||||
GROUP BY d.id
|
GROUP BY d.id
|
||||||
""", (id,))
|
""", (id,))
|
||||||
dataset = cursor.fetchone()
|
dataset = cursor.fetchone()
|
||||||
@ -913,7 +899,7 @@ def get_dataset(id):
|
|||||||
FROM equipments e
|
FROM equipments e
|
||||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
equipment = cursor.fetchall()
|
equipment = cursor.fetchall()
|
||||||
|
|
||||||
@ -952,14 +938,14 @@ def create_dataset():
|
|||||||
|
|
||||||
# 1. 验证装备ID是否存在
|
# 1. 验证装备ID是否存在
|
||||||
if 'equipment_ids' in data and data['equipment_ids']:
|
if 'equipment_ids' in data and data['equipment_ids']:
|
||||||
# 直接从 equipment 表查询,不需要 JOIN
|
ids = data['equipment_ids']
|
||||||
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
|
placeholders = ','.join(['?'] * len(ids))
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
SELECT DISTINCT id FROM equipments
|
SELECT DISTINCT id FROM equipments
|
||||||
WHERE id IN ({equipment_ids_str}) AND type = %s
|
WHERE id IN ({placeholders}) AND type = ?
|
||||||
""", (data['equipment_type'],))
|
""", (*ids, data['equipment_type']))
|
||||||
|
|
||||||
valid_ids = [row[0] for row in cursor.fetchall()]
|
valid_ids = [row['id'] for row in cursor.fetchall()]
|
||||||
logger.info(f"Valid equipment IDs: {valid_ids}")
|
logger.info(f"Valid equipment IDs: {valid_ids}")
|
||||||
|
|
||||||
# 如果有无效的ID,返回错误
|
# 如果有无效的ID,返回错误
|
||||||
@ -971,7 +957,7 @@ def create_dataset():
|
|||||||
# 2. 创建数据集
|
# 2. 创建数据集
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO datasets (name, description, equipment_type, purpose)
|
INSERT INTO datasets (name, description, equipment_type, purpose)
|
||||||
VALUES (%s, %s, %s, %s)
|
VALUES (?, ?, ?, ?)
|
||||||
""", (data['name'], data['description'], data['equipment_type'], data['purpose']))
|
""", (data['name'], data['description'], data['equipment_type'], data['purpose']))
|
||||||
|
|
||||||
dataset_id = cursor.lastrowid
|
dataset_id = cursor.lastrowid
|
||||||
@ -982,7 +968,7 @@ def create_dataset():
|
|||||||
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
|
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
|
||||||
cursor.executemany("""
|
cursor.executemany("""
|
||||||
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
||||||
VALUES (%s, %s)
|
VALUES (?, ?)
|
||||||
""", values)
|
""", values)
|
||||||
logger.info(f"Added {len(values)} equipment associations")
|
logger.info(f"Added {len(values)} equipment associations")
|
||||||
|
|
||||||
@ -1004,14 +990,15 @@ def update_dataset(id):
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 1. 验证装备ID是否存在
|
# 1. 验证装备ID是否存在
|
||||||
if 'equipment_ids' in data:
|
if 'equipment_ids' in data and data['equipment_ids']:
|
||||||
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
|
ids = data['equipment_ids']
|
||||||
|
placeholders = ','.join(['?'] * len(ids))
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
SELECT id FROM equipments
|
SELECT id FROM equipments
|
||||||
WHERE id IN ({equipment_ids_str}) AND type = %s
|
WHERE id IN ({placeholders}) AND type = ?
|
||||||
""", (data['equipment_type'],))
|
""", (*ids, data['equipment_type']))
|
||||||
|
|
||||||
valid_ids = [row[0] for row in cursor.fetchall()]
|
valid_ids = [row['id'] for row in cursor.fetchall()]
|
||||||
logger.info(f"Valid equipment IDs: {valid_ids}")
|
logger.info(f"Valid equipment IDs: {valid_ids}")
|
||||||
|
|
||||||
# 如果有无效的ID,返回错误
|
# 如果有无效的ID,返回错误
|
||||||
@ -1023,21 +1010,21 @@ def update_dataset(id):
|
|||||||
# 2. 更新数据集基本信息
|
# 2. 更新数据集基本信息
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE datasets
|
UPDATE datasets
|
||||||
SET name = %s, description = %s, equipment_type = %s, purpose = %s
|
SET name = ?, description = ?, equipment_type = ?, purpose = ?
|
||||||
WHERE id = %s
|
WHERE id = ?
|
||||||
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
|
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
|
||||||
|
|
||||||
# 3. 更新装备关联
|
# 3. 更新装备关联
|
||||||
if 'equipment_ids' in data:
|
if 'equipment_ids' in data:
|
||||||
# 先删除旧的关联
|
# 先删除旧的关联
|
||||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
|
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
|
||||||
|
|
||||||
# 添加新的关联
|
# 添加新的关联
|
||||||
if valid_ids: # 确保有有效的ID才执行插入
|
if valid_ids: # 确保有有效的ID才执行插入
|
||||||
values = [(id, equipment_id) for equipment_id in valid_ids]
|
values = [(id, equipment_id) for equipment_id in valid_ids]
|
||||||
cursor.executemany("""
|
cursor.executemany("""
|
||||||
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
||||||
VALUES (%s, %s)
|
VALUES (?, ?)
|
||||||
""", values)
|
""", values)
|
||||||
logger.info(f"Updated {len(values)} equipment associations")
|
logger.info(f"Updated {len(values)} equipment associations")
|
||||||
|
|
||||||
@ -1058,10 +1045,10 @@ def delete_dataset(id):
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 删除装备关联
|
# 删除装备关联
|
||||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
|
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
|
||||||
|
|
||||||
# 删除数据集
|
# 删除数据集
|
||||||
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
|
cursor.execute("DELETE FROM datasets WHERE id = ?", (id,))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return jsonify({'success': True})
|
return jsonify({'success': True})
|
||||||
@ -1076,10 +1063,10 @@ def get_latest_model(equipment_type):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT * FROM trained_models
|
SELECT * FROM trained_models
|
||||||
WHERE equipment_type = %s AND is_active = TRUE
|
WHERE equipment_type = ? AND is_active = TRUE
|
||||||
ORDER BY training_date DESC LIMIT 1
|
ORDER BY training_date DESC LIMIT 1
|
||||||
""", (equipment_type,))
|
""", (equipment_type,))
|
||||||
|
|
||||||
@ -1095,7 +1082,7 @@ def get_models():
|
|||||||
"""获取模型列表"""
|
"""获取模型列表"""
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT * FROM trained_models
|
SELECT * FROM trained_models
|
||||||
ORDER BY training_date DESC
|
ORDER BY training_date DESC
|
||||||
@ -1105,9 +1092,7 @@ def get_models():
|
|||||||
|
|
||||||
# 格式化时间和数值字段
|
# 格式化时间和数值字段
|
||||||
for model in models:
|
for model in models:
|
||||||
# 将数据库中的datetime转换为ISO格式字符串
|
# SQLite 已存储日期为文本格式,无需转换
|
||||||
if model['training_date']:
|
|
||||||
model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
if model['r2_score'] is not None:
|
if model['r2_score'] is not None:
|
||||||
model['r2_score'] = float(model['r2_score'])
|
model['r2_score'] = float(model['r2_score'])
|
||||||
if model['mae'] is not None:
|
if model['mae'] is not None:
|
||||||
@ -1130,12 +1115,12 @@ def activate_model(id):
|
|||||||
"""激活指定的模型"""
|
"""激活指定的模型"""
|
||||||
try:
|
try:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True) # 使用字典游标
|
cursor = conn.cursor() # 使用字典游标
|
||||||
|
|
||||||
# 获取模型信息
|
# 获取模型信息
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT equipment_type, model_type FROM trained_models
|
SELECT equipment_type, model_type FROM trained_models
|
||||||
WHERE id = %s
|
WHERE id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
model = cursor.fetchone()
|
model = cursor.fetchone()
|
||||||
|
|
||||||
@ -1146,14 +1131,14 @@ def activate_model(id):
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE trained_models
|
UPDATE trained_models
|
||||||
SET is_active = FALSE
|
SET is_active = FALSE
|
||||||
WHERE equipment_type = %s AND model_type = %s
|
WHERE equipment_type = ? AND model_type = ?
|
||||||
""", (model['equipment_type'], model['model_type']))
|
""", (model['equipment_type'], model['model_type']))
|
||||||
|
|
||||||
# 激活指定模型
|
# 激活指定模型
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
UPDATE trained_models
|
UPDATE trained_models
|
||||||
SET is_active = TRUE
|
SET is_active = TRUE
|
||||||
WHERE id = %s
|
WHERE id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@ -1176,21 +1161,21 @@ def delete_model(id):
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT model_path, scaler_path
|
SELECT model_path, scaler_path
|
||||||
FROM trained_models
|
FROM trained_models
|
||||||
WHERE id = %s
|
WHERE id = ?
|
||||||
""", (id,))
|
""", (id,))
|
||||||
model = cursor.fetchone()
|
model = cursor.fetchone()
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
return jsonify({'error': 'Model not found'}), 404
|
return jsonify({'error': 'Model not found'}), 404
|
||||||
|
|
||||||
# 删除模型件
|
# 删除模型文件
|
||||||
if os.path.exists(model[0]):
|
if os.path.exists(model['model_path']):
|
||||||
os.remove(model[0])
|
os.remove(model['model_path'])
|
||||||
if os.path.exists(model[1]):
|
if os.path.exists(model['scaler_path']):
|
||||||
os.remove(model[1])
|
os.remove(model['scaler_path'])
|
||||||
|
|
||||||
# 删除数据库记录
|
# 删除数据库记录
|
||||||
cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
|
cursor.execute("DELETE FROM trained_models WHERE id = ?", (id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
return jsonify({'success': True})
|
return jsonify({'success': True})
|
||||||
@ -1231,7 +1216,7 @@ def analyze_manufacturers():
|
|||||||
return jsonify({'error': '请选择数据集'}), 400
|
return jsonify({'error': '请选择数据集'}), 400
|
||||||
|
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 获取数据集中的装备和生产商数据
|
# 获取数据集中的装备和生产商数据
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
@ -1239,13 +1224,22 @@ def analyze_manufacturers():
|
|||||||
FROM manufacturers m
|
FROM manufacturers m
|
||||||
JOIN equipments e ON e.manufacturer_id = m.id
|
JOIN equipments e ON e.manufacturer_id = m.id
|
||||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||||
WHERE de.dataset_id = %s
|
WHERE de.dataset_id = ?
|
||||||
""", (dataset_id,))
|
""", (dataset_id,))
|
||||||
|
|
||||||
manufacturers = cursor.fetchall()
|
manufacturers = cursor.fetchall()
|
||||||
|
|
||||||
if not manufacturers:
|
if not manufacturers:
|
||||||
return jsonify({'error': '数据集中没有生产商数据'}), 404
|
logger.info("No manufacturer data in dataset, returning empty result")
|
||||||
|
return jsonify({
|
||||||
|
'manufacturer_names': [],
|
||||||
|
'manufacturer_tech_levels': [],
|
||||||
|
'manufacturer_scale_levels': [],
|
||||||
|
'manufacturer_supply_chain_levels': [],
|
||||||
|
'manufacturer_composite_scores': [],
|
||||||
|
'region_distribution': [],
|
||||||
|
'manufacturer_scores': []
|
||||||
|
})
|
||||||
|
|
||||||
# 准备分析数据
|
# 准备分析数据
|
||||||
manufacturer_names = []
|
manufacturer_names = []
|
||||||
|
|||||||
@ -1,9 +1,5 @@
|
|||||||
@echo off
|
@echo off
|
||||||
set FLASK_DEBUG=false
|
set FLASK_DEBUG=false
|
||||||
set MYSQL_HOST=localhost
|
|
||||||
set MYSQL_USER=root
|
|
||||||
set MYSQL_PASSWORD=123456
|
|
||||||
set MYSQL_DB=equipment_cost_db
|
|
||||||
echo Starting Cost Prediction System...
|
echo Starting Cost Prediction System...
|
||||||
start /B run.exe
|
start /B run.exe
|
||||||
start http://localhost:5001
|
start http://localhost:5001
|
||||||
Loading…
Reference in New Issue
Block a user