将mysql改成sqlite,减少前端依赖

This commit is contained in:
tian 2026-04-28 23:30:48 +08:00
parent 137451ba7a
commit 48ba547c36
24 changed files with 2216 additions and 10338 deletions

File diff suppressed because one or more lines are too long

108
CLAUDE.md Normal file
View File

@ -0,0 +1,108 @@
# CLAUDE.md
本文件为 Claude Codeclaude.ai/code在此仓库中工作提供指引。
## 常用命令
```bash
# 安装后端依赖(核心:无需 MySQL、无需 PyTorch
pip install -e .
pip install -e ".[dev]" # 含开发工具pytest, black, mypy
pip install -e ".[torch]" # 安装可选 PyTorch神经网络训练需要
# 启动后端Flask0.0.0.0:5001
python run.py
# 启动前端开发服务器另开终端localhost:3000
cd frontend && npm install && npm run serve
# 构建前端生产版本
cd frontend && npm run build
# 运行测试
python -m pytest tests/
python -m pytest tests/test_demo_service.py -q # 单个测试文件
python -m pytest tests/test_demo_routes.py -q
# 代码格式化 / 类型检查
black src/ tests/ # 自动格式化line-length=88
mypy src/ # 类型检查
# 运行独立演示模式(仅需 5 个依赖)
cd demo_standalone && pip install -r requirements.txt && python server.py
```
## 项目概述
基于机器学习的装备成本预测系统。支持两种装备类型:**火箭炮**和**巡飞弹**。
### 技术栈
- **后端**Python 3.9-3.11, Flask 3.1+
- **数据库**SQLitePython 内置,零外部依赖),首次启动自动建表,无需手动安装配置
- **机器学习**scikit-learn核心、XGBoost、LightGBM
- **可选依赖**PyTorch仅神经网络训练需要约 800MB
- **前端**Vue 3 (Composition API) + Element Plus + ECharts, 使用 Vite 构建
### 关键文件
| 文件 | 用途 |
|---|---|
| `run.py` | 入口点,启动 Flask 服务器 |
| `src/app.py` | Flask 应用工厂 |
| `src/routes.py` | 所有 API 路由(约 1300 行) |
| `src/model_trainer.py` | ModelTrainer 类 + CostPredictionModelPyTorch NN |
| `src/data_preparation.py` | DataPreparation + EquipmentDataset |
| `src/cost_prediction.py` | CostPredictor预测编排 |
| `src/demo_service.py` | DemoModelService基于 CSV无需数据库 |
| `src/database/db_connection.py` | SQLite 数据库连接 + 建表 DDL内置 |
| `config.py` | 运行时配置 |
| `frontend/src/router/index.js` | Vue Router共 8 个路由 |
| `frontend/src/api/index.js` | Axios API 客户端 |
## 架构
### 与旧架构的核心变化
| 项目 | 旧架构MySQL | 新架构SQLite |
|------|----------------|-----------------|
| 数据库 | MySQL 8.0+,需单独安装运行 | SQLitePython 内置,零依赖 |
| 数据库依赖 | sqlalchemy, pymysql, cryptography, mysql-connector-python | 无(仅用 Python 标准库 sqlite3 |
| PyTorch | 硬依赖(顶层 import无则崩溃 | 可选依赖try/except 保护,无 PyTorch 可启动) |
| 前端构建 | Vue CLI + Babel + SCSS | Vite无需 Babel/SCSS |
| Vuex | 存在但完全空 | 已移除 |
| 依赖总数(核心)| 15+ | 5flask, numpy, pandas, scikit-learn, openpyxl |
### 训练数据流
```
用户界面 → POST /api/train → 查询 SQLite → DataPreparation特征提取 + 标准化)
→ ModelTrainer.fit_model()(训练 XGBoost, LightGBM, RF, GBM, PyTorch NN, PLS
→ 保存最优模型 → 写入 SQLite trained_models 表
```
### 预测数据流
```
用户界面 → POST /api/predict → 从 SQLite 加载最优模型 + 标准化器 → 提取特征
→ 特征标准化 → 预测 → ±20% 置信区间 → JSON 返回
```
### 机器学习模型
| 模型 | 标识 | 说明 |
|---|---|---|
| XGBoost | `xgboost` | 针对小数据量使用保守参数 |
| LightGBM | `lightgbm` | 针对小数据量使用保守参数 |
| Random Forest | `rf` | sklearn.ensemble |
| Gradient Boosting | `gbm` | sklearn.ensemble |
| PyTorch NN | `pytorch` | 可选(需安装 torch针对不同装备类型定制网络结构 |
| PLS 回归 | `pls` | 不参与最优模型评选 |
| Linear/Ridge/SVR/KNN | 仅演示 | 用于算法对比演示 |
## 重要注意事项
- **小数据量机器学习**:所有模型超参数均为保守设置(强正则化、浅树、低学习率、早停)
- **两种装备类型**火箭炮27+ 特征和巡飞弹24+ 特征),各自有独立数据库表和神经网络结构
- **生产商议价能力特征**:技术等级、规模、供应链、地区等信息被纳入特征,地区成本乘数(如美国 1.2 倍、中国 0.8 倍)
- **SQLite 首次使用**`data/equipment_cost.db` 文件在首次数据库操作时自动创建,无需手动初始化
- **编码规范**UTF-8 无 BOMLF 换行符,文件末尾保留换行符
- **中文内容**:绝不修改中文注释或文本。编辑中文附近代码时,完整保留原有中文内容

View File

@ -2,11 +2,8 @@ import os
class Config: class Config:
"""配置类""" """配置类"""
# 数据库配置 # 数据库配置(使用 SQLite
MYSQL_HOST = os.getenv('MYSQL_HOST', 'localhost') SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
MYSQL_USER = os.getenv('MYSQL_USER', 'root')
MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '123456')
MYSQL_DB = os.getenv('MYSQL_DB', 'equipment_cost_db')
# Flask配置 # Flask配置
FLASK_HOST = '0.0.0.0' FLASK_HOST = '0.0.0.0'

View File

@ -1,5 +0,0 @@
module.exports = {
presets: [
'@vue/cli-plugin-babel/preset'
]
}

17
frontend/index.html Normal file
View File

@ -0,0 +1,17 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="/favicon.ico">
<title>装备成本估算系统</title>
</head>
<body>
<noscript>
<strong>装备成本估算系统需要启用 JavaScript 才能运行。</strong>
</noscript>
<div id="app"></div>
<script type="module" src="/src/main.js"></script>
</body>
</html>

11130
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -7,33 +7,24 @@
"npm": ">=8" "npm": ">=8"
}, },
"scripts": { "scripts": {
"serve": "vue-cli-service serve", "serve": "vite",
"build": "vue-cli-service build", "build": "vite build",
"lint": "vue-cli-service lint" "lint": "eslint --ext .js,.vue src/"
}, },
"dependencies": { "dependencies": {
"axios": "^1.6.0", "axios": "^1.6.0",
"core-js": "^3.8.3",
"echarts": "^5.4.3", "echarts": "^5.4.3",
"element-plus": "^2.4.2", "element-plus": "^2.4.2",
"vue": "^3.2.13", "vue": "^3.2.13",
"vue-router": "^4.0.3", "vue-router": "^4.0.3"
"vuex": "^4.0.0"
}, },
"devDependencies": { "devDependencies": {
"@babel/core": "^7.12.16",
"@babel/eslint-parser": "^7.12.16",
"@element-plus/icons-vue": "^2.3.1", "@element-plus/icons-vue": "^2.3.1",
"@vue/cli-plugin-babel": "~5.0.0", "@vitejs/plugin-vue": "^5.0.0",
"@vue/cli-plugin-eslint": "~5.0.0",
"@vue/cli-plugin-router": "~5.0.0",
"@vue/cli-plugin-vuex": "~5.0.0",
"@vue/cli-service": "~5.0.0",
"@vue/compiler-sfc": "^3.2.13",
"eslint": "^7.32.0", "eslint": "^7.32.0",
"eslint-plugin-vue": "^8.0.3", "eslint-plugin-vue": "^8.0.3",
"sass": "^1.32.7", "sass-embedded": "^1.99.0",
"sass-loader": "^12.0.0" "vite": "^5.0.0"
}, },
"eslintConfig": { "eslintConfig": {
"root": true, "root": true,
@ -45,17 +36,12 @@
"eslint:recommended" "eslint:recommended"
], ],
"parserOptions": { "parserOptions": {
"parser": "@babel/eslint-parser" "ecmaVersion": "latest",
"sourceType": "module"
}, },
"rules": { "rules": {
"vue/multi-word-component-names": "off", "vue/multi-word-component-names": "off",
"no-unused-vars": "warn" "no-unused-vars": "warn"
} }
}, }
"browserslist": [
"> 1%",
"last 2 versions",
"not dead",
"not ie 11"
]
} }

