将mysql改成sqlite,减少前端依赖

This commit is contained in:
tian 2026-04-28 23:30:48 +08:00
parent 137451ba7a
commit 48ba547c36
24 changed files with 2216 additions and 10338 deletions

File diff suppressed because one or more lines are too long

108
CLAUDE.md Normal file
View File

@ -0,0 +1,108 @@
# CLAUDE.md
本文件为 Claude Codeclaude.ai/code在此仓库中工作提供指引。
## 常用命令
```bash
# 安装后端依赖(核心:无需 MySQL、无需 PyTorch
pip install -e .
pip install -e ".[dev]" # 含开发工具pytest, black, mypy
pip install -e ".[torch]" # 安装可选 PyTorch神经网络训练需要
# 启动后端Flask0.0.0.0:5001
python run.py
# 启动前端开发服务器另开终端localhost:3000
cd frontend && npm install && npm run serve
# 构建前端生产版本
cd frontend && npm run build
# 运行测试
python -m pytest tests/
python -m pytest tests/test_demo_service.py -q # 单个测试文件
python -m pytest tests/test_demo_routes.py -q
# 代码格式化 / 类型检查
black src/ tests/ # 自动格式化line-length=88
mypy src/ # 类型检查
# 运行独立演示模式(仅需 5 个依赖)
cd demo_standalone && pip install -r requirements.txt && python server.py
```
## 项目概述
基于机器学习的装备成本预测系统。支持两种装备类型:**火箭炮**和**巡飞弹**。
### 技术栈
- **后端**Python 3.9-3.11, Flask 3.1+
- **数据库**SQLitePython 内置,零外部依赖),首次启动自动建表,无需手动安装配置
- **机器学习**scikit-learn核心、XGBoost、LightGBM
- **可选依赖**PyTorch仅神经网络训练需要约 800MB
- **前端**Vue 3 (Composition API) + Element Plus + ECharts, 使用 Vite 构建
### 关键文件
| 文件 | 用途 |
|---|---|
| `run.py` | 入口点,启动 Flask 服务器 |
| `src/app.py` | Flask 应用工厂 |
| `src/routes.py` | 所有 API 路由(约 1300 行) |
| `src/model_trainer.py` | ModelTrainer 类 + CostPredictionModelPyTorch NN |
| `src/data_preparation.py` | DataPreparation + EquipmentDataset |
| `src/cost_prediction.py` | CostPredictor预测编排 |
| `src/demo_service.py` | DemoModelService基于 CSV无需数据库 |
| `src/database/db_connection.py` | SQLite 数据库连接 + 建表 DDL内置 |
| `config.py` | 运行时配置 |
| `frontend/src/router/index.js` | Vue Router共 8 个路由 |
| `frontend/src/api/index.js` | Axios API 客户端 |
## 架构
### 与旧架构的核心变化
| 项目 | 旧架构MySQL | 新架构SQLite |
|------|----------------|-----------------|
| 数据库 | MySQL 8.0+,需单独安装运行 | SQLitePython 内置,零依赖 |
| 数据库依赖 | sqlalchemy, pymysql, cryptography, mysql-connector-python | 无(仅用 Python 标准库 sqlite3 |
| PyTorch | 硬依赖(顶层 import无则崩溃 | 可选依赖try/except 保护,无 PyTorch 可启动) |
| 前端构建 | Vue CLI + Babel + SCSS | Vite无需 Babel/SCSS |
| Vuex | 存在但完全空 | 已移除 |
| 依赖总数(核心)| 15+ | 5flask, numpy, pandas, scikit-learn, openpyxl |
### 训练数据流
```
用户界面 → POST /api/train → 查询 SQLite → DataPreparation特征提取 + 标准化)
→ ModelTrainer.fit_model()(训练 XGBoost, LightGBM, RF, GBM, PyTorch NN, PLS
→ 保存最优模型 → 写入 SQLite trained_models 表
```
### 预测数据流
```
用户界面 → POST /api/predict → 从 SQLite 加载最优模型 + 标准化器 → 提取特征
→ 特征标准化 → 预测 → ±20% 置信区间 → JSON 返回
```
### 机器学习模型
| 模型 | 标识 | 说明 |
|---|---|---|
| XGBoost | `xgboost` | 针对小数据量使用保守参数 |
| LightGBM | `lightgbm` | 针对小数据量使用保守参数 |
| Random Forest | `rf` | sklearn.ensemble |
| Gradient Boosting | `gbm` | sklearn.ensemble |
| PyTorch NN | `pytorch` | 可选(需安装 torch针对不同装备类型定制网络结构 |
| PLS 回归 | `pls` | 不参与最优模型评选 |
| Linear/Ridge/SVR/KNN | 仅演示 | 用于算法对比演示 |
## 重要注意事项
- **小数据量机器学习**:所有模型超参数均为保守设置(强正则化、浅树、低学习率、早停)
- **两种装备类型**火箭炮27+ 特征和巡飞弹24+ 特征),各自有独立数据库表和神经网络结构
- **生产商议价能力特征**:技术等级、规模、供应链、地区等信息被纳入特征,地区成本乘数(如美国 1.2 倍、中国 0.8 倍)
- **SQLite 首次使用**`data/equipment_cost.db` 文件在首次数据库操作时自动创建,无需手动初始化
- **编码规范**UTF-8 无 BOMLF 换行符,文件末尾保留换行符
- **中文内容**:绝不修改中文注释或文本。编辑中文附近代码时,完整保留原有中文内容

View File

@ -2,11 +2,8 @@ import os
class Config:
"""配置类"""
# 数据库配置
MYSQL_HOST = os.getenv('MYSQL_HOST', 'localhost')
MYSQL_USER = os.getenv('MYSQL_USER', 'root')
MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '123456')
MYSQL_DB = os.getenv('MYSQL_DB', 'equipment_cost_db')
# 数据库配置(使用 SQLite
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
# Flask配置
FLASK_HOST = '0.0.0.0'

View File

@ -1,5 +0,0 @@
module.exports = {
presets: [
'@vue/cli-plugin-babel/preset'
]
}

17
frontend/index.html Normal file
View File

@ -0,0 +1,17 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="/favicon.ico">
<title>装备成本估算系统</title>
</head>
<body>
<noscript>
<strong>装备成本估算系统需要启用 JavaScript 才能运行。</strong>
</noscript>
<div id="app"></div>
<script type="module" src="/src/main.js"></script>
</body>
</html>

11134
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -7,33 +7,24 @@
"npm": ">=8"
},
"scripts": {
"serve": "vue-cli-service serve",
"build": "vue-cli-service build",
"lint": "vue-cli-service lint"
"serve": "vite",
"build": "vite build",
"lint": "eslint --ext .js,.vue src/"
},
"dependencies": {
"axios": "^1.6.0",
"core-js": "^3.8.3",
"echarts": "^5.4.3",
"element-plus": "^2.4.2",
"vue": "^3.2.13",
"vue-router": "^4.0.3",
"vuex": "^4.0.0"
"vue-router": "^4.0.3"
},
"devDependencies": {
"@babel/core": "^7.12.16",
"@babel/eslint-parser": "^7.12.16",
"@element-plus/icons-vue": "^2.3.1",
"@vue/cli-plugin-babel": "~5.0.0",
"@vue/cli-plugin-eslint": "~5.0.0",
"@vue/cli-plugin-router": "~5.0.0",
"@vue/cli-plugin-vuex": "~5.0.0",
"@vue/cli-service": "~5.0.0",
"@vue/compiler-sfc": "^3.2.13",
"@vitejs/plugin-vue": "^5.0.0",
"eslint": "^7.32.0",
"eslint-plugin-vue": "^8.0.3",
"sass": "^1.32.7",
"sass-loader": "^12.0.0"
"sass-embedded": "^1.99.0",
"vite": "^5.0.0"
},
"eslintConfig": {
"root": true,
@ -45,17 +36,12 @@
"eslint:recommended"
],
"parserOptions": {
"parser": "@babel/eslint-parser"
"ecmaVersion": "latest",
"sourceType": "module"
},
"rules": {
"vue/multi-word-component-names": "off",
"no-unused-vars": "warn"
}
},
"browserslist": [
"> 1%",
"last 2 versions",
"not dead",
"not ie 11"
]
}
}

