将mysql改成sqlite,减少前端依赖
This commit is contained in:
parent
137451ba7a
commit
48ba547c36
29
.claude/settings.local.json
Normal file
29
.claude/settings.local.json
Normal file
File diff suppressed because one or more lines are too long
108
CLAUDE.md
Normal file
108
CLAUDE.md
Normal file
@ -0,0 +1,108 @@
|
||||
# CLAUDE.md
|
||||
|
||||
本文件为 Claude Code(claude.ai/code)在此仓库中工作提供指引。
|
||||
|
||||
## 常用命令
|
||||
|
||||
```bash
|
||||
# 安装后端依赖(核心:无需 MySQL、无需 PyTorch)
|
||||
pip install -e .
|
||||
pip install -e ".[dev]" # 含开发工具(pytest, black, mypy)
|
||||
pip install -e ".[torch]" # 安装可选 PyTorch(神经网络训练需要)
|
||||
|
||||
# 启动后端(Flask,0.0.0.0:5001)
|
||||
python run.py
|
||||
|
||||
# 启动前端开发服务器(另开终端,localhost:3000)
|
||||
cd frontend && npm install && npm run serve
|
||||
|
||||
# 构建前端生产版本
|
||||
cd frontend && npm run build
|
||||
|
||||
# 运行测试
|
||||
python -m pytest tests/
|
||||
python -m pytest tests/test_demo_service.py -q # 单个测试文件
|
||||
python -m pytest tests/test_demo_routes.py -q
|
||||
|
||||
# 代码格式化 / 类型检查
|
||||
black src/ tests/ # 自动格式化(line-length=88)
|
||||
mypy src/ # 类型检查
|
||||
|
||||
# 运行独立演示模式(仅需 5 个依赖)
|
||||
cd demo_standalone && pip install -r requirements.txt && python server.py
|
||||
```
|
||||
|
||||
## 项目概述
|
||||
|
||||
基于机器学习的装备成本预测系统。支持两种装备类型:**火箭炮**和**巡飞弹**。
|
||||
|
||||
### 技术栈
|
||||
|
||||
- **后端**:Python 3.9-3.11, Flask 3.1+
|
||||
- **数据库**:SQLite(Python 内置,零外部依赖),首次启动自动建表,无需手动安装配置
|
||||
- **机器学习**:scikit-learn(核心)、XGBoost、LightGBM
|
||||
- **可选依赖**:PyTorch(仅神经网络训练需要,约 800MB)
|
||||
- **前端**:Vue 3 (Composition API) + Element Plus + ECharts, 使用 Vite 构建
|
||||
|
||||
### 关键文件
|
||||
|
||||
| 文件 | 用途 |
|
||||
|---|---|
|
||||
| `run.py` | 入口点,启动 Flask 服务器 |
|
||||
| `src/app.py` | Flask 应用工厂 |
|
||||
| `src/routes.py` | 所有 API 路由(约 1300 行) |
|
||||
| `src/model_trainer.py` | ModelTrainer 类 + CostPredictionModel(PyTorch NN) |
|
||||
| `src/data_preparation.py` | DataPreparation + EquipmentDataset |
|
||||
| `src/cost_prediction.py` | CostPredictor(预测编排) |
|
||||
| `src/demo_service.py` | DemoModelService(基于 CSV,无需数据库) |
|
||||
| `src/database/db_connection.py` | SQLite 数据库连接 + 建表 DDL(内置) |
|
||||
| `config.py` | 运行时配置 |
|
||||
| `frontend/src/router/index.js` | Vue Router,共 8 个路由 |
|
||||
| `frontend/src/api/index.js` | Axios API 客户端 |
|
||||
|
||||
## 架构
|
||||
|
||||
### 与旧架构的核心变化
|
||||
|
||||
| 项目 | 旧架构(MySQL) | 新架构(SQLite) |
|
||||
|------|----------------|-----------------|
|
||||
| 数据库 | MySQL 8.0+,需单独安装运行 | SQLite,Python 内置,零依赖 |
|
||||
| 数据库依赖 | sqlalchemy, pymysql, cryptography, mysql-connector-python | 无(仅用 Python 标准库 sqlite3) |
|
||||
| PyTorch | 硬依赖(顶层 import,无则崩溃) | 可选依赖(try/except 保护,无 PyTorch 可启动) |
|
||||
| 前端构建 | Vue CLI + Babel + SCSS | Vite(无需 Babel/SCSS) |
|
||||
| Vuex | 存在但完全空 | 已移除 |
|
||||
| 依赖总数(核心)| 15+ | 5(flask, numpy, pandas, scikit-learn, openpyxl) |
|
||||
|
||||
### 训练数据流
|
||||
```
|
||||
用户界面 → POST /api/train → 查询 SQLite → DataPreparation(特征提取 + 标准化)
|
||||
→ ModelTrainer.fit_model()(训练 XGBoost, LightGBM, RF, GBM, PyTorch NN, PLS)
|
||||
→ 保存最优模型 → 写入 SQLite trained_models 表
|
||||
```
|
||||
|
||||
### 预测数据流
|
||||
```
|
||||
用户界面 → POST /api/predict → 从 SQLite 加载最优模型 + 标准化器 → 提取特征
|
||||
→ 特征标准化 → 预测 → ±20% 置信区间 → JSON 返回
|
||||
```
|
||||
|
||||
### 机器学习模型
|
||||
|
||||
| 模型 | 标识 | 说明 |
|
||||
|---|---|---|
|
||||
| XGBoost | `xgboost` | 针对小数据量使用保守参数 |
|
||||
| LightGBM | `lightgbm` | 针对小数据量使用保守参数 |
|
||||
| Random Forest | `rf` | sklearn.ensemble |
|
||||
| Gradient Boosting | `gbm` | sklearn.ensemble |
|
||||
| PyTorch NN | `pytorch` | 可选(需安装 torch),针对不同装备类型定制网络结构 |
|
||||
| PLS 回归 | `pls` | 不参与最优模型评选 |
|
||||
| Linear/Ridge/SVR/KNN | 仅演示 | 用于算法对比演示 |
|
||||
|
||||
## 重要注意事项
|
||||
|
||||
- **小数据量机器学习**:所有模型超参数均为保守设置(强正则化、浅树、低学习率、早停)
|
||||
- **两种装备类型**:火箭炮(27+ 特征)和巡飞弹(24+ 特征),各自有独立数据库表和神经网络结构
|
||||
- **生产商议价能力特征**:技术等级、规模、供应链、地区等信息被纳入特征,地区成本乘数(如美国 1.2 倍、中国 0.8 倍)
|
||||
- **SQLite 首次使用**:`data/equipment_cost.db` 文件在首次数据库操作时自动创建,无需手动初始化
|
||||
- **编码规范**:UTF-8 无 BOM,LF 换行符,文件末尾保留换行符
|
||||
- **中文内容**:绝不修改中文注释或文本。编辑中文附近代码时,完整保留原有中文内容
|
||||
@ -2,11 +2,8 @@ import os
|
||||
|
||||
class Config:
|
||||
"""配置类"""
|
||||
# 数据库配置
|
||||
MYSQL_HOST = os.getenv('MYSQL_HOST', 'localhost')
|
||||
MYSQL_USER = os.getenv('MYSQL_USER', 'root')
|
||||
MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '123456')
|
||||
MYSQL_DB = os.getenv('MYSQL_DB', 'equipment_cost_db')
|
||||
# 数据库配置(使用 SQLite)
|
||||
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
|
||||
|
||||
# Flask配置
|
||||
FLASK_HOST = '0.0.0.0'
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
module.exports = {
|
||||
presets: [
|
||||
'@vue/cli-plugin-babel/preset'
|
||||
]
|
||||
}
|
||||
17
frontend/index.html
Normal file
17
frontend/index.html
Normal file
@ -0,0 +1,17 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||
<link rel="icon" href="/favicon.ico">
|
||||
<title>装备成本估算系统</title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>
|
||||
<strong>装备成本估算系统需要启用 JavaScript 才能运行。</strong>
|
||||
</noscript>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
11130
frontend/package-lock.json
generated
11130
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -7,33 +7,24 @@
|
||||
"npm": ">=8"
|
||||
},
|
||||
"scripts": {
|
||||
"serve": "vue-cli-service serve",
|
||||
"build": "vue-cli-service build",
|
||||
"lint": "vue-cli-service lint"
|
||||
"serve": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint --ext .js,.vue src/"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.6.0",
|
||||
"core-js": "^3.8.3",
|
||||
"echarts": "^5.4.3",
|
||||
"element-plus": "^2.4.2",
|
||||
"vue": "^3.2.13",
|
||||
"vue-router": "^4.0.3",
|
||||
"vuex": "^4.0.0"
|
||||
"vue-router": "^4.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.12.16",
|
||||
"@babel/eslint-parser": "^7.12.16",
|
||||
"@element-plus/icons-vue": "^2.3.1",
|
||||
"@vue/cli-plugin-babel": "~5.0.0",
|
||||
"@vue/cli-plugin-eslint": "~5.0.0",
|
||||
"@vue/cli-plugin-router": "~5.0.0",
|
||||
"@vue/cli-plugin-vuex": "~5.0.0",
|
||||
"@vue/cli-service": "~5.0.0",
|
||||
"@vue/compiler-sfc": "^3.2.13",
|
||||
"@vitejs/plugin-vue": "^5.0.0",
|
||||
"eslint": "^7.32.0",
|
||||
"eslint-plugin-vue": "^8.0.3",
|
||||
"sass": "^1.32.7",
|
||||
"sass-loader": "^12.0.0"
|
||||
"sass-embedded": "^1.99.0",
|
||||
"vite": "^5.0.0"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"root": true,
|
||||
@ -45,17 +36,12 @@
|
||||
"eslint:recommended"
|
||||
],
|
||||
"parserOptions": {
|
||||
"parser": "@babel/eslint-parser"
|
||||
"ecmaVersion": "latest",
|
||||
"sourceType": "module"
|
||||
},
|
||||
"rules": {
|
||||
"vue/multi-word-component-names": "off",
|
||||
"no-unused-vars": "warn"
|
||||
}
|
||||
},
|
||||
"browserslist": [
|
||||
"> 1%",
|
||||
"last 2 versions",
|
||||
"not dead",
|
||||
"not ie 11"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
|
||||
<title><%= htmlWebpackPlugin.options.title %></title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>
|
||||
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
|
||||
</noscript>
|
||||
<div id="app"></div>
|
||||
<!-- built files will be auto injected -->
|
||||
</body>
|
||||
</html>
|
||||
@ -1,7 +1,6 @@
|
||||
import { createApp } from 'vue'
|
||||
import App from './App.vue'
|
||||
import router from './router'
|
||||
import store from './store'
|
||||
import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import './assets/styles/global.css'
|
||||
@ -15,7 +14,6 @@ app.use(ElementPlus, {
|
||||
size: 'default'
|
||||
})
|
||||
app.use(router)
|
||||
app.use(store)
|
||||
|
||||
// 注册图标
|
||||
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
import { createStore } from 'vuex'
|
||||
|
||||
export default createStore({
|
||||
state: {
|
||||
},
|
||||
getters: {
|
||||
},
|
||||
mutations: {
|
||||
},
|
||||
actions: {
|
||||
},
|
||||
modules: {
|
||||
}
|
||||
})
|
||||
@ -19,7 +19,7 @@ export default defineConfig({
|
||||
port: 3000,
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:5000',
|
||||
target: 'http://localhost:5001',
|
||||
changeOrigin: true
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,12 +11,6 @@ dependencies = [
|
||||
"flask>=3.1.0",
|
||||
"flask-cors>=5.0.0",
|
||||
|
||||
# 数据库
|
||||
"sqlalchemy>=2.0.36",
|
||||
"pymysql>=1.1.1",
|
||||
"cryptography>=43.0.0",
|
||||
"mysql-connector-python>=8.0.0",
|
||||
|
||||
# 数据处理
|
||||
"numpy>=1.26.0,<2.0.0",
|
||||
"pandas>=2.2.0",
|
||||
@ -25,7 +19,6 @@ dependencies = [
|
||||
"scikit-learn>=1.5.2",
|
||||
"xgboost>=2.1.0",
|
||||
"lightgbm>=4.5.0",
|
||||
"torch==2.5.1",
|
||||
|
||||
# 工具
|
||||
"openpyxl>=3.1.5",
|
||||
@ -34,6 +27,10 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# PyTorch 为可选依赖(安装约 800MB,仅训练神经网络时需要)
|
||||
torch = [
|
||||
"torch==2.5.1",
|
||||
]
|
||||
dev = [
|
||||
# 测试工具
|
||||
"pytest>=7.0",
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
flask>=3.1.0
|
||||
flask-cors>=5.0.0
|
||||
sqlalchemy>=2.0.36
|
||||
pymysql>=1.1.1
|
||||
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
|
||||
mysql-connector-python>=8.0.0 # 添加这行
|
||||
|
||||
numpy>=1.26.0,<2.0.0
|
||||
pandas>=2.2.0
|
||||
|
||||
xgboost>=2.1.0
|
||||
lightgbm>=4.5.0
|
||||
|
||||
scikit-learn>=1.5.2
|
||||
|
||||
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
|
||||
python-dotenv>=1.0.0 # 环境变量
|
||||
openpyxl>=3.1.5
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
@ -34,7 +34,6 @@ pyinstaller --clean `
|
||||
--add-data "src/loitering_munition_data.sql;data" `
|
||||
--add-data "src/rocket_artillery_data.sql;data" `
|
||||
--add-data "src/manufacturer_data.sql;data" `
|
||||
--add-data "src/schema.sql;data" `
|
||||
--add-data "config.py;." `
|
||||
--add-data "src;src" `
|
||||
--add-data "frontend;frontend" `
|
||||
@ -43,18 +42,11 @@ pyinstaller --clean `
|
||||
--add-data "models;models" `
|
||||
--collect-all "xgboost" `
|
||||
--collect-all "lightgbm" `
|
||||
--collect-all "torch" `
|
||||
--collect-all "sklearn" `
|
||||
--collect-all "numpy" `
|
||||
--collect-all "pandas" `
|
||||
--collect-all "sqlalchemy" `
|
||||
--collect-all "pymysql" `
|
||||
--collect-all "cryptography" `
|
||||
--collect-all "flask" `
|
||||
--collect-all "flask_cors" `
|
||||
--hidden-import "xgboost.testing" `
|
||||
--hidden-import "torch.utils.tensorboard" `
|
||||
--hidden-import "pytest" `
|
||||
run.py
|
||||
|
||||
# Copy necessary files
|
||||
|
||||
@ -15,12 +15,6 @@ def create_app():
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# 配置数据库连接
|
||||
app.config['MYSQL_HOST'] = config.MYSQL_HOST
|
||||
app.config['MYSQL_USER'] = config.MYSQL_USER
|
||||
app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
|
||||
app.config['MYSQL_DB'] = config.MYSQL_DB
|
||||
|
||||
# 注册路由
|
||||
app.register_blueprint(api_bp, url_prefix='/api')
|
||||
logger.info("API blueprint registered")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from scipy import stats
|
||||
@ -9,6 +8,16 @@ from src.database import get_db_connection
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from .logger import setup_logger
|
||||
|
||||
# PyTorch 为可选依赖
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
_HAS_TORCH = False
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class CostPredictor:
|
||||
@ -18,7 +27,11 @@ class CostPredictor:
|
||||
self.model = None
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
self.equipment_type = None
|
||||
|
||||
if _HAS_TORCH:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = None
|
||||
|
||||
self.load_model()
|
||||
|
||||
@ -26,10 +39,9 @@ class CostPredictor:
|
||||
"""
|
||||
加载预训练模型和标准化器
|
||||
"""
|
||||
if _HAS_TORCH:
|
||||
try:
|
||||
# 创建默认模型
|
||||
self._create_default_model()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model: {str(e)}")
|
||||
self._create_default_model()
|
||||
@ -38,46 +50,36 @@ class CostPredictor:
|
||||
"""
|
||||
创建默认模型并进行初始化训练
|
||||
"""
|
||||
import torch.nn as nn
|
||||
if not _HAS_TORCH:
|
||||
raise ImportError("PyTorch is not installed.")
|
||||
|
||||
class DefaultModel(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(input_size, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(input_size, 64), nn.ReLU(),
|
||||
nn.Linear(64, 32), nn.ReLU(),
|
||||
nn.Linear(32, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
# 创建示例数据
|
||||
example_features = {
|
||||
'length_m': [7.35, 10.2],
|
||||
'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2],
|
||||
'weight_kg': [13700, 28500],
|
||||
'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8],
|
||||
'height_m': [3.1, 3.2], 'weight_kg': [13700, 28500],
|
||||
'max_range_km': [20.4, 70],
|
||||
'firing_angle_horizontal': [102, 110],
|
||||
'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1],
|
||||
'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150],
|
||||
'rate_of_fire': [40, 60]
|
||||
'firing_angle_horizontal': [102, 110], 'firing_angle_vertical': [55, 60],
|
||||
'rocket_length_m': [2.87, 4.1], 'rocket_diameter_mm': [122, 220],
|
||||
'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60]
|
||||
}
|
||||
|
||||
# 转换为 tensor
|
||||
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
|
||||
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
|
||||
|
||||
# 训练标准化器
|
||||
self.scaler_X.fit(X.numpy())
|
||||
self.scaler_y.fit(y.numpy())
|
||||
|
||||
# 创建模型
|
||||
self.model = DefaultModel(X.shape[1]).to(self.device)
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import logging
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from src.database import get_db_connection
|
||||
@ -9,9 +7,22 @@ from .logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class EquipmentDataset(Dataset):
|
||||
# PyTorch 为可选依赖
|
||||
try:
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
Dataset = object
|
||||
DataLoader = None
|
||||
_HAS_TORCH = False
|
||||
|
||||
class EquipmentDataset(Dataset if _HAS_TORCH else object):
|
||||
"""装备数据集类"""
|
||||
def __init__(self, features, targets=None):
|
||||
if not _HAS_TORCH:
|
||||
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||
self.features = torch.FloatTensor(features)
|
||||
self.targets = torch.FloatTensor(targets) if targets is not None else None
|
||||
|
||||
@ -41,7 +52,7 @@ class DataPreparation:
|
||||
|
||||
# 获取数据库连接
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取所有生产商数据,用于计算特征
|
||||
cursor.execute("""
|
||||
|
||||
@ -1,37 +1,227 @@
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from ..logger import setup_logger
|
||||
|
||||
# 获取logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# SQLite 数据库文件路径(相对于项目根目录)
|
||||
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data')
|
||||
DB_PATH = os.path.join(DB_DIR, 'equipment_cost.db')
|
||||
|
||||
# 建表 SQL
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS equipments (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
type TEXT,
|
||||
manufacturer TEXT,
|
||||
manufacturer_id INTEGER,
|
||||
created_at TEXT DEFAULT (datetime('now','localtime'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS common_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
length_m REAL,
|
||||
width_m REAL,
|
||||
height_m REAL,
|
||||
weight_kg REAL,
|
||||
max_range_km REAL,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS rocket_artillery_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
firing_angle_horizontal REAL,
|
||||
firing_angle_vertical REAL,
|
||||
rocket_length_m REAL,
|
||||
rocket_diameter_mm REAL,
|
||||
rocket_weight_kg REAL,
|
||||
rate_of_fire REAL,
|
||||
combat_weight_kg REAL,
|
||||
speed_kmh REAL,
|
||||
min_range_km REAL,
|
||||
max_range_km REAL,
|
||||
mobility_type TEXT,
|
||||
structure_layout TEXT,
|
||||
engine_model TEXT,
|
||||
engine_params TEXT,
|
||||
power_hp REAL,
|
||||
travel_range_km REAL,
|
||||
fire_density REAL,
|
||||
range_ratio REAL,
|
||||
mobility_score INTEGER,
|
||||
combat_readiness_score INTEGER,
|
||||
deployment_score INTEGER,
|
||||
terrain_adaptability_score INTEGER,
|
||||
rocket_power_ratio REAL,
|
||||
platform_efficiency REAL,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS loitering_munition_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
wingspan_m REAL,
|
||||
warhead_weight_kg REAL,
|
||||
max_speed_ms REAL,
|
||||
cruise_speed_kmh REAL,
|
||||
endurance_min REAL,
|
||||
flight_time_min REAL,
|
||||
max_range_km REAL,
|
||||
max_payload_kg REAL,
|
||||
ceiling_altitude_m REAL,
|
||||
combat_radius_km REAL,
|
||||
folded_length_mm REAL,
|
||||
folded_width_mm REAL,
|
||||
folded_height_mm REAL,
|
||||
warhead_type TEXT,
|
||||
launch_mode TEXT,
|
||||
power_system TEXT,
|
||||
guidance_system TEXT,
|
||||
engine_power_kw REAL,
|
||||
engine_thrust_n REAL,
|
||||
datalink_range_km REAL,
|
||||
guidance_accuracy_m REAL,
|
||||
min_altitude_m REAL,
|
||||
max_altitude_m REAL,
|
||||
length_width_ratio REAL,
|
||||
weight_range_ratio REAL,
|
||||
speed_weight_ratio REAL,
|
||||
guidance_system_score INTEGER,
|
||||
warhead_power_score INTEGER,
|
||||
warhead_type_code INTEGER,
|
||||
launch_mode_code INTEGER,
|
||||
power_system_code INTEGER,
|
||||
guidance_system_code INTEGER,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS feature_encoding (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
feature_type TEXT,
|
||||
feature_value TEXT,
|
||||
code INTEGER,
|
||||
UNIQUE(feature_type, feature_value)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS cost_data (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
actual_cost REAL,
|
||||
predicted_cost REAL,
|
||||
prediction_date TEXT,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS custom_params (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
equipment_id INTEGER,
|
||||
param_name TEXT,
|
||||
param_value TEXT,
|
||||
param_unit TEXT,
|
||||
description TEXT,
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS datasets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
equipment_type TEXT NOT NULL,
|
||||
purpose TEXT NOT NULL,
|
||||
created_at TEXT DEFAULT (datetime('now','localtime')),
|
||||
updated_at TEXT DEFAULT (datetime('now','localtime'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS dataset_equipments (
|
||||
dataset_id INTEGER NOT NULL,
|
||||
equipment_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (dataset_id, equipment_id),
|
||||
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
|
||||
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS trained_models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
equipment_type TEXT NOT NULL,
|
||||
model_path TEXT NOT NULL,
|
||||
scaler_path TEXT NOT NULL,
|
||||
r2_score REAL,
|
||||
mae REAL,
|
||||
rmse REAL,
|
||||
feature_importance TEXT,
|
||||
training_data_size INTEGER,
|
||||
training_date TEXT DEFAULT (datetime('now','localtime')),
|
||||
is_active INTEGER DEFAULT 0,
|
||||
created_by TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS manufacturers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
country TEXT NOT NULL,
|
||||
tech_level INTEGER NOT NULL,
|
||||
scale_level INTEGER NOT NULL,
|
||||
supply_chain_level INTEGER NOT NULL,
|
||||
created_at TEXT DEFAULT (datetime('now','localtime')),
|
||||
updated_at TEXT DEFAULT (datetime('now','localtime')),
|
||||
UNIQUE(name)
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_equipment_type ON equipments(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_equipment_name ON equipments(name);
|
||||
CREATE INDEX IF NOT EXISTS idx_cost_data_equipment ON cost_data(equipment_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_equipment_type ON trained_models(equipment_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_active ON trained_models(is_active);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_country ON manufacturers(country);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_tech_level ON manufacturers(tech_level);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_scale_level ON manufacturers(scale_level);
|
||||
CREATE INDEX IF NOT EXISTS idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
|
||||
CREATE INDEX IF NOT EXISTS idx_equipment_manufacturer ON equipments(manufacturer_id);
|
||||
"""
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库:确保数据库文件和表存在"""
|
||||
os.makedirs(DB_DIR, exist_ok=True)
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.executescript(SCHEMA_SQL)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Database initialized at {DB_PATH}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
"""
|
||||
数据库连接上下文管理器
|
||||
返回的 connection 已设置 dict row_factory,
|
||||
以便按列名访问。
|
||||
"""
|
||||
connection = None
|
||||
conn = None
|
||||
try:
|
||||
connection = mysql.connector.connect(
|
||||
host=os.getenv('MYSQL_HOST', 'localhost'),
|
||||
user=os.getenv('MYSQL_USER', 'root'),
|
||||
password=os.getenv('MYSQL_PASSWORD', '123456'),
|
||||
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
|
||||
)
|
||||
logger.info("Database connection established")
|
||||
yield connection
|
||||
# 确保数据库已初始化
|
||||
if not os.path.exists(DB_PATH):
|
||||
logger.info("Database file not found, initializing...")
|
||||
init_db()
|
||||
|
||||
except Error as e:
|
||||
logger.error(f"Error connecting to MySQL: {str(e)}")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = lambda c, r: {col[0]: r[idx] for idx, col in enumerate(c.description)}
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
logger.debug("Database connection established")
|
||||
yield conn
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error: {str(e)}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
if connection and connection.is_connected():
|
||||
connection.close()
|
||||
logger.info("Database connection closed")
|
||||
if conn:
|
||||
conn.close()
|
||||
logger.debug("Database connection closed")
|
||||
|
||||
@ -267,11 +267,15 @@ class FeatureAnalysis:
|
||||
def calculate_manufacturer_features(self, manufacturer_data):
|
||||
"""计算生产商相关的特征"""
|
||||
try:
|
||||
# 确保所有必要的字段都存在,使用默认值处理缺失数据
|
||||
tech_level = float(manufacturer_data.get('tech_level', 0))
|
||||
scale_level = float(manufacturer_data.get('scale_level', 0))
|
||||
supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0))
|
||||
country = manufacturer_data.get('country', '未知')
|
||||
# 处理 None 值(数据库 NULL),使用默认值
|
||||
raw_tech = manufacturer_data.get('tech_level')
|
||||
raw_scale = manufacturer_data.get('scale_level')
|
||||
raw_supply = manufacturer_data.get('supply_chain_level')
|
||||
|
||||
tech_level = float(raw_tech) if raw_tech is not None else 0
|
||||
scale_level = float(raw_scale) if raw_scale is not None else 0
|
||||
supply_chain_level = float(raw_supply) if raw_supply is not None else 0
|
||||
country = manufacturer_data.get('country', '未知') or '未知'
|
||||
|
||||
# 计算综合得分
|
||||
composite_score = (
|
||||
|
||||
@ -27,7 +27,7 @@ def import_training_data(excel_file):
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipments
|
||||
WHERE name = %s AND type = '火箭炮'
|
||||
WHERE name = ? AND type = '火箭炮'
|
||||
""", (row['名称'],))
|
||||
|
||||
existing_equipment = cursor.fetchone()
|
||||
@ -38,7 +38,7 @@ def import_training_data(excel_file):
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipments (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
VALUES (?, ?, ?)
|
||||
""", (row['名称'], '火箭炮', row['制造商']))
|
||||
|
||||
equipment_id = cursor.lastrowid
|
||||
@ -47,7 +47,7 @@ def import_training_data(excel_file):
|
||||
cursor.execute("""
|
||||
INSERT INTO common_params
|
||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
row['总长_m'] if pd.notna(row['总长_m']) else None,
|
||||
@ -65,7 +65,7 @@ def import_training_data(excel_file):
|
||||
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
|
||||
structure_layout, engine_model, engine_params, power_hp,
|
||||
travel_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
|
||||
@ -89,7 +89,7 @@ def import_training_data(excel_file):
|
||||
if pd.notna(row['成本_元']):
|
||||
cursor.execute("""
|
||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||
VALUES (%s, %s)
|
||||
VALUES (?, ?)
|
||||
""", (equipment_id, row['成本_元']))
|
||||
|
||||
logger.info("火箭炮数据导入完成")
|
||||
@ -105,8 +105,8 @@ def import_training_data(excel_file):
|
||||
equipment_names.add(row['名称'])
|
||||
# 检查是否已存在相同名称的装备
|
||||
cursor.execute("""
|
||||
SELECT id FROM equipment
|
||||
WHERE name = %s AND type = '巡飞弹'
|
||||
SELECT id FROM equipments
|
||||
WHERE name = ? AND type = '巡飞弹'
|
||||
""", (row['名称'],))
|
||||
|
||||
existing_equipment = cursor.fetchone()
|
||||
@ -117,7 +117,7 @@ def import_training_data(excel_file):
|
||||
# 插入基本信息
|
||||
cursor.execute("""
|
||||
INSERT INTO equipments (name, type, manufacturer)
|
||||
VALUES (%s, %s, %s)
|
||||
VALUES (?, ?, ?)
|
||||
""", (
|
||||
row['名称'],
|
||||
'巡飞弹',
|
||||
@ -130,7 +130,7 @@ def import_training_data(excel_file):
|
||||
cursor.execute("""
|
||||
INSERT INTO common_params
|
||||
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
|
||||
@ -147,7 +147,7 @@ def import_training_data(excel_file):
|
||||
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
|
||||
folded_length_mm, folded_width_mm, folded_height_mm,
|
||||
power_system, guidance_system)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
|
||||
@ -168,7 +168,7 @@ def import_training_data(excel_file):
|
||||
if pd.notna(row['成本_元']):
|
||||
cursor.execute("""
|
||||
INSERT INTO cost_data (equipment_id, actual_cost)
|
||||
VALUES (%s, %s)
|
||||
VALUES (?, ?)
|
||||
""", (equipment_id, float(row['成本_元'])))
|
||||
|
||||
logger.info("巡飞弹数据导入完成")
|
||||
@ -190,9 +190,9 @@ def import_training_data(excel_file):
|
||||
|
||||
# 获取装备ID - 使用新的游标
|
||||
logger.debug(f"查询装备ID: {equipment_name}")
|
||||
with conn.cursor() as id_cursor:
|
||||
id_cursor = conn.cursor()
|
||||
id_cursor.execute("""
|
||||
SELECT id FROM equipments WHERE name = %s
|
||||
SELECT id FROM equipments WHERE name = ?
|
||||
""", (equipment_name,))
|
||||
result = id_cursor.fetchone()
|
||||
|
||||
@ -200,15 +200,15 @@ def import_training_data(excel_file):
|
||||
logger.warning(f"未找到装备: {equipment_name}")
|
||||
continue
|
||||
|
||||
equipment_id = result[0]
|
||||
equipment_id = result['id']
|
||||
logger.debug(f"找到装备ID: {equipment_id}")
|
||||
|
||||
# 检查参数是否存在 - 使用新的游标
|
||||
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
|
||||
with conn.cursor() as check_cursor:
|
||||
check_cursor = conn.cursor()
|
||||
check_cursor.execute("""
|
||||
SELECT id FROM custom_params
|
||||
WHERE equipment_id = %s AND param_name = %s
|
||||
WHERE equipment_id = ? AND param_name = ?
|
||||
""", (equipment_id, param_name))
|
||||
exists = check_cursor.fetchone()
|
||||
|
||||
@ -222,11 +222,11 @@ def import_training_data(excel_file):
|
||||
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
|
||||
|
||||
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
|
||||
with conn.cursor() as insert_cursor:
|
||||
insert_cursor = conn.cursor()
|
||||
insert_cursor.execute("""
|
||||
INSERT INTO custom_params
|
||||
(equipment_id, param_name, param_value, param_unit, description)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
equipment_id,
|
||||
param_name,
|
||||
|
||||
237
src/import_sql_data.py
Normal file
237
src/import_sql_data.py
Normal file
@ -0,0 +1,237 @@
|
||||
"""
|
||||
导入 SQL 数据文件到 SQLite 数据库
|
||||
处理 src/ 目录下的 MySQL 格式 SQL 数据文件:
|
||||
- rocket_artillery_data.sql (96 条火箭炮数据)
|
||||
- loitering_munition_data.sql (100 条巡飞弹数据)
|
||||
- manufacturer_data.sql (生产商数据)
|
||||
"""
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
from src.database.db_connection import get_db_connection, DB_PATH
|
||||
from src.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def parse_mysql_insert(sql_text):
|
||||
"""
|
||||
解析 MySQL INSERT 语句,返回 (table_name, columns, values_list)
|
||||
columns 是列名列表, values_list 是每行的值列表
|
||||
"""
|
||||
# 去掉块注释 /* ... */
|
||||
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
|
||||
# 去掉行注释 --
|
||||
sql_text = re.sub(r'--.*', '', sql_text)
|
||||
|
||||
results = []
|
||||
|
||||
# 匹配 INSERT INTO table (columns) VALUES (values), (values), ...;
|
||||
pattern = re.compile(
|
||||
r"INSERT\s+INTO\s+(\w+)\s*\(([^)]+)\)\s*VALUES\s*(.+?);",
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
for match in pattern.finditer(sql_text):
|
||||
table_name = match.group(1).strip()
|
||||
columns_str = match.group(2).strip()
|
||||
values_str = match.group(3).strip()
|
||||
|
||||
# 提取列名(去掉注释部分)
|
||||
columns = []
|
||||
for col in columns_str.split(','):
|
||||
col = col.strip()
|
||||
# 去掉行内注释
|
||||
col = re.sub(r'\s*--.*', '', col).strip()
|
||||
columns.append(col)
|
||||
|
||||
# 解析 values - 处理多个用逗号分隔的 ( ... ) 元组
|
||||
values_list = []
|
||||
current_tuple = []
|
||||
depth = 0
|
||||
current_val = ''
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
for ch in values_str:
|
||||
if in_string:
|
||||
if ch == string_char:
|
||||
in_string = False
|
||||
current_val += ch
|
||||
continue
|
||||
|
||||
if ch in ("'", '"'):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
current_val += ch
|
||||
continue
|
||||
|
||||
if ch == '(':
|
||||
if depth > 0:
|
||||
current_val += ch
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
current_val = ''
|
||||
continue
|
||||
|
||||
if ch == ')':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
current_tuple.append(current_val.strip())
|
||||
values_list.append(current_tuple)
|
||||
current_tuple = []
|
||||
else:
|
||||
current_val += ch
|
||||
continue
|
||||
|
||||
if ch == ',' and depth == 1:
|
||||
current_tuple.append(current_val.strip())
|
||||
current_val = ''
|
||||
continue
|
||||
|
||||
if depth >= 1:
|
||||
current_val += ch
|
||||
|
||||
results.append((table_name, columns, values_list))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def convert_value(val):
|
||||
"""将 MySQL 值字符串转为 Python 对象"""
|
||||
val = val.strip()
|
||||
if val.upper() == 'NULL' or val == '':
|
||||
return None
|
||||
if val.upper() == 'TRUE':
|
||||
return 1
|
||||
if val.upper() == 'FALSE':
|
||||
return 0
|
||||
# 字符串
|
||||
if (val.startswith("'") and val.endswith("'")) or \
|
||||
(val.startswith('"') and val.endswith('"')):
|
||||
return val[1:-1]
|
||||
# 数字
|
||||
try:
|
||||
if '.' in val:
|
||||
return float(val)
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
|
||||
def import_sql_file(filepath):
|
||||
"""导入单个 SQL 文件"""
|
||||
logger.info(f"Reading {filepath}...")
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
sql_text = f.read()
|
||||
|
||||
parsed = parse_mysql_insert(sql_text)
|
||||
total_rows = 0
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
for table_name, columns, values_list in parsed:
|
||||
if not values_list:
|
||||
continue
|
||||
|
||||
# 构建参数化 INSERT
|
||||
placeholders = ','.join(['?'] * len(columns))
|
||||
col_names = ','.join(columns)
|
||||
sql = f"INSERT OR IGNORE INTO {table_name} ({col_names}) VALUES ({placeholders})"
|
||||
|
||||
row_count = 0
|
||||
for values in values_list:
|
||||
if len(values) != len(columns):
|
||||
logger.warning(f"Column mismatch in {table_name}: "
|
||||
f"expected {len(columns)} values, got {len(values)}: {values}")
|
||||
continue
|
||||
converted = [convert_value(v) for v in values]
|
||||
try:
|
||||
cursor.execute(sql, converted)
|
||||
if cursor.rowcount > 0:
|
||||
row_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error inserting into {table_name}: {e}")
|
||||
logger.warning(f" Values: {converted}")
|
||||
|
||||
if row_count > 0:
|
||||
logger.info(f" {table_name}: inserted {row_count} rows")
|
||||
total_rows += row_count
|
||||
|
||||
conn.commit()
|
||||
|
||||
return total_rows
|
||||
|
||||
|
||||
def run_manufacturer_update():
|
||||
"""执行 manufacturer_id 更新(把 equipments.manufacturer 名称映射为 manufacturers.id)"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE equipments
|
||||
SET manufacturer_id = (
|
||||
SELECT id FROM manufacturers WHERE name = equipments.manufacturer
|
||||
)
|
||||
WHERE manufacturer_id IS NULL
|
||||
AND manufacturer IS NOT NULL
|
||||
AND EXISTS (SELECT 1 FROM manufacturers WHERE name = equipments.manufacturer)
|
||||
""")
|
||||
conn.commit()
|
||||
updated = cursor.rowcount
|
||||
if updated > 0:
|
||||
logger.info(f"Updated manufacturer_id for {updated} equipment(s)")
|
||||
else:
|
||||
logger.info("No manufacturer_id updates needed")
|
||||
|
||||
|
||||
def import_all():
|
||||
"""导入所有 SQL 数据文件"""
|
||||
base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
# 清除现有数据库,重新开始
|
||||
if os.path.exists(DB_PATH):
|
||||
os.remove(DB_PATH)
|
||||
logger.info(f"Removed existing database: {DB_PATH}")
|
||||
|
||||
files = [
|
||||
(os.path.join(base_dir, 'src', 'manufacturer_data.sql'), '生产商数据'),
|
||||
(os.path.join(base_dir, 'src', 'rocket_artillery_data.sql'), '火箭炮数据'),
|
||||
(os.path.join(base_dir, 'src', 'loitering_munition_data.sql'), '巡飞弹数据'),
|
||||
]
|
||||
|
||||
total = 0
|
||||
for filepath, label in files:
|
||||
if not os.path.exists(filepath):
|
||||
logger.warning(f"File not found: {filepath}")
|
||||
continue
|
||||
logger.info(f"正在导入 {label}...")
|
||||
rows = import_sql_file(filepath)
|
||||
logger.info(f" {label}: 共导入 {rows} 行")
|
||||
total += rows
|
||||
|
||||
# 更新 manufacturer_id
|
||||
run_manufacturer_update()
|
||||
|
||||
# 统计结果
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as cnt FROM equipments")
|
||||
eq_count = cursor.fetchone()['cnt']
|
||||
cursor.execute("SELECT type, COUNT(*) as cnt FROM equipments GROUP BY type")
|
||||
by_type = {r['type']: r['cnt'] for r in cursor.fetchall()}
|
||||
cursor.execute("SELECT COUNT(*) as cnt FROM manufacturers")
|
||||
mf_count = cursor.fetchone()['cnt']
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"导入完成!")
|
||||
logger.info(f" 装备总计: {eq_count}")
|
||||
for t, c in by_type.items():
|
||||
logger.info(f" {t}: {c}")
|
||||
logger.info(f" 生产商: {mf_count}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
return eq_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import_all()
|
||||
@ -1,7 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
||||
@ -11,40 +8,43 @@ from datetime import datetime
|
||||
import json
|
||||
from src.feature_analysis import FeatureAnalysis
|
||||
from src.database import get_db_connection
|
||||
from src.data_preparation import DataPreparation, EquipmentDataset
|
||||
from .logger import setup_logger
|
||||
import math
|
||||
|
||||
# PyTorch 为可选依赖
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from src.data_preparation import EquipmentDataset
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
DataLoader = None
|
||||
EquipmentDataset = None
|
||||
_HAS_TORCH = False
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class CostPredictionModel(nn.Module):
|
||||
# 条件基类:有 PyTorch 时继承 nn.Module,否则继承 object
|
||||
_PYTORCH_BASE = nn.Module if _HAS_TORCH else object
|
||||
|
||||
class CostPredictionModel(_PYTORCH_BASE):
|
||||
def __init__(self, input_size, equipment_type):
|
||||
if not _HAS_TORCH:
|
||||
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||
super().__init__()
|
||||
self.equipment_type = equipment_type
|
||||
|
||||
if equipment_type == '火箭炮':
|
||||
# 火箭炮使用更简单和稳定的网络结构
|
||||
self.net = nn.Sequential(
|
||||
# 第一层:特征映射
|
||||
nn.Linear(input_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(32),
|
||||
|
||||
# 第二层:特征提取
|
||||
nn.Linear(32, 16),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(16),
|
||||
|
||||
# 第三层:特征整合
|
||||
nn.Linear(16, 8),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(8),
|
||||
|
||||
# 输出层
|
||||
nn.Linear(input_size, 32), nn.ReLU(), nn.BatchNorm1d(32),
|
||||
nn.Linear(32, 16), nn.ReLU(), nn.BatchNorm1d(16),
|
||||
nn.Linear(16, 8), nn.ReLU(), nn.BatchNorm1d(8),
|
||||
nn.Linear(8, 1)
|
||||
)
|
||||
|
||||
# 使用正交初始化
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight, gain=0.5)
|
||||
@ -54,45 +54,19 @@ class CostPredictionModel(nn.Module):
|
||||
torch.nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
self.net.apply(init_weights)
|
||||
|
||||
else: # 巡飞弹保持原有结构
|
||||
# 生产商特征网络 - 更简单的结构
|
||||
else:
|
||||
self.manufacturer_net = nn.Sequential(
|
||||
nn.Linear(5, 4),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(4),
|
||||
nn.Dropout(0.2)
|
||||
nn.Linear(5, 4), nn.ReLU(), nn.BatchNorm1d(4), nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# 巡飞弹特征网络 - 较深的结构
|
||||
self.equipment_net = nn.Sequential(
|
||||
nn.Linear(input_size - 5, 64),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(64, 32),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(32, 16),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(16),
|
||||
nn.Dropout(0.2)
|
||||
nn.Linear(input_size - 5, 64), nn.LeakyReLU(0.1), nn.BatchNorm1d(64), nn.Dropout(0.2),
|
||||
nn.Linear(64, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
|
||||
nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# 合并网络 - 较复杂的结构
|
||||
self.combined_net = nn.Sequential(
|
||||
nn.Linear(20, 32), # 4 + 16 = 20
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(32, 16),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(16),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(16, 8),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(8),
|
||||
nn.Linear(20, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
|
||||
nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2),
|
||||
nn.Linear(16, 8), nn.LeakyReLU(0.1), nn.BatchNorm1d(8),
|
||||
nn.Linear(8, 1)
|
||||
)
|
||||
|
||||
@ -100,20 +74,17 @@ class CostPredictionModel(nn.Module):
|
||||
if self.equipment_type == '火箭炮':
|
||||
return self.net(x)
|
||||
else:
|
||||
# 分离特征
|
||||
manufacturer_features = x[:, -5:]
|
||||
equipment_features = x[:, :-5]
|
||||
|
||||
# 特征处理
|
||||
manu_out = self.manufacturer_net(manufacturer_features)
|
||||
equip_out = self.equipment_net(equipment_features)
|
||||
|
||||
# 特征融合
|
||||
combined = torch.cat([equip_out, manu_out], dim=1)
|
||||
return self.combined_net(combined)
|
||||
|
||||
class ModelTrainer:
|
||||
def __init__(self):
|
||||
if not _HAS_TORCH:
|
||||
raise ImportError("PyTorch is not installed. Install with: pip install torch")
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model = None
|
||||
self.feature_scaler = None
|
||||
@ -307,7 +278,7 @@ class ModelTrainer:
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type != %s
|
||||
WHERE equipment_type = ? AND model_type != ?
|
||||
""", (equipment_type, 'pls'))
|
||||
|
||||
# 保存新模型记录
|
||||
@ -316,7 +287,7 @@ class ModelTrainer:
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
scaler_path, training_date, is_active, created_by,
|
||||
r2_score, mae, rmse
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
|
||||
) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
|
||||
""", (
|
||||
f"{equipment_type}_{timestamp}",
|
||||
'pytorch',
|
||||
@ -343,10 +314,10 @@ class ModelTrainer:
|
||||
try:
|
||||
# 从数据库获取最新的激活模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE
|
||||
WHERE equipment_type = ? AND model_type = ? AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type, model_type))
|
||||
model_record = cursor.fetchone()
|
||||
@ -707,13 +678,13 @@ class ModelTrainer:
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type != %s
|
||||
WHERE equipment_type = ? AND model_type != ?
|
||||
""", (equipment_type, 'pls'))
|
||||
else:
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type = %s
|
||||
WHERE equipment_type = ? AND model_type = ?
|
||||
""", (equipment_type, 'pls'))
|
||||
|
||||
# 保存新模型记录
|
||||
@ -722,7 +693,7 @@ class ModelTrainer:
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
scaler_path, training_date, is_active, created_by,
|
||||
r2_score, mae, rmse
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
|
||||
) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
|
||||
""", (
|
||||
f"{equipment_type}_{model_type}_{timestamp}",
|
||||
model_type,
|
||||
|
||||
340
src/routes.py
340
src/routes.py
@ -4,16 +4,23 @@ from .feature_analysis import FeatureAnalysis
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import mysql.connector
|
||||
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
from .create_template import create_excel_template
|
||||
import json
|
||||
import os
|
||||
from .data_preparation import DataPreparation
|
||||
from .model_trainer import ModelTrainer
|
||||
from .database import get_db_connection
|
||||
from .demo_service import DemoModelService
|
||||
from .logger import setup_logger
|
||||
import torch
|
||||
|
||||
# PyTorch 为可选导入
|
||||
try:
|
||||
import torch
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
_HAS_TORCH = False
|
||||
|
||||
# 创建蓝图
|
||||
api_bp = Blueprint('api', __name__)
|
||||
@ -71,7 +78,7 @@ def demo_algorithms():
|
||||
|
||||
@api_bp.route('/demo/dataset', methods=['GET'])
|
||||
def demo_dataset():
|
||||
"""Return the local demo dataset summary without using MySQL."""
|
||||
"""Return the local demo dataset summary without using a database."""
|
||||
try:
|
||||
service = DemoModelService()
|
||||
return jsonify(service.get_dataset_summary())
|
||||
@ -101,11 +108,11 @@ def predict():
|
||||
|
||||
# 获取最新的激活模型(非PLS模型)
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s
|
||||
AND model_type != 'pls' # 明确排除PLS模型
|
||||
WHERE equipment_type = ?
|
||||
AND model_type != 'pls'
|
||||
AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
@ -147,14 +154,14 @@ def analyze_features():
|
||||
return jsonify({'error': '请选择数据集'}), 400
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 首先获取数据集的装备类型
|
||||
cursor.execute("""
|
||||
SELECT DISTINCT e.type
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
LIMIT 1
|
||||
""", (dataset_id,))
|
||||
|
||||
@ -181,7 +188,7 @@ def analyze_features():
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (dataset_id,))
|
||||
else:
|
||||
@ -205,7 +212,7 @@ def analyze_features():
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (dataset_id,))
|
||||
|
||||
@ -247,27 +254,27 @@ def analyze_features():
|
||||
# 添加装备特有的分析数据
|
||||
if equipment_data[0]['type'] == '火箭炮':
|
||||
rocket_data = {
|
||||
'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data],
|
||||
'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data],
|
||||
'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data],
|
||||
'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data],
|
||||
'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data],
|
||||
'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data]
|
||||
'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
|
||||
'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
|
||||
'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
|
||||
'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
|
||||
'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
|
||||
'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data]
|
||||
}
|
||||
analysis_result.update(rocket_data)
|
||||
else:
|
||||
missile_data = {
|
||||
'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data],
|
||||
'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data],
|
||||
'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data],
|
||||
'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data],
|
||||
'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data],
|
||||
'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data],
|
||||
'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data],
|
||||
'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data],
|
||||
'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data],
|
||||
'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data],
|
||||
'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data]
|
||||
'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data],
|
||||
'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data],
|
||||
'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data],
|
||||
'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data],
|
||||
'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data],
|
||||
'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item['guidance_accuracy_m'] is not None else 0 for item in equipment_data],
|
||||
'datalink_range_km': [float(item['datalink_range_km']) if item['datalink_range_km'] is not None else 0 for item in equipment_data],
|
||||
'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data],
|
||||
'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data],
|
||||
'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data],
|
||||
'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data]
|
||||
}
|
||||
analysis_result.update(missile_data)
|
||||
|
||||
@ -295,7 +302,7 @@ def train_model():
|
||||
|
||||
# 获取训练数据
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取训练集数据(包含生产商信息)
|
||||
if equipment_type == '火箭炮':
|
||||
@ -309,7 +316,7 @@ def train_model():
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (train_dataset_id,))
|
||||
else:
|
||||
@ -323,7 +330,7 @@ def train_model():
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (train_dataset_id,))
|
||||
|
||||
@ -342,7 +349,7 @@ def train_model():
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (validation_dataset_id,))
|
||||
else:
|
||||
@ -356,7 +363,7 @@ def train_model():
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (validation_dataset_id,))
|
||||
validation_data = cursor.fetchall()
|
||||
@ -476,7 +483,7 @@ def get_equipment_data():
|
||||
"""获取装备数据列表"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取所有装备数据(使用equipment_id替代id)
|
||||
cursor.execute("""
|
||||
@ -485,43 +492,39 @@ def get_equipment_data():
|
||||
cd.actual_cost, cd.predicted_cost,
|
||||
CASE
|
||||
WHEN e.type = '火箭炮' THEN (
|
||||
SELECT CONCAT(
|
||||
firing_angle_horizontal, ',',
|
||||
firing_angle_vertical, ',',
|
||||
rocket_length_m, ',',
|
||||
rocket_diameter_mm, ',',
|
||||
rocket_weight_kg, ',',
|
||||
rate_of_fire, ',',
|
||||
combat_weight_kg, ',',
|
||||
speed_kmh, ',',
|
||||
min_range_km, ',',
|
||||
max_range_km, ',',
|
||||
mobility_type, ',',
|
||||
structure_layout, ',',
|
||||
engine_model, ',',
|
||||
engine_params, ',',
|
||||
power_hp, ',',
|
||||
SELECT firing_angle_horizontal || ',' ||
|
||||
firing_angle_vertical || ',' ||
|
||||
rocket_length_m || ',' ||
|
||||
rocket_diameter_mm || ',' ||
|
||||
rocket_weight_kg || ',' ||
|
||||
rate_of_fire || ',' ||
|
||||
combat_weight_kg || ',' ||
|
||||
speed_kmh || ',' ||
|
||||
min_range_km || ',' ||
|
||||
max_range_km || ',' ||
|
||||
mobility_type || ',' ||
|
||||
structure_layout || ',' ||
|
||||
engine_model || ',' ||
|
||||
engine_params || ',' ||
|
||||
power_hp || ',' ||
|
||||
travel_range_km
|
||||
)
|
||||
FROM rocket_artillery_params
|
||||
WHERE equipment_id = e.id
|
||||
)
|
||||
WHEN e.type = '巡飞弹' THEN (
|
||||
SELECT CONCAT(
|
||||
wingspan_m, ',',
|
||||
warhead_weight_kg, ',',
|
||||
max_speed_ms, ',',
|
||||
cruise_speed_kmh, ',',
|
||||
endurance_min, ',',
|
||||
max_range_km, ',',
|
||||
max_payload_kg, ',',
|
||||
ceiling_altitude_m, ',',
|
||||
combat_radius_km, ',',
|
||||
warhead_type, ',',
|
||||
launch_mode, ',',
|
||||
power_system, ',',
|
||||
SELECT wingspan_m || ',' ||
|
||||
warhead_weight_kg || ',' ||
|
||||
max_speed_ms || ',' ||
|
||||
cruise_speed_kmh || ',' ||
|
||||
endurance_min || ',' ||
|
||||
max_range_km || ',' ||
|
||||
max_payload_kg || ',' ||
|
||||
ceiling_altitude_m || ',' ||
|
||||
combat_radius_km || ',' ||
|
||||
warhead_type || ',' ||
|
||||
launch_mode || ',' ||
|
||||
power_system || ',' ||
|
||||
guidance_system
|
||||
)
|
||||
FROM loitering_munition_params
|
||||
WHERE equipment_id = e.id
|
||||
)
|
||||
@ -548,19 +551,17 @@ def delete_equipment(id):
|
||||
删除装备数据
|
||||
"""
|
||||
try:
|
||||
db = get_db_connection()
|
||||
cursor = db.cursor()
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 删除相关数据
|
||||
cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM equipments WHERE id = %s", (id,))
|
||||
cursor.execute("DELETE FROM cost_data WHERE equipment_id = ?", (id,))
|
||||
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = ?", (id,))
|
||||
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = ?", (id,))
|
||||
cursor.execute("DELETE FROM common_params WHERE equipment_id = ?", (id,))
|
||||
cursor.execute("DELETE FROM equipments WHERE id = ?", (id,))
|
||||
|
||||
db.commit()
|
||||
cursor.close()
|
||||
db.close()
|
||||
conn.commit()
|
||||
|
||||
return jsonify({'status': 'success'})
|
||||
|
||||
@ -594,16 +595,6 @@ def download_template():
|
||||
logger.error(f"Error creating template: {str(e)}")
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
def get_db_connection():
|
||||
"""
|
||||
获取数据库连接
|
||||
"""
|
||||
return mysql.connector.connect(
|
||||
host="localhost",
|
||||
user="root",
|
||||
password="123456",
|
||||
database="equipment_cost_db"
|
||||
)
|
||||
|
||||
@api_bp.route('/pls/predict', methods=['POST'])
|
||||
def pls_predict():
|
||||
@ -614,11 +605,11 @@ def pls_predict():
|
||||
|
||||
# 获取最新的PLS模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s
|
||||
AND model_type = 'pls' # 只选择PLS模型
|
||||
WHERE equipment_type = ?
|
||||
AND model_type = 'pls'
|
||||
AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
@ -696,38 +687,38 @@ def update_equipment(id):
|
||||
# 更新装备基本信息
|
||||
cursor.execute("""
|
||||
UPDATE equipments
|
||||
SET name = %s, manufacturer = %s
|
||||
WHERE id = %s
|
||||
SET name = ?, manufacturer = ?
|
||||
WHERE id = ?
|
||||
""", (data['name'], data['manufacturer'], equipment_id))
|
||||
|
||||
# 更新通用参数
|
||||
cursor.execute("""
|
||||
UPDATE common_params
|
||||
SET length_m = %s, width_m = %s, height_m = %s, weight_kg = %s
|
||||
WHERE equipment_id = %s
|
||||
SET length_m = ?, width_m = ?, height_m = ?, weight_kg = ?
|
||||
WHERE equipment_id = ?
|
||||
""", (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id))
|
||||
|
||||
# 根据装备类型更新特有参数
|
||||
if data['type'] == '火箭炮':
|
||||
cursor.execute("""
|
||||
UPDATE rocket_artillery_params
|
||||
SET firing_angle_horizontal = %s,
|
||||
firing_angle_vertical = %s,
|
||||
rocket_length_m = %s,
|
||||
rocket_diameter_mm = %s,
|
||||
rocket_weight_kg = %s,
|
||||
rate_of_fire = %s,
|
||||
combat_weight_kg = %s,
|
||||
speed_kmh = %s,
|
||||
min_range_km = %s,
|
||||
max_range_km = %s,
|
||||
mobility_type = %s,
|
||||
structure_layout = %s,
|
||||
engine_model = %s,
|
||||
engine_params = %s,
|
||||
power_hp = %s,
|
||||
travel_range_km = %s
|
||||
WHERE equipment_id = %s
|
||||
SET firing_angle_horizontal = ?,
|
||||
firing_angle_vertical = ?,
|
||||
rocket_length_m = ?,
|
||||
rocket_diameter_mm = ?,
|
||||
rocket_weight_kg = ?,
|
||||
rate_of_fire = ?,
|
||||
combat_weight_kg = ?,
|
||||
speed_kmh = ?,
|
||||
min_range_km = ?,
|
||||
max_range_km = ?,
|
||||
mobility_type = ?,
|
||||
structure_layout = ?,
|
||||
engine_model = ?,
|
||||
engine_params = ?,
|
||||
power_hp = ?,
|
||||
travel_range_km = ?
|
||||
WHERE equipment_id = ?
|
||||
""", (
|
||||
data['firing_angle_horizontal'],
|
||||
data['firing_angle_vertical'],
|
||||
@ -750,17 +741,17 @@ def update_equipment(id):
|
||||
else:
|
||||
cursor.execute("""
|
||||
UPDATE loitering_munition_params
|
||||
SET wingspan_m = %s,
|
||||
warhead_weight_kg = %s,
|
||||
max_speed_ms = %s,
|
||||
cruise_speed_kmh = %s,
|
||||
endurance_min = %s,
|
||||
max_range_km = %s,
|
||||
warhead_type = %s,
|
||||
launch_mode = %s,
|
||||
power_system = %s,
|
||||
guidance_system = %s
|
||||
WHERE equipment_id = %s
|
||||
SET wingspan_m = ?,
|
||||
warhead_weight_kg = ?,
|
||||
max_speed_ms = ?,
|
||||
cruise_speed_kmh = ?,
|
||||
endurance_min = ?,
|
||||
max_range_km = ?,
|
||||
warhead_type = ?,
|
||||
launch_mode = ?,
|
||||
power_system = ?,
|
||||
guidance_system = ?
|
||||
WHERE equipment_id = ?
|
||||
""", (
|
||||
data['wingspan_m'],
|
||||
data['warhead_weight_kg'],
|
||||
@ -779,8 +770,8 @@ def update_equipment(id):
|
||||
if 'actual_cost' in data:
|
||||
cursor.execute("""
|
||||
UPDATE cost_data
|
||||
SET actual_cost = %s
|
||||
WHERE equipment_id = %s
|
||||
SET actual_cost = ?
|
||||
WHERE equipment_id = ?
|
||||
""", (data['actual_cost'], equipment_id))
|
||||
|
||||
conn.commit()
|
||||
@ -798,7 +789,7 @@ def get_equipment_details(id):
|
||||
logger.info(f"Getting details for equipment ID: {id}")
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取装备基本信息类型
|
||||
cursor.execute("""
|
||||
@ -806,7 +797,7 @@ def get_equipment_details(id):
|
||||
FROM equipments e
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE e.id = %s
|
||||
WHERE e.id = ?
|
||||
""", (id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
@ -822,14 +813,14 @@ def get_equipment_details(id):
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
FROM rocket_artillery_params
|
||||
WHERE equipment_id = %s
|
||||
WHERE equipment_id = ?
|
||||
""", (id,))
|
||||
custom_params = cursor.fetchone()
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
FROM loitering_munition_params
|
||||
WHERE equipment_id = %s
|
||||
WHERE equipment_id = ?
|
||||
""", (id,))
|
||||
custom_params = cursor.fetchone()
|
||||
|
||||
@ -853,7 +844,7 @@ def get_datasets():
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT d.*,
|
||||
COUNT(de.equipment_id) as equipment_count,
|
||||
@ -872,12 +863,7 @@ def get_datasets():
|
||||
else:
|
||||
dataset['equipment_names'] = []
|
||||
|
||||
# 格式化时间和数值字段
|
||||
for dataset in datasets:
|
||||
if dataset['created_at']:
|
||||
dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
if dataset['updated_at']:
|
||||
dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
# SQLite 已存储日期为文本格式,无需转换
|
||||
|
||||
return jsonify(datasets)
|
||||
except Exception as e:
|
||||
@ -891,14 +877,14 @@ def get_dataset(id):
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
# 获取数据集基本信息
|
||||
cursor.execute("""
|
||||
SELECT d.*,
|
||||
COUNT(de.equipment_id) as equipment_count
|
||||
FROM datasets d
|
||||
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
|
||||
WHERE d.id = %s
|
||||
WHERE d.id = ?
|
||||
GROUP BY d.id
|
||||
""", (id,))
|
||||
dataset = cursor.fetchone()
|
||||
@ -913,7 +899,7 @@ def get_dataset(id):
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
""", (id,))
|
||||
equipment = cursor.fetchall()
|
||||
|
||||
@ -952,14 +938,14 @@ def create_dataset():
|
||||
|
||||
# 1. 验证装备ID是否存在
|
||||
if 'equipment_ids' in data and data['equipment_ids']:
|
||||
# 直接从 equipment 表查询,不需要 JOIN
|
||||
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
|
||||
ids = data['equipment_ids']
|
||||
placeholders = ','.join(['?'] * len(ids))
|
||||
cursor.execute(f"""
|
||||
SELECT DISTINCT id FROM equipments
|
||||
WHERE id IN ({equipment_ids_str}) AND type = %s
|
||||
""", (data['equipment_type'],))
|
||||
WHERE id IN ({placeholders}) AND type = ?
|
||||
""", (*ids, data['equipment_type']))
|
||||
|
||||
valid_ids = [row[0] for row in cursor.fetchall()]
|
||||
valid_ids = [row['id'] for row in cursor.fetchall()]
|
||||
logger.info(f"Valid equipment IDs: {valid_ids}")
|
||||
|
||||
# 如果有无效的ID,返回错误
|
||||
@ -971,7 +957,7 @@ def create_dataset():
|
||||
# 2. 创建数据集
|
||||
cursor.execute("""
|
||||
INSERT INTO datasets (name, description, equipment_type, purpose)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (data['name'], data['description'], data['equipment_type'], data['purpose']))
|
||||
|
||||
dataset_id = cursor.lastrowid
|
||||
@ -982,7 +968,7 @@ def create_dataset():
|
||||
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
|
||||
cursor.executemany("""
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
||||
VALUES (%s, %s)
|
||||
VALUES (?, ?)
|
||||
""", values)
|
||||
logger.info(f"Added {len(values)} equipment associations")
|
||||
|
||||
@ -1004,14 +990,15 @@ def update_dataset(id):
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. 验证装备ID是否存在
|
||||
if 'equipment_ids' in data:
|
||||
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
|
||||
if 'equipment_ids' in data and data['equipment_ids']:
|
||||
ids = data['equipment_ids']
|
||||
placeholders = ','.join(['?'] * len(ids))
|
||||
cursor.execute(f"""
|
||||
SELECT id FROM equipments
|
||||
WHERE id IN ({equipment_ids_str}) AND type = %s
|
||||
""", (data['equipment_type'],))
|
||||
WHERE id IN ({placeholders}) AND type = ?
|
||||
""", (*ids, data['equipment_type']))
|
||||
|
||||
valid_ids = [row[0] for row in cursor.fetchall()]
|
||||
valid_ids = [row['id'] for row in cursor.fetchall()]
|
||||
logger.info(f"Valid equipment IDs: {valid_ids}")
|
||||
|
||||
# 如果有无效的ID,返回错误
|
||||
@ -1023,21 +1010,21 @@ def update_dataset(id):
|
||||
# 2. 更新数据集基本信息
|
||||
cursor.execute("""
|
||||
UPDATE datasets
|
||||
SET name = %s, description = %s, equipment_type = %s, purpose = %s
|
||||
WHERE id = %s
|
||||
SET name = ?, description = ?, equipment_type = ?, purpose = ?
|
||||
WHERE id = ?
|
||||
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
|
||||
|
||||
# 3. 更新装备关联
|
||||
if 'equipment_ids' in data:
|
||||
# 先删除旧的关联
|
||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
|
||||
|
||||
# 添加新的关联
|
||||
if valid_ids: # 确保有有效的ID才执行插入
|
||||
values = [(id, equipment_id) for equipment_id in valid_ids]
|
||||
cursor.executemany("""
|
||||
INSERT INTO dataset_equipments (dataset_id, equipment_id)
|
||||
VALUES (%s, %s)
|
||||
VALUES (?, ?)
|
||||
""", values)
|
||||
logger.info(f"Updated {len(values)} equipment associations")
|
||||
|
||||
@ -1058,10 +1045,10 @@ def delete_dataset(id):
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 删除装备关联
|
||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
|
||||
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
|
||||
|
||||
# 删除数据集
|
||||
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
|
||||
cursor.execute("DELETE FROM datasets WHERE id = ?", (id,))
|
||||
|
||||
conn.commit()
|
||||
return jsonify({'success': True})
|
||||
@ -1076,10 +1063,10 @@ def get_latest_model(equipment_type):
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s AND is_active = TRUE
|
||||
WHERE equipment_type = ? AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
|
||||
@ -1095,7 +1082,7 @@ def get_models():
|
||||
"""获取模型列表"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
ORDER BY training_date DESC
|
||||
@ -1105,9 +1092,7 @@ def get_models():
|
||||
|
||||
# 格式化时间和数值字段
|
||||
for model in models:
|
||||
# 将数据库中的datetime转换为ISO格式字符串
|
||||
if model['training_date']:
|
||||
model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
# SQLite 已存储日期为文本格式,无需转换
|
||||
if model['r2_score'] is not None:
|
||||
model['r2_score'] = float(model['r2_score'])
|
||||
if model['mae'] is not None:
|
||||
@ -1130,12 +1115,12 @@ def activate_model(id):
|
||||
"""激活指定的模型"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True) # 使用字典游标
|
||||
cursor = conn.cursor() # 使用字典游标
|
||||
|
||||
# 获取模型信息
|
||||
cursor.execute("""
|
||||
SELECT equipment_type, model_type FROM trained_models
|
||||
WHERE id = %s
|
||||
WHERE id = ?
|
||||
""", (id,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
@ -1146,14 +1131,14 @@ def activate_model(id):
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type = %s
|
||||
WHERE equipment_type = ? AND model_type = ?
|
||||
""", (model['equipment_type'], model['model_type']))
|
||||
|
||||
# 激活指定模型
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = TRUE
|
||||
WHERE id = %s
|
||||
WHERE id = ?
|
||||
""", (id,))
|
||||
|
||||
conn.commit()
|
||||
@ -1176,21 +1161,21 @@ def delete_model(id):
|
||||
cursor.execute("""
|
||||
SELECT model_path, scaler_path
|
||||
FROM trained_models
|
||||
WHERE id = %s
|
||||
WHERE id = ?
|
||||
""", (id,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
if not model:
|
||||
return jsonify({'error': 'Model not found'}), 404
|
||||
|
||||
# 删除模型件
|
||||
if os.path.exists(model[0]):
|
||||
os.remove(model[0])
|
||||
if os.path.exists(model[1]):
|
||||
os.remove(model[1])
|
||||
# 删除模型文件
|
||||
if os.path.exists(model['model_path']):
|
||||
os.remove(model['model_path'])
|
||||
if os.path.exists(model['scaler_path']):
|
||||
os.remove(model['scaler_path'])
|
||||
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
|
||||
cursor.execute("DELETE FROM trained_models WHERE id = ?", (id,))
|
||||
conn.commit()
|
||||
|
||||
return jsonify({'success': True})
|
||||
@ -1231,7 +1216,7 @@ def analyze_manufacturers():
|
||||
return jsonify({'error': '请选择数据集'}), 400
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取数据集中的装备和生产商数据
|
||||
cursor.execute("""
|
||||
@ -1239,13 +1224,22 @@ def analyze_manufacturers():
|
||||
FROM manufacturers m
|
||||
JOIN equipments e ON e.manufacturer_id = m.id
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
WHERE de.dataset_id = %s
|
||||
WHERE de.dataset_id = ?
|
||||
""", (dataset_id,))
|
||||
|
||||
manufacturers = cursor.fetchall()
|
||||
|
||||
if not manufacturers:
|
||||
return jsonify({'error': '数据集中没有生产商数据'}), 404
|
||||
logger.info("No manufacturer data in dataset, returning empty result")
|
||||
return jsonify({
|
||||
'manufacturer_names': [],
|
||||
'manufacturer_tech_levels': [],
|
||||
'manufacturer_scale_levels': [],
|
||||
'manufacturer_supply_chain_levels': [],
|
||||
'manufacturer_composite_scores': [],
|
||||
'region_distribution': [],
|
||||
'manufacturer_scores': []
|
||||
})
|
||||
|
||||
# 准备分析数据
|
||||
manufacturer_names = []
|
||||
|
||||
@ -1,9 +1,5 @@
|
||||
@echo off
|
||||
set FLASK_DEBUG=false
|
||||
set MYSQL_HOST=localhost
|
||||
set MYSQL_USER=root
|
||||
set MYSQL_PASSWORD=123456
|
||||
set MYSQL_DB=equipment_cost_db
|
||||
echo Starting Cost Prediction System...
|
||||
start /B run.exe
|
||||
start http://localhost:5001
|
||||
Loading…
Reference in New Issue
Block a user