View File

@ -1,17 +0,0 @@
<!DOCTYPE html>
<html lang="">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
<title><%= htmlWebpackPlugin.options.title %></title>
</head>
<body>
<noscript>
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
</noscript>
<div id="app"></div>
<!-- built files will be auto injected -->
</body>
</html>

View File

@ -1,7 +1,6 @@
import { createApp } from 'vue' import { createApp } from 'vue'
import App from './App.vue' import App from './App.vue'
import router from './router' import router from './router'
import store from './store'
import ElementPlus from 'element-plus' import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css' import 'element-plus/dist/index.css'
import './assets/styles/global.css' import './assets/styles/global.css'
@ -15,7 +14,6 @@ app.use(ElementPlus, {
size: 'default' size: 'default'
}) })
app.use(router) app.use(router)
app.use(store)
// 注册图标 // 注册图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) { for (const [key, component] of Object.entries(ElementPlusIconsVue)) {

View File

@ -1,14 +0,0 @@
import { createStore } from 'vuex'
export default createStore({
state: {
},
getters: {
},
mutations: {
},
actions: {
},
modules: {
}
})

View File

@ -19,7 +19,7 @@ export default defineConfig({
port: 3000, port: 3000,
proxy: { proxy: {
'/api': { '/api': {
target: 'http://localhost:5000', target: 'http://localhost:5001',
changeOrigin: true changeOrigin: true
} }
} }

View File

@ -11,12 +11,6 @@ dependencies = [
"flask>=3.1.0", "flask>=3.1.0",
"flask-cors>=5.0.0", "flask-cors>=5.0.0",
# 数据库
"sqlalchemy>=2.0.36",
"pymysql>=1.1.1",
"cryptography>=43.0.0",
"mysql-connector-python>=8.0.0",
# 数据处理 # 数据处理
"numpy>=1.26.0,<2.0.0", "numpy>=1.26.0,<2.0.0",
"pandas>=2.2.0", "pandas>=2.2.0",
@ -25,7 +19,6 @@ dependencies = [
"scikit-learn>=1.5.2", "scikit-learn>=1.5.2",
"xgboost>=2.1.0", "xgboost>=2.1.0",
"lightgbm>=4.5.0", "lightgbm>=4.5.0",
"torch==2.5.1",
# 工具 # 工具
"openpyxl>=3.1.5", "openpyxl>=3.1.5",
@ -34,6 +27,10 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
# PyTorch 为可选依赖(安装约 800MB仅训练神经网络时需要
torch = [
"torch==2.5.1",
]
dev = [ dev = [
# 测试工具 # 测试工具
"pytest>=7.0", "pytest>=7.0",

View File

@ -1,15 +1,12 @@
flask>=3.1.0 flask>=3.1.0
flask-cors>=5.0.0 flask-cors>=5.0.0
sqlalchemy>=2.0.36
pymysql>=1.1.1
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
mysql-connector-python>=8.0.0 # 添加这行
numpy>=1.26.0,<2.0.0 numpy>=1.26.0,<2.0.0
pandas>=2.2.0 pandas>=2.2.0
xgboost>=2.1.0 xgboost>=2.1.0
lightgbm>=4.5.0 lightgbm>=4.5.0
scikit-learn>=1.5.2 scikit-learn>=1.5.2
openpyxl>=3.1.5 # 用于读取 .xlsx 文件 openpyxl>=3.1.5
python-dotenv>=1.0.0 # 环境变量 python-dotenv>=1.0.0

View File

@ -34,7 +34,6 @@ pyinstaller --clean `
--add-data "src/loitering_munition_data.sql;data" ` --add-data "src/loitering_munition_data.sql;data" `
--add-data "src/rocket_artillery_data.sql;data" ` --add-data "src/rocket_artillery_data.sql;data" `
--add-data "src/manufacturer_data.sql;data" ` --add-data "src/manufacturer_data.sql;data" `
--add-data "src/schema.sql;data" `
--add-data "config.py;." ` --add-data "config.py;." `
--add-data "src;src" ` --add-data "src;src" `
--add-data "frontend;frontend" ` --add-data "frontend;frontend" `
@ -43,18 +42,11 @@ pyinstaller --clean `
--add-data "models;models" ` --add-data "models;models" `
--collect-all "xgboost" ` --collect-all "xgboost" `
--collect-all "lightgbm" ` --collect-all "lightgbm" `
--collect-all "torch" `
--collect-all "sklearn" ` --collect-all "sklearn" `
--collect-all "numpy" ` --collect-all "numpy" `
--collect-all "pandas" ` --collect-all "pandas" `
--collect-all "sqlalchemy" `
--collect-all "pymysql" `
--collect-all "cryptography" `
--collect-all "flask" ` --collect-all "flask" `
--collect-all "flask_cors" ` --collect-all "flask_cors" `
--hidden-import "xgboost.testing" `
--hidden-import "torch.utils.tensorboard" `
--hidden-import "pytest" `
run.py run.py
# Copy necessary files # Copy necessary files

View File

@ -15,12 +15,6 @@ def create_app():
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
# 配置数据库连接
app.config['MYSQL_HOST'] = config.MYSQL_HOST
app.config['MYSQL_USER'] = config.MYSQL_USER
app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
app.config['MYSQL_DB'] = config.MYSQL_DB
# 注册路由 # 注册路由
app.register_blueprint(api_bp, url_prefix='/api') app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered") logger.info("API blueprint registered")

View File

@ -1,5 +1,4 @@
import numpy as np import numpy as np
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from scipy import stats from scipy import stats
@ -9,6 +8,16 @@ from src.database import get_db_connection
from src.feature_analysis import FeatureAnalysis from src.feature_analysis import FeatureAnalysis
from .logger import setup_logger from .logger import setup_logger
# PyTorch 为可选依赖
try:
import torch
import torch.nn as nn
_HAS_TORCH = True
except ImportError:
torch = None
nn = None
_HAS_TORCH = False
logger = setup_logger(__name__) logger = setup_logger(__name__)
class CostPredictor: class CostPredictor:
@ -18,7 +27,11 @@ class CostPredictor:
self.model = None self.model = None
self.feature_analyzer = FeatureAnalysis() self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None self.equipment_type = None
if _HAS_TORCH:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = None
self.load_model() self.load_model()
@ -26,10 +39,9 @@ class CostPredictor:
""" """
加载预训练模型和标准化器 加载预训练模型和标准化器
""" """
if _HAS_TORCH:
try: try:
# 创建默认模型
self._create_default_model() self._create_default_model()
except Exception as e: except Exception as e:
logging.error(f"Error loading model: {str(e)}") logging.error(f"Error loading model: {str(e)}")
self._create_default_model() self._create_default_model()
@ -38,46 +50,36 @@ class CostPredictor:
""" """
创建默认模型并进行初始化训练 创建默认模型并进行初始化训练
""" """
import torch.nn as nn if not _HAS_TORCH:
raise ImportError("PyTorch is not installed.")
class DefaultModel(nn.Module): class DefaultModel(nn.Module):
def __init__(self, input_size): def __init__(self, input_size):
super().__init__() super().__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
nn.Linear(input_size, 64), nn.Linear(input_size, 64), nn.ReLU(),
nn.ReLU(), nn.Linear(64, 32), nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1) nn.Linear(32, 1)
) )
def forward(self, x): def forward(self, x):
return self.layers(x) return self.layers(x)
# 创建示例数据
example_features = { example_features = {
'length_m': [7.35, 10.2], 'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8],
'width_m': [2.4, 2.8], 'height_m': [3.1, 3.2], 'weight_kg': [13700, 28500],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70], 'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110], 'firing_angle_horizontal': [102, 110], 'firing_angle_vertical': [55, 60],
'firing_angle_vertical': [55, 60], 'rocket_length_m': [2.87, 4.1], 'rocket_diameter_mm': [122, 220],
'rocket_length_m': [2.87, 4.1], 'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60]
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
} }
# 转换为 tensor
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t() X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
y = torch.tensor([[800000], [4500000]], dtype=torch.float32) y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
# 训练标准化器
self.scaler_X.fit(X.numpy()) self.scaler_X.fit(X.numpy())
self.scaler_y.fit(y.numpy()) self.scaler_y.fit(y.numpy())
# 创建模型
self.model = DefaultModel(X.shape[1]).to(self.device) self.model = DefaultModel(X.shape[1]).to(self.device)
self.equipment_type = '火箭炮' self.equipment_type = '火箭炮'

View File

@ -1,7 +1,5 @@
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
import numpy as np import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import logging import logging
from src.feature_analysis import FeatureAnalysis from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection from src.database import get_db_connection
@ -9,9 +7,22 @@ from .logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
class EquipmentDataset(Dataset): # PyTorch 为可选依赖
try:
import torch
from torch.utils.data import Dataset, DataLoader
_HAS_TORCH = True
except ImportError:
torch = None
Dataset = object
DataLoader = None
_HAS_TORCH = False
class EquipmentDataset(Dataset if _HAS_TORCH else object):
"""装备数据集类""" """装备数据集类"""
def __init__(self, features, targets=None): def __init__(self, features, targets=None):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
self.features = torch.FloatTensor(features) self.features = torch.FloatTensor(features)
self.targets = torch.FloatTensor(targets) if targets is not None else None self.targets = torch.FloatTensor(targets) if targets is not None else None
@ -41,7 +52,7 @@ class DataPreparation:
# 获取数据库连接 # 获取数据库连接
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 获取所有生产商数据,用于计算特征 # 获取所有生产商数据,用于计算特征
cursor.execute(""" cursor.execute("""

View File

@ -1,37 +1,227 @@
import mysql.connector import sqlite3
from mysql.connector import Error
from contextlib import contextmanager from contextlib import contextmanager
import os import os
from dotenv import load_dotenv
from ..logger import setup_logger from ..logger import setup_logger
# 获取logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
# 加载环境变量 # SQLite 数据库文件路径(相对于项目根目录)
load_dotenv() DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data')
DB_PATH = os.path.join(DB_DIR, 'equipment_cost.db')
# 建表 SQL
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS equipments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
type TEXT,
manufacturer TEXT,
manufacturer_id INTEGER,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE TABLE IF NOT EXISTS common_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
length_m REAL,
width_m REAL,
height_m REAL,
weight_kg REAL,
max_range_km REAL,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS rocket_artillery_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
firing_angle_horizontal REAL,
firing_angle_vertical REAL,
rocket_length_m REAL,
rocket_diameter_mm REAL,
rocket_weight_kg REAL,
rate_of_fire REAL,
combat_weight_kg REAL,
speed_kmh REAL,
min_range_km REAL,
max_range_km REAL,
mobility_type TEXT,
structure_layout TEXT,
engine_model TEXT,
engine_params TEXT,
power_hp REAL,
travel_range_km REAL,
fire_density REAL,
range_ratio REAL,
mobility_score INTEGER,
combat_readiness_score INTEGER,
deployment_score INTEGER,
terrain_adaptability_score INTEGER,
rocket_power_ratio REAL,
platform_efficiency REAL,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS loitering_munition_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
wingspan_m REAL,
warhead_weight_kg REAL,
max_speed_ms REAL,
cruise_speed_kmh REAL,
endurance_min REAL,
flight_time_min REAL,
max_range_km REAL,
max_payload_kg REAL,
ceiling_altitude_m REAL,
combat_radius_km REAL,
folded_length_mm REAL,
folded_width_mm REAL,
folded_height_mm REAL,
warhead_type TEXT,
launch_mode TEXT,
power_system TEXT,
guidance_system TEXT,
engine_power_kw REAL,
engine_thrust_n REAL,
datalink_range_km REAL,
guidance_accuracy_m REAL,
min_altitude_m REAL,
max_altitude_m REAL,
length_width_ratio REAL,
weight_range_ratio REAL,
speed_weight_ratio REAL,
guidance_system_score INTEGER,
warhead_power_score INTEGER,
warhead_type_code INTEGER,
launch_mode_code INTEGER,
power_system_code INTEGER,
guidance_system_code INTEGER,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS feature_encoding (
id INTEGER PRIMARY KEY AUTOINCREMENT,
feature_type TEXT,
feature_value TEXT,
code INTEGER,
UNIQUE(feature_type, feature_value)
);
CREATE TABLE IF NOT EXISTS cost_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
actual_cost REAL,
predicted_cost REAL,
prediction_date TEXT,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS custom_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
param_name TEXT,
param_value TEXT,
param_unit TEXT,
description TEXT,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS datasets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT,
equipment_type TEXT NOT NULL,
purpose TEXT NOT NULL,
created_at TEXT DEFAULT (datetime('now','localtime')),
updated_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE TABLE IF NOT EXISTS dataset_equipments (
dataset_id INTEGER NOT NULL,
equipment_id INTEGER NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS trained_models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
model_type TEXT NOT NULL,
equipment_type TEXT NOT NULL,
model_path TEXT NOT NULL,
scaler_path TEXT NOT NULL,
r2_score REAL,
mae REAL,
rmse REAL,
feature_importance TEXT,
training_data_size INTEGER,
training_date TEXT DEFAULT (datetime('now','localtime')),
is_active INTEGER DEFAULT 0,
created_by TEXT
);
CREATE TABLE IF NOT EXISTS manufacturers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
country TEXT NOT NULL,
tech_level INTEGER NOT NULL,
scale_level INTEGER NOT NULL,
supply_chain_level INTEGER NOT NULL,
created_at TEXT DEFAULT (datetime('now','localtime')),
updated_at TEXT DEFAULT (datetime('now','localtime')),
UNIQUE(name)
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_equipment_type ON equipments(type);
CREATE INDEX IF NOT EXISTS idx_equipment_name ON equipments(name);
CREATE INDEX IF NOT EXISTS idx_cost_data_equipment ON cost_data(equipment_id);
CREATE INDEX IF NOT EXISTS idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX IF NOT EXISTS idx_model_active ON trained_models(is_active);
CREATE INDEX IF NOT EXISTS idx_manufacturer_country ON manufacturers(country);
CREATE INDEX IF NOT EXISTS idx_manufacturer_tech_level ON manufacturers(tech_level);
CREATE INDEX IF NOT EXISTS idx_manufacturer_scale_level ON manufacturers(scale_level);
CREATE INDEX IF NOT EXISTS idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
CREATE INDEX IF NOT EXISTS idx_equipment_manufacturer ON equipments(manufacturer_id);
"""
def init_db():
"""初始化数据库:确保数据库文件和表存在"""
os.makedirs(DB_DIR, exist_ok=True)
conn = sqlite3.connect(DB_PATH)
conn.executescript(SCHEMA_SQL)
conn.commit()
conn.close()
logger.info(f"Database initialized at {DB_PATH}")
@contextmanager @contextmanager
def get_db_connection(): def get_db_connection():
""" """
数据库连接上下文管理器 数据库连接上下文管理器
返回的 connection 已设置 dict row_factory
以便按列名访问
""" """
connection = None conn = None
try: try:
connection = mysql.connector.connect( # 确保数据库已初始化
host=os.getenv('MYSQL_HOST', 'localhost'), if not os.path.exists(DB_PATH):
user=os.getenv('MYSQL_USER', 'root'), logger.info("Database file not found, initializing...")
password=os.getenv('MYSQL_PASSWORD', '123456'), init_db()
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
)
logger.info("Database connection established")
yield connection
except Error as e: conn = sqlite3.connect(DB_PATH)
logger.error(f"Error connecting to MySQL: {str(e)}") conn.row_factory = lambda c, r: {col[0]: r[idx] for idx, col in enumerate(c.description)}
conn.execute("PRAGMA foreign_keys = ON")
logger.debug("Database connection established")
yield conn
except sqlite3.Error as e:
logger.error(f"Database error: {str(e)}")
raise raise
finally: finally:
if connection and connection.is_connected(): if conn:
connection.close() conn.close()
logger.info("Database connection closed") logger.debug("Database connection closed")

View File

@ -267,11 +267,15 @@ class FeatureAnalysis:
def calculate_manufacturer_features(self, manufacturer_data): def calculate_manufacturer_features(self, manufacturer_data):
"""计算生产商相关的特征""" """计算生产商相关的特征"""
try: try:
# 确保所有必要的字段都存在,使用默认值处理缺失数据 # 处理 None 值(数据库 NULL使用默认值
tech_level = float(manufacturer_data.get('tech_level', 0)) raw_tech = manufacturer_data.get('tech_level')
scale_level = float(manufacturer_data.get('scale_level', 0)) raw_scale = manufacturer_data.get('scale_level')
supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0)) raw_supply = manufacturer_data.get('supply_chain_level')
country = manufacturer_data.get('country', '未知')
tech_level = float(raw_tech) if raw_tech is not None else 0
scale_level = float(raw_scale) if raw_scale is not None else 0
supply_chain_level = float(raw_supply) if raw_supply is not None else 0
country = manufacturer_data.get('country', '未知') or '未知'
# 计算综合得分 # 计算综合得分
composite_score = ( composite_score = (

View File

@ -27,7 +27,7 @@ def import_training_data(excel_file):
# 检查是否已存在相同名称的装备 # 检查是否已存在相同名称的装备
cursor.execute(""" cursor.execute("""
SELECT id FROM equipments SELECT id FROM equipments
WHERE name = %s AND type = '火箭炮' WHERE name = ? AND type = '火箭炮'
""", (row['名称'],)) """, (row['名称'],))
existing_equipment = cursor.fetchone() existing_equipment = cursor.fetchone()
@ -38,7 +38,7 @@ def import_training_data(excel_file):
# 插入基本信息 # 插入基本信息
cursor.execute(""" cursor.execute("""
INSERT INTO equipments (name, type, manufacturer) INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s) VALUES (?, ?, ?)
""", (row['名称'], '火箭炮', row['制造商'])) """, (row['名称'], '火箭炮', row['制造商']))
equipment_id = cursor.lastrowid equipment_id = cursor.lastrowid
@ -47,7 +47,7 @@ def import_training_data(excel_file):
cursor.execute(""" cursor.execute("""
INSERT INTO common_params INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s) VALUES (?, ?, ?, ?, ?, ?)
""", ( """, (
equipment_id, equipment_id,
row['总长_m'] if pd.notna(row['总长_m']) else None, row['总长_m'] if pd.notna(row['总长_m']) else None,
@ -65,7 +65,7 @@ def import_training_data(excel_file):
combat_weight_kg, speed_kmh, min_range_km, mobility_type, combat_weight_kg, speed_kmh, min_range_km, mobility_type,
structure_layout, engine_model, engine_params, power_hp, structure_layout, engine_model, engine_params, power_hp,
travel_range_km) travel_range_km)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """, (
equipment_id, equipment_id,
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None, row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
@ -89,7 +89,7 @@ def import_training_data(excel_file):
if pd.notna(row['成本_元']): if pd.notna(row['成本_元']):
cursor.execute(""" cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost) INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s) VALUES (?, ?)
""", (equipment_id, row['成本_元'])) """, (equipment_id, row['成本_元']))
logger.info("火箭炮数据导入完成") logger.info("火箭炮数据导入完成")
@ -105,8 +105,8 @@ def import_training_data(excel_file):
equipment_names.add(row['名称']) equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备 # 检查是否已存在相同名称的装备
cursor.execute(""" cursor.execute("""
SELECT id FROM equipment SELECT id FROM equipments
WHERE name = %s AND type = '巡飞弹' WHERE name = ? AND type = '巡飞弹'
""", (row['名称'],)) """, (row['名称'],))
existing_equipment = cursor.fetchone() existing_equipment = cursor.fetchone()
@ -117,7 +117,7 @@ def import_training_data(excel_file):
# 插入基本信息 # 插入基本信息
cursor.execute(""" cursor.execute("""
INSERT INTO equipments (name, type, manufacturer) INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s) VALUES (?, ?, ?)
""", ( """, (
row['名称'], row['名称'],
'巡飞弹', '巡飞弹',
@ -130,7 +130,7 @@ def import_training_data(excel_file):
cursor.execute(""" cursor.execute("""
INSERT INTO common_params INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s) VALUES (?, ?, ?, ?, ?, ?)
""", ( """, (
equipment_id, equipment_id,
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None, float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
@ -147,7 +147,7 @@ def import_training_data(excel_file):
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode, cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
folded_length_mm, folded_width_mm, folded_height_mm, folded_length_mm, folded_width_mm, folded_height_mm,
power_system, guidance_system) power_system, guidance_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """, (
equipment_id, equipment_id,
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None, float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
@ -168,7 +168,7 @@ def import_training_data(excel_file):
if pd.notna(row['成本_元']): if pd.notna(row['成本_元']):
cursor.execute(""" cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost) INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s) VALUES (?, ?)
""", (equipment_id, float(row['成本_元']))) """, (equipment_id, float(row['成本_元'])))
logger.info("巡飞弹数据导入完成") logger.info("巡飞弹数据导入完成")
@ -190,9 +190,9 @@ def import_training_data(excel_file):
# 获取装备ID - 使用新的游标 # 获取装备ID - 使用新的游标
logger.debug(f"查询装备ID: {equipment_name}") logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor: id_cursor = conn.cursor()
id_cursor.execute(""" id_cursor.execute("""
SELECT id FROM equipments WHERE name = %s SELECT id FROM equipments WHERE name = ?
""", (equipment_name,)) """, (equipment_name,))
result = id_cursor.fetchone() result = id_cursor.fetchone()
@ -200,15 +200,15 @@ def import_training_data(excel_file):
logger.warning(f"未找到装备: {equipment_name}") logger.warning(f"未找到装备: {equipment_name}")
continue continue
equipment_id = result[0] equipment_id = result['id']
logger.debug(f"找到装备ID: {equipment_id}") logger.debug(f"找到装备ID: {equipment_id}")
# 检查参数是否存在 - 使用新的游标 # 检查参数是否存在 - 使用新的游标
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'") logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
with conn.cursor() as check_cursor: check_cursor = conn.cursor()
check_cursor.execute(""" check_cursor.execute("""
SELECT id FROM custom_params SELECT id FROM custom_params
WHERE equipment_id = %s AND param_name = %s WHERE equipment_id = ? AND param_name = ?
""", (equipment_id, param_name)) """, (equipment_id, param_name))
exists = check_cursor.fetchone() exists = check_cursor.fetchone()
@ -222,11 +222,11 @@ def import_training_data(excel_file):
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'") logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
with conn.cursor() as insert_cursor: insert_cursor = conn.cursor()
insert_cursor.execute(""" insert_cursor.execute("""
INSERT INTO custom_params INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description) (equipment_id, param_name, param_value, param_unit, description)
VALUES (%s, %s, %s, %s, %s) VALUES (?, ?, ?, ?, ?)
""", ( """, (
equipment_id, equipment_id,
param_name, param_name,

237
src/import_sql_data.py Normal file
View File

@ -0,0 +1,237 @@
"""
导入 SQL 数据文件到 SQLite 数据库
处理 src/ 目录下的 MySQL 格式 SQL 数据文件:
- rocket_artillery_data.sql (96 条火箭炮数据)
- loitering_munition_data.sql (100 条巡飞弹数据)
- manufacturer_data.sql (生产商数据)
"""
import re
import os
import sys
from src.database.db_connection import get_db_connection, DB_PATH
from src.logger import setup_logger
logger = setup_logger(__name__)
def parse_mysql_insert(sql_text):
"""
解析 MySQL INSERT 语句返回 (table_name, columns, values_list)
columns 是列名列表, values_list 是每行的值列表
"""
# 去掉块注释 /* ... */
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# 去掉行注释 --
sql_text = re.sub(r'--.*', '', sql_text)
results = []
# 匹配 INSERT INTO table (columns) VALUES (values), (values), ...;
pattern = re.compile(
r"INSERT\s+INTO\s+(\w+)\s*\(([^)]+)\)\s*VALUES\s*(.+?);",
re.DOTALL | re.IGNORECASE
)
for match in pattern.finditer(sql_text):
table_name = match.group(1).strip()
columns_str = match.group(2).strip()
values_str = match.group(3).strip()
# 提取列名(去掉注释部分)
columns = []
for col in columns_str.split(','):
col = col.strip()
# 去掉行内注释
col = re.sub(r'\s*--.*', '', col).strip()
columns.append(col)
# 解析 values - 处理多个用逗号分隔的 ( ... ) 元组
values_list = []
current_tuple = []
depth = 0
current_val = ''
in_string = False
string_char = None
for ch in values_str:
if in_string:
if ch == string_char:
in_string = False
current_val += ch
continue
if ch in ("'", '"'):
in_string = True
string_char = ch
current_val += ch
continue
if ch == '(':
if depth > 0:
current_val += ch
depth += 1
if depth == 1:
current_val = ''
continue
if ch == ')':
depth -= 1
if depth == 0:
current_tuple.append(current_val.strip())
values_list.append(current_tuple)
current_tuple = []
else:
current_val += ch
continue
if ch == ',' and depth == 1:
current_tuple.append(current_val.strip())
current_val = ''
continue
if depth >= 1:
current_val += ch
results.append((table_name, columns, values_list))
return results
def convert_value(val):
"""将 MySQL 值字符串转为 Python 对象"""
val = val.strip()
if val.upper() == 'NULL' or val == '':
return None
if val.upper() == 'TRUE':
return 1
if val.upper() == 'FALSE':
return 0
# 字符串
if (val.startswith("'") and val.endswith("'")) or \
(val.startswith('"') and val.endswith('"')):
return val[1:-1]
# 数字
try:
if '.' in val:
return float(val)
return int(val)
except ValueError:
return val
def import_sql_file(filepath):
"""导入单个 SQL 文件"""
logger.info(f"Reading {filepath}...")
with open(filepath, 'r', encoding='utf-8') as f:
sql_text = f.read()
parsed = parse_mysql_insert(sql_text)
total_rows = 0
with get_db_connection() as conn:
cursor = conn.cursor()
for table_name, columns, values_list in parsed:
if not values_list:
continue
# 构建参数化 INSERT
placeholders = ','.join(['?'] * len(columns))
col_names = ','.join(columns)
sql = f"INSERT OR IGNORE INTO {table_name} ({col_names}) VALUES ({placeholders})"
row_count = 0
for values in values_list:
if len(values) != len(columns):
logger.warning(f"Column mismatch in {table_name}: "
f"expected {len(columns)} values, got {len(values)}: {values}")
continue
converted = [convert_value(v) for v in values]
try:
cursor.execute(sql, converted)
if cursor.rowcount > 0:
row_count += 1
except Exception as e:
logger.warning(f"Error inserting into {table_name}: {e}")
logger.warning(f" Values: {converted}")
if row_count > 0:
logger.info(f" {table_name}: inserted {row_count} rows")
total_rows += row_count
conn.commit()
return total_rows
def run_manufacturer_update():
"""执行 manufacturer_id 更新(把 equipments.manufacturer 名称映射为 manufacturers.id"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE equipments
SET manufacturer_id = (
SELECT id FROM manufacturers WHERE name = equipments.manufacturer
)
WHERE manufacturer_id IS NULL
AND manufacturer IS NOT NULL
AND EXISTS (SELECT 1 FROM manufacturers WHERE name = equipments.manufacturer)
""")
conn.commit()
updated = cursor.rowcount
if updated > 0:
logger.info(f"Updated manufacturer_id for {updated} equipment(s)")
else:
logger.info("No manufacturer_id updates needed")
def import_all():
"""导入所有 SQL 数据文件"""
base_dir = os.path.dirname(os.path.dirname(__file__))
# 清除现有数据库,重新开始
if os.path.exists(DB_PATH):
os.remove(DB_PATH)
logger.info(f"Removed existing database: {DB_PATH}")
files = [
(os.path.join(base_dir, 'src', 'manufacturer_data.sql'), '生产商数据'),
(os.path.join(base_dir, 'src', 'rocket_artillery_data.sql'), '火箭炮数据'),
(os.path.join(base_dir, 'src', 'loitering_munition_data.sql'), '巡飞弹数据'),
]
total = 0
for filepath, label in files:
if not os.path.exists(filepath):
logger.warning(f"File not found: {filepath}")
continue
logger.info(f"正在导入 {label}...")
rows = import_sql_file(filepath)
logger.info(f" {label}: 共导入 {rows}")
total += rows
# 更新 manufacturer_id
run_manufacturer_update()
# 统计结果
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as cnt FROM equipments")
eq_count = cursor.fetchone()['cnt']
cursor.execute("SELECT type, COUNT(*) as cnt FROM equipments GROUP BY type")
by_type = {r['type']: r['cnt'] for r in cursor.fetchall()}
cursor.execute("SELECT COUNT(*) as cnt FROM manufacturers")
mf_count = cursor.fetchone()['cnt']
logger.info("=" * 50)
logger.info(f"导入完成!")
logger.info(f" 装备总计: {eq_count}")
for t, c in by_type.items():
logger.info(f" {t}: {c}")
logger.info(f" 生产商: {mf_count}")
logger.info("=" * 50)
return eq_count
if __name__ == '__main__':
import_all()

View File

@ -1,7 +1,4 @@
import numpy as np import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
@ -11,40 +8,43 @@ from datetime import datetime
import json import json
from src.feature_analysis import FeatureAnalysis from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection from src.database import get_db_connection
from src.data_preparation import DataPreparation, EquipmentDataset
from .logger import setup_logger from .logger import setup_logger
import math import math
# PyTorch 为可选依赖
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.data_preparation import EquipmentDataset
_HAS_TORCH = True
except ImportError:
torch = None
nn = None
DataLoader = None
EquipmentDataset = None
_HAS_TORCH = False
logger = setup_logger(__name__) logger = setup_logger(__name__)
class CostPredictionModel(nn.Module): # 条件基类:有 PyTorch 时继承 nn.Module否则继承 object
_PYTORCH_BASE = nn.Module if _HAS_TORCH else object
class CostPredictionModel(_PYTORCH_BASE):
def __init__(self, input_size, equipment_type): def __init__(self, input_size, equipment_type):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
super().__init__() super().__init__()
self.equipment_type = equipment_type self.equipment_type = equipment_type
if equipment_type == '火箭炮': if equipment_type == '火箭炮':
# 火箭炮使用更简单和稳定的网络结构
self.net = nn.Sequential( self.net = nn.Sequential(
# 第一层:特征映射 nn.Linear(input_size, 32), nn.ReLU(), nn.BatchNorm1d(32),
nn.Linear(input_size, 32), nn.Linear(32, 16), nn.ReLU(), nn.BatchNorm1d(16),
nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.BatchNorm1d(8),
nn.BatchNorm1d(32),
# 第二层:特征提取
nn.Linear(32, 16),
nn.ReLU(),
nn.BatchNorm1d(16),
# 第三层:特征整合
nn.Linear(16, 8),
nn.ReLU(),
nn.BatchNorm1d(8),
# 输出层
nn.Linear(8, 1) nn.Linear(8, 1)
) )
# 使用正交初始化
def init_weights(m): def init_weights(m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=0.5) torch.nn.init.orthogonal_(m.weight, gain=0.5)
@ -54,45 +54,19 @@ class CostPredictionModel(nn.Module):
torch.nn.init.constant_(m.bias, 0.0) torch.nn.init.constant_(m.bias, 0.0)
self.net.apply(init_weights) self.net.apply(init_weights)
else:
else: # 巡飞弹保持原有结构
# 生产商特征网络 - 更简单的结构
self.manufacturer_net = nn.Sequential( self.manufacturer_net = nn.Sequential(
nn.Linear(5, 4), nn.Linear(5, 4), nn.ReLU(), nn.BatchNorm1d(4), nn.Dropout(0.2)
nn.ReLU(),
nn.BatchNorm1d(4),
nn.Dropout(0.2)
) )
# 巡飞弹特征网络 - 较深的结构
self.equipment_net = nn.Sequential( self.equipment_net = nn.Sequential(
nn.Linear(input_size - 5, 64), nn.Linear(input_size - 5, 64), nn.LeakyReLU(0.1), nn.BatchNorm1d(64), nn.Dropout(0.2),
nn.LeakyReLU(0.1), nn.Linear(64, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
nn.BatchNorm1d(64), nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2)
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(32),
nn.Dropout(0.2),
nn.Linear(32, 16),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(16),
nn.Dropout(0.2)
) )
# 合并网络 - 较复杂的结构
self.combined_net = nn.Sequential( self.combined_net = nn.Sequential(
nn.Linear(20, 32), # 4 + 16 = 20 nn.Linear(20, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
nn.LeakyReLU(0.1), nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2),
nn.BatchNorm1d(32), nn.Linear(16, 8), nn.LeakyReLU(0.1), nn.BatchNorm1d(8),
nn.Dropout(0.2),
nn.Linear(32, 16),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(16),
nn.Dropout(0.2),
nn.Linear(16, 8),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(8),
nn.Linear(8, 1) nn.Linear(8, 1)
) )
@ -100,20 +74,17 @@ class CostPredictionModel(nn.Module):
if self.equipment_type == '火箭炮': if self.equipment_type == '火箭炮':
return self.net(x) return self.net(x)
else: else:
# 分离特征
manufacturer_features = x[:, -5:] manufacturer_features = x[:, -5:]
equipment_features = x[:, :-5] equipment_features = x[:, :-5]
# 特征处理
manu_out = self.manufacturer_net(manufacturer_features) manu_out = self.manufacturer_net(manufacturer_features)
equip_out = self.equipment_net(equipment_features) equip_out = self.equipment_net(equipment_features)
# 特征融合
combined = torch.cat([equip_out, manu_out], dim=1) combined = torch.cat([equip_out, manu_out], dim=1)
return self.combined_net(combined) return self.combined_net(combined)
class ModelTrainer: class ModelTrainer:
def __init__(self): def __init__(self):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None self.model = None
self.feature_scaler = None self.feature_scaler = None
@ -307,7 +278,7 @@ class ModelTrainer:
cursor.execute(""" cursor.execute("""
UPDATE trained_models UPDATE trained_models
SET is_active = FALSE SET is_active = FALSE
WHERE equipment_type = %s AND model_type != %s WHERE equipment_type = ? AND model_type != ?
""", (equipment_type, 'pls')) """, (equipment_type, 'pls'))
# 保存新模型记录 # 保存新模型记录
@ -316,7 +287,7 @@ class ModelTrainer:
model_name, model_type, equipment_type, model_path, model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by, scaler_path, training_date, is_active, created_by,
r2_score, mae, rmse r2_score, mae, rmse
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s) ) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
""", ( """, (
f"{equipment_type}_{timestamp}", f"{equipment_type}_{timestamp}",
'pytorch', 'pytorch',
@ -343,10 +314,10 @@ class ModelTrainer:
try: try:
# 从数据库获取最新的激活模型 # 从数据库获取最新的激活模型
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT * FROM trained_models SELECT * FROM trained_models
WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE WHERE equipment_type = ? AND model_type = ? AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1 ORDER BY training_date DESC LIMIT 1
""", (equipment_type, model_type)) """, (equipment_type, model_type))
model_record = cursor.fetchone() model_record = cursor.fetchone()
@ -707,13 +678,13 @@ class ModelTrainer:
cursor.execute(""" cursor.execute("""
UPDATE trained_models UPDATE trained_models
SET is_active = FALSE SET is_active = FALSE
WHERE equipment_type = %s AND model_type != %s WHERE equipment_type = ? AND model_type != ?
""", (equipment_type, 'pls')) """, (equipment_type, 'pls'))
else: else:
cursor.execute(""" cursor.execute("""
UPDATE trained_models UPDATE trained_models
SET is_active = FALSE SET is_active = FALSE
WHERE equipment_type = %s AND model_type = %s WHERE equipment_type = ? AND model_type = ?
""", (equipment_type, 'pls')) """, (equipment_type, 'pls'))
# 保存新模型记录 # 保存新模型记录
@ -722,7 +693,7 @@ class ModelTrainer:
model_name, model_type, equipment_type, model_path, model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by, scaler_path, training_date, is_active, created_by,
r2_score, mae, rmse r2_score, mae, rmse
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s) ) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
""", ( """, (
f"{equipment_type}_{model_type}_{timestamp}", f"{equipment_type}_{model_type}_{timestamp}",
model_type, model_type,

View File

@ -4,16 +4,23 @@ from .feature_analysis import FeatureAnalysis
import pandas as pd import pandas as pd
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
import mysql.connector
from sklearn.metrics import mean_absolute_error from sklearn.metrics import mean_absolute_error
from .create_template import create_excel_template from .create_template import create_excel_template
import json import json
import os import os
from .data_preparation import DataPreparation from .data_preparation import DataPreparation
from .model_trainer import ModelTrainer from .model_trainer import ModelTrainer
from .database import get_db_connection
from .demo_service import DemoModelService from .demo_service import DemoModelService
from .logger import setup_logger from .logger import setup_logger
# PyTorch 为可选导入
try:
import torch import torch
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
# 创建蓝图 # 创建蓝图
api_bp = Blueprint('api', __name__) api_bp = Blueprint('api', __name__)
@ -71,7 +78,7 @@ def demo_algorithms():
@api_bp.route('/demo/dataset', methods=['GET']) @api_bp.route('/demo/dataset', methods=['GET'])
def demo_dataset(): def demo_dataset():
"""Return the local demo dataset summary without using MySQL.""" """Return the local demo dataset summary without using a database."""
try: try:
service = DemoModelService() service = DemoModelService()
return jsonify(service.get_dataset_summary()) return jsonify(service.get_dataset_summary())
@ -101,11 +108,11 @@ def predict():
# 获取最新的激活模型非PLS模型 # 获取最新的激活模型非PLS模型
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT * FROM trained_models SELECT * FROM trained_models
WHERE equipment_type = %s WHERE equipment_type = ?
AND model_type != 'pls' # 明确排除PLS模型 AND model_type != 'pls'
AND is_active = TRUE AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1 ORDER BY training_date DESC LIMIT 1
""", (equipment_type,)) """, (equipment_type,))
@ -147,14 +154,14 @@ def analyze_features():
return jsonify({'error': '请选择数据集'}), 400 return jsonify({'error': '请选择数据集'}), 400
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 首先获取数据集的装备类型 # 首先获取数据集的装备类型
cursor.execute(""" cursor.execute("""
SELECT DISTINCT e.type SELECT DISTINCT e.type
FROM equipments e FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id JOIN dataset_equipments de ON e.id = de.equipment_id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
LIMIT 1 LIMIT 1
""", (dataset_id,)) """, (dataset_id,))
@ -181,7 +188,7 @@ def analyze_features():
LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL AND cd.actual_cost IS NOT NULL
""", (dataset_id,)) """, (dataset_id,))
else: else:
@ -205,7 +212,7 @@ def analyze_features():
LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL AND cd.actual_cost IS NOT NULL
""", (dataset_id,)) """, (dataset_id,))
@ -247,27 +254,27 @@ def analyze_features():
# 添加装备特有的分析数据 # 添加装备特有的分析数据
if equipment_data[0]['type'] == '火箭炮': if equipment_data[0]['type'] == '火箭炮':
rocket_data = { rocket_data = {
'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data], 'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data], 'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data], 'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data], 'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data], 'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data] 'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data]
} }
analysis_result.update(rocket_data) analysis_result.update(rocket_data)
else: else:
missile_data = { missile_data = {
'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data], 'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data],
'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data], 'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data],
'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data], 'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data],
'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data], 'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data],
'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data], 'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data],
'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data], 'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item['guidance_accuracy_m'] is not None else 0 for item in equipment_data],
'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data], 'datalink_range_km': [float(item['datalink_range_km']) if item['datalink_range_km'] is not None else 0 for item in equipment_data],
'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data], 'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data],
'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data], 'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data],
'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data], 'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data],
'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data] 'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data]
} }
analysis_result.update(missile_data) analysis_result.update(missile_data)
@ -295,7 +302,7 @@ def train_model():
# 获取训练数据 # 获取训练数据
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 获取训练集数据(包含生产商信息) # 获取训练集数据(包含生产商信息)
if equipment_type == '火箭炮': if equipment_type == '火箭炮':
@ -309,7 +316,7 @@ def train_model():
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,)) """, (train_dataset_id,))
else: else:
@ -323,7 +330,7 @@ def train_model():
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,)) """, (train_dataset_id,))
@ -342,7 +349,7 @@ def train_model():
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,)) """, (validation_dataset_id,))
else: else:
@ -356,7 +363,7 @@ def train_model():
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,)) """, (validation_dataset_id,))
validation_data = cursor.fetchall() validation_data = cursor.fetchall()
@ -476,7 +483,7 @@ def get_equipment_data():
"""获取装备数据列表""" """获取装备数据列表"""
try: try:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 获取所有装备数据(使用equipment_id替代id) # 获取所有装备数据(使用equipment_id替代id)
cursor.execute(""" cursor.execute("""
@ -485,43 +492,39 @@ def get_equipment_data():
cd.actual_cost, cd.predicted_cost, cd.actual_cost, cd.predicted_cost,
CASE CASE
WHEN e.type = '火箭炮' THEN ( WHEN e.type = '火箭炮' THEN (
SELECT CONCAT( SELECT firing_angle_horizontal || ',' ||
firing_angle_horizontal, ',', firing_angle_vertical || ',' ||
firing_angle_vertical, ',', rocket_length_m || ',' ||
rocket_length_m, ',', rocket_diameter_mm || ',' ||
rocket_diameter_mm, ',', rocket_weight_kg || ',' ||
rocket_weight_kg, ',', rate_of_fire || ',' ||
rate_of_fire, ',', combat_weight_kg || ',' ||
combat_weight_kg, ',', speed_kmh || ',' ||
speed_kmh, ',', min_range_km || ',' ||
min_range_km, ',', max_range_km || ',' ||
max_range_km, ',', mobility_type || ',' ||
mobility_type, ',', structure_layout || ',' ||
structure_layout, ',', engine_model || ',' ||
engine_model, ',', engine_params || ',' ||
engine_params, ',', power_hp || ',' ||
power_hp, ',',
travel_range_km travel_range_km
)
FROM rocket_artillery_params FROM rocket_artillery_params
WHERE equipment_id = e.id WHERE equipment_id = e.id
) )
WHEN e.type = '巡飞弹' THEN ( WHEN e.type = '巡飞弹' THEN (
SELECT CONCAT( SELECT wingspan_m || ',' ||
wingspan_m, ',', warhead_weight_kg || ',' ||
warhead_weight_kg, ',', max_speed_ms || ',' ||
max_speed_ms, ',', cruise_speed_kmh || ',' ||
cruise_speed_kmh, ',', endurance_min || ',' ||
endurance_min, ',', max_range_km || ',' ||
max_range_km, ',', max_payload_kg || ',' ||
max_payload_kg, ',', ceiling_altitude_m || ',' ||
ceiling_altitude_m, ',', combat_radius_km || ',' ||
combat_radius_km, ',', warhead_type || ',' ||
warhead_type, ',', launch_mode || ',' ||
launch_mode, ',', power_system || ',' ||
power_system, ',',
guidance_system guidance_system
)
FROM loitering_munition_params FROM loitering_munition_params
WHERE equipment_id = e.id WHERE equipment_id = e.id
) )
@ -548,19 +551,17 @@ def delete_equipment(id):
删除装备数据 删除装备数据
""" """
try: try:
db = get_db_connection() with get_db_connection() as conn:
cursor = db.cursor() cursor = conn.cursor()
# 删除相关数据 # 删除相关数据
cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,)) cursor.execute("DELETE FROM cost_data WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,)) cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,)) cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,)) cursor.execute("DELETE FROM common_params WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM equipments WHERE id = %s", (id,)) cursor.execute("DELETE FROM equipments WHERE id = ?", (id,))
db.commit() conn.commit()
cursor.close()
db.close()
return jsonify({'status': 'success'}) return jsonify({'status': 'success'})
@ -594,16 +595,6 @@ def download_template():
logger.error(f"Error creating template: {str(e)}") logger.error(f"Error creating template: {str(e)}")
return jsonify({'error': str(e)}), 500 return jsonify({'error': str(e)}), 500
def get_db_connection():
"""
获取数据库连接
"""
return mysql.connector.connect(
host="localhost",
user="root",
password="123456",
database="equipment_cost_db"
)
@api_bp.route('/pls/predict', methods=['POST']) @api_bp.route('/pls/predict', methods=['POST'])
def pls_predict(): def pls_predict():
@ -614,11 +605,11 @@ def pls_predict():
# 获取最新的PLS模型 # 获取最新的PLS模型
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT * FROM trained_models SELECT * FROM trained_models
WHERE equipment_type = %s WHERE equipment_type = ?
AND model_type = 'pls' # 只选择PLS模型 AND model_type = 'pls'
AND is_active = TRUE AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1 ORDER BY training_date DESC LIMIT 1
""", (equipment_type,)) """, (equipment_type,))
@ -696,38 +687,38 @@ def update_equipment(id):
# 更新装备基本信息 # 更新装备基本信息
cursor.execute(""" cursor.execute("""
UPDATE equipments UPDATE equipments
SET name = %s, manufacturer = %s SET name = ?, manufacturer = ?
WHERE id = %s WHERE id = ?
""", (data['name'], data['manufacturer'], equipment_id)) """, (data['name'], data['manufacturer'], equipment_id))
# 更新通用参数 # 更新通用参数
cursor.execute(""" cursor.execute("""
UPDATE common_params UPDATE common_params
SET length_m = %s, width_m = %s, height_m = %s, weight_kg = %s SET length_m = ?, width_m = ?, height_m = ?, weight_kg = ?
WHERE equipment_id = %s WHERE equipment_id = ?
""", (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id)) """, (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id))
# 根据装备类型更新特有参数 # 根据装备类型更新特有参数
if data['type'] == '火箭炮': if data['type'] == '火箭炮':
cursor.execute(""" cursor.execute("""
UPDATE rocket_artillery_params UPDATE rocket_artillery_params
SET firing_angle_horizontal = %s, SET firing_angle_horizontal = ?,
firing_angle_vertical = %s, firing_angle_vertical = ?,
rocket_length_m = %s, rocket_length_m = ?,
rocket_diameter_mm = %s, rocket_diameter_mm = ?,
rocket_weight_kg = %s, rocket_weight_kg = ?,
rate_of_fire = %s, rate_of_fire = ?,
combat_weight_kg = %s, combat_weight_kg = ?,
speed_kmh = %s, speed_kmh = ?,
min_range_km = %s, min_range_km = ?,
max_range_km = %s, max_range_km = ?,
mobility_type = %s, mobility_type = ?,
structure_layout = %s, structure_layout = ?,
engine_model = %s, engine_model = ?,
engine_params = %s, engine_params = ?,
power_hp = %s, power_hp = ?,
travel_range_km = %s travel_range_km = ?
WHERE equipment_id = %s WHERE equipment_id = ?
""", ( """, (
data['firing_angle_horizontal'], data['firing_angle_horizontal'],
data['firing_angle_vertical'], data['firing_angle_vertical'],
@ -750,17 +741,17 @@ def update_equipment(id):
else: else:
cursor.execute(""" cursor.execute("""
UPDATE loitering_munition_params UPDATE loitering_munition_params
SET wingspan_m = %s, SET wingspan_m = ?,
warhead_weight_kg = %s, warhead_weight_kg = ?,
max_speed_ms = %s, max_speed_ms = ?,
cruise_speed_kmh = %s, cruise_speed_kmh = ?,
endurance_min = %s, endurance_min = ?,
max_range_km = %s, max_range_km = ?,
warhead_type = %s, warhead_type = ?,
launch_mode = %s, launch_mode = ?,
power_system = %s, power_system = ?,
guidance_system = %s guidance_system = ?
WHERE equipment_id = %s WHERE equipment_id = ?
""", ( """, (
data['wingspan_m'], data['wingspan_m'],
data['warhead_weight_kg'], data['warhead_weight_kg'],
@ -779,8 +770,8 @@ def update_equipment(id):
if 'actual_cost' in data: if 'actual_cost' in data:
cursor.execute(""" cursor.execute("""
UPDATE cost_data UPDATE cost_data
SET actual_cost = %s SET actual_cost = ?
WHERE equipment_id = %s WHERE equipment_id = ?
""", (data['actual_cost'], equipment_id)) """, (data['actual_cost'], equipment_id))
conn.commit() conn.commit()
@ -798,7 +789,7 @@ def get_equipment_details(id):
logger.info(f"Getting details for equipment ID: {id}") logger.info(f"Getting details for equipment ID: {id}")
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 获取装备基本信息类型 # 获取装备基本信息类型
cursor.execute(""" cursor.execute("""
@ -806,7 +797,7 @@ def get_equipment_details(id):
FROM equipments e FROM equipments e
LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.id = %s WHERE e.id = ?
""", (id,)) """, (id,))
result = cursor.fetchone() result = cursor.fetchone()
@ -822,14 +813,14 @@ def get_equipment_details(id):
cursor.execute(""" cursor.execute("""
SELECT * SELECT *
FROM rocket_artillery_params FROM rocket_artillery_params
WHERE equipment_id = %s WHERE equipment_id = ?
""", (id,)) """, (id,))
custom_params = cursor.fetchone() custom_params = cursor.fetchone()
else: else:
cursor.execute(""" cursor.execute("""
SELECT * SELECT *
FROM loitering_munition_params FROM loitering_munition_params
WHERE equipment_id = %s WHERE equipment_id = ?
""", (id,)) """, (id,))
custom_params = cursor.fetchone() custom_params = cursor.fetchone()
@ -853,7 +844,7 @@ def get_datasets():
""" """
try: try:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT d.*, SELECT d.*,
COUNT(de.equipment_id) as equipment_count, COUNT(de.equipment_id) as equipment_count,
@ -872,12 +863,7 @@ def get_datasets():
else: else:
dataset['equipment_names'] = [] dataset['equipment_names'] = []
# 格式化时间和数值字段 # SQLite 已存储日期为文本格式,无需转换
for dataset in datasets:
if dataset['created_at']:
dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S')
if dataset['updated_at']:
dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S')
return jsonify(datasets) return jsonify(datasets)
except Exception as e: except Exception as e:
@ -891,14 +877,14 @@ def get_dataset(id):
""" """
try: try:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 获取数据集基本信息 # 获取数据集基本信息
cursor.execute(""" cursor.execute("""
SELECT d.*, SELECT d.*,
COUNT(de.equipment_id) as equipment_count COUNT(de.equipment_id) as equipment_count
FROM datasets d FROM datasets d
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
WHERE d.id = %s WHERE d.id = ?
GROUP BY d.id GROUP BY d.id
""", (id,)) """, (id,))
dataset = cursor.fetchone() dataset = cursor.fetchone()
@ -913,7 +899,7 @@ def get_dataset(id):
FROM equipments e FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
""", (id,)) """, (id,))
equipment = cursor.fetchall() equipment = cursor.fetchall()
@ -952,14 +938,14 @@ def create_dataset():
# 1. 验证装备ID是否存在 # 1. 验证装备ID是否存在
if 'equipment_ids' in data and data['equipment_ids']: if 'equipment_ids' in data and data['equipment_ids']:
# 直接从 equipment 表查询,不需要 JOIN ids = data['equipment_ids']
equipment_ids_str = ','.join(map(str, data['equipment_ids'])) placeholders = ','.join(['?'] * len(ids))
cursor.execute(f""" cursor.execute(f"""
SELECT DISTINCT id FROM equipments SELECT DISTINCT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s WHERE id IN ({placeholders}) AND type = ?
""", (data['equipment_type'],)) """, (*ids, data['equipment_type']))
valid_ids = [row[0] for row in cursor.fetchall()] valid_ids = [row['id'] for row in cursor.fetchall()]
logger.info(f"Valid equipment IDs: {valid_ids}") logger.info(f"Valid equipment IDs: {valid_ids}")
# 如果有无效的ID返回错误 # 如果有无效的ID返回错误
@ -971,7 +957,7 @@ def create_dataset():
# 2. 创建数据集 # 2. 创建数据集
cursor.execute(""" cursor.execute("""
INSERT INTO datasets (name, description, equipment_type, purpose) INSERT INTO datasets (name, description, equipment_type, purpose)
VALUES (%s, %s, %s, %s) VALUES (?, ?, ?, ?)
""", (data['name'], data['description'], data['equipment_type'], data['purpose'])) """, (data['name'], data['description'], data['equipment_type'], data['purpose']))
dataset_id = cursor.lastrowid dataset_id = cursor.lastrowid
@ -982,7 +968,7 @@ def create_dataset():
values = [(dataset_id, equipment_id) for equipment_id in valid_ids] values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
cursor.executemany(""" cursor.executemany("""
INSERT INTO dataset_equipments (dataset_id, equipment_id) INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s) VALUES (?, ?)
""", values) """, values)
logger.info(f"Added {len(values)} equipment associations") logger.info(f"Added {len(values)} equipment associations")
@ -1004,14 +990,15 @@ def update_dataset(id):
cursor = conn.cursor() cursor = conn.cursor()
# 1. 验证装备ID是否存在 # 1. 验证装备ID是否存在
if 'equipment_ids' in data: if 'equipment_ids' in data and data['equipment_ids']:
equipment_ids_str = ','.join(map(str, data['equipment_ids'])) ids = data['equipment_ids']
placeholders = ','.join(['?'] * len(ids))
cursor.execute(f""" cursor.execute(f"""
SELECT id FROM equipments SELECT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s WHERE id IN ({placeholders}) AND type = ?
""", (data['equipment_type'],)) """, (*ids, data['equipment_type']))
valid_ids = [row[0] for row in cursor.fetchall()] valid_ids = [row['id'] for row in cursor.fetchall()]
logger.info(f"Valid equipment IDs: {valid_ids}") logger.info(f"Valid equipment IDs: {valid_ids}")
# 如果有无效的ID返回错误 # 如果有无效的ID返回错误
@ -1023,21 +1010,21 @@ def update_dataset(id):
# 2. 更新数据集基本信息 # 2. 更新数据集基本信息
cursor.execute(""" cursor.execute("""
UPDATE datasets UPDATE datasets
SET name = %s, description = %s, equipment_type = %s, purpose = %s SET name = ?, description = ?, equipment_type = ?, purpose = ?
WHERE id = %s WHERE id = ?
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id)) """, (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
# 3. 更新装备关联 # 3. 更新装备关联
if 'equipment_ids' in data: if 'equipment_ids' in data:
# 先删除旧的关联 # 先删除旧的关联
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,)) cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
# 添加新的关联 # 添加新的关联
if valid_ids: # 确保有有效的ID才执行插入 if valid_ids: # 确保有有效的ID才执行插入
values = [(id, equipment_id) for equipment_id in valid_ids] values = [(id, equipment_id) for equipment_id in valid_ids]
cursor.executemany(""" cursor.executemany("""
INSERT INTO dataset_equipments (dataset_id, equipment_id) INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s) VALUES (?, ?)
""", values) """, values)
logger.info(f"Updated {len(values)} equipment associations") logger.info(f"Updated {len(values)} equipment associations")
@ -1058,10 +1045,10 @@ def delete_dataset(id):
cursor = conn.cursor() cursor = conn.cursor()
# 删除装备关联 # 删除装备关联
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,)) cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
# 删除数据集 # 删除数据集
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,)) cursor.execute("DELETE FROM datasets WHERE id = ?", (id,))
conn.commit() conn.commit()
return jsonify({'success': True}) return jsonify({'success': True})
@ -1076,10 +1063,10 @@ def get_latest_model(equipment_type):
""" """
try: try:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT * FROM trained_models SELECT * FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE WHERE equipment_type = ? AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1 ORDER BY training_date DESC LIMIT 1
""", (equipment_type,)) """, (equipment_type,))
@ -1095,7 +1082,7 @@ def get_models():
"""获取模型列表""" """获取模型列表"""
try: try:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT * FROM trained_models SELECT * FROM trained_models
ORDER BY training_date DESC ORDER BY training_date DESC
@ -1105,9 +1092,7 @@ def get_models():
# 格式化时间和数值字段 # 格式化时间和数值字段
for model in models: for model in models:
# 将数据库中的datetime转换为ISO格式字符串 # SQLite 已存储日期为文本格式,无需转换
if model['training_date']:
model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S')
if model['r2_score'] is not None: if model['r2_score'] is not None:
model['r2_score'] = float(model['r2_score']) model['r2_score'] = float(model['r2_score'])
if model['mae'] is not None: if model['mae'] is not None:
@ -1130,12 +1115,12 @@ def activate_model(id):
"""激活指定的模型""" """激活指定的模型"""
try: try:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) # 使用字典游标 cursor = conn.cursor() # 使用字典游标
# 获取模型信息 # 获取模型信息
cursor.execute(""" cursor.execute("""
SELECT equipment_type, model_type FROM trained_models SELECT equipment_type, model_type FROM trained_models
WHERE id = %s WHERE id = ?
""", (id,)) """, (id,))
model = cursor.fetchone() model = cursor.fetchone()
@ -1146,14 +1131,14 @@ def activate_model(id):
cursor.execute(""" cursor.execute("""
UPDATE trained_models UPDATE trained_models
SET is_active = FALSE SET is_active = FALSE
WHERE equipment_type = %s AND model_type = %s WHERE equipment_type = ? AND model_type = ?
""", (model['equipment_type'], model['model_type'])) """, (model['equipment_type'], model['model_type']))
# 激活指定模型 # 激活指定模型
cursor.execute(""" cursor.execute("""
UPDATE trained_models UPDATE trained_models
SET is_active = TRUE SET is_active = TRUE
WHERE id = %s WHERE id = ?
""", (id,)) """, (id,))
conn.commit() conn.commit()
@ -1176,21 +1161,21 @@ def delete_model(id):
cursor.execute(""" cursor.execute("""
SELECT model_path, scaler_path SELECT model_path, scaler_path
FROM trained_models FROM trained_models
WHERE id = %s WHERE id = ?
""", (id,)) """, (id,))
model = cursor.fetchone() model = cursor.fetchone()
if not model: if not model:
return jsonify({'error': 'Model not found'}), 404 return jsonify({'error': 'Model not found'}), 404
# 删除模型 # 删除模型
if os.path.exists(model[0]): if os.path.exists(model['model_path']):
os.remove(model[0]) os.remove(model['model_path'])
if os.path.exists(model[1]): if os.path.exists(model['scaler_path']):
os.remove(model[1]) os.remove(model['scaler_path'])
# 删除数据库记录 # 删除数据库记录
cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,)) cursor.execute("DELETE FROM trained_models WHERE id = ?", (id,))
conn.commit() conn.commit()
return jsonify({'success': True}) return jsonify({'success': True})
@ -1231,7 +1216,7 @@ def analyze_manufacturers():
return jsonify({'error': '请选择数据集'}), 400 return jsonify({'error': '请选择数据集'}), 400
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor()
# 获取数据集中的装备和生产商数据 # 获取数据集中的装备和生产商数据
cursor.execute(""" cursor.execute("""
@ -1239,13 +1224,22 @@ def analyze_manufacturers():
FROM manufacturers m FROM manufacturers m
JOIN equipments e ON e.manufacturer_id = m.id JOIN equipments e ON e.manufacturer_id = m.id
JOIN dataset_equipments de ON e.id = de.equipment_id JOIN dataset_equipments de ON e.id = de.equipment_id
WHERE de.dataset_id = %s WHERE de.dataset_id = ?
""", (dataset_id,)) """, (dataset_id,))
manufacturers = cursor.fetchall() manufacturers = cursor.fetchall()
if not manufacturers: if not manufacturers:
return jsonify({'error': '数据集中没有生产商数据'}), 404 logger.info("No manufacturer data in dataset, returning empty result")
return jsonify({
'manufacturer_names': [],
'manufacturer_tech_levels': [],
'manufacturer_scale_levels': [],
'manufacturer_supply_chain_levels': [],
'manufacturer_composite_scores': [],
'region_distribution': [],
'manufacturer_scores': []
})
# 准备分析数据 # 准备分析数据
manufacturer_names = [] manufacturer_names = []

View File

@ -1,9 +1,5 @@
@echo off @echo off
set FLASK_DEBUG=false set FLASK_DEBUG=false
set MYSQL_HOST=localhost
set MYSQL_USER=root
set MYSQL_PASSWORD=123456
set MYSQL_DB=equipment_cost_db
echo Starting Cost Prediction System... echo Starting Cost Prediction System...
start /B run.exe start /B run.exe
start http://localhost:5001 start http://localhost:5001