View File

@ -1,17 +0,0 @@
<!DOCTYPE html>
<html lang="">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
<title><%= htmlWebpackPlugin.options.title %></title>
</head>
<body>
<noscript>
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>
</noscript>
<div id="app"></div>
<!-- built files will be auto injected -->
</body>
</html>

View File

@ -1,7 +1,6 @@
import { createApp } from 'vue'
import App from './App.vue'
import router from './router'
import store from './store'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import './assets/styles/global.css'
@ -15,7 +14,6 @@ app.use(ElementPlus, {
size: 'default'
})
app.use(router)
app.use(store)
// 注册图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {

View File

@ -1,14 +0,0 @@
import { createStore } from 'vuex'
export default createStore({
state: {
},
getters: {
},
mutations: {
},
actions: {
},
modules: {
}
})

View File

@ -19,7 +19,7 @@ export default defineConfig({
port: 3000,
proxy: {
'/api': {
target: 'http://localhost:5000',
target: 'http://localhost:5001',
changeOrigin: true
}
}

View File

@ -10,23 +10,16 @@ dependencies = [
# Web框架
"flask>=3.1.0",
"flask-cors>=5.0.0",
# 数据库
"sqlalchemy>=2.0.36",
"pymysql>=1.1.1",
"cryptography>=43.0.0",
"mysql-connector-python>=8.0.0",
# 数据处理
"numpy>=1.26.0,<2.0.0",
"pandas>=2.2.0",
# 机器学习
"scikit-learn>=1.5.2",
"xgboost>=2.1.0",
"lightgbm>=4.5.0",
"torch==2.5.1",
# 工具
"openpyxl>=3.1.5",
"python-dotenv>=1.0.0",
@ -34,6 +27,10 @@ dependencies = [
]
[project.optional-dependencies]
# PyTorch 为可选依赖(安装约 800MB仅训练神经网络时需要
torch = [
"torch==2.5.1",
]
dev = [
# 测试工具
"pytest>=7.0",

View File

@ -1,15 +1,12 @@
flask>=3.1.0
flask-cors>=5.0.0
sqlalchemy>=2.0.36
pymysql>=1.1.1
cryptography>=43.0.0 # MySQL 8.0+ 认证需要
mysql-connector-python>=8.0.0 # 添加这行
numpy>=1.26.0,<2.0.0
pandas>=2.2.0
xgboost>=2.1.0
lightgbm>=4.5.0
scikit-learn>=1.5.2
openpyxl>=3.1.5 # 用于读取 .xlsx 文件
python-dotenv>=1.0.0 # 环境变量
openpyxl>=3.1.5
python-dotenv>=1.0.0

View File

@ -34,7 +34,6 @@ pyinstaller --clean `
--add-data "src/loitering_munition_data.sql;data" `
--add-data "src/rocket_artillery_data.sql;data" `
--add-data "src/manufacturer_data.sql;data" `
--add-data "src/schema.sql;data" `
--add-data "config.py;." `
--add-data "src;src" `
--add-data "frontend;frontend" `
@ -43,18 +42,11 @@ pyinstaller --clean `
--add-data "models;models" `
--collect-all "xgboost" `
--collect-all "lightgbm" `
--collect-all "torch" `
--collect-all "sklearn" `
--collect-all "numpy" `
--collect-all "pandas" `
--collect-all "sqlalchemy" `
--collect-all "pymysql" `
--collect-all "cryptography" `
--collect-all "flask" `
--collect-all "flask_cors" `
--hidden-import "xgboost.testing" `
--hidden-import "torch.utils.tensorboard" `
--hidden-import "pytest" `
run.py
# Copy necessary files

View File

@ -15,12 +15,6 @@ def create_app():
app = Flask(__name__)
CORS(app)
# 配置数据库连接
app.config['MYSQL_HOST'] = config.MYSQL_HOST
app.config['MYSQL_USER'] = config.MYSQL_USER
app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD
app.config['MYSQL_DB'] = config.MYSQL_DB
# 注册路由
app.register_blueprint(api_bp, url_prefix='/api')
logger.info("API blueprint registered")

View File

@ -1,5 +1,4 @@
import numpy as np
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from scipy import stats
@ -9,6 +8,16 @@ from src.database import get_db_connection
from src.feature_analysis import FeatureAnalysis
from .logger import setup_logger
# PyTorch 为可选依赖
try:
import torch
import torch.nn as nn
_HAS_TORCH = True
except ImportError:
torch = None
nn = None
_HAS_TORCH = False
logger = setup_logger(__name__)
class CostPredictor:
@ -18,66 +27,59 @@ class CostPredictor:
self.model = None
self.feature_analyzer = FeatureAnalysis()
self.equipment_type = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if _HAS_TORCH:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = None
self.load_model()
def load_model(self):
"""
加载预训练模型和标准化器
"""
try:
# 创建默认模型
self._create_default_model()
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
if _HAS_TORCH:
try:
self._create_default_model()
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
self._create_default_model()
def _create_default_model(self):
"""
创建默认模型并进行初始化训练
"""
import torch.nn as nn
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed.")
class DefaultModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(input_size, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
return self.layers(x)
# 创建示例数据
example_features = {
'length_m': [7.35, 10.2],
'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2],
'weight_kg': [13700, 28500],
'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8],
'height_m': [3.1, 3.2], 'weight_kg': [13700, 28500],
'max_range_km': [20.4, 70],
'firing_angle_horizontal': [102, 110],
'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1],
'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150],
'rate_of_fire': [40, 60]
'firing_angle_horizontal': [102, 110], 'firing_angle_vertical': [55, 60],
'rocket_length_m': [2.87, 4.1], 'rocket_diameter_mm': [122, 220],
'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60]
}
# 转换为 tensor
X = torch.tensor(list(example_features.values()), dtype=torch.float32).t()
y = torch.tensor([[800000], [4500000]], dtype=torch.float32)
# 训练标准化器
self.scaler_X.fit(X.numpy())
self.scaler_y.fit(y.numpy())
# 创建模型
self.model = DefaultModel(X.shape[1]).to(self.device)
self.equipment_type = '火箭炮'

View File

@ -1,7 +1,5 @@
from sklearn.preprocessing import StandardScaler
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import logging
from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
@ -9,9 +7,22 @@ from .logger import setup_logger
logger = setup_logger(__name__)
class EquipmentDataset(Dataset):
# PyTorch 为可选依赖
try:
import torch
from torch.utils.data import Dataset, DataLoader
_HAS_TORCH = True
except ImportError:
torch = None
Dataset = object
DataLoader = None
_HAS_TORCH = False
class EquipmentDataset(Dataset if _HAS_TORCH else object):
"""装备数据集类"""
def __init__(self, features, targets=None):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
self.features = torch.FloatTensor(features)
self.targets = torch.FloatTensor(targets) if targets is not None else None
@ -41,7 +52,7 @@ class DataPreparation:
# 获取数据库连接
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 获取所有生产商数据,用于计算特征
cursor.execute("""

View File

@ -1,37 +1,227 @@
import mysql.connector
from mysql.connector import Error
import sqlite3
from contextlib import contextmanager
import os
from dotenv import load_dotenv
from ..logger import setup_logger
# 获取logger
logger = setup_logger(__name__)
# 加载环境变量
load_dotenv()
# SQLite 数据库文件路径(相对于项目根目录)
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data')
DB_PATH = os.path.join(DB_DIR, 'equipment_cost.db')
# 建表 SQL
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS equipments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
type TEXT,
manufacturer TEXT,
manufacturer_id INTEGER,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE TABLE IF NOT EXISTS common_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
length_m REAL,
width_m REAL,
height_m REAL,
weight_kg REAL,
max_range_km REAL,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS rocket_artillery_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
firing_angle_horizontal REAL,
firing_angle_vertical REAL,
rocket_length_m REAL,
rocket_diameter_mm REAL,
rocket_weight_kg REAL,
rate_of_fire REAL,
combat_weight_kg REAL,
speed_kmh REAL,
min_range_km REAL,
max_range_km REAL,
mobility_type TEXT,
structure_layout TEXT,
engine_model TEXT,
engine_params TEXT,
power_hp REAL,
travel_range_km REAL,
fire_density REAL,
range_ratio REAL,
mobility_score INTEGER,
combat_readiness_score INTEGER,
deployment_score INTEGER,
terrain_adaptability_score INTEGER,
rocket_power_ratio REAL,
platform_efficiency REAL,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS loitering_munition_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
wingspan_m REAL,
warhead_weight_kg REAL,
max_speed_ms REAL,
cruise_speed_kmh REAL,
endurance_min REAL,
flight_time_min REAL,
max_range_km REAL,
max_payload_kg REAL,
ceiling_altitude_m REAL,
combat_radius_km REAL,
folded_length_mm REAL,
folded_width_mm REAL,
folded_height_mm REAL,
warhead_type TEXT,
launch_mode TEXT,
power_system TEXT,
guidance_system TEXT,
engine_power_kw REAL,
engine_thrust_n REAL,
datalink_range_km REAL,
guidance_accuracy_m REAL,
min_altitude_m REAL,
max_altitude_m REAL,
length_width_ratio REAL,
weight_range_ratio REAL,
speed_weight_ratio REAL,
guidance_system_score INTEGER,
warhead_power_score INTEGER,
warhead_type_code INTEGER,
launch_mode_code INTEGER,
power_system_code INTEGER,
guidance_system_code INTEGER,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS feature_encoding (
id INTEGER PRIMARY KEY AUTOINCREMENT,
feature_type TEXT,
feature_value TEXT,
code INTEGER,
UNIQUE(feature_type, feature_value)
);
CREATE TABLE IF NOT EXISTS cost_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
actual_cost REAL,
predicted_cost REAL,
prediction_date TEXT,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS custom_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
equipment_id INTEGER,
param_name TEXT,
param_value TEXT,
param_unit TEXT,
description TEXT,
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS datasets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT,
equipment_type TEXT NOT NULL,
purpose TEXT NOT NULL,
created_at TEXT DEFAULT (datetime('now','localtime')),
updated_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE TABLE IF NOT EXISTS dataset_equipments (
dataset_id INTEGER NOT NULL,
equipment_id INTEGER NOT NULL,
PRIMARY KEY (dataset_id, equipment_id),
FOREIGN KEY (dataset_id) REFERENCES datasets(id),
FOREIGN KEY (equipment_id) REFERENCES equipments(id)
);
CREATE TABLE IF NOT EXISTS trained_models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
model_type TEXT NOT NULL,
equipment_type TEXT NOT NULL,
model_path TEXT NOT NULL,
scaler_path TEXT NOT NULL,
r2_score REAL,
mae REAL,
rmse REAL,
feature_importance TEXT,
training_data_size INTEGER,
training_date TEXT DEFAULT (datetime('now','localtime')),
is_active INTEGER DEFAULT 0,
created_by TEXT
);
CREATE TABLE IF NOT EXISTS manufacturers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
country TEXT NOT NULL,
tech_level INTEGER NOT NULL,
scale_level INTEGER NOT NULL,
supply_chain_level INTEGER NOT NULL,
created_at TEXT DEFAULT (datetime('now','localtime')),
updated_at TEXT DEFAULT (datetime('now','localtime')),
UNIQUE(name)
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_equipment_type ON equipments(type);
CREATE INDEX IF NOT EXISTS idx_equipment_name ON equipments(name);
CREATE INDEX IF NOT EXISTS idx_cost_data_equipment ON cost_data(equipment_id);
CREATE INDEX IF NOT EXISTS idx_model_equipment_type ON trained_models(equipment_type);
CREATE INDEX IF NOT EXISTS idx_model_active ON trained_models(is_active);
CREATE INDEX IF NOT EXISTS idx_manufacturer_country ON manufacturers(country);
CREATE INDEX IF NOT EXISTS idx_manufacturer_tech_level ON manufacturers(tech_level);
CREATE INDEX IF NOT EXISTS idx_manufacturer_scale_level ON manufacturers(scale_level);
CREATE INDEX IF NOT EXISTS idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level);
CREATE INDEX IF NOT EXISTS idx_equipment_manufacturer ON equipments(manufacturer_id);
"""
def init_db():
"""初始化数据库:确保数据库文件和表存在"""
os.makedirs(DB_DIR, exist_ok=True)
conn = sqlite3.connect(DB_PATH)
conn.executescript(SCHEMA_SQL)
conn.commit()
conn.close()
logger.info(f"Database initialized at {DB_PATH}")
@contextmanager
def get_db_connection():
"""
数据库连接上下文管理器
返回的 connection 已设置 dict row_factory
以便按列名访问
"""
connection = None
conn = None
try:
connection = mysql.connector.connect(
host=os.getenv('MYSQL_HOST', 'localhost'),
user=os.getenv('MYSQL_USER', 'root'),
password=os.getenv('MYSQL_PASSWORD', '123456'),
database=os.getenv('MYSQL_DATABASE', 'equipment_cost_db')
)
logger.info("Database connection established")
yield connection
except Error as e:
logger.error(f"Error connecting to MySQL: {str(e)}")
# 确保数据库已初始化
if not os.path.exists(DB_PATH):
logger.info("Database file not found, initializing...")
init_db()
conn = sqlite3.connect(DB_PATH)
conn.row_factory = lambda c, r: {col[0]: r[idx] for idx, col in enumerate(c.description)}
conn.execute("PRAGMA foreign_keys = ON")
logger.debug("Database connection established")
yield conn
except sqlite3.Error as e:
logger.error(f"Database error: {str(e)}")
raise
finally:
if connection and connection.is_connected():
connection.close()
logger.info("Database connection closed")
if conn:
conn.close()
logger.debug("Database connection closed")

View File

@ -267,19 +267,23 @@ class FeatureAnalysis:
def calculate_manufacturer_features(self, manufacturer_data):
"""计算生产商相关的特征"""
try:
# 确保所有必要的字段都存在,使用默认值处理缺失数据
tech_level = float(manufacturer_data.get('tech_level', 0))
scale_level = float(manufacturer_data.get('scale_level', 0))
supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0))
country = manufacturer_data.get('country', '未知')
# 处理 None 值(数据库 NULL使用默认值
raw_tech = manufacturer_data.get('tech_level')
raw_scale = manufacturer_data.get('scale_level')
raw_supply = manufacturer_data.get('supply_chain_level')
tech_level = float(raw_tech) if raw_tech is not None else 0
scale_level = float(raw_scale) if raw_scale is not None else 0
supply_chain_level = float(raw_supply) if raw_supply is not None else 0
country = manufacturer_data.get('country', '未知') or '未知'
# 计算综合得分
composite_score = (
tech_level * 0.4 + # 技术水平权重最高
scale_level * 0.3 + # 规模水平次之
supply_chain_level * 0.3 # 供应链水平
)
# 计算区域系数(基于不同地区的成本差异)
region_factors = {
'美国': 1.2,
@ -292,9 +296,9 @@ class FeatureAnalysis:
'韩国': 0.9,
'日本': 1.1
}
region_factor = region_factors.get(country, 1.0)
# 记录计算过程
logger.info(f"Manufacturer features calculation:")
logger.info(f"Tech level: {tech_level}")
@ -303,7 +307,7 @@ class FeatureAnalysis:
logger.info(f"Country: {country}")
logger.info(f"Composite score: {composite_score}")
logger.info(f"Region factor: {region_factor}")
return {
'manufacturer_tech_level': tech_level,
'manufacturer_scale_level': scale_level,
@ -311,7 +315,7 @@ class FeatureAnalysis:
'manufacturer_composite_score': composite_score,
'manufacturer_region_factor': region_factor
}
except Exception as e:
logger.error(f"Error calculating manufacturer features: {str(e)}")
# 返回默认值而不是抛出异常,确保分析过程可以继续

View File

@ -27,7 +27,7 @@ def import_training_data(excel_file):
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipments
WHERE name = %s AND type = '火箭炮'
WHERE name = ? AND type = '火箭炮'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
@ -38,7 +38,7 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s)
VALUES (?, ?, ?)
""", (row['名称'], '火箭炮', row['制造商']))
equipment_id = cursor.lastrowid
@ -47,7 +47,7 @@ def import_training_data(excel_file):
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?)
""", (
equipment_id,
row['总长_m'] if pd.notna(row['总长_m']) else None,
@ -65,7 +65,7 @@ def import_training_data(excel_file):
combat_weight_kg, speed_kmh, min_range_km, mobility_type,
structure_layout, engine_model, engine_params, power_hp,
travel_range_km)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
equipment_id,
row['方向射界_度'] if pd.notna(row['方向射界_度']) else None,
@ -89,7 +89,7 @@ def import_training_data(excel_file):
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
VALUES (?, ?)
""", (equipment_id, row['成本_元']))
logger.info("火箭炮数据导入完成")
@ -105,8 +105,8 @@ def import_training_data(excel_file):
equipment_names.add(row['名称'])
# 检查是否已存在相同名称的装备
cursor.execute("""
SELECT id FROM equipment
WHERE name = %s AND type = '巡飞弹'
SELECT id FROM equipments
WHERE name = ? AND type = '巡飞弹'
""", (row['名称'],))
existing_equipment = cursor.fetchone()
@ -117,7 +117,7 @@ def import_training_data(excel_file):
# 插入基本信息
cursor.execute("""
INSERT INTO equipments (name, type, manufacturer)
VALUES (%s, %s, %s)
VALUES (?, ?, ?)
""", (
row['名称'],
'巡飞弹',
@ -130,7 +130,7 @@ def import_training_data(excel_file):
cursor.execute("""
INSERT INTO common_params
(equipment_id, length_m, width_m, height_m, weight_kg, max_range_km)
VALUES (%s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?)
""", (
equipment_id,
float(row['弹长_m']) if pd.notna(row['弹长_m']) else None,
@ -147,7 +147,7 @@ def import_training_data(excel_file):
cruise_speed_kmh, flight_time_min, warhead_type, launch_mode,
folded_length_mm, folded_width_mm, folded_height_mm,
power_system, guidance_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
equipment_id,
float(row['翼展_m']) if pd.notna(row['翼展_m']) else None,
@ -168,7 +168,7 @@ def import_training_data(excel_file):
if pd.notna(row['成本_元']):
cursor.execute("""
INSERT INTO cost_data (equipment_id, actual_cost)
VALUES (%s, %s)
VALUES (?, ?)
""", (equipment_id, float(row['成本_元'])))
logger.info("巡飞弹数据导入完成")
@ -190,48 +190,48 @@ def import_training_data(excel_file):
# 获取装备ID - 使用新的游标
logger.debug(f"查询装备ID: {equipment_name}")
with conn.cursor() as id_cursor:
id_cursor.execute("""
SELECT id FROM equipments WHERE name = %s
""", (equipment_name,))
result = id_cursor.fetchone()
id_cursor = conn.cursor()
id_cursor.execute("""
SELECT id FROM equipments WHERE name = ?
""", (equipment_name,))
result = id_cursor.fetchone()
if not result:
logger.warning(f"未找到装备: {equipment_name}")
continue
equipment_id = result[0]
equipment_id = result['id']
logger.debug(f"找到装备ID: {equipment_id}")
# 检查参数是否存在 - 使用新的游标
logger.debug(f"检查参数是否存在: equipment_id={equipment_id}, param_name='{param_name}'")
with conn.cursor() as check_cursor:
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = %s AND param_name = %s
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
check_cursor = conn.cursor()
check_cursor.execute("""
SELECT id FROM custom_params
WHERE equipment_id = ? AND param_name = ?
""", (equipment_id, param_name))
exists = check_cursor.fetchone()
if exists:
logger.warning(f"装备 '{equipment_name}' 的参数 '{param_name}' 已存在,跳过导入")
continue
# 插入新的参数 - 使用新的游标
param_value = str(row['参数值']) if pd.notna(row['参数值']) else None
param_unit = row['参数单位'] if pd.notna(row['参数单位']) else None
param_desc = row['参数说明'] if pd.notna(row['参数说明']) else None
logger.debug(f"插入新参数: value='{param_value}', unit='{param_unit}', desc='{param_desc}'")
with conn.cursor() as insert_cursor:
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (%s, %s, %s, %s, %s)
""", (
equipment_id,
param_name,
param_value,
param_unit,
insert_cursor = conn.cursor()
insert_cursor.execute("""
INSERT INTO custom_params
(equipment_id, param_name, param_value, param_unit, description)
VALUES (?, ?, ?, ?, ?)
""", (
equipment_id,
param_name,
param_value,
param_unit,
param_desc
))
logger.debug(f"成功插入参数记录")

237
src/import_sql_data.py Normal file
View File

@ -0,0 +1,237 @@
"""
导入 SQL 数据文件到 SQLite 数据库
处理 src/ 目录下的 MySQL 格式 SQL 数据文件:
- rocket_artillery_data.sql (96 条火箭炮数据)
- loitering_munition_data.sql (100 条巡飞弹数据)
- manufacturer_data.sql (生产商数据)
"""
import re
import os
import sys
from src.database.db_connection import get_db_connection, DB_PATH
from src.logger import setup_logger
logger = setup_logger(__name__)
def parse_mysql_insert(sql_text):
"""
解析 MySQL INSERT 语句返回 (table_name, columns, values_list)
columns 是列名列表, values_list 是每行的值列表
"""
# 去掉块注释 /* ... */
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# 去掉行注释 --
sql_text = re.sub(r'--.*', '', sql_text)
results = []
# 匹配 INSERT INTO table (columns) VALUES (values), (values), ...;
pattern = re.compile(
r"INSERT\s+INTO\s+(\w+)\s*\(([^)]+)\)\s*VALUES\s*(.+?);",
re.DOTALL | re.IGNORECASE
)
for match in pattern.finditer(sql_text):
table_name = match.group(1).strip()
columns_str = match.group(2).strip()
values_str = match.group(3).strip()
# 提取列名(去掉注释部分)
columns = []
for col in columns_str.split(','):
col = col.strip()
# 去掉行内注释
col = re.sub(r'\s*--.*', '', col).strip()
columns.append(col)
# 解析 values - 处理多个用逗号分隔的 ( ... ) 元组
values_list = []
current_tuple = []
depth = 0
current_val = ''
in_string = False
string_char = None
for ch in values_str:
if in_string:
if ch == string_char:
in_string = False
current_val += ch
continue
if ch in ("'", '"'):
in_string = True
string_char = ch
current_val += ch
continue
if ch == '(':
if depth > 0:
current_val += ch
depth += 1
if depth == 1:
current_val = ''
continue
if ch == ')':
depth -= 1
if depth == 0:
current_tuple.append(current_val.strip())
values_list.append(current_tuple)
current_tuple = []
else:
current_val += ch
continue
if ch == ',' and depth == 1:
current_tuple.append(current_val.strip())
current_val = ''
continue
if depth >= 1:
current_val += ch
results.append((table_name, columns, values_list))
return results
def convert_value(val):
"""将 MySQL 值字符串转为 Python 对象"""
val = val.strip()
if val.upper() == 'NULL' or val == '':
return None
if val.upper() == 'TRUE':
return 1
if val.upper() == 'FALSE':
return 0
# 字符串
if (val.startswith("'") and val.endswith("'")) or \
(val.startswith('"') and val.endswith('"')):
return val[1:-1]
# 数字
try:
if '.' in val:
return float(val)
return int(val)
except ValueError:
return val
def import_sql_file(filepath):
"""导入单个 SQL 文件"""
logger.info(f"Reading {filepath}...")
with open(filepath, 'r', encoding='utf-8') as f:
sql_text = f.read()
parsed = parse_mysql_insert(sql_text)
total_rows = 0
with get_db_connection() as conn:
cursor = conn.cursor()
for table_name, columns, values_list in parsed:
if not values_list:
continue
# 构建参数化 INSERT
placeholders = ','.join(['?'] * len(columns))
col_names = ','.join(columns)
sql = f"INSERT OR IGNORE INTO {table_name} ({col_names}) VALUES ({placeholders})"
row_count = 0
for values in values_list:
if len(values) != len(columns):
logger.warning(f"Column mismatch in {table_name}: "
f"expected {len(columns)} values, got {len(values)}: {values}")
continue
converted = [convert_value(v) for v in values]
try:
cursor.execute(sql, converted)
if cursor.rowcount > 0:
row_count += 1
except Exception as e:
logger.warning(f"Error inserting into {table_name}: {e}")
logger.warning(f" Values: {converted}")
if row_count > 0:
logger.info(f" {table_name}: inserted {row_count} rows")
total_rows += row_count
conn.commit()
return total_rows
def run_manufacturer_update():
"""执行 manufacturer_id 更新(把 equipments.manufacturer 名称映射为 manufacturers.id"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE equipments
SET manufacturer_id = (
SELECT id FROM manufacturers WHERE name = equipments.manufacturer
)
WHERE manufacturer_id IS NULL
AND manufacturer IS NOT NULL
AND EXISTS (SELECT 1 FROM manufacturers WHERE name = equipments.manufacturer)
""")
conn.commit()
updated = cursor.rowcount
if updated > 0:
logger.info(f"Updated manufacturer_id for {updated} equipment(s)")
else:
logger.info("No manufacturer_id updates needed")
def import_all():
"""导入所有 SQL 数据文件"""
base_dir = os.path.dirname(os.path.dirname(__file__))
# 清除现有数据库,重新开始
if os.path.exists(DB_PATH):
os.remove(DB_PATH)
logger.info(f"Removed existing database: {DB_PATH}")
files = [
(os.path.join(base_dir, 'src', 'manufacturer_data.sql'), '生产商数据'),
(os.path.join(base_dir, 'src', 'rocket_artillery_data.sql'), '火箭炮数据'),
(os.path.join(base_dir, 'src', 'loitering_munition_data.sql'), '巡飞弹数据'),
]
total = 0
for filepath, label in files:
if not os.path.exists(filepath):
logger.warning(f"File not found: {filepath}")
continue
logger.info(f"正在导入 {label}...")
rows = import_sql_file(filepath)
logger.info(f" {label}: 共导入 {rows}")
total += rows
# 更新 manufacturer_id
run_manufacturer_update()
# 统计结果
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as cnt FROM equipments")
eq_count = cursor.fetchone()['cnt']
cursor.execute("SELECT type, COUNT(*) as cnt FROM equipments GROUP BY type")
by_type = {r['type']: r['cnt'] for r in cursor.fetchall()}
cursor.execute("SELECT COUNT(*) as cnt FROM manufacturers")
mf_count = cursor.fetchone()['cnt']
logger.info("=" * 50)
logger.info(f"导入完成!")
logger.info(f" 装备总计: {eq_count}")
for t, c in by_type.items():
logger.info(f" {t}: {c}")
logger.info(f" 生产商: {mf_count}")
logger.info("=" * 50)
return eq_count
if __name__ == '__main__':
import_all()

View File

@ -1,7 +1,4 @@
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
@ -11,40 +8,43 @@ from datetime import datetime
import json
from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
from src.data_preparation import DataPreparation, EquipmentDataset
from .logger import setup_logger
import math
# PyTorch 为可选依赖
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.data_preparation import EquipmentDataset
_HAS_TORCH = True
except ImportError:
torch = None
nn = None
DataLoader = None
EquipmentDataset = None
_HAS_TORCH = False
logger = setup_logger(__name__)
class CostPredictionModel(nn.Module):
# 条件基类:有 PyTorch 时继承 nn.Module否则继承 object
_PYTORCH_BASE = nn.Module if _HAS_TORCH else object
class CostPredictionModel(_PYTORCH_BASE):
def __init__(self, input_size, equipment_type):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
super().__init__()
self.equipment_type = equipment_type
if equipment_type == '火箭炮':
# 火箭炮使用更简单和稳定的网络结构
self.net = nn.Sequential(
# 第一层:特征映射
nn.Linear(input_size, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
# 第二层:特征提取
nn.Linear(32, 16),
nn.ReLU(),
nn.BatchNorm1d(16),
# 第三层:特征整合
nn.Linear(16, 8),
nn.ReLU(),
nn.BatchNorm1d(8),
# 输出层
nn.Linear(input_size, 32), nn.ReLU(), nn.BatchNorm1d(32),
nn.Linear(32, 16), nn.ReLU(), nn.BatchNorm1d(16),
nn.Linear(16, 8), nn.ReLU(), nn.BatchNorm1d(8),
nn.Linear(8, 1)
)
# 使用正交初始化
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=0.5)
@ -52,68 +52,39 @@ class CostPredictionModel(nn.Module):
elif isinstance(m, nn.BatchNorm1d):
torch.nn.init.constant_(m.weight, 0.5)
torch.nn.init.constant_(m.bias, 0.0)
self.net.apply(init_weights)
else: # 巡飞弹保持原有结构
# 生产商特征网络 - 更简单的结构
else:
self.manufacturer_net = nn.Sequential(
nn.Linear(5, 4),
nn.ReLU(),
nn.BatchNorm1d(4),
nn.Dropout(0.2)
nn.Linear(5, 4), nn.ReLU(), nn.BatchNorm1d(4), nn.Dropout(0.2)
)
# 巡飞弹特征网络 - 较深的结构
self.equipment_net = nn.Sequential(
nn.Linear(input_size - 5, 64),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(64),
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(32),
nn.Dropout(0.2),
nn.Linear(32, 16),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(16),
nn.Dropout(0.2)
nn.Linear(input_size - 5, 64), nn.LeakyReLU(0.1), nn.BatchNorm1d(64), nn.Dropout(0.2),
nn.Linear(64, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2)
)
# 合并网络 - 较复杂的结构
self.combined_net = nn.Sequential(
nn.Linear(20, 32), # 4 + 16 = 20
nn.LeakyReLU(0.1),
nn.BatchNorm1d(32),
nn.Dropout(0.2),
nn.Linear(32, 16),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(16),
nn.Dropout(0.2),
nn.Linear(16, 8),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(8),
nn.Linear(20, 32), nn.LeakyReLU(0.1), nn.BatchNorm1d(32), nn.Dropout(0.2),
nn.Linear(32, 16), nn.LeakyReLU(0.1), nn.BatchNorm1d(16), nn.Dropout(0.2),
nn.Linear(16, 8), nn.LeakyReLU(0.1), nn.BatchNorm1d(8),
nn.Linear(8, 1)
)
def forward(self, x):
if self.equipment_type == '火箭炮':
return self.net(x)
else:
# 分离特征
manufacturer_features = x[:, -5:]
equipment_features = x[:, :-5]
# 特征处理
manu_out = self.manufacturer_net(manufacturer_features)
equip_out = self.equipment_net(equipment_features)
# 特征融合
combined = torch.cat([equip_out, manu_out], dim=1)
return self.combined_net(combined)
class ModelTrainer:
def __init__(self):
if not _HAS_TORCH:
raise ImportError("PyTorch is not installed. Install with: pip install torch")
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.feature_scaler = None
@ -307,7 +278,7 @@ class ModelTrainer:
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type != %s
WHERE equipment_type = ? AND model_type != ?
""", (equipment_type, 'pls'))
# 保存新模型记录
@ -316,7 +287,7 @@ class ModelTrainer:
model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by,
r2_score, mae, rmse
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
""", (
f"{equipment_type}_{timestamp}",
'pytorch',
@ -343,10 +314,10 @@ class ModelTrainer:
try:
# 从数据库获取最新的激活模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE
WHERE equipment_type = ? AND model_type = ? AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type, model_type))
model_record = cursor.fetchone()
@ -707,13 +678,13 @@ class ModelTrainer:
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type != %s
WHERE equipment_type = ? AND model_type != ?
""", (equipment_type, 'pls'))
else:
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type = %s
WHERE equipment_type = ? AND model_type = ?
""", (equipment_type, 'pls'))
# 保存新模型记录
@ -722,7 +693,7 @@ class ModelTrainer:
model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by,
r2_score, mae, rmse
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
) VALUES (?, ?, ?, ?, ?, datetime('now','localtime'), TRUE, ?, ?, ?, ?)
""", (
f"{equipment_type}_{model_type}_{timestamp}",
model_type,

View File

@ -4,16 +4,23 @@ from .feature_analysis import FeatureAnalysis
import pandas as pd
from datetime import datetime
import numpy as np
import mysql.connector
from sklearn.metrics import mean_absolute_error
from .create_template import create_excel_template
import json
import os
from .data_preparation import DataPreparation
from .model_trainer import ModelTrainer
from .database import get_db_connection
from .demo_service import DemoModelService
from .logger import setup_logger
import torch
# PyTorch 为可选导入
try:
import torch
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
# 创建蓝图
api_bp = Blueprint('api', __name__)
@ -71,7 +78,7 @@ def demo_algorithms():
@api_bp.route('/demo/dataset', methods=['GET'])
def demo_dataset():
"""Return the local demo dataset summary without using MySQL."""
"""Return the local demo dataset summary without using a database."""
try:
service = DemoModelService()
return jsonify(service.get_dataset_summary())
@ -101,11 +108,11 @@ def predict():
# 获取最新的激活模型非PLS模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type != 'pls' # 明确排除PLS模型
WHERE equipment_type = ?
AND model_type != 'pls'
AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
@ -147,14 +154,14 @@ def analyze_features():
return jsonify({'error': '请选择数据集'}), 400
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 首先获取数据集的装备类型
cursor.execute("""
SELECT DISTINCT e.type
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
LIMIT 1
""", (dataset_id,))
@ -181,7 +188,7 @@ def analyze_features():
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL
""", (dataset_id,))
else:
@ -205,7 +212,7 @@ def analyze_features():
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL
""", (dataset_id,))
@ -247,27 +254,27 @@ def analyze_features():
# 添加装备特有的分析数据
if equipment_data[0]['type'] == '火箭炮':
rocket_data = {
'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data],
'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data],
'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data],
'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data],
'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data],
'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data]
'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data],
'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data],
'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data],
'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data],
'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data],
'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data]
}
analysis_result.update(rocket_data)
else:
missile_data = {
'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data],
'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data],
'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data],
'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data],
'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data],
'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data],
'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data],
'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data],
'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data],
'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data],
'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data]
'length_width_ratio': [float(item['length_width_ratio']) if item['length_width_ratio'] is not None else 0 for item in equipment_data],
'weight_range_ratio': [float(item['weight_range_ratio']) if item['weight_range_ratio'] is not None else 0 for item in equipment_data],
'speed_weight_ratio': [float(item['speed_weight_ratio']) if item['speed_weight_ratio'] is not None else 0 for item in equipment_data],
'guidance_system_score': [float(item['guidance_system_score']) if item['guidance_system_score'] is not None else 0 for item in equipment_data],
'warhead_power_score': [float(item['warhead_power_score']) if item['warhead_power_score'] is not None else 0 for item in equipment_data],
'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item['guidance_accuracy_m'] is not None else 0 for item in equipment_data],
'datalink_range_km': [float(item['datalink_range_km']) if item['datalink_range_km'] is not None else 0 for item in equipment_data],
'max_altitude_m': [float(item['max_altitude_m']) if item['max_altitude_m'] is not None else 0 for item in equipment_data],
'min_altitude_m': [float(item['min_altitude_m']) if item['min_altitude_m'] is not None else 0 for item in equipment_data],
'engine_power_kw': [float(item['engine_power_kw']) if item['engine_power_kw'] is not None else 0 for item in equipment_data],
'engine_thrust_n': [float(item['engine_thrust_n']) if item['engine_thrust_n'] is not None else 0 for item in equipment_data]
}
analysis_result.update(missile_data)
@ -295,7 +302,7 @@ def train_model():
# 获取训练数据
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 获取训练集数据(包含生产商信息)
if equipment_type == '火箭炮':
@ -309,7 +316,7 @@ def train_model():
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,))
else:
@ -323,7 +330,7 @@ def train_model():
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,))
@ -342,7 +349,7 @@ def train_model():
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,))
else:
@ -356,7 +363,7 @@ def train_model():
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,))
validation_data = cursor.fetchall()
@ -476,52 +483,48 @@ def get_equipment_data():
"""获取装备数据列表"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 获取所有装备数据(使用equipment_id替代id)
cursor.execute("""
SELECT e.id as equipment_id, e.name, e.type, e.manufacturer,
cp.length_m, cp.width_m, cp.height_m, cp.weight_kg,
cd.actual_cost, cd.predicted_cost,
CASE
CASE
WHEN e.type = '火箭炮' THEN (
SELECT CONCAT(
firing_angle_horizontal, ',',
firing_angle_vertical, ',',
rocket_length_m, ',',
rocket_diameter_mm, ',',
rocket_weight_kg, ',',
rate_of_fire, ',',
combat_weight_kg, ',',
speed_kmh, ',',
min_range_km, ',',
max_range_km, ',',
mobility_type, ',',
structure_layout, ',',
engine_model, ',',
engine_params, ',',
power_hp, ',',
travel_range_km
)
SELECT firing_angle_horizontal || ',' ||
firing_angle_vertical || ',' ||
rocket_length_m || ',' ||
rocket_diameter_mm || ',' ||
rocket_weight_kg || ',' ||
rate_of_fire || ',' ||
combat_weight_kg || ',' ||
speed_kmh || ',' ||
min_range_km || ',' ||
max_range_km || ',' ||
mobility_type || ',' ||
structure_layout || ',' ||
engine_model || ',' ||
engine_params || ',' ||
power_hp || ',' ||
travel_range_km
FROM rocket_artillery_params
WHERE equipment_id = e.id
)
WHEN e.type = '巡飞弹' THEN (
SELECT CONCAT(
wingspan_m, ',',
warhead_weight_kg, ',',
max_speed_ms, ',',
cruise_speed_kmh, ',',
endurance_min, ',',
max_range_km, ',',
max_payload_kg, ',',
ceiling_altitude_m, ',',
combat_radius_km, ',',
warhead_type, ',',
launch_mode, ',',
power_system, ',',
guidance_system
)
SELECT wingspan_m || ',' ||
warhead_weight_kg || ',' ||
max_speed_ms || ',' ||
cruise_speed_kmh || ',' ||
endurance_min || ',' ||
max_range_km || ',' ||
max_payload_kg || ',' ||
ceiling_altitude_m || ',' ||
combat_radius_km || ',' ||
warhead_type || ',' ||
launch_mode || ',' ||
power_system || ',' ||
guidance_system
FROM loitering_munition_params
WHERE equipment_id = e.id
)
@ -548,19 +551,17 @@ def delete_equipment(id):
删除装备数据
"""
try:
db = get_db_connection()
cursor = db.cursor()
# 删除相关数据
cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM equipments WHERE id = %s", (id,))
db.commit()
cursor.close()
db.close()
with get_db_connection() as conn:
cursor = conn.cursor()
# 删除相关数据
cursor.execute("DELETE FROM cost_data WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM common_params WHERE equipment_id = ?", (id,))
cursor.execute("DELETE FROM equipments WHERE id = ?", (id,))
conn.commit()
return jsonify({'status': 'success'})
@ -594,16 +595,6 @@ def download_template():
logger.error(f"Error creating template: {str(e)}")
return jsonify({'error': str(e)}), 500
def get_db_connection():
"""
获取数据库连接
"""
return mysql.connector.connect(
host="localhost",
user="root",
password="123456",
database="equipment_cost_db"
)
@api_bp.route('/pls/predict', methods=['POST'])
def pls_predict():
@ -614,11 +605,11 @@ def pls_predict():
# 获取最新的PLS模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type = 'pls' # 只选择PLS模型
WHERE equipment_type = ?
AND model_type = 'pls'
AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
@ -696,38 +687,38 @@ def update_equipment(id):
# 更新装备基本信息
cursor.execute("""
UPDATE equipments
SET name = %s, manufacturer = %s
WHERE id = %s
SET name = ?, manufacturer = ?
WHERE id = ?
""", (data['name'], data['manufacturer'], equipment_id))
# 更新通用参数
cursor.execute("""
UPDATE common_params
SET length_m = %s, width_m = %s, height_m = %s, weight_kg = %s
WHERE equipment_id = %s
SET length_m = ?, width_m = ?, height_m = ?, weight_kg = ?
WHERE equipment_id = ?
""", (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id))
# 根据装备类型更新特有参数
if data['type'] == '火箭炮':
cursor.execute("""
UPDATE rocket_artillery_params
SET firing_angle_horizontal = %s,
firing_angle_vertical = %s,
rocket_length_m = %s,
rocket_diameter_mm = %s,
rocket_weight_kg = %s,
rate_of_fire = %s,
combat_weight_kg = %s,
speed_kmh = %s,
min_range_km = %s,
max_range_km = %s,
mobility_type = %s,
structure_layout = %s,
engine_model = %s,
engine_params = %s,
power_hp = %s,
travel_range_km = %s
WHERE equipment_id = %s
SET firing_angle_horizontal = ?,
firing_angle_vertical = ?,
rocket_length_m = ?,
rocket_diameter_mm = ?,
rocket_weight_kg = ?,
rate_of_fire = ?,
combat_weight_kg = ?,
speed_kmh = ?,
min_range_km = ?,
max_range_km = ?,
mobility_type = ?,
structure_layout = ?,
engine_model = ?,
engine_params = ?,
power_hp = ?,
travel_range_km = ?
WHERE equipment_id = ?
""", (
data['firing_angle_horizontal'],
data['firing_angle_vertical'],
@ -750,17 +741,17 @@ def update_equipment(id):
else:
cursor.execute("""
UPDATE loitering_munition_params
SET wingspan_m = %s,
warhead_weight_kg = %s,
max_speed_ms = %s,
cruise_speed_kmh = %s,
endurance_min = %s,
max_range_km = %s,
warhead_type = %s,
launch_mode = %s,
power_system = %s,
guidance_system = %s
WHERE equipment_id = %s
SET wingspan_m = ?,
warhead_weight_kg = ?,
max_speed_ms = ?,
cruise_speed_kmh = ?,
endurance_min = ?,
max_range_km = ?,
warhead_type = ?,
launch_mode = ?,
power_system = ?,
guidance_system = ?
WHERE equipment_id = ?
""", (
data['wingspan_m'],
data['warhead_weight_kg'],
@ -779,8 +770,8 @@ def update_equipment(id):
if 'actual_cost' in data:
cursor.execute("""
UPDATE cost_data
SET actual_cost = %s
WHERE equipment_id = %s
SET actual_cost = ?
WHERE equipment_id = ?
""", (data['actual_cost'], equipment_id))
conn.commit()
@ -798,7 +789,7 @@ def get_equipment_details(id):
logger.info(f"Getting details for equipment ID: {id}")
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 获取装备基本信息类型
cursor.execute("""
@ -806,7 +797,7 @@ def get_equipment_details(id):
FROM equipments e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.id = %s
WHERE e.id = ?
""", (id,))
result = cursor.fetchone()
@ -822,14 +813,14 @@ def get_equipment_details(id):
cursor.execute("""
SELECT *
FROM rocket_artillery_params
WHERE equipment_id = %s
WHERE equipment_id = ?
""", (id,))
custom_params = cursor.fetchone()
else:
cursor.execute("""
SELECT *
FROM loitering_munition_params
WHERE equipment_id = %s
WHERE equipment_id = ?
""", (id,))
custom_params = cursor.fetchone()
@ -853,7 +844,7 @@ def get_datasets():
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
cursor.execute("""
SELECT d.*,
COUNT(de.equipment_id) as equipment_count,
@ -872,12 +863,7 @@ def get_datasets():
else:
dataset['equipment_names'] = []
# 格式化时间和数值字段
for dataset in datasets:
if dataset['created_at']:
dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S')
if dataset['updated_at']:
dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S')
# SQLite 已存储日期为文本格式,无需转换
return jsonify(datasets)
except Exception as e:
@ -891,14 +877,14 @@ def get_dataset(id):
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 获取数据集基本信息
cursor.execute("""
SELECT d.*,
COUNT(de.equipment_id) as equipment_count
FROM datasets d
LEFT JOIN dataset_equipments de ON d.id = de.dataset_id
WHERE d.id = %s
WHERE d.id = ?
GROUP BY d.id
""", (id,))
dataset = cursor.fetchone()
@ -913,7 +899,7 @@ def get_dataset(id):
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
""", (id,))
equipment = cursor.fetchall()
@ -952,14 +938,14 @@ def create_dataset():
# 1. 验证装备ID是否存在
if 'equipment_ids' in data and data['equipment_ids']:
# 直接从 equipment 表查询,不需要 JOIN
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
ids = data['equipment_ids']
placeholders = ','.join(['?'] * len(ids))
cursor.execute(f"""
SELECT DISTINCT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s
""", (data['equipment_type'],))
WHERE id IN ({placeholders}) AND type = ?
""", (*ids, data['equipment_type']))
valid_ids = [row[0] for row in cursor.fetchall()]
valid_ids = [row['id'] for row in cursor.fetchall()]
logger.info(f"Valid equipment IDs: {valid_ids}")
# 如果有无效的ID返回错误
@ -971,7 +957,7 @@ def create_dataset():
# 2. 创建数据集
cursor.execute("""
INSERT INTO datasets (name, description, equipment_type, purpose)
VALUES (%s, %s, %s, %s)
VALUES (?, ?, ?, ?)
""", (data['name'], data['description'], data['equipment_type'], data['purpose']))
dataset_id = cursor.lastrowid
@ -982,7 +968,7 @@ def create_dataset():
values = [(dataset_id, equipment_id) for equipment_id in valid_ids]
cursor.executemany("""
INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s)
VALUES (?, ?)
""", values)
logger.info(f"Added {len(values)} equipment associations")
@ -1004,14 +990,15 @@ def update_dataset(id):
cursor = conn.cursor()
# 1. 验证装备ID是否存在
if 'equipment_ids' in data:
equipment_ids_str = ','.join(map(str, data['equipment_ids']))
if 'equipment_ids' in data and data['equipment_ids']:
ids = data['equipment_ids']
placeholders = ','.join(['?'] * len(ids))
cursor.execute(f"""
SELECT id FROM equipments
WHERE id IN ({equipment_ids_str}) AND type = %s
""", (data['equipment_type'],))
WHERE id IN ({placeholders}) AND type = ?
""", (*ids, data['equipment_type']))
valid_ids = [row[0] for row in cursor.fetchall()]
valid_ids = [row['id'] for row in cursor.fetchall()]
logger.info(f"Valid equipment IDs: {valid_ids}")
# 如果有无效的ID返回错误
@ -1023,21 +1010,21 @@ def update_dataset(id):
# 2. 更新数据集基本信息
cursor.execute("""
UPDATE datasets
SET name = %s, description = %s, equipment_type = %s, purpose = %s
WHERE id = %s
SET name = ?, description = ?, equipment_type = ?, purpose = ?
WHERE id = ?
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
# 3. 更新装备关联
if 'equipment_ids' in data:
# 先删除旧的关联
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
# 添加新的关联
if valid_ids: # 确保有有效的ID才执行插入
values = [(id, equipment_id) for equipment_id in valid_ids]
cursor.executemany("""
INSERT INTO dataset_equipments (dataset_id, equipment_id)
VALUES (%s, %s)
VALUES (?, ?)
""", values)
logger.info(f"Updated {len(values)} equipment associations")
@ -1058,10 +1045,10 @@ def delete_dataset(id):
cursor = conn.cursor()
# 删除装备关联
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,))
cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = ?", (id,))
# 删除数据集
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
cursor.execute("DELETE FROM datasets WHERE id = ?", (id,))
conn.commit()
return jsonify({'success': True})
@ -1076,10 +1063,10 @@ def get_latest_model(equipment_type):
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
WHERE equipment_type = ? AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
@ -1095,7 +1082,7 @@ def get_models():
"""获取模型列表"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM trained_models
ORDER BY training_date DESC
@ -1105,9 +1092,7 @@ def get_models():
# 格式化时间和数值字段
for model in models:
# 将数据库中的datetime转换为ISO格式字符串
if model['training_date']:
model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S')
# SQLite 已存储日期为文本格式,无需转换
if model['r2_score'] is not None:
model['r2_score'] = float(model['r2_score'])
if model['mae'] is not None:
@ -1130,12 +1115,12 @@ def activate_model(id):
"""激活指定的模型"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True) # 使用字典游标
cursor = conn.cursor() # 使用字典游标
# 获取模型信息
cursor.execute("""
SELECT equipment_type, model_type FROM trained_models
WHERE id = %s
WHERE id = ?
""", (id,))
model = cursor.fetchone()
@ -1146,14 +1131,14 @@ def activate_model(id):
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type = %s
WHERE equipment_type = ? AND model_type = ?
""", (model['equipment_type'], model['model_type']))
# 激活指定模型
cursor.execute("""
UPDATE trained_models
SET is_active = TRUE
WHERE id = %s
WHERE id = ?
""", (id,))
conn.commit()
@ -1176,21 +1161,21 @@ def delete_model(id):
cursor.execute("""
SELECT model_path, scaler_path
FROM trained_models
WHERE id = %s
WHERE id = ?
""", (id,))
model = cursor.fetchone()
if not model:
return jsonify({'error': 'Model not found'}), 404
# 删除模型
if os.path.exists(model[0]):
os.remove(model[0])
if os.path.exists(model[1]):
os.remove(model[1])
# 删除模型
if os.path.exists(model['model_path']):
os.remove(model['model_path'])
if os.path.exists(model['scaler_path']):
os.remove(model['scaler_path'])
# 删除数据库记录
cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
cursor.execute("DELETE FROM trained_models WHERE id = ?", (id,))
conn.commit()
return jsonify({'success': True})
@ -1231,7 +1216,7 @@ def analyze_manufacturers():
return jsonify({'error': '请选择数据集'}), 400
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor = conn.cursor()
# 获取数据集中的装备和生产商数据
cursor.execute("""
@ -1239,13 +1224,22 @@ def analyze_manufacturers():
FROM manufacturers m
JOIN equipments e ON e.manufacturer_id = m.id
JOIN dataset_equipments de ON e.id = de.equipment_id
WHERE de.dataset_id = %s
WHERE de.dataset_id = ?
""", (dataset_id,))
manufacturers = cursor.fetchall()
if not manufacturers:
return jsonify({'error': '数据集中没有生产商数据'}), 404
logger.info("No manufacturer data in dataset, returning empty result")
return jsonify({
'manufacturer_names': [],
'manufacturer_tech_levels': [],
'manufacturer_scale_levels': [],
'manufacturer_supply_chain_levels': [],
'manufacturer_composite_scores': [],
'region_distribution': [],
'manufacturer_scores': []
})
# 准备分析数据
manufacturer_names = []

View File

@ -1,9 +1,5 @@
@echo off
set FLASK_DEBUG=false
set MYSQL_HOST=localhost
set MYSQL_USER=root
set MYSQL_PASSWORD=123456
set MYSQL_DB=equipment_cost_db
echo Starting Cost Prediction System...
start /B run.exe
start http://localhost:5001