From dba9f2fcc902293476ef608d0f270a29039c73c4 Mon Sep 17 00:00:00 2001 From: Tian jianyong <11429339@qq.com> Date: Mon, 25 Nov 2024 19:58:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=20tensor=20=E6=94=B9=E4=B8=BA=20torch?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=9B=B4=E6=96=B0=E4=BE=9D=E8=B5=96=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E7=94=9F=E4=BA=A7=E5=95=86=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=92=8C=E7=89=B9=E5=BE=81=E5=88=86=E6=9E=90?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 25 - .gitignore | 39 ++ .python-version | 1 + LICENSE | 21 + README.md | 54 ++ config.py | 129 +++-- frontend/src/views/AnalysisPage.vue | 269 +++++++--- pyproject.toml | 59 +++ requirements.txt | 13 +- run.py | 40 +- scripts/setup_env.ps1 | 121 +++++ scripts/setup_env.sh | 66 +++ src/__init__.py | 4 +- src/app.py | 46 +- src/cost_prediction.py | 271 ++-------- src/data_preparation.py | 109 ++-- src/feature_analysis.py | 84 +++- src/import_data.py | 8 +- src/init_data.sql | 319 ------------ src/loitering_munition_data.sql | 67 ++- src/manufacturer_data.sql | 12 +- src/model_trainer.py | 739 +++++----------------------- src/real_data.sql | 485 ------------------ src/rocket_artillery_data.sql | 24 +- src/routes.py | 415 +++++++++------- src/schema.sql | 50 +- 26 files changed, 1378 insertions(+), 2092 deletions(-) delete mode 100644 .env.example create mode 100644 .python-version create mode 100644 LICENSE create mode 100644 pyproject.toml create mode 100644 scripts/setup_env.ps1 create mode 100755 scripts/setup_env.sh delete mode 100644 src/init_data.sql delete mode 100644 src/real_data.sql diff --git a/.env.example b/.env.example deleted file mode 100644 index 49204df..0000000 --- a/.env.example +++ /dev/null @@ -1,25 +0,0 @@ -# 数据库配置 -MYSQL_HOST=localhost -MYSQL_USER=root -MYSQL_PASSWORD=your_password_here -MYSQL_DATABASE=equipment_cost_db - -# 服务配置 -PORT=5001 -DEBUG=False - -# 日志配置 -LOG_LEVEL=INFO -LOG_DIR=logs - -# 模型配置 -MODEL_DIR=models -DATA_DIR=data - -# 安全配置 -SECRET_KEY=your_secret_key_here -ALLOWED_HOSTS=localhost,127.0.0.1 - -# 其他配置 -UPLOAD_MAX_SIZE=10485760 # 10MB in bytes -ALLOWED_FILE_TYPES=.xlsx,.xls \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6c98c42..468c727 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,42 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +.env +.venv +env/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS .DS_Store +Thumbs.db + node_modules /dist /models @@ -9,6 +47,7 @@ node_modules # local env files .env.local .env.*.local +.venv # Log files npm-debug.log* diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..b6d8b76 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.8 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1def3fa --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +# MIT License + +Copyright (c) 2024 Your Name or Your Organization + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index b04ef07..1ecc6d2 100644 --- a/README.md +++ b/README.md @@ -21,3 +21,57 @@ collation-server=utf8mb4_unicode_ci [client] default-character-set=utf8mb4 ``` + +## 环境配置 + +本项目需要 Python 3.9-3.11 版本。推荐使用 Python 3.11.8。 + +### 使用脚本自动配置(推荐) + +Unix/macOS: + +```bash +chmod +x scripts/setup_env.sh +./scripts/setup_env.sh +``` + +Windows (PowerShell): + +```powershell +Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser +.\scripts\setup_env.ps1 +``` + +### 手动配置 + +1. 安装 pyenv +2. 安装 Python 3.11.8: + + ```bash + pyenv install 3.11.8 + ``` + +3. 设置本地 Python 版本: + + ```bash + pyenv local 3.11.8 + ``` + +4. 创建虚拟环境: + + ```bash + python -m venv .venv + ``` + +5. 激活虚拟环境: + + ```bash + source .venv/bin/activate # Unix + .venv\Scripts\activate # Windows + ``` + +6. 安装依赖: + + ```bash + pip install -e ".[dev]" + ``` diff --git a/config.py b/config.py index ec92cd1..fa09fcf 100644 --- a/config.py +++ b/config.py @@ -1,32 +1,103 @@ import os -import secrets -# 数据库配置 -DATABASE_URI = "mysql+pymysql://root:123456@localhost:3306/equipment_cost_db" +class Config: + """配置类""" + # 数据库配置 + MYSQL_HOST = 'localhost' + MYSQL_USER = 'root' + MYSQL_PASSWORD = '123456' + MYSQL_DB = 'equipment_cost_db' + + # Flask配置 + FLASK_HOST = '0.0.0.0' + FLASK_PORT = 5001 + FLASK_DEBUG = True + + # 目录配置 + MODEL_DIR = 'models' + DATA_DIR = 'data' + LOG_DIR = 'logs' + UPLOAD_DIR = 'uploads' + TEMPLATE_DIR = 'templates' + + # 文件上传配置 + ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'} + MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB + + # API配置 + API_VERSION = 'v1' + API_PREFIX = f'/api/{API_VERSION}' + + # 日志配置 + LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + LOG_LEVEL = 'INFO' + LOG_FILE = os.path.join(LOG_DIR, 'app.log') + LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB + LOG_BACKUP_COUNT = 5 + + # PyTorch配置 + DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU + BATCH_SIZE = 32 + LEARNING_RATE = 0.001 + NUM_EPOCHS = 100 + + # 模型训练配置 + TRAIN_TEST_SPLIT = 0.2 + RANDOM_SEED = 42 + EARLY_STOPPING_PATIENCE = 10 + MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints') + + # 缓存配置 + CACHE_TYPE = 'simple' + CACHE_DEFAULT_TIMEOUT = 300 + + # 安全配置 + SECRET_KEY = 'your-secret-key-here' + JWT_SECRET_KEY = 'your-jwt-secret-key-here' + JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时 + + # 跨域配置 + CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080'] + + # 数据验证配置 + MAX_EQUIPMENT_NAME_LENGTH = 100 + MAX_MANUFACTURER_NAME_LENGTH = 100 + + @classmethod + def init_app(cls, app): + """初始化应用配置""" + # 创建必要的目录 + for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR, + cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]: + os.makedirs(directory, exist_ok=True) + + # 配置日志 + import logging + from logging.handlers import RotatingFileHandler + + formatter = logging.Formatter(cls.LOG_FORMAT) + file_handler = RotatingFileHandler( + cls.LOG_FILE, + maxBytes=cls.LOG_MAX_SIZE, + backupCount=cls.LOG_BACKUP_COUNT + ) + file_handler.setFormatter(formatter) + file_handler.setLevel(cls.LOG_LEVEL) + + app.logger.addHandler(file_handler) + app.logger.setLevel(cls.LOG_LEVEL) + + # 配置上传目录 + app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR + app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH + + # 配置跨域 + from flask_cors import CORS + CORS(app, resources={ + r"/api/*": {"origins": cls.CORS_ORIGINS} + }) + + return app -# 安全密钥配置(自动生成随机密钥) -SECRET_KEY = secrets.token_hex(16) - -# 环境配置 -DEBUG = False -ENV = 'production' - -# 文件上传配置 -UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') -ALLOWED_EXTENSIONS = {'csv', 'xlsx', 'xls', 'json'} -MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB 最大上传限制 - -# API配置 -API_VERSION = 'v1' -API_PREFIX = f'/api/{API_VERSION}' - -# 跨域配置 -CORS_ORIGINS = [ - "http://localhost:8080", - "http://127.0.0.1:8080", -] - -# 日志配置 -LOG_LEVEL = 'DEBUG' -LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' -LOG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs/app.log') \ No newline at end of file +# 创建配置实例 +config = Config() \ No newline at end of file diff --git a/frontend/src/views/AnalysisPage.vue b/frontend/src/views/AnalysisPage.vue index 9a77831..ea7c201 100644 --- a/frontend/src/views/AnalysisPage.vue +++ b/frontend/src/views/AnalysisPage.vue @@ -103,7 +103,31 @@
+ + +

制导性能分析

+
+
+
+ + +

生产商分析

+
+
+
+ + +

生产商地区分布

+
+
+
+ + +

生产商综合评分

+
+
+
@@ -131,6 +155,10 @@ const newFeatureChartRef = ref(null) const engineChartRef = ref(null) const fireChartRef = ref(null) const mobilityChartRef = ref(null) +const manufacturerChartRef = ref(null) +const regionChartRef = ref(null) +const scoreChartRef = ref(null) +const guidanceChartRef = ref(null) // 图表实例引用 const importanceChart = ref(null) @@ -139,6 +167,10 @@ const newFeatureChart = ref(null) const engineChart = ref(null) const fireChart = ref(null) const mobilityChart = ref(null) +const manufacturerChart = ref(null) +const regionChart = ref(null) +const scoreChart = ref(null) +const guidanceChart = ref(null) // 监听分析结果变化 watch(() => analysisResult.value, async (newResult) => { @@ -236,69 +268,28 @@ const startAnalysis = async () => { analyzing.value = true try { - // 打印请求参数 - console.log('Analysis request params:', { - dataset_id: analysisForm.value.dataset_id, - equipment_type: analysisForm.value.equipment_type - }) - - const response = await axios.post(`${API_BASE_URL}/analyze-features`, { + // 调用特征分析接口 + const featureResponse = await axios.post(`${API_BASE_URL}/analyze-features`, { dataset_id: analysisForm.value.dataset_id }) - // 打印原始响应数据 - console.log('Raw API response:', response) - console.log('Response data type:', typeof response.data) - console.log('Response data:', response.data) - - // 检查响应数据的结构 - if (!response.data) { - throw new Error('API返回的数据为空') - } - - // 确保数据正确赋值 - analysisResult.value = response.data - - // 验证数赋值是否成功 - console.log('Analysis result after assignment:', { - value: analysisResult.value, - important_features: analysisResult.value?.important_features, - correlation_analysis: analysisResult.value?.correlation_analysis, - equipment_names: analysisResult.value?.equipment_names, - length_width_ratio: analysisResult.value?.length_width_ratio + // 调用生产商分析接口 + const manufacturerResponse = await axios.post(`${API_BASE_URL}/analyze-manufacturers`, { + dataset_id: analysisForm.value.dataset_id }) - - // 如果是巡飞弹类型,检查特定数据 - if (analysisForm.value.equipment_type === '巡飞弹') { - const missileData = { - equipment_names: analysisResult.value?.equipment_names || [], - length_width_ratio: analysisResult.value?.length_width_ratio || [], - engine_power_kw: analysisResult.value?.engine_power_kw || [], - guidance_system_score: analysisResult.value?.guidance_system_score || [], - warhead_power_score: analysisResult.value?.warhead_power_score || [] - } - - console.log('Missile specific data:', missileData) - - // 验证数据完整性 - const missingFields = Object.entries(missileData) - .filter(([key, value]) => !Array.isArray(value) || value.length === 0) - .map(([key]) => key) - - if (missingFields.length > 0) { - console.warn('Missing or empty missile data fields:', missingFields) - ElMessage.warning(`数据不完整,缺少字段: ${missingFields.join(', ')}`) - } + + // 合并两个接口的结果 + analysisResult.value = { + ...featureResponse.data, + ...manufacturerResponse.data } - + + // 验证数据 + console.log('Combined analysis result:', analysisResult.value) + } catch (error) { console.error('Analysis error:', error) - console.error('Error details:', { - message: error.message, - response: error.response?.data, - status: error.response?.status - }) - ElMessage.error(error.message || '特征析失败') + ElMessage.error(error.message || '分析失败') } finally { analyzing.value = false } @@ -329,6 +320,18 @@ const createResizeHandler = () => { if (mobilityChart.value && !mobilityChart.value.isDisposed()) { mobilityChart.value.resize() } + if (manufacturerChart.value && !manufacturerChart.value.isDisposed()) { + manufacturerChart.value.resize() + } + if (regionChart.value && !regionChart.value.isDisposed()) { + regionChart.value.resize() + } + if (scoreChart.value && !scoreChart.value.isDisposed()) { + scoreChart.value.resize() + } + if (guidanceChart.value && !guidanceChart.value.isDisposed()) { + guidanceChart.value.resize() + } } catch (error) { console.error('Error in resize handler:', error) } @@ -360,7 +363,7 @@ onUnmounted(() => { // 销毁所有图表实例 [importanceChart, correlationChart, newFeatureChart, engineChart, - fireChart, mobilityChart].forEach(chart => { + fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => { if (chart.value && !chart.value.isDisposed()) { try { chart.value.dispose() @@ -384,7 +387,7 @@ const renderCharts = () => { try { // 先销毁所有现有的图表实例 [importanceChart, correlationChart, newFeatureChart, engineChart, - fireChart, mobilityChart].forEach(chart => { + fireChart, mobilityChart, manufacturerChart, regionChart, scoreChart, guidanceChart].forEach(chart => { if (chart.value && !chart.value.isDisposed()) { chart.value.dispose() chart.value = null @@ -899,6 +902,156 @@ const renderCharts = () => { mobilityChart.value.setOption(mobilityOption, { notMerge: true }) } + // 渲染生产商分析图表 + if (manufacturerChartRef.value) { + manufacturerChart.value = echarts.init(manufacturerChartRef.value) + const manufacturerOption = { + title: { text: '生产商特征影响分析' }, + tooltip: { + trigger: 'axis', + axisPointer: { type: 'shadow' } + }, + legend: { + data: ['技术水平', '规模水平', '供应链水平', '综合得分'] + }, + xAxis: { + type: 'category', + data: analysisResult.value.manufacturer_names || [] + }, + yAxis: { + type: 'value', + name: '评分', + min: 0, + max: 10 + }, + series: [ + { + name: '技术水平', + type: 'bar', + data: analysisResult.value.manufacturer_tech_levels || [] + }, + { + name: '规模水平', + type: 'bar', + data: analysisResult.value.manufacturer_scale_levels || [] + }, + { + name: '供应链水平', + type: 'bar', + data: analysisResult.value.manufacturer_supply_chain_levels || [] + }, + { + name: '综合得分', + type: 'line', + data: analysisResult.value.manufacturer_composite_scores || [] + } + ] + } + manufacturerChart.value.setOption(manufacturerOption) + } + + // 渲染地区分布图表 + if (regionChartRef.value) { + regionChart.value = echarts.init(regionChartRef.value) + const regionOption = { + title: { text: '生产商地区分布' }, + tooltip: { + trigger: 'item', + formatter: '{b}: {c} ({d}%)' + }, + series: [ + { + type: 'pie', + radius: '65%', + data: analysisResult.value.region_distribution || [], + emphasis: { + itemStyle: { + shadowBlur: 10, + shadowOffsetX: 0, + shadowColor: 'rgba(0, 0, 0, 0.5)' + } + } + } + ] + } + regionChart.value.setOption(regionOption) + } + + // 渲染综合评分图表 + if (scoreChartRef.value) { + scoreChart.value = echarts.init(scoreChartRef.value) + const scoreOption = { + title: { text: '生产商综合评分雷达图' }, + tooltip: {}, + radar: { + indicator: [ + { name: '技术水平', max: 10 }, + { name: '规模水平', max: 10 }, + { name: '供应链水平', max: 10 }, + { name: '区域系数', max: 1.5 }, + { name: '综合得分', max: 10 } + ] + }, + series: [ + { + type: 'radar', + data: analysisResult.value.manufacturer_scores || [] + } + ] + } + scoreChart.value.setOption(scoreOption) + } + + // 渲染制导性能分析图表 + if (guidanceChartRef.value && analysisForm.value.equipment_type === '巡飞弹') { + guidanceChart.value = echarts.init(guidanceChartRef.value) + const guidanceOption = { + title: { text: '制导性能分析' }, + tooltip: { + trigger: 'axis', + axisPointer: { type: 'cross' } + }, + legend: { + data: ['制导精度(m)', '数据链距离(km)', '制导系统评分'] + }, + xAxis: { + type: 'category', + data: analysisResult.value.equipment_names || [] + }, + yAxis: [ + { + type: 'value', + name: '制导精度(m)', + position: 'left' + }, + { + type: 'value', + name: '距离(km)', + position: 'right' + } + ], + series: [ + { + name: '制导精度(m)', + type: 'bar', + data: analysisResult.value.guidance_accuracy_m || [] + }, + { + name: '数据链距离(km)', + type: 'line', + yAxisIndex: 1, + data: analysisResult.value.datalink_range_km || [] + }, + { + name: '制导系统评分', + type: 'line', + data: analysisResult.value.guidance_system_score || [] + } + ] + } + guidanceChart.value.setOption(guidanceOption) + } + console.log('Charts rendered successfully') } catch (error) { console.error('Error in chart rendering:', error) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7e6e325 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[project] +name = "cost-prediction" +version = "0.1.0" +description = "装备成本预测系统" +requires-python = ">=3.9,<3.12" +readme = "README.md" +license = {file = "LICENSE"} + +dependencies = [ + # Web框架 + "flask>=3.1.0", + "flask-cors>=5.0.0", + + # 数据库 + "sqlalchemy>=2.0.36", + "pymysql>=1.1.1", + "cryptography>=43.0.0", + "mysql-connector-python>=8.0.0", + + # 数据处理 + "numpy>=1.26.0,<2.0.0", + "pandas>=2.2.0", + + # 机器学习 + "scikit-learn>=1.5.2", + "torch==2.5.1", + "torchvision==0.20.1", + "torchaudio==2.5.1", + + # 工具 + "openpyxl>=3.1.5", # Excel支持 + "python-dotenv>=1.0.0", # 环境变量 +] + +[project.optional-dependencies] +dev = [ + # 测试工具 + "pytest>=7.0", + "black>=22.0", # 代码格式化 + "mypy>=1.0", # 类型检查 +] + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] + +[tool.black] +line-length = 88 +target-version = ["py39", "py310", "py311"] + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a3b4ab7..f168f07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,12 +3,11 @@ flask-cors>=5.0.0 sqlalchemy>=2.0.36 pymysql>=1.1.1 cryptography>=43.0.0 # MySQL 8.0+ 认证需要 -numpy>=2.0.2 -pandas>=2.2.3 - -urllib3>=2.2.3 -openpyxl>=3.1.5 # 用于读取 .xlsx 文件 -xlrd>=2.0.1 # 用于读取 .xls 文件 +mysql-connector-python>=8.0.0 # 添加这行 +numpy>=1.26.0,<2.0.0 +pandas>=2.2.0 scikit-learn>=1.5.2 -tensorflow>=2.18.0 \ No newline at end of file + +openpyxl>=3.1.5 # 用于读取 .xlsx 文件 +python-dotenv>=1.0.0 # 环境变量 \ No newline at end of file diff --git a/run.py b/run.py index 5bca7f6..4a56ece 100644 --- a/run.py +++ b/run.py @@ -1,13 +1,33 @@ -from src.app import create_app -import logging +from src import create_app +from src.logger import setup_logger +from config import config +import os -# 创建应用实例 -app = create_app() +logger = setup_logger(__name__) + +def main(): + try: + # 创建必要的目录 + os.makedirs(config.MODEL_DIR, exist_ok=True) + os.makedirs(config.LOG_DIR, exist_ok=True) + os.makedirs(config.DATA_DIR, exist_ok=True) + + # 创建并运行应用 + app = create_app() + + logger.info(f"Starting server in {'debug' if config.FLASK_DEBUG else 'production'} mode") + logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}") + + app.run( + host=config.FLASK_HOST, + port=config.FLASK_PORT, + debug=config.FLASK_DEBUG + ) + + except Exception as e: + logger.error(f"Error starting application: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + raise if __name__ == '__main__': - # 设置日志 - logging.basicConfig(level=logging.INFO) - logging.info('=== Server Starting ===') - logging.info('Initializing directories...') - - app.run(host='0.0.0.0', port=5001, debug=True) \ No newline at end of file + main() \ No newline at end of file diff --git a/scripts/setup_env.ps1 b/scripts/setup_env.ps1 new file mode 100644 index 0000000..2807d0f --- /dev/null +++ b/scripts/setup_env.ps1 @@ -0,0 +1,121 @@ +# 设置错误操作首选项 +$ErrorActionPreference = "Stop" + +# 检查管理员权限 +$isAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator) +if (-not $isAdmin) { + Write-Warning "建议使用管理员权限运行此脚本" + Start-Sleep -Seconds 3 +} + +# 检查 pyenv-win 是否安装 +if (!(Get-Command pyenv -ErrorAction SilentlyContinue)) { + Write-Host "pyenv not found. Installing..." + try { + # 下载并安装 pyenv-win + Invoke-WebRequest -UseBasicParsing -Uri "https://raw.githubusercontent.com/pyenv-win/pyenv-win/master/pyenv-win/install-pyenv-win.ps1" -OutFile "./install-pyenv-win.ps1" + & ./install-pyenv-win.ps1 + + # 添加环境变量 + $env:PYENV = "$env:USERPROFILE\.pyenv\pyenv-win" + $env:Path = "$env:PYENV\bin;$env:PYENV\shims;$env:Path" + + # 刷新环境变量 + $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + } + catch { + Write-Error "Failed to install pyenv: $_" + exit 1 + } +} + +try { + # 安装指定版本的 Python + Write-Host "Installing Python 3.11.8..." + pyenv install 3.11.8 + if ($LASTEXITCODE -ne 0) { + throw "Failed to install Python 3.11.8" + } + + # 设置本地 Python 版本 + Write-Host "Setting local Python version..." + pyenv local 3.11.8 + if ($LASTEXITCODE -ne 0) { + throw "Failed to set local Python version" + } + + # 验证 Python 版本 + $pythonVersion = python -V + if (-not $pythonVersion.Contains("3.11.8")) { + throw "Wrong Python version: $pythonVersion" + } + Write-Host "Using Python version: $pythonVersion" + + # 创建虚拟环境 + Write-Host "Creating virtual environment..." + python -m venv .venv + + # 激活虚拟环境 + Write-Host "Activating virtual environment..." + .\.venv\Scripts\Activate.ps1 + + # 升级 pip 和构建工具 + Write-Host "Upgrading pip and build tools..." + python -m pip install --upgrade pip setuptools wheel + + # 分步安装依赖以确保正确的顺序和版本 + Write-Host "Installing database dependencies..." + pip install mysql-connector-python==8.0.33 + + Write-Host "Installing PyTorch and related packages..." + pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu + + Write-Host "Installing basic dependencies..." + pip install numpy==1.26.4 pandas==2.2.1 + + Write-Host "Installing machine learning packages..." + pip install scikit-learn==1.5.2 + + # 安装开发依赖 + Write-Host "Installing development dependencies..." + pip install -e ".[dev]" + if ($LASTEXITCODE -ne 0) { + Write-Warning "Failed to install development dependencies. Installing core package..." + pip install -e . + } + + # 验证安装 + Write-Host "Verifying installations..." + python -c "import torch; print(f'PyTorch version: {torch.__version__}')" + python -c "import numpy; print(f'NumPy version: {numpy.__version__}')" + python -c "import pandas; print(f'Pandas version: {pandas.__version__}')" + python -c "import sklearn; print(f'Scikit-learn version: {sklearn.__version__}')" + + Write-Host "Environment setup complete!" -ForegroundColor Green +} +catch { + Write-Error "An error occurred: $_" + exit 1 +} +finally { + # 清理临时文件 + if (Test-Path "./install-pyenv-win.ps1") { + Remove-Item "./install-pyenv-win.ps1" + } +} + +# 显示使用说明 +Write-Host @" + +环境设置完成!使用说明: +1. 虚拟环境已激活,命令提示符前应该显示 (.venv) +2. 要退出虚拟环境,运行: deactivate +3. 要重新激活虚拟环境,运行: .\.venv\Scripts\Activate.ps1 +4. 项目依赖已安装,可以开始开发了 + +如果遇到问题,请检查: +- Python 版本: python -V +- PyTorch 安装: python -c "import torch; print(torch.__version__)" +- 虚拟环境状态: 确保看到 (.venv) 前缀 + +"@ -ForegroundColor Cyan \ No newline at end of file diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh new file mode 100755 index 0000000..159a2db --- /dev/null +++ b/scripts/setup_env.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# 检查 pyenv 是否安装 +if ! command -v pyenv &> /dev/null; then + echo "pyenv not found. Installing..." + if [[ "$OSTYPE" == "darwin"* ]]; then + brew install pyenv + else + curl https://pyenv.run | bash + fi +fi + +# 安装指定版本的 Python +pyenv install 3.11.8 || true + +# 设置本地 Python 版本 +pyenv local 3.11.8 + +# 确保使用正确的 Python 版本 +eval "$(pyenv init -)" +pyenv shell 3.11.8 + +# 验证 Python 版本 +python_version=$(python -V 2>&1) +if [[ $python_version != *"3.11.8"* ]]; then + echo "Error: Wrong Python version: $python_version" + echo "Please ensure pyenv is properly configured in your shell" + exit 1 +fi + +# 创建虚拟环境 +python -m venv .venv + +# 激活虚拟环境 +source .venv/bin/activate + +# 升级 pip 和构建工具 +python -m pip install --upgrade pip setuptools wheel + +# 分步安装依赖以确保正确的顺序和版本 +echo "Installing database dependencies..." +pip install mysql-connector-python==8.0.33 + +echo "Installing PyTorch and related packages..." +pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu + +echo "Installing basic dependencies..." +pip install numpy==1.26.4 pandas==2.2.1 + +echo "Installing machine learning packages..." +pip install scikit-learn==1.5.2 + +# 安装开发依赖 +if ! pip install -e ".[dev]"; then + echo "Warning: Failed to install development dependencies. Installing core package..." + pip install -e . +fi + +# 验证安装 +echo "Verifying Python version..." +python --version + +echo "Verifying PyTorch installation..." +python -c "import torch; print(f'PyTorch version: {torch.__version__}')" + +echo "Environment setup complete!" \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 497b4a4..4ea24f5 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1 +1,3 @@ -# 这个文件可以为空,但必须存在 +from .app import create_app + +__all__ = ['create_app'] diff --git a/src/app.py b/src/app.py index 037fb79..e9fcbea 100644 --- a/src/app.py +++ b/src/app.py @@ -2,49 +2,35 @@ from flask import Flask from flask_cors import CORS from .routes import api_bp from .logger import setup_logger +from config import config import os -# 获取logger logger = setup_logger(__name__) def create_app(): - """ - 创建并配置Flask应用 - """ + """创建并配置 Flask 应用""" try: - # 创建必要的目录 - os.makedirs('logs', exist_ok=True) - os.makedirs('data', exist_ok=True) - os.makedirs('models', exist_ok=True) - - logger.info("=== Server Starting ===") - logger.info("Initializing directories...") - - # 创建Flask应用 app = Flask(__name__) - - # 配置CORS CORS(app) - logger.info("CORS enabled") - # 注册API蓝图 + # 配置数据库连接 + app.config['MYSQL_HOST'] = config.MYSQL_HOST + app.config['MYSQL_USER'] = config.MYSQL_USER + app.config['MYSQL_PASSWORD'] = config.MYSQL_PASSWORD + app.config['MYSQL_DB'] = config.MYSQL_DB + + # 注册路由 app.register_blueprint(api_bp, url_prefix='/api') logger.info("API blueprint registered") - # 配置数据库连接 - app.config['MYSQL_HOST'] = 'localhost' - app.config['MYSQL_USER'] = 'root' - app.config['MYSQL_PASSWORD'] = '123456' - app.config['MYSQL_DB'] = 'equipment_cost_db' - - logger.info("Starting server...") + # 记录配置信息 + logger.info(f"Database: {app.config['MYSQL_DB']} on {app.config['MYSQL_HOST']}") + logger.info(f"Server will run on {config.FLASK_HOST}:{config.FLASK_PORT}") + logger.info(f"Debug mode: {config.FLASK_DEBUG}") return app except Exception as e: - logger.error(f"Error creating app: {str(e)}") - raise - -if __name__ == '__main__': - app = create_app() - app.run(host='localhost', port=5001) \ No newline at end of file + logger.error(f"Error creating application: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + raise \ No newline at end of file diff --git a/src/cost_prediction.py b/src/cost_prediction.py index 2c68e26..c4f00fc 100644 --- a/src/cost_prediction.py +++ b/src/cost_prediction.py @@ -1,15 +1,12 @@ import numpy as np +import torch from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from sklearn.preprocessing import StandardScaler -import tensorflow as tf from scipy import stats -import joblib -import os -import pandas as pd -from .feature_analysis import FeatureAnalysis import logging from src.model_trainer import ModelTrainer from src.database import get_db_connection +from src.feature_analysis import FeatureAnalysis from .logger import setup_logger logger = setup_logger(__name__) @@ -21,37 +18,18 @@ class CostPredictor: self.model = None self.feature_analyzer = FeatureAnalysis() self.equipment_type = None - - # 添加 TensorFlow 配置 - tf.config.run_functions_eagerly(False) # 启用图执行模式 - - # 创建预测函数 - @tf.function(reduce_retracing=True, jit_compile=True) - def predict_fn(x): - return self.model(x, training=False) - - self._predict_fn = predict_fn + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.load_model() def load_model(self): """ - 加载预训练型和标准化器 + 加载预训练模型和标准化器 """ try: - model_dir = 'models' - os.makedirs(model_dir, exist_ok=True) - # 创建默认模型 self._create_default_model() - # 创建预测函数 - @tf.function(reduce_retracing=True, jit_compile=True) - def predict_fn(x): - return self.model(x, training=False) - - self._predict_fn = predict_fn - except Exception as e: logging.error(f"Error loading model: {str(e)}") self._create_default_model() @@ -60,28 +38,24 @@ class CostPredictor: """ 创建默认模型并进行初始化训练 """ - # 创建输入层 - inputs = tf.keras.Input(shape=(11,)) + import torch.nn as nn - # 创建隐藏层 - x = tf.keras.layers.Dense(64, activation='relu')(inputs) - x = tf.keras.layers.Dense(32, activation='relu')(x) - - # 创建输出层 - outputs = tf.keras.layers.Dense(1)(x) - - # 创建模型 - self.model = tf.keras.Model(inputs=inputs, outputs=outputs) - - # 编译模型 - self.model.compile( - optimizer='adam', - loss=tf.keras.losses.mean_squared_error, - metrics=[tf.keras.metrics.mean_absolute_error] - ) + class DefaultModel(nn.Module): + def __init__(self, input_size): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_size, 64), + nn.ReLU(), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 1) + ) + + def forward(self, x): + return self.layers(x) # 创建示例数据 - example_data = pd.DataFrame({ + example_features = { 'length_m': [7.35, 10.2], 'width_m': [2.4, 2.8], 'height_m': [3.1, 3.2], @@ -93,59 +67,23 @@ class CostPredictor: 'rocket_diameter_mm': [122, 220], 'rocket_weight_kg': [66.6, 150], 'rate_of_fire': [40, 60] - }) + } + + # 转换为 tensor + X = torch.tensor(list(example_features.values()), dtype=torch.float32).t() + y = torch.tensor([[800000], [4500000]], dtype=torch.float32) # 训练标准化器 - self.scaler_X.fit(example_data) - self.scaler_y.fit(np.array([[800000], [4500000]])) # 使用正数成本范围 + self.scaler_X.fit(X.numpy()) + self.scaler_y.fit(y.numpy()) - # 设置默认装备类型 - self.equipment_type = '火箭炮' - - def _create_example_data(self): - """ - 创建示例数据来训练标准化器 - """ - # 火箭炮示例数据 - rocket_data = pd.DataFrame({ - 'length_m': [7.35, 10.2], - 'width_m': [2.4, 2.8], - 'height_m': [3.1, 3.2], - 'weight_kg': [13700, 28500], - 'max_range_km': [20.4, 70], - 'firing_angle_horizontal': [102, 110], - 'firing_angle_vertical': [55, 60], - 'rocket_length_m': [2.87, 4.1], - 'rocket_diameter_mm': [122, 220], - 'rocket_weight_kg': [66.6, 150], - 'rate_of_fire': [40, 60] - }) - - # 巡飞弹示例数据 - missile_data = pd.DataFrame({ - 'length_m': [1.3, 2.5], - 'width_m': [0.23, 0.6], - 'height_m': [0.23, 0.6], - 'weight_kg': [12.5, 135], - 'max_range_km': [40, 250], - 'max_speed_kmh': [180, 185], - 'cruise_speed_kmh': [100, 110], - 'flight_time_min': [60, 120], - 'folded_length_mm': [1300, 2500], - 'folded_width_mm': [230, 600], - 'folded_height_mm': [230, 600] - }) - - # 训练标准化器 - self.scaler_X.fit(rocket_data) # 使用火箭炮数据 - self.scaler_y.fit(np.array([[800000], [4500000]])) # 示例成本数据 - - # 设置默认装备类型 + # 创建模型 + self.model = DefaultModel(X.shape[1]).to(self.device) self.equipment_type = '火箭炮' def predict(self, data): """ - 使用训练好的最优模型进行预测 + 使用训练好的模型进行预测 """ try: logger.info(f"Starting prediction for {data.get('type')}") @@ -158,20 +96,31 @@ class CostPredictor: # 准备特征数据 features = self.feature_analyzer.get_equipment_specific_features(equipment_type) - X = np.array([[data.get(feature) for feature in features]]) + X = [] + for feature in features: + value = data.get(feature, 0.0) + X.append(float(value)) + + # 转换为 tensor + X = torch.tensor([X], dtype=torch.float32).to(self.device) # 预测 - y_pred = trainer.predict(X) + with torch.no_grad(): + trainer.model.eval() # 设置为评估模式 + y_pred = trainer.model(X) + + # 转回 numpy + y_pred = y_pred.cpu().numpy() # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) + confidence_interval = self._calculate_confidence_interval(y_pred[0]) # 获取模型类型 model_type = trainer.get_model_type() return { 'predicted_cost': float(y_pred[0]), - 'model_type': model_type, # 返回使用的模型类型 + 'model_type': model_type, 'confidence_interval': { 'lower': float(confidence_interval[0]), 'upper': float(confidence_interval[1]) @@ -187,11 +136,10 @@ class CostPredictor: 计算预测值的置信区间 """ try: - # 使用预测值的20%作为标准差(增加不确定性) + # 使用预测值的20%作为标准差 std = abs(prediction) * 0.2 # 计算置信区间 - from scipy import stats interval = stats.norm.interval(confidence, loc=prediction, scale=std) # 确保区间值为正数且合理 @@ -213,130 +161,15 @@ class CostPredictor: """ 模型评估 """ + # 确保输入是 numpy 数组 + if torch.is_tensor(y_true): + y_true = y_true.cpu().numpy() + if torch.is_tensor(y_pred): + y_pred = y_pred.cpu().numpy() + return { 'mae': float(mean_absolute_error(y_true, y_pred)), 'mse': float(mean_squared_error(y_true, y_pred)), 'rmse': float(np.sqrt(mean_squared_error(y_true, y_pred))), 'r2': float(r2_score(y_true, y_pred)) - } - - def predict_pls(self, data): - """ - 使用 PLS 型预测成本 - """ - try: - logger.info(f"Starting PLS prediction for {data.get('type')}") - equipment_type = data.get('type') - - # 加载 PLS 模型 - trainer = ModelTrainer() - if not trainer.load_model(equipment_type, model_type='pls'): # 指定加载 PLS 模型 - raise ValueError(f"No trained PLS model found for {equipment_type}") - - # 准备特征数据 - features = self.feature_analyzer.get_equipment_specific_features(equipment_type) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - y_pred = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) - - return { - 'predicted_cost': float(y_pred[0]), - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } - - except Exception as e: - logger.error(f"PLS prediction error: {str(e)}") - raise - - def predict_all(self, data): - """ - 使用所有可用模型进行预测 - """ - try: - logger.info(f"Starting multi-model prediction for {data.get('type')}") - equipment_type = data.get('type') - results = {} - - # 1. 获取所有激活的模型 - with get_db_connection() as conn: - cursor = conn.cursor(dictionary=True) - cursor.execute(""" - SELECT id, model_type, model_name, r2_score, mae, rmse - FROM trained_models - WHERE equipment_type = %s AND is_active = TRUE - """, (equipment_type,)) - active_models = cursor.fetchall() - - if not active_models: - raise ValueError(f"No active models found for {equipment_type}") - - # 2. 使用每个模型进行预测 - trainer = ModelTrainer() - for model_info in active_models: - try: - # 加载特定模型 - if not trainer.load_model(equipment_type, model_type=model_info['model_type']): - logger.warning(f"Failed to load model: {model_info['model_name']}") - continue - - # 准备特征数据 - features = self.feature_analyzer.get_equipment_specific_features(equipment_type) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - y_pred = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(y_pred[0]) - - # 保存结果 - results[model_info['model_type']] = { - 'predicted_cost': float(y_pred[0]), - 'model_info': { - 'name': model_info['model_name'], - 'type': model_info['model_type'], - 'r2_score': float(model_info['r2_score']), - 'mae': float(model_info['mae']), - 'rmse': float(model_info['rmse']) - }, - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } - - except Exception as e: - logger.error(f"Error predicting with model {model_info['model_name']}: {str(e)}") - continue - - if not results: - raise ValueError("No successful predictions from any model") - - # 3. 计算综合预测结果 - all_predictions = [result['predicted_cost'] for result in results.values()] - ensemble_prediction = float(np.mean(all_predictions)) - prediction_std = float(np.std(all_predictions)) - - # 4. 返回所有结果 - return { - 'individual_predictions': results, - 'ensemble_prediction': { - 'predicted_cost': ensemble_prediction, - 'standard_deviation': prediction_std, - 'confidence_interval': { - 'lower': float(ensemble_prediction - 1.96 * prediction_std), - 'upper': float(ensemble_prediction + 1.96 * prediction_std) - } - } - } - - except Exception as e: - logger.error(f"Error in multi-model prediction: {str(e)}") - raise \ No newline at end of file + } \ No newline at end of file diff --git a/src/data_preparation.py b/src/data_preparation.py index 6380647..8b83330 100644 --- a/src/data_preparation.py +++ b/src/data_preparation.py @@ -1,29 +1,35 @@ from sklearn.preprocessing import StandardScaler -from datetime import datetime -import os -import joblib -import pandas as pd import numpy as np -from src.feature_analysis import FeatureAnalysis -from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor -from xgboost import XGBRegressor -from lightgbm import LGBMRegressor -from sklearn.model_selection import cross_val_score, LeaveOneOut -import json +import torch +from torch.utils.data import Dataset, DataLoader import logging -from src.database.db_connection import get_db_connection -from sklearn.metrics import mean_absolute_error, mean_squared_error +from src.feature_analysis import FeatureAnalysis +from src.database import get_db_connection from .logger import setup_logger logger = setup_logger(__name__) +class EquipmentDataset(Dataset): + """装备数据集类""" + def __init__(self, features, targets=None): + self.features = torch.FloatTensor(features) + self.targets = torch.FloatTensor(targets) if targets is not None else None + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + if self.targets is not None: + return self.features[idx], self.targets[idx] + return self.features[idx] + class DataPreparation: def __init__(self): self.feature_analyzer = FeatureAnalysis() self.feature_scaler = StandardScaler() - self.target_scaler = StandardScaler() # 添加目标值标准化器 + self.target_scaler = StandardScaler() - def prepare_training_data(self, equipment_data, equipment_type): + def prepare_training_data(self, equipment_data, equipment_type, batch_size=32): """ 准备训练数据 """ @@ -31,19 +37,24 @@ class DataPreparation: logger.info(f"Preparing training data for {equipment_type}") logger.info(f"Raw data size: {len(equipment_data)}") - # 如果输入已经是 numpy 数组,直接返回 + # 如果输入已经是 numpy 数组,转换为 torch.Tensor if isinstance(equipment_data, np.ndarray): X = equipment_data - logger.info(f"Input is already numpy array with shape: {X.shape}") + logger.info(f"Input is numpy array with shape: {X.shape}") # 处理无效值 X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) + # 转换为 PyTorch 数据集 + dataset = EquipmentDataset(X) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + return { - 'X': X, + 'dataloader': dataloader, 'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type), 'feature_scaler': self.feature_scaler, - 'target_scaler': self.target_scaler + 'target_scaler': self.target_scaler, + 'raw_shape': X.shape } # 从原始数据中提取特征和目标值 @@ -51,27 +62,28 @@ class DataPreparation: features = [] targets = [] - for item in equipment_data: - # 提取特征值 - feature_values = [] - for name in feature_names: - value = item.get(name) - try: - feature_values.append(float(value) if value is not None else 0.0) - except (ValueError, TypeError): - feature_values.append(0.0) - features.append(feature_values) + # 获取数据库连接 + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) - # 提取目标值(成本) - try: - cost = float(item['actual_cost']) - if cost > 0: # 只使用正数成本值 - targets.append(cost) - else: - logger.warning(f"Skipping non-positive cost value: {cost}") - except (ValueError, TypeError, KeyError): - logger.error(f"Invalid cost value: {item.get('actual_cost')}") - continue + for item in equipment_data: + # 获取该装备的生产商数据 + manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor) + + # 计算生产商特征 + manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data) + + # 合并装备特征和生产商特征 + feature_values = [] + for name in feature_names: + if name in manufacturer_features: + value = manufacturer_features[name] + else: + value = item.get(name) + feature_values.append(float(value) if value is not None else 0.0) + + features.append(feature_values) + targets.append(float(item['actual_cost'])) # 转换为numpy数组 X = np.array(features, dtype=float) @@ -85,25 +97,16 @@ class DataPreparation: X_scaled = self.feature_scaler.fit_transform(X) y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel() - # 记录标准化后的数据范围 - logger.info(f"Scaled X range: min={X_scaled.min()}, max={X_scaled.max()}") - logger.info(f"Scaled y range: min={y_scaled.min()}, max={y_scaled.max()}") - - # 记录标准化器参数 - logger.info("Feature scaler params:") - logger.info(f"Mean: {self.feature_scaler.mean_}") - logger.info(f"Scale: {self.feature_scaler.scale_}") - - logger.info("Target scaler params:") - logger.info(f"Mean: {self.target_scaler.mean_}") - logger.info(f"Scale: {self.target_scaler.scale_}") + # 创建 PyTorch 数据集和数据加载器 + dataset = EquipmentDataset(X_scaled, y_scaled) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) return { - 'X': X_scaled, - 'y': y_scaled, + 'dataloader': dataloader, 'feature_names': feature_names, 'feature_scaler': self.feature_scaler, - 'target_scaler': self.target_scaler + 'target_scaler': self.target_scaler, + 'raw_shape': X.shape } except Exception as e: diff --git a/src/feature_analysis.py b/src/feature_analysis.py index 1b7f8e9..4a45ea7 100644 --- a/src/feature_analysis.py +++ b/src/feature_analysis.py @@ -15,9 +15,9 @@ class FeatureAnalysis: 'width_m': '宽度(m)', 'height_m': '高度(m)', 'weight_kg': '重量(kg)', - 'max_range_km': '最大射程(km)', # 火箭炮特有参数 + 'max_range_km': '最大射程(km)', 'firing_angle_horizontal': '方向射界(度)', 'firing_angle_vertical': '高低射界(度)', 'rocket_length_m': '火箭弹长度(m)', @@ -39,6 +39,7 @@ class FeatureAnalysis: 'terrain_adaptability_score': '地形适应性评分', # 巡飞弹特有参数 + 'max_range_km': '最大射程(km)', 'wingspan_m': '翼展(m)', 'warhead_weight_kg': '战斗部重量(kg)', 'max_speed_ms': '最大速度(m/s)', @@ -57,7 +58,14 @@ class FeatureAnalysis: 'weight_range_ratio': '重量射程比', 'speed_weight_ratio': '速度重量比', 'guidance_system_score': '制导系统评分', - 'warhead_power_score': '战斗部威力评分' + 'warhead_power_score': '战斗部威力评分', + + # 添加生产商特征映射 + 'manufacturer_tech_level': '生产商技术水平', + 'manufacturer_scale_level': '生产商规模水平', + 'manufacturer_supply_chain_level': '生产商供应链水平', + 'manufacturer_composite_score': '生产商综合得分', + 'manufacturer_region_factor': '生产商区域系数' } def get_equipment_specific_features(self, equipment_type): @@ -121,6 +129,17 @@ class FeatureAnalysis: 'guidance_system_score', 'warhead_power_score' ]) + + # 添加生产商特征 + manufacturer_features = [ + 'manufacturer_tech_level', + 'manufacturer_scale_level', + 'manufacturer_supply_chain_level', + 'manufacturer_composite_score', + 'manufacturer_region_factor' + ] + + numeric_features.extend(manufacturer_features) return numeric_features def analyze_features(self, features, target, feature_names): @@ -234,4 +253,63 @@ class FeatureAnalysis: except Exception as e: logger.error(f"Error in analyze_features: {str(e)}") logger.error("Detailed traceback:", exc_info=True) - raise \ No newline at end of file + raise + + def calculate_manufacturer_features(self, manufacturer_data): + """计算生产商相关的特征""" + try: + # 确保所有必要的字段都存在,使用默认值处理缺失数据 + tech_level = float(manufacturer_data.get('tech_level', 0)) + scale_level = float(manufacturer_data.get('scale_level', 0)) + supply_chain_level = float(manufacturer_data.get('supply_chain_level', 0)) + country = manufacturer_data.get('country', '未知') + + # 计算综合得分 + composite_score = ( + tech_level * 0.4 + # 技术水平权重最高 + scale_level * 0.3 + # 规模水平次之 + supply_chain_level * 0.3 # 供应链水平 + ) + + # 计算区域系数(基于不同地区的成本差异) + region_factors = { + '美国': 1.2, + '英国': 1.15, + '德国': 1.15, + '法国': 1.15, + '以色列': 1.1, + '中国': 0.8, + '俄罗斯': 0.85, + '韩国': 0.9, + '日本': 1.1 + } + + region_factor = region_factors.get(country, 1.0) + + # 记录计算过程 + logger.info(f"Manufacturer features calculation:") + logger.info(f"Tech level: {tech_level}") + logger.info(f"Scale level: {scale_level}") + logger.info(f"Supply chain level: {supply_chain_level}") + logger.info(f"Country: {country}") + logger.info(f"Composite score: {composite_score}") + logger.info(f"Region factor: {region_factor}") + + return { + 'manufacturer_tech_level': tech_level, + 'manufacturer_scale_level': scale_level, + 'manufacturer_supply_chain_level': supply_chain_level, + 'manufacturer_composite_score': composite_score, + 'manufacturer_region_factor': region_factor + } + + except Exception as e: + logger.error(f"Error calculating manufacturer features: {str(e)}") + # 返回默认值而不是抛出异常,确保分析过程可以继续 + return { + 'manufacturer_tech_level': 0, + 'manufacturer_scale_level': 0, + 'manufacturer_supply_chain_level': 0, + 'manufacturer_composite_score': 0, + 'manufacturer_region_factor': 1.0 + } \ No newline at end of file diff --git a/src/import_data.py b/src/import_data.py index 03dfa36..57286c1 100644 --- a/src/import_data.py +++ b/src/import_data.py @@ -26,7 +26,7 @@ def import_training_data(excel_file): equipment_names.add(row['名称']) # 检查是否已存在相同名称的装备 cursor.execute(""" - SELECT id FROM equipment + SELECT id FROM equipments WHERE name = %s AND type = '火箭炮' """, (row['名称'],)) @@ -37,7 +37,7 @@ def import_training_data(excel_file): # 插入基本信息 cursor.execute(""" - INSERT INTO equipment (name, type, manufacturer) + INSERT INTO equipments (name, type, manufacturer) VALUES (%s, %s, %s) """, (row['名称'], '火箭炮', row['制造商'])) @@ -116,7 +116,7 @@ def import_training_data(excel_file): # 插入基本信息 cursor.execute(""" - INSERT INTO equipment (name, type, manufacturer) + INSERT INTO equipments (name, type, manufacturer) VALUES (%s, %s, %s) """, ( row['名称'], @@ -192,7 +192,7 @@ def import_training_data(excel_file): logger.debug(f"查询装备ID: {equipment_name}") with conn.cursor() as id_cursor: id_cursor.execute(""" - SELECT id FROM equipment WHERE name = %s + SELECT id FROM equipments WHERE name = %s """, (equipment_name,)) result = id_cursor.fetchone() diff --git a/src/init_data.sql b/src/init_data.sql deleted file mode 100644 index ee99db3..0000000 --- a/src/init_data.sql +++ /dev/null @@ -1,319 +0,0 @@ -/* -这是用于开发和测试环境的示例数据。 -生产环境请使用系统的数据导入功能添加实际数据。 - -主要用途: -1. 提供开发测试数据 -2. 作为数据格式参考 -3. 用于系统功能验证 -*/ - --- 插入装备基本信息 -INSERT INTO equipment (name, type, manufacturer, target_type) VALUES -('终结者', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), -('胜利-2', '火箭炮', '伊朗', '地面固定目标'); - --- 插入巡飞弹技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_speed_kmh, - cruise_speed_kmh, - max_range_km, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm -) VALUES ( - 1, -- 终结者巡飞弹 - 0.56, - 0.15, - 0.20, - 2.72, - 160.93, - 96.56, - 24, - 15, - '破片杀伤战斗部', - '凭自身动力起飞', - 560, - 150, - 200 -); - --- 插入火箭炮技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_range_km -) VALUES ( - 2, -- 胜利-2火箭炮 - 10, - 2.5, - 3.34, - 15000, - 23 -); - --- 插入成本数据(示例数据) -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(1, 1000000), -- 终结者巡飞弹成本 -(2, 5000000); -- 胜利-2火箭炮成本 - --- 插入更多巡飞弹变体数据用于训练 -INSERT INTO equipment (name, type, manufacturer, target_type) VALUES -('终结者-A', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), -('终结者-B', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'), -('终结者-C', '巡飞弹', '美国', '静止和移动的人员和轻型装甲车辆'); - --- 插入变体技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_speed_kmh, - cruise_speed_kmh, - max_range_km, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm -) VALUES --- 终结者-A(稍大型号) -(3, 0.58, 0.16, 0.21, 2.85, 170, 100, 26, 16, '破片杀伤战斗部', '凭自身动力起飞', 580, 160, 210), --- 终结者-B(稍小型号) -(4, 0.54, 0.14, 0.19, 2.60, 155, 93, 22, 14, '破片杀伤战斗部', '凭自身动力起飞', 540, 140, 190), --- 终结者-C(标准型号的改进版) -(5, 0.56, 0.15, 0.20, 2.70, 165, 98, 25, 15, '破片杀伤战斗部', '凭自身动力起飞', 560, 150, 200); - --- 插入变体成本数据 -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(3, 1100000), -- 终结者-A成本(较高) -(4, 900000), -- 终结者-B成本(较低) -(5, 1050000); -- 终结者-C成本(中等) - --- 添加更多巡飞弹数据 -INSERT INTO equipment (name, type, manufacturer, target_type) VALUES -('哈比', '巡飞弹', '以色列', '防空系统和雷达站'), -('游荡者', '巡飞弹', '以色列', '装甲车辆和防空系统'), -('凤凰', '巡飞弹', '土耳其', '固定目标和装甲车辆'), -('弹簧刀', '巡飞弹', '波兰', '装甲目标'), -('彩虹-4', '巡飞弹', '中国', '地面固定目标'); - --- 添加它们的技术参数 -INSERT INTO technical_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_speed_kmh, - cruise_speed_kmh, - max_range_km, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm -) VALUES --- 哈比 -(6, 2.5, 0.6, 0.6, 135, 185, 110, 250, 120, '高爆战斗部', '箱式发射', 2500, 600, 600), --- 游荡者 -(7, 2.3, 0.4, 0.4, 30, 190, 120, 30, 30, '破片杀伤战斗部', '箱式发射', 2300, 400, 400), --- 凤凰 -(8, 2.0, 0.3, 0.3, 25, 170, 100, 20, 25, '破片杀伤战斗部', '箱式发射', 2000, 300, 300), --- 弹簧刀 -(9, 1.8, 0.35, 0.35, 28, 180, 110, 25, 30, '破片杀伤战斗部', '箱式发射', 1800, 350, 350), --- 彩虹-4 -(10, 3.5, 0.8, 0.8, 345, 210, 130, 300, 180, '高爆战斗部', '箱式发射', 3500, 800, 800); - --- 添加成本数据 -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(6, 800000), -- 哈比 -(7, 500000), -- 游荡者 -(8, 450000), -- 凤凰 -(9, 480000), -- 弹簧刀 -(10, 1500000); -- 彩虹-4 - --- 火箭炮数据 -INSERT INTO equipment (name, type, manufacturer) VALUES -('BM-21', '火箭炮', '俄罗斯'), -('SR5', '火箭炮', '中国'), -('HIMARS', '火箭炮', '美国'), -('LAR-160', '火箭炮', '以色列'), -('T-122', '火箭炮', '土耳其'), -('RM-70', '火箭炮', '捷克'), -('ASTROS II', '火箭炮', '巴西'); - --- 火箭炮通用参数 -INSERT INTO common_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_range_km -) VALUES --- BM-21 -(1, 7.35, 2.4, 3.1, 13700, 20.4), --- SR5 -(2, 10.2, 2.8, 3.2, 28500, 70), --- HIMARS -(3, 7.0, 2.4, 3.2, 16250, 70), --- LAR-160 -(4, 6.7, 2.5, 2.8, 15000, 45), --- T-122 -(5, 7.2, 2.5, 2.9, 18000, 40), --- RM-70 -(6, 7.5, 2.5, 3.0, 17200, 20.3), --- ASTROS II -(7, 8.0, 2.7, 3.1, 24500, 90); - --- 火箭炮特有参数 -INSERT INTO rocket_artillery_params ( - equipment_id, - firing_angle_horizontal, - firing_angle_vertical, - rocket_length_m, - rocket_diameter_mm, - rocket_weight_kg, - rate_of_fire, - combat_weight_kg, - speed_kmh, - min_range_km, - mobility_type, - structure_layout, - engine_model, - engine_params, - power_hp, - travel_range_km -) VALUES --- BM-21 -(1, 102, 55, 2.87, 122, 66.6, 40, 13700, 75, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '240马力', 240, 500), --- SR5 -(2, 110, 60, 4.1, 220, 150, 60, 28500, 90, 2.0, '轮式', '前置驾驶舱', 'V6柴油', '320马力', 320, 650), --- HIMARS -(3, 90, 65, 3.94, 227, 301, 6, 16250, 85, 2.0, '轮式', '前置驾驶舱', 'V8柴油', '290马力', 290, 480), --- LAR-160 -(4, 100, 58, 3.3, 160, 110, 18, 15000, 80, 1.8, '轮式', '前置驾驶舱', 'V6柴油', '260马力', 260, 550), --- T-122 -(5, 110, 65, 2.95, 122, 65.5, 40, 18000, 85, 1.5, '轮式', '前置驾驶舱', 'V8柴油', '280马力', 280, 600), --- RM-70 -(6, 100, 50, 2.87, 122, 66.6, 40, 17200, 70, 1.6, '轮式', '前置驾驶舱', 'V8柴油', '250马力', 250, 520), --- ASTROS II -(7, 90, 65, 4.3, 300, 550, 30, 24500, 80, 2.2, '轮式', '前置驾驶舱', 'V8柴油', '350马力', 350, 700); - --- 巡飞弹数据 -INSERT INTO equipment (name, type, manufacturer) VALUES -('Hero-120', '巡飞弹', '以色列'), -('Switchblade 600', '巡飞弹', '美国'), -('Warmate', '巡飞弹', '波兰'), -('CH-901', '巡飞弹', '中国'), -('HAROP', '巡飞弹', '以色列'), -('Coyote', '巡飞弹', '美国'), -('WS-43', '巡飞弹', '中国'); - --- 巡飞弹通用参数 -INSERT INTO common_params ( - equipment_id, - length_m, - width_m, - height_m, - weight_kg, - max_range_km -) VALUES --- Hero-120 -(8, 1.3, 0.23, 0.23, 12.5, 40), --- Switchblade 600 -(9, 1.3, 0.22, 0.22, 15.0, 40), --- Warmate -(10, 1.1, 0.15, 0.15, 5.7, 15), --- CH-901 -(11, 1.2, 0.18, 0.18, 9.0, 20), --- HAROP -(12, 2.5, 0.43, 0.43, 135, 1000), --- Coyote -(13, 0.9, 0.12, 0.12, 5.9, 20), --- WS-43 -(14, 1.8, 0.35, 0.35, 20, 60); - --- 巡飞弹特有参数 -INSERT INTO loitering_munition_params ( - equipment_id, - wingspan_m, - warhead_weight_kg, - max_speed_ms, - cruise_speed_kmh, - flight_time_min, - warhead_type, - launch_mode, - folded_length_mm, - folded_width_mm, - folded_height_mm, - power_system, - guidance_system -) VALUES --- Hero-120 -(8, 2.1, 3.5, 50, 100, 60, '破片杀伤战斗部', '箱式发射', 1300, 230, 230, '电动机', 'GPS/INS'), --- Switchblade 600 -(9, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', 1300, 220, 220, '电动机', 'GPS/INS/光电'), --- Warmate -(10, 1.4, 1.4, 41.7, 90, 30, '破片杀伤战斗部', '箱式发射', 1100, 150, 150, '电动机', 'GPS/INS'), --- CH-901 -(11, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射', 1200, 180, 180, '电动机', 'GPS/INS'), --- HAROP -(12, 3.0, 23, 51.4, 110, 360, '高爆战斗部', '箱式发射', 2500, 430, 430, '活塞发动机', 'GPS/INS/光电/数据链'), --- Coyote -(13, 1.2, 1.8, 41.7, 95, 30, '破片杀伤战斗部', '箱式发射', 900, 120, 120, '电动机', 'GPS/INS'), --- WS-43 -(14, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射', 1800, 350, 350, '电动机', 'GPS/INS/光电'); - --- 插入成本数据(示例成本) -INSERT INTO cost_data (equipment_id, actual_cost) VALUES --- 火箭炮 -(1, 800000), -- BM-21 -(2, 4500000), -- SR5 -(3, 5500000), -- HIMARS -(4, 3500000), -- LAR-160 -(5, 2800000), -- T-122 -(6, 1500000), -- RM-70 -(7, 4800000), -- ASTROS II --- 巡飞弹 -(8, 150000), -- Hero-120 -(9, 180000), -- Switchblade 600 -(10, 80000), -- Warmate -(11, 100000), -- CH-901 -(12, 850000), -- HAROP -(13, 75000), -- Coyote -(14, 120000); -- WS-43 - --- 创建初始数据集 -INSERT INTO datasets (name, description, equipment_type, purpose) VALUES -('火箭炮训练集', '用于训练火箭炮成本预测模型的数据集', '火箭炮', '训练'), -('巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'), -('火箭炮验证集', '用于验证火箭炮成本预测模型的数据集', '火箭炮', '验证'), -('巡飞弹验证集', '用于验证巡飞弹成本预测模型的数据集', '巡飞弹', '验证'); - --- 关联装备到数据集 -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES --- 火箭炮训练集 -(1, 1), (1, 2), (1, 3), (1, 4), --- 巡飞弹训练集 -(2, 8), (2, 9), (2, 10), (2, 11), (2, 12), --- 火箭炮验证集 -(3, 5), (3, 6), (3, 7), --- 巡飞弹验证集 -(4, 13), (4, 14); \ No newline at end of file diff --git a/src/loitering_munition_data.sql b/src/loitering_munition_data.sql index 350dc0a..e14ca4e 100644 --- a/src/loitering_munition_data.sql +++ b/src/loitering_munition_data.sql @@ -26,15 +26,15 @@ */ -- 插入装备基本信息 -INSERT INTO equipment ( +INSERT INTO equipments ( id, -- 装备ID name, -- 装备名称 type, -- 装备类型 manufacturer -- 制造商 ) VALUES -(1, 'IAI Harop', '巡飞弹', '以色列'), -(2, 'IAI Harpy', '巡飞弹', '以色列'), -(3, 'IAI Mini Harpy', '巡飞弹', '以色列'), +(1, 'IAI Harop', '巡飞弹', '以色列 IAI'), +(2, 'IAI Harpy', '巡飞弹', '以色列 IAI'), +(3, 'IAI Mini Harpy', '巡飞弹', '以色列 IAI'), (4, 'Hero-30', '巡飞弹', '以色列 UVision'), (5, 'Hero-70', '巡飞弹', '以色列 UVision'), (6, 'Hero-120', '巡飞弹', '以色列 UVision'), @@ -65,11 +65,11 @@ INSERT INTO equipment ( (31, 'Alpagu', '巡飞弹', '土耳其 STM'), (32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'), (33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'), -(34, 'Shahed-131', '巡飞弹', '伊朗'), -(35, 'Shahed-131B', '巡飞弹', '伊朗'), -(36, 'Shahed-136', '巡飞弹', '伊朗'), -(37, 'Shahed-136B', '巡飞弹', '伊朗'), -(38, 'Shahed-136C', '巡飞弹', '伊朗'), +(34, 'Shahed-131', '巡飞弹', '伊朗国防工业'), +(35, 'Shahed-131B', '巡飞弹', '伊朗国防工业'), +(36, 'Shahed-136', '巡飞弹', '伊朗国防工业'), +(37, 'Shahed-136B', '巡飞弹', '伊朗国防工业'), +(38, 'Shahed-136C', '巡飞弹', '伊朗国防工业'), (39, 'Green Dragon', '巡飞弹', '以色列 IAI'), (40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'), (41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'), @@ -285,7 +285,7 @@ INSERT INTO loitering_munition_params ( (24, 2.8, 8.0, 70, 180, 240, 50, 10.0, 4000, 25, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'), (25, 3.0, 9.0, 75, 190, 270, 60, 11.0, 4500, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'), (26, 3.2, 10.0, 80, 200, 300, 70, 12.0, 5000, 35, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'), -(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'), +(27, 3.5, 15.0, 85, 220, 360, 100, 18.0, 6000, 50, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红���'), (28, 3.6, 16.0, 90, 230, 400, 120, 20.0, 6500, 60, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外/卫通'), (29, 1.2, 1.0, 40, 90, 30, 5, 1.5, 1500, 3, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'), (30, 1.3, 1.2, 45, 100, 40, 8, 2.0, 2000, 4, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI辅助'), @@ -338,7 +338,7 @@ INSERT INTO loitering_munition_params ( (77, 2.8, 40.0, 250, 200, 120, 180, 50.0, 5500, 90, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Light (78, 3.2, 80.0, 230, 180, 150, 200, 100.0, 6000, 100, '破甲战斗部', '空中发射', '涡轮喷气', 'GPS/INS/光电/数据链/AI辅助'), -- SmartGlider Heavy (79, 1.5, 3.5, 160, 140, 60, 50, 5.0, 3500, 25, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助'), -- Taifun -(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K +(80, 1.8, 4.5, 180, 150, 80, 70, 6.0, 4000, 35, '破片杀伤战斗部', '箱式发射', '电������', 'GPS/INS/光电/AI辅助/红外'), -- Taifun-K (81, 1.5, 3.0, 120, 100, 60, 40, 4.0, 3000, 20, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ES (82, 2.0, 5.0, 140, 120, 90, 60, 6.0, 4000, 30, '破片杀伤/破甲双用战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -- HERO-ER (83, 2.5, 8.0, 160, 140, 120, 80, 10.0, 5000, 40, '破片杀伤/破甲双用战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/红外'), -- HERO-XL @@ -469,7 +469,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES (2, '巡飞弹验证集 2024', '包含20个巡飞弹型号,用于验证模型性能', '巡飞弹', '验证'); -- 训练集(80个型号) -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES +INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES -- 以色列系列(8/10) (1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列 (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), -- Hero系列 @@ -520,7 +520,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES (1, 96), (1, 97), (1, 98), (1, 99); -- Shadow/Argus系列 -- 验证集(20个型号) -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES +INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES -- 以色列系列(2/10) (2, 9), -- Hero-900 (2, 48), -- Rotem L @@ -666,4 +666,43 @@ SET WHEN l.max_range_km > 500 THEN 5000 WHEN l.max_range_km > 100 THEN 3000 ELSE 1500 - END; \ No newline at end of file + END; + + + +-- 更新巡飞弹的制导精度 +UPDATE loitering_munition_params l +SET guidance_accuracy_m = + CASE + -- 基础精度(根据制导系统类型) + WHEN guidance_system LIKE '%GPS/INS%' AND guidance_system LIKE '%AI辅助%' THEN 2.0 + WHEN guidance_system LIKE '%GPS/INS%' THEN 3.0 + WHEN guidance_system LIKE '%激光制导%' THEN 1.0 + WHEN guidance_system LIKE '%红外制导%' THEN 2.0 + WHEN guidance_system LIKE '%卫星制导%' THEN 2.5 + ELSE 5.0 + END * + -- 速度影响因子(速度越快,精度略微降低) + CASE + WHEN max_speed_ms > 200 THEN 1.2 + WHEN max_speed_ms > 150 THEN 1.1 + WHEN max_speed_ms > 100 THEN 1.0 + ELSE 0.9 + END * + -- 重量影响因子(重量越大,精度略微降低) + CASE + WHEN warhead_weight_kg > 100 THEN 1.2 + WHEN warhead_weight_kg > 50 THEN 1.1 + WHEN warhead_weight_kg > 20 THEN 1.0 + ELSE 0.9 + END * + -- 飞行高度影响因子(高度越高,精度略微降低) + CASE + WHEN ceiling_altitude_m > 5000 THEN 1.2 + WHEN ceiling_altitude_m > 3000 THEN 1.1 + WHEN ceiling_altitude_m > 1000 THEN 1.0 + ELSE 0.9 + END +WHERE equipment_id IN ( + SELECT id FROM equipments WHERE type = '巡飞弹' +); \ No newline at end of file diff --git a/src/manufacturer_data.sql b/src/manufacturer_data.sql index 01cf09b..f80321a 100644 --- a/src/manufacturer_data.sql +++ b/src/manufacturer_data.sql @@ -40,7 +40,7 @@ INSERT INTO manufacturers ( ('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业 -- 俄罗斯供应商 -('俄罗斯', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限 +('俄罗斯 Rostec', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限 ('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商 ('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商 @@ -72,7 +72,9 @@ INSERT INTO manufacturers ( ('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业 -- 更新装备表中的供应商ID -UPDATE equipment e -SET manufacturer_id = m.id -FROM manufacturers m -WHERE e.manufacturer = m.name; \ No newline at end of file +UPDATE equipments e +SET manufacturer_id = ( + SELECT id + FROM manufacturers m + WHERE m.name = e.manufacturer +); \ No newline at end of file diff --git a/src/model_trainer.py b/src/model_trainer.py index 8792f90..9da7c39 100644 --- a/src/model_trainer.py +++ b/src/model_trainer.py @@ -1,282 +1,112 @@ import numpy as np -import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import DataLoader from sklearn.preprocessing import StandardScaler -from sklearn.model_selection import cross_val_score -from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor -from sklearn.impute import SimpleImputer -import xgboost as xgb -import lightgbm as lgb +from sklearn.model_selection import train_test_split import logging -import joblib import os -from src.feature_analysis import FeatureAnalysis from datetime import datetime import json +from src.feature_analysis import FeatureAnalysis from src.database import get_db_connection -from src.data_preparation import DataPreparation -from sklearn.cross_decomposition import PLSRegression +from src.data_preparation import DataPreparation, EquipmentDataset from .logger import setup_logger logger = setup_logger(__name__) +class CostPredictionModel(nn.Module): + def __init__(self, input_size): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_size, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 64), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 1) + ) + + def forward(self, x): + return self.layers(x) + class ModelTrainer: def __init__(self): - """ - 初始化 ModelTrainer - """ - self.models = { - 'xgboost': self._create_xgboost_model(), - 'lightgbm': self._create_lightgbm_model(), - 'gbm': self._create_gbm_model(), - 'rf': self._create_rf_model(), - 'pls': self._create_pls_model() - } - self.best_model = None - self.imputer = SimpleImputer(strategy='mean') + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = None self.feature_scaler = None self.target_scaler = None self.equipment_type = None self.feature_analyzer = FeatureAnalysis() - def fit_model(self, X_train, y_train, model_names, X_val=None, y_val=None, equipment_type=None): - """ - 训练模型并返回评估结果 - """ + def train_model(self, dataloader, epochs=100, learning_rate=0.001): + """训练模型""" try: - self.equipment_type = equipment_type - logger.info(f"Training data range - X: min={X_train.min()}, max={X_train.max()}") - logger.info(f"Training data range - y: min={y_train.min()}, max={y_train.max()}") + # 获取输入特征维度 + sample_features, _ = next(iter(dataloader)) + input_size = sample_features.shape[1] - results = {} - best_score = -float('inf') - best_model_info = None + # 创建模型 + self.model = CostPredictionModel(input_size).to(self.device) + criterion = nn.MSELoss() + optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) - # 首先训练 PLS 模型 - logger.info("Training pls...") - pls_model = self.models['pls'] - pls_model.fit(X_train, y_train) - pls_metrics = self._calculate_metrics( - pls_model, - X_train, y_train, - X_val, y_val - ) - results['pls'] = pls_metrics - - # 训练其他机器学习模型 - for model_name in model_names: - if model_name == 'pls': # 跳过 PLS 模型,因为已经训练过了 - continue + # 训练循环 + for epoch in range(epochs): + self.model.train() + total_loss = 0 + for batch_features, batch_targets in dataloader: + # 移动数据到设备 + batch_features = batch_features.to(self.device) + batch_targets = batch_targets.to(self.device) - if model_name not in self.models: - logger.warning(f"Unknown model: {model_name}") - continue + # 前向传播 + outputs = self.model(batch_features) + loss = criterion(outputs, batch_targets.view(-1, 1)) - logger.info(f"Training {model_name}...") - model = self.models[model_name] + # 反向传播 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() - # 训练模型 - model.fit(X_train, y_train) - - # 计算评估指标 - metrics = self._calculate_metrics( - model, - X_train, y_train, - X_val, y_val - ) - - results[model_name] = metrics - - # 更新最佳模型(只在机器学习模型中比较) - if metrics['validation']['r2'] > best_score: - best_score = metrics['validation']['r2'] - best_model_info = { - 'type': model_name, - 'r2': metrics['validation']['r2'], - 'mae': metrics['validation']['mae'], - 'rmse': metrics['validation']['rmse'] - } - self.best_model = model + # 记录训练进度 + if (epoch + 1) % 10 == 0: + avg_loss = total_loss / len(dataloader) + logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}') - # 保存最佳模型和 PLS 模型 - if equipment_type and best_model_info: - self._save_best_model(equipment_type, best_model_info, X_train, y_train, X_val, y_val) - - return { - 'metrics': results, - 'best_model': best_model_info - } + return True except Exception as e: logger.error(f"Error in model training: {str(e)}") raise - - def _calculate_metrics(self, model, X_train, y_train, X_val=None, y_val=None): - """ - 计算模型评估指标 - """ - from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error - - # 训练集评估 - train_pred = model.predict(X_train) - - # 如果使用了标准化,需要转换回原始范围 - if hasattr(self, 'target_scaler'): - train_pred = self.target_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel() - y_train_orig = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel() - else: - y_train_orig = y_train - - # 记录预测范围 - logger.info(f"Train predictions range: min={train_pred.min()}, max={train_pred.max()}") - logger.info(f"Train actual range: min={y_train_orig.min()}, max={y_train_orig.max()}") - - train_metrics = { - 'r2': r2_score(y_train_orig, train_pred), - 'mae': mean_absolute_error(y_train_orig, train_pred), - 'rmse': np.sqrt(mean_squared_error(y_train_orig, train_pred)) - } - - # 验证集评估 - if X_val is not None and y_val is not None: - val_pred = model.predict(X_val) - - # 如果使用了标准化,需要转换回原始范围 - if hasattr(self, 'target_scaler'): - val_pred = self.target_scaler.inverse_transform(val_pred.reshape(-1, 1)).ravel() - y_val_orig = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel() - else: - y_val_orig = y_val - - # 记录预测范围 - logger.info(f"Validation predictions range: min={val_pred.min()}, max={val_pred.max()}") - logger.info(f"Validation actual range: min={y_val_orig.min()}, max={y_val_orig.max()}") - - val_metrics = { - 'r2': r2_score(y_val_orig, val_pred), - 'mae': mean_absolute_error(y_val_orig, val_pred), - 'rmse': np.sqrt(mean_squared_error(y_val_orig, val_pred)) - } - else: - # 使用交叉验证 - cv_scores = cross_val_score(model, X_train, y_train, cv=5) - val_metrics = { - 'r2': cv_scores.mean(), - 'mae': None, - 'rmse': None - } - - return { - 'train': train_metrics, - 'validation': val_metrics - } - def _create_xgboost_model(self): - """ - 创建 XGBoost 模型,增强正则化 - """ - return xgb.XGBRegressor( - n_estimators=50, # 减少树的数量 - learning_rate=0.05, # 学习率 - max_depth=3, # 减小树的深 - min_child_weight=3, # 增加节点权重 - subsample=0.7, # 减小样本采样比例 - colsample_bytree=0.7, # 减小特征采样比例 - reg_alpha=0.1, # L1 正则化 - reg_lambda=1, # L2 正则化 - random_state=42 - ) - - def _create_lightgbm_model(self): - """ - 创建 LightGBM 模型,增强正则化 - """ - return lgb.LGBMRegressor( - n_estimators=50, - learning_rate=0.05, - max_depth=3, - num_leaves=7, - min_data_in_leaf=3, - min_sum_hessian_in_leaf=1e-3, - subsample=0.7, - colsample_bytree=0.7, - reg_alpha=0.1, - reg_lambda=1, - random_state=42, - verbose=-1 - ) - - def _create_gbm_model(self): - """ - 创建 GBM 模型,增强正则化以减轻过拟合 - """ - return GradientBoostingRegressor( - n_estimators=100, - learning_rate=0.1, - max_depth=3, - random_state=42, - subsample=0.8, - min_samples_split=3, - min_samples_leaf=2 - ) - - def _create_rf_model(self): - """ - 创建随机森林模型,针对小样本数据调整参数 - """ - return RandomForestRegressor( - n_estimators=100, - max_depth=3, - random_state=42, - min_samples_split=3, - min_samples_leaf=2 - ) - - def _create_pls_model(self): - """ - 创建 PLS 模型,优化参数配置 - """ - return PLSRegression( - n_components=2, # 减少主成分数量,从5减到2 - scale=True, # 保持数据标准化 - max_iter=500, # 减少最大迭代次数,避免过拟合 - tol=1e-6 # 降低收敛精度,避免过拟合 - ) - - def _save_best_model(self, equipment_type, best_model_info, X_train, y_train, X_val=None, y_val=None): - """ - 保存最佳模型和 PLS 模型 - """ + def save_model(self, equipment_type): + """保存模型""" try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_dir = 'models' os.makedirs(model_dir, exist_ok=True) - # 1. 保存最佳机器学习模型 - model_path = f'{model_dir}/{equipment_type}_{timestamp}' - if isinstance(self.best_model, xgb.XGBRegressor): - self.best_model.save_model(f'{model_path}.json') - model_format = 'json' - else: - joblib.dump(self.best_model, f'{model_path}.joblib') - model_format = 'joblib' - - # 2. 保存 PLS 模型 - pls_model = self.models['pls'] - pls_path = f'{model_dir}/{equipment_type}_{timestamp}_pls.joblib' - joblib.dump(pls_model, pls_path) + # 保存模型 + model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth' + torch.save({ + 'model_state_dict': self.model.state_dict(), + 'input_size': self.model.layers[0].in_features + }, model_path) - # 3. 保存标准化器 - scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.joblib' - joblib.dump({ + # 保存标准化器 + scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth' + torch.save({ 'feature_scaler': self.feature_scaler, 'target_scaler': self.target_scaler }, scaler_path) - logger.info(f"Saved best model to {model_path}.{model_format}") - logger.info(f"Saved PLS model to {pls_path}") - logger.info(f"Saved scalers to {scaler_path}") - - # 4. 更新数据库中的模型记录 + # 更新数据库 with get_db_connection() as conn: cursor = conn.cursor() @@ -287,420 +117,73 @@ class ModelTrainer: WHERE equipment_type = %s """, (equipment_type,)) - # 获取 PLS 模型的评估指标 - pls_metrics = self._calculate_metrics( - self.models['pls'], - X_train, - y_train, - X_val, - y_val - ) - - # 保存最佳机器学习模型记录 - self.best_model.equipment_type = equipment_type # 设置装备类型 - ml_feature_importance = self._get_feature_importance(self.best_model) - + # 保存新模型记录 cursor.execute(""" INSERT INTO trained_models ( - model_name, model_type, equipment_type, model_path, scaler_path, - r2_score, mae, rmse, feature_importance, training_data_size, - training_date, is_active, created_by - ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s) + model_name, model_type, equipment_type, model_path, + scaler_path, training_date, is_active, created_by + ) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s) """, ( - f"{equipment_type}_{timestamp}", # model_name - best_model_info['type'], # model_type - equipment_type, # equipment_type - f"{model_path}.{model_format}", # model_path - scaler_path, # scaler_path - best_model_info['r2'], # r2_score - best_model_info['mae'], # mae - best_model_info['rmse'], # rmse - json.dumps(ml_feature_importance), # feature_importance - len(X_train), # training_data_size - 'system' # created_by - )) - - # 保存 PLS 模型记录 - pls_model.equipment_type = equipment_type # 设置装备类型 - pls_feature_importance = self._get_feature_importance(pls_model) - - cursor.execute(""" - INSERT INTO trained_models ( - model_name, model_type, equipment_type, model_path, scaler_path, - r2_score, mae, rmse, feature_importance, training_data_size, - training_date, is_active, created_by - ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), TRUE, %s) - """, ( - f"{equipment_type}_{timestamp}_pls", # model_name - 'pls', # model_type - equipment_type, # equipment_type - pls_path, # model_path - scaler_path, # scaler_path - float(pls_metrics['validation']['r2']), # r2_score - float(pls_metrics['validation']['mae']), # mae - float(pls_metrics['validation']['rmse']), # rmse - json.dumps(pls_feature_importance), # feature_importance - len(X_train), # training_data_size - 'system' # created_by + f"{equipment_type}_{timestamp}", + 'pytorch', + equipment_type, + model_path, + scaler_path, + 'system' )) conn.commit() - - except Exception as e: - logger.error(f"Error saving models: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - raise - - def load_model(self, equipment_type, model_type='ml'): - """ - 加载已训练的模型 - """ - try: - logger.info(f"Loading {model_type} model for {equipment_type}") - # 从数据库获取激活的模型 + return True + + except Exception as e: + logger.error(f"Error saving model: {str(e)}") + return False + + def load_model(self, equipment_type): + """加载模型""" + try: + # 从数据库获取最新的激活模型 with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) - - # 构建查询语句 - if model_type == 'pls': - query = """ - SELECT * FROM trained_models - WHERE equipment_type = %s - AND model_type = 'pls' - AND is_active = TRUE - LIMIT 1 - """ - params = (equipment_type,) - else: - query = """ - SELECT * FROM trained_models - WHERE equipment_type = %s - AND model_type != 'pls' - AND is_active = TRUE - LIMIT 1 - """ - params = (equipment_type,) - - # 记录查询信息 - logger.info(f"Executing query: {query}") - logger.info(f"Query params: {params}") - - cursor.execute(query, params) + cursor.execute(""" + SELECT * FROM trained_models + WHERE equipment_type = %s AND is_active = TRUE + ORDER BY training_date DESC LIMIT 1 + """, (equipment_type,)) model_record = cursor.fetchone() - # 记录查询结果 - if model_record: - logger.info(f"Found model record: {model_record}") - else: - logger.warning(f"No active model found for type {model_type}") + if not model_record: return False - # 检查文件是否存在 - logger.info(f"Checking model file: {model_record['model_path']}") - logger.info(f"Checking scaler file: {model_record['scaler_path']}") - - if not os.path.exists(model_record['model_path']): - logger.error(f"Model file not found: {model_record['model_path']}") - raise FileNotFoundError(f"Model file not found: {model_record['model_path']}") - - if not os.path.exists(model_record['scaler_path']): - logger.error(f"Scaler file not found: {model_record['scaler_path']}") - raise FileNotFoundError(f"Scaler file not found: {model_record['scaler_path']}") - - # 加载模型文件 - logger.info(f"Loading model from {model_record['model_path']}") - if model_type == 'pls': - self.best_model = joblib.load(model_record['model_path']) - logger.info("Loaded PLS model") - else: - if model_record['model_type'] == 'xgboost': - self.best_model = xgb.XGBRegressor() - self.best_model.load_model(model_record['model_path']) - logger.info("Loaded XGBoost model") - else: - self.best_model = joblib.load(model_record['model_path']) - logger.info(f"Loaded {model_record['model_type']} model") + # 加载模型 + checkpoint = torch.load(model_record['model_path']) + input_size = checkpoint['input_size'] + self.model = CostPredictionModel(input_size).to(self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) # 加载标准化器 - logger.info(f"Loading scalers from {model_record['scaler_path']}") - scalers = joblib.load(model_record['scaler_path']) + scalers = torch.load(model_record['scaler_path']) self.feature_scaler = scalers['feature_scaler'] self.target_scaler = scalers['target_scaler'] - logger.info("Loaded scalers successfully") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") - logger.error(f"Detailed traceback:", exc_info=True) return False - - def get_missile_features(self): - """获取巡飞弹的特征列表""" - return [ - # 基本参数 - 'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km', - - # 性能参数 - 新增和修改的参数 - 'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh', - 'endurance_min', 'max_payload_kg', 'min_combat_radius_km', - - # 动力系统参数 - 新增参数 - 'engine_power_kw', 'engine_thrust_n', - - # 制导与控制参数 - 新增参数 - 'datalink_range_km', 'guidance_accuracy_m', - 'min_altitude_m', 'max_altitude_m', - - # 特征工程参数 - 新增评分指标 - 'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio', - 'guidance_system_score', 'warhead_power_score' - ] - - def get_rocket_features(self): - """获取火箭炮的特征列表""" - return [ - # 基本参数 - 'length_m', 'width_m', 'height_m', 'weight_kg', 'max_range_km', - - # 火箭炮特有参数 - 新增和修改的参数 - 'firing_angle_horizontal', 'firing_angle_vertical', - 'rocket_length_m', 'rocket_diameter_mm', 'rocket_weight_kg', - 'rate_of_fire', 'combat_weight_kg', 'speed_kmh', - 'min_range_km', 'max_range_km', 'mobility_type', 'structure_layout', - 'engine_model', 'power_hp', 'travel_range_km', - - # 特征工程参数 - 新增评分指标 - 'fire_density', 'range_ratio', 'mobility_score', - 'combat_readiness_score', 'rocket_power_ratio', - 'platform_efficiency', 'deployment_score', - 'terrain_adaptability_score' - ] - + def predict(self, features): - """使用加载的模型进行预测""" + """使用模型进行预测""" try: - if not self.best_model: - raise ValueError("No model loaded") - - if not self.feature_scaler: - raise ValueError("Feature scaler not loaded") - - if not self.target_scaler: - raise ValueError("Target scaler not loaded") - - logger.info("Starting prediction") - logger.info(f"Input features shape: {features.shape}") - logger.info(f"Input features: \n{features}") - - # 获取正确的特征列表 - if self.equipment_type == '巡飞弹': - feature_list = self.get_missile_features() - logger.info(f"Using missile features: {feature_list}") - - # 确保特征顺序一致 - features_ordered = np.zeros((features.shape[0], len(feature_list))) - for i, feature_name in enumerate(feature_list): - if feature_name in features: - features_ordered[:, i] = features[feature_name] - features = features_ordered - - # 处理缺失值 - features_filled = np.array(features, dtype=float) - features_filled[np.isnan(features_filled)] = 0 - features_filled = np.nan_to_num(features_filled, 0) - - logger.info(f"Filled features: \n{features_filled}") - - # 标准化特征 - X = self.feature_scaler.transform(features_filled) - logger.info(f"Transformed features shape: {X.shape}") - logger.info(f"Transformed features: \n{X}") - - # 预测 - y_pred_scaled = self.best_model.predict(X) - logger.info(f"Scaled prediction shape: {y_pred_scaled.shape}") - logger.info(f"Scaled prediction: {y_pred_scaled}") - - # 反标准化 - y_pred = self.target_scaler.inverse_transform(y_pred_scaled.reshape(-1, 1)) - logger.info(f"Final prediction shape: {y_pred.shape}") - logger.info(f"Final prediction: {y_pred}") - - return y_pred.ravel() - + self.model.eval() + with torch.no_grad(): + # 转换为tensor并移动到正确的设备 + features_tensor = torch.FloatTensor(features).to(self.device) + # 进行预测 + predictions = self.model(features_tensor) + # 移回CPU并转换为numpy数组 + return predictions.cpu().numpy() except Exception as e: logger.error(f"Error in prediction: {str(e)}") - raise - - def _get_feature_importance(self, model): - """ - 获取特征重要性 - """ - try: - if not model: - return {} - - # 获取征名称 - if self.equipment_type == '巡飞弹': - feature_names = [ - # 基本参数 - 'length_m', 'width_m', 'height_m', 'weight_kg', - - # 性能参数 - 'wingspan_m', 'warhead_weight_kg', 'max_speed_ms', 'cruise_speed_kmh', - 'endurance_min', 'max_range_km','max_payload_kg', 'ceiling_altitude_m', - 'combat_radius_km', - - # 动力系统参数 - 'engine_power_kw', 'engine_thrust_n', - - # 制导与控制参数 - 'datalink_range_km', 'guidance_accuracy_m', - 'min_altitude_m', 'max_altitude_m', - - # 特征工程参数 - 'length_width_ratio', 'weight_range_ratio', 'speed_weight_ratio', - 'guidance_system_score', 'warhead_power_score' - ] - else: - # 其他装备类型使用原有的特征获取逻辑 - feature_analyzer = FeatureAnalysis() - feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type) - - # 获取特征重要性 - if hasattr(model, 'feature_importances_'): - importances = model.feature_importances_ - elif hasattr(model, 'coef_'): - if len(model.coef_.shape) > 1: # 如果是二维数组 - importances = np.abs(model.coef_[0]) # 取第一行 - else: - importances = np.abs(model.coef_) - else: - return {} - - # 创建特征重要性字典 - importance_dict = {} - for name, importance in zip(feature_names, importances): - importance_dict[name] = float(importance) # 确保转换为 Python 标量 - - # 按重要性降序排序 - sorted_dict = dict(sorted( - importance_dict.items(), - key=lambda x: x[1], - reverse=True - )) - - # 过滤掉重要性为0的特征 - return {k: v for k, v in sorted_dict.items() if v > 0} - - except Exception as e: - logger.error(f"Error getting feature importance: {str(e)}") - return {} - - def _calculate_confidence_interval(self, prediction, confidence=0.95): - """ - 计算预测值的置信区间 - """ - try: - # 使用预测值的20%作为标准差(增加不确定性) - std = abs(prediction) * 0.2 - - # 计算置信区间 - from scipy import stats - interval = stats.norm.interval(confidence, loc=prediction, scale=std) - - # 确保区间值为正数且合理 - lower = max(1000, interval[0]) # 最小值设为1000元 - upper = max(prediction * 1.2, interval[1]) # 至少比预测值大20% - - logger.info(f"Calculated confidence interval: [{lower:.2f}, {upper:.2f}]") - - return [lower, upper] - - except Exception as e: - logger.error(f"Error calculating confidence interval: {str(e)}") - # 如果计算失败,返回基于20%的简单区间 - lower = max(1000, prediction * 0.8) - upper = prediction * 1.2 - return [lower, upper] - - def get_model_type(self): - """ - 获取当前模型的类型 - """ - if isinstance(self.best_model, xgb.XGBRegressor): - return 'xgboost' - elif isinstance(self.best_model, lgb.LGBMRegressor): - return 'lightgbm' - elif isinstance(self.best_model, GradientBoostingRegressor): - return 'gbm' - elif isinstance(self.best_model, RandomForestRegressor): - return 'rf' - else: - return 'unknown' - - def _get_pls_feature_importance(self): - """ - 获取 PLS 模型的特征重要性 - """ - try: - if not self.models['pls']: - return {} - - # 获取特征名称 - feature_analyzer = FeatureAnalysis() - feature_names = feature_analyzer.get_equipment_specific_features(self.equipment_type) - - # 获取 PLS 模型的系数作为特征重要性 - pls_model = self.models['pls'] - if hasattr(pls_model, 'coef_'): - # 使用绝对值作为重要性指标 - importances = np.abs(pls_model.coef_.ravel()) # 使用 ravel() 展平数组 - else: - return {} - - # 创建特征重要性字典 - importance_dict = {} - for name, importance in zip(feature_names, importances): - importance_dict[name] = float(importance) # 确保转换为 Python 标量 - - # 按重要性降序排序 - sorted_dict = dict(sorted( - importance_dict.items(), - key=lambda x: x[1], - reverse=True - )) - - # 过滤掉重要性为0的特征 - return {k: v for k, v in sorted_dict.items() if v > 0} - - except Exception as e: - logger.error(f"Error getting PLS feature importance: {str(e)}") - logger.error("Detailed traceback:", exc_info=True) - return {} - - def _preprocess_data(self, data): - """数据预处理""" - try: - # 获取正确的特征列表 - if self.equipment_type == '巡飞弹': - feature_list = self.get_missile_features() - else: - feature_list = self.get_rocket_features() - - logger.info(f"Using features: {feature_list}") - - # 处理缺失值 - features_filled = np.array(data, dtype=float) - features_filled[np.isnan(features_filled)] = 0 - features_filled = np.nan_to_num(features_filled, 0) - - logger.info(f"Filled features: \n{features_filled}") - - return features_filled - - except Exception as e: - logger.error(f"Error in data preprocessing: {str(e)}") raise \ No newline at end of file diff --git a/src/real_data.sql b/src/real_data.sql deleted file mode 100644 index d00ab59..0000000 --- a/src/real_data.sql +++ /dev/null @@ -1,485 +0,0 @@ --- 清空现有数据 -SET FOREIGN_KEY_CHECKS=0; -TRUNCATE TABLE dataset_equipment; -TRUNCATE TABLE datasets; -TRUNCATE TABLE cost_data; -TRUNCATE TABLE loitering_munition_params; -TRUNCATE TABLE common_params; -TRUNCATE TABLE equipment; -SET FOREIGN_KEY_CHECKS=1; - --- 按系列插入装备数据,确保ID连续 --- 1. HAROP/Harpy 系列 (ID: 1-3) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(1, 'IAI Harop', '巡飞弹', '以色列'), -(2, 'IAI Harpy', '巡飞弹', '以色列'), -(3, 'IAI Mini Harpy', '巡飞弹', '以色列'); - --- 2. Hero 系列 (ID: 4-9) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(4, 'Hero-30', '巡飞弹', '以色列 UVision'), -(5, 'Hero-70', '巡飞弹', '以色列 UVision'), -(6, 'Hero-120', '巡飞弹', '以色列 UVision'), -(7, 'Hero-250', '巡飞弹', '以色列 UVision'), -(8, 'Hero-400EC', '巡飞弹', '以色列 UVision'), -(9, 'Hero-900', '巡飞弹', '以色列 UVision'); - --- 3. Switchblade 系列 (ID: 10-13) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(10, 'Switchblade 300', '巡飞弹', '美国 AeroVironment'), -(11, 'Switchblade 600', '巡飞弹', '美国 AeroVironment'), -(12, 'Switchblade 300 Block 10', '巡飞弹', '美国 AeroVironment'), -(13, 'Switchblade 600 Extended Range', '巡飞弹', '美国 AeroVironment'); - --- 4. Warmate 系列 (ID: 14-18) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(14, 'Warmate 1.0', '巡飞弹', '波兰 WB Electronics'), -(15, 'Warmate 2.0', '巡飞弹', '波兰 WB Electronics'), -(16, 'Warmate-V', '巡飞弹', '波兰 WB Electronics'), -(17, 'Warmate-L', '巡飞弹', '波兰 WB Electronics'), -(18, 'Warmate 3.0', '巡飞弹', '波兰 WB Electronics'); - --- 5. CH-901/902 系列 (ID: 19-23) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(19, 'CH-901', '巡飞弹', '中国航天科工'), -(20, 'CH-901A', '巡飞弹', '中国航天科工'), -(21, 'CH-901H', '巡飞弹', '中国航天科工'), -(22, 'CH-902', '巡飞弹', '中国航天科工'), -(23, 'CH-902A', '巡飞弹', '中国航天科工'); - --- 6. WS-43/61 系列 (ID: 24-28) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(24, 'WS-43', '巡飞弹', '中国航天科工'), -(25, 'WS-43A', '巡飞弹', '中国航天科工'), -(26, 'WS-43B', '巡飞弹', '中国航天科工'), -(27, 'WS-61', '巡飞弹', '中国航天科工'), -(28, 'WS-61A', '巡飞弹', '中国航天科工'); - --- 7. Kargu/Alpagu 系列 (ID: 29-33) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(29, 'Kargu', '巡飞弹', '土耳其 STM'), -(30, 'Kargu-2', '巡飞弹', '土耳其 STM'), -(31, 'Alpagu', '巡飞弹', '土耳其 STM'), -(32, 'Alpagu Block-II', '巡飞弹', '土耳其 STM'), -(33, 'Kargu Autonomous', '巡飞弹', '土耳其 STM'); - --- 8. Shahed 系列 (ID: 34-38) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(34, 'Shahed-131', '巡飞弹', '伊朗'), -(35, 'Shahed-131B', '巡飞弹', '伊朗'), -(36, 'Shahed-136', '巡飞弹', '伊朗'), -(37, 'Shahed-136B', '巡飞弹', '伊朗'), -(38, 'Shahed-136C', '巡飞弹', '伊朗'); - --- 9. Green Dragon 系列 (ID: 39-43) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(39, 'Green Dragon', '巡飞弹', '以色列 IAI'), -(40, 'Green Dragon Extended Range', '巡飞弹', '以色列 IAI'), -(41, 'Green Dragon Block 2', '巡飞弹', '以色列 IAI'), -(42, 'Green Dragon Maritime', '巡飞弹', '以色列 IAI'), -(43, 'Green Dragon-S', '巡飞弹', '以色列 IAI'); - --- 10. Phoenix Ghost 系列 (ID: 44-48) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(44, 'Phoenix Ghost', '巡飞弹', '美国 AEVEX Aerospace'), -(45, 'Phoenix Ghost Block I', '巡飞弹', '美国 AEVEX Aerospace'), -(46, 'Phoenix Ghost Block II', '巡飞弹', '美国 AEVEX Aerospace'), -(47, 'Phoenix Ghost Maritime', '巡飞弹', '美国 AEVEX Aerospace'), -(48, 'Phoenix Ghost-ER', '巡飞弹', '美国 AEVEX Aerospace'); - --- 11. ZALA Lancet 系列 (ID: 49-52) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(49, 'Lancet-1', '巡飞弹', '俄罗斯 ZALA'), -(50, 'Lancet-3', '巡飞弹', '俄罗斯 ZALA'), -(51, 'Lancet-3M', '巡飞弹', '俄罗斯 ZALA'), -(52, 'Lancet-4', '巡飞弹', '俄罗斯 ZALA'); - --- 12. Rotem L 系列 (ID: 53-56) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(53, 'Rotem L', '巡飞弹', '以色列 IAI'), -(54, 'Rotem L-X', '巡飞弹', '以色列 IAI'), -(55, 'Rotem L-M', '巡飞弹', '以色列 IAI'), -(56, 'Rotem L-ER', '巡飞弹', '以色列 IAI'); - --- 13. KUB-BLA 系列 (ID: 57-60) -INSERT INTO equipment (id, name, type, manufacturer) VALUES -(57, 'KUB-BLA', '巡飞弹', '俄罗斯 ZALA'), -(58, 'KUB-BLA-E', '巡飞弹', '俄罗斯 ZALA'), -(59, 'KUB-BLA-M', '巡飞弹', '俄罗斯 ZALA'), -(60, 'KUB-BLA-ER', '巡飞弹', '俄罗斯 ZALA'); - --- 插入通用参数 -INSERT INTO common_params (equipment_id, length_m, width_m, height_m, weight_kg, max_range_km) VALUES -(1, 2.5, 0.43, 0.43, 135, 1000), -- IAI Harop -(2, 2.7, 0.35, 0.35, 125, 500), -- IAI Harpy -(3, 2.1, 0.30, 0.30, 45, 100), -- IAI Mini Harpy -(4, 0.76, 0.17, 0.17, 3.0, 15), -- Hero-30 -(5, 0.87, 0.18, 0.18, 6.5, 25), -- Hero-70 -(6, 1.3, 0.23, 0.23, 12.5, 40), -- Hero-120 -(7, 2.1, 0.30, 0.30, 35, 150), -- Hero-250 -(8, 2.4, 0.35, 0.35, 40, 150), -- Hero-400EC -(9, 2.9, 0.40, 0.40, 90, 250), -- Hero-900 -(10, 0.58, 0.12, 0.12, 2.5, 10), -(11, 1.30, 0.22, 0.22, 15.0, 40), -(12, 0.60, 0.12, 0.12, 2.7, 15), -- Switchblade 300 Block 10 -(13, 1.35, 0.22, 0.22, 16.0, 50), -- Switchblade 600 Extended Range -(14, 0.68, 0.12, 0.12, 2.5, 10), -(15, 1.30, 0.22, 0.22, 15.0, 40), -(16, 0.68, 0.12, 0.12, 2.5, 10), -(17, 1.30, 0.22, 0.22, 15.0, 40), -(18, 0.68, 0.12, 0.12, 2.5, 10), -(19, 1.2, 0.18, 0.18, 9.0, 20), -(20, 1.2, 0.18, 0.18, 9.3, 25), -(21, 1.2, 0.18, 0.18, 9.5, 20), -(22, 1.4, 0.22, 0.22, 15.0, 30), -(23, 1.4, 0.22, 0.22, 15.5, 35), -(24, 1.8, 0.35, 0.35, 20, 60), -(25, 1.8, 0.35, 0.35, 21, 70), -(26, 1.9, 0.35, 0.35, 22, 80), -(27, 2.2, 0.40, 0.40, 35, 100), -(28, 2.2, 0.40, 0.40, 37, 120), -(29, 0.6, 0.35, 0.35, 7.0, 10), -(30, 0.6, 0.35, 0.35, 7.2, 15), -(31, 1.0, 0.23, 0.23, 3.7, 5), -(32, 1.0, 0.23, 0.23, 3.9, 8), -(33, 0.6, 0.35, 0.35, 7.5, 15), -(34, 2.6, 0.34, 0.34, 135, 900), -(35, 2.6, 0.34, 0.34, 140, 1000), -(36, 3.5, 0.42, 0.42, 200, 2000), -(37, 3.5, 0.42, 0.42, 210, 2200), -(38, 3.5, 0.42, 0.42, 215, 2500), -(39, 1.5, 0.20, 0.20, 15, 40), -(40, 1.6, 0.20, 0.20, 16, 50), -(41, 1.5, 0.20, 0.20, 15.5, 45), -(42, 1.5, 0.20, 0.20, 15.8, 40), -(43, 1.2, 0.18, 0.18, 12, 30), -(44, 1.5, 0.25, 0.25, 14.0, 30), -(45, 1.5, 0.25, 0.25, 14.5, 35), -(46, 1.6, 0.26, 0.26, 15.0, 40), -(47, 1.5, 0.25, 0.25, 14.8, 30), -(48, 1.7, 0.27, 0.27, 16.0, 50), -(49, 1.0, 0.20, 0.20, 5.0, 40), -(50, 1.65, 0.35, 0.35, 12.0, 70), -(51, 1.65, 0.35, 0.35, 12.5, 80), -(52, 1.80, 0.40, 0.40, 15.0, 100), -(53, 0.8, 0.25, 0.25, 4.5, 10), -- Rotem L -(54, 0.8, 0.25, 0.25, 4.8, 15), -- Rotem L-X -(55, 0.8, 0.25, 0.25, 4.7, 10), -- Rotem L-M -(56, 0.9, 0.27, 0.27, 5.2, 20), -- Rotem L-ER -(57, 1.21, 0.95, 0.165, 3.0, 40), -- KUB-BLA -(58, 1.21, 0.95, 0.165, 3.2, 50), -- KUB-BLA-E -(59, 1.21, 0.95, 0.165, 3.3, 45), -- KUB-BLA-M -(60, 1.25, 1.0, 0.17, 3.5, 60); -- KUB-BLA-ER - --- 插入特有参数 -INSERT INTO loitering_munition_params (equipment_id, wingspan_m, warhead_weight_kg, max_speed_ms, cruise_speed_kmh, - endurance_min, - warhead_type, - launch_mode, - power_system, - guidance_system -) VALUES --- HAROP/Harpy系列 -(1, 3.0, 23, 51.4, 185, 360, '高爆战斗部', '箱式发射/空中发射', '活塞发动机', 'GPS/INS/光电/数据链'), -(2, 2.1, 32, 51.4, 148, 120, '高爆战斗部', '箱式发射', '活塞发动机', 'GPS/INS/被动雷达'), -(3, 1.8, 8, 47.2, 130, 120, '高爆战斗部', '箱式发射', '电动机', 'GPS/INS/光电/被动雷达'), - --- Hero系列 -(4, 1.0, 0.5, 36.1, 100, 30, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电'), -(5, 1.5, 1.2, 38.9, 105, 45, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电'), -(6, 2.1, 3.5, 41.7, 100, 60, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), -(7, 2.5, 10.0, 47.2, 130, 120, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), -(8, 2.8, 8.0, 47.2, 130, 240, '破片杀伤战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), -(9, 3.0, 20.0, 51.4, 150, 360, '破片杀伤战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链'), - --- Switchblade系列 -(10, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'), -(11, 2.2, 4.0, 51.4, 115, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), -(12, 0.70, 0.25, 41.7, 100, 20, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'), -(13, 2.3, 4.1, 51.4, 115, 50, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链/AI辅助'), - --- Warmate系列 -(14, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电'), -(15, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), -(16, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'), -(17, 1.30, 0.22, 0.22, 15.0, 40, '破甲战斗部', '箱式发射', '电动机', 'GPS/INS/光电/数据链'), -(18, 0.68, 0.2, 38.9, 98, 15, '破片杀伤战斗部', '单兵发射管', '电动机', 'GPS/INS/光电/数据链'), - --- CH-901/902系列 -(19, 1.8, 2.0, 44.4, 95, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), -(20, 1.8, 2.2, 47.2, 100, 140, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -(21, 1.8, 3.0, 44.4, 95, 120, '破甲战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), -(22, 2.2, 3.5, 50.0, 110, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -(23, 2.2, 3.5, 50.0, 110, 200, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), -(24, 2.4, 3.8, 47.2, 100, 45, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), -(25, 2.4, 4.0, 50.0, 110, 60, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -(26, 2.5, 4.0, 50.0, 110, 80, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -(27, 3.0, 8.0, 55.6, 120, 120, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'), -(28, 3.0, 8.5, 55.6, 120, 150, '模块化战斗部', '箱式发射', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), -(29, 0.7, 1.0, 36.1, 72, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'), -(30, 0.7, 1.1, 38.9, 75, 40, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'), -(31, 1.3, 0.8, 41.7, 80, 20, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电'), -(32, 1.3, 0.9, 44.4, 85, 25, '破片杀伤战斗部', '弹射式', '电动机', 'GPS/INS/光电/AI识别'), -(33, 0.7, 1.2, 38.9, 75, 45, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/自主决策'), -(34, 2.2, 15, 55.6, 150, 180, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电'), -(35, 2.2, 15, 58.3, 160, 200, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'), -(36, 2.5, 30, 61.1, 180, 240, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链'), -(37, 2.5, 35, 63.9, 185, 260, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助'), -(38, 2.5, 40, 66.7, 190, 300, '高爆战斗部', '箱式发射/弹射式', '活塞发动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), -(39, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), -(40, 2.2, 3.0, 50.0, 115, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), -(41, 2.0, 3.5, 47.2, 110, 90, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -(42, 2.0, 3.0, 47.2, 110, 90, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'), -(43, 1.8, 2.5, 44.4, 100, 60, '破片杀伤战斗部', '箱式发射/单兵发射', '电动机', 'GPS/INS/光电/数据链'), -(44, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链'), -(45, 2.2, 3.8, 50.0, 115, 140, '破片杀伤/破甲双用战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助'), -(46, 2.3, 4.0, 52.8, 120, 160, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/红外'), -(47, 2.2, 3.5, 47.2, 110, 120, '破片杀伤战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/抗盐雾'), -(48, 2.4, 4.2, 55.6, 125, 180, '模块化战斗部', '箱式发射/弹射式', '电动机', 'GPS/INS/光电/数据链/AI辅助/卫通'), -(49, 1.2, 1.0, 44.4, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'), -(50, 2.0, 3.0, 50.0, 110, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'), -(51, 2.0, 3.5, 52.8, 120, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'), -(52, 2.3, 5.0, 55.6, 130, 60, '模块化战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外/卫通'), -(53, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别'), -(54, 0.9, 1.2, 38.9, 85, 45, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'), -(55, 0.9, 1.0, 36.1, 80, 30, '破片杀伤战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/抗盐雾'), -(56, 1.0, 1.3, 41.7, 90, 60, '破片杀伤/破甲双用战斗部', '垂直起降', '电动机', 'GPS/INS/光电/AI识别/数据链'), -(57, 1.2, 1.0, 41.7, 80, 30, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别'), -(58, 1.2, 1.2, 44.4, 85, 40, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链'), -(59, 1.2, 1.3, 44.4, 85, 35, '破片杀伤战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/红外'), -(60, 1.3, 1.5, 47.2, 90, 50, '破片杀伤/破甲双用战斗部', '弹射式发射', '电动机', 'GPS/INS/光电/AI识别/数据链/红外'); - --- 插入成本数据 -INSERT INTO cost_data (equipment_id, actual_cost) VALUES -(1, 800000), -- IAI Harop -(2, 700000), -- IAI Harpy -(3, 350000), -- IAI Mini Harpy -(4, 70000), -- Hero-30 -(5, 120000), -- Hero-70 -(6, 150000), -- Hero-120 -(7, 300000), -- Hero-250 -(8, 400000), -- Hero-400EC -(9, 650000), -- Hero-900 -(10, 60000), -- Switchblade 300 -(11, 180000), -- Switchblade 600 -(12, 75000), -- Switchblade 300 Block 10 -(13, 200000), -- Switchblade 600 Extended Range -(14, 60000), -- Warmate 1.0 -(15, 180000), -- Warmate 2.0 -(16, 60000), -- Warmate-V -(17, 180000), -- Warmate-L -(18, 60000), -- Warmate 3.0 -(19, 100000), -- CH-901 -(20, 120000), -- CH-901A -(21, 130000), -- CH-901H -(22, 180000), -- CH-902 -(23, 200000), -- CH-902A -(24, 120000), -- WS-43 -(25, 150000), -- WS-43A -(26, 180000), -- WS-43B -(27, 300000), -- WS-61 -(28, 350000), -- WS-61A -(29, 70000), -- Kargu -(30, 85000), -- Kargu-2 -(31, 45000), -- Alpagu -(32, 55000), -- Alpagu Block-II -(33, 95000), -- Kargu Autonomous -(34, 20000), -- Shahed-131 -(35, 25000), -- Shahed-131B -(36, 40000), -- Shahed-136 -(37, 45000), -- Shahed-136B -(38, 50000), -- Shahed-136C -(39, 160000), -- Green Dragon -(40, 200000), -- Green Dragon Extended Range -(41, 180000), -- Green Dragon Block 2 -(42, 190000), -- Green Dragon Maritime -(43, 140000), -- Green Dragon-S -(44, 150000), -- Phoenix Ghost -(45, 180000), -- Phoenix Ghost Block I -(46, 220000), -- Phoenix Ghost Block II -(47, 190000), -- Phoenix Ghost Maritime -(48, 250000), -- Phoenix Ghost-ER -(49, 80000), -- Lancet-1 -(50, 150000), -- Lancet-3 -(51, 180000), -- Lancet-3M -(52, 250000), -- Lancet-4 -(53, 65000), -- Rotem L -(54, 85000), -- Rotem L-X -(55, 75000), -- Rotem L-M -(56, 95000), -- Rotem L-ER -(57, 95000), -- KUB-BLA -(58, 120000), -- KUB-BLA-E -(59, 110000), -- KUB-BLA-M -(60, 150000); -- KUB-BLA-ER - --- 创建数据集 -INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES -(1, '巡飞弹训练集', '用于训练巡飞弹成本预测模型的数据集', '巡飞弹', '训练'), -(2, '巡飞弹验证集', '用于验证模型效果的数据集', '巡飞弹', '验证'); - --- 关联装备到数据集(按照制造商和型号分配) -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES --- 训练集(约80%的数据,48个型号) --- 以色列系列 -(1, 1), (1, 2), (1, 3), -- HAROP/Harpy系列 -(1, 4), (1, 5), (1, 6), -- Hero系列基础型号 -(1, 39), (1, 40), (1, 41), (1, 42), (1, 43), -- Green Dragon系列 -(1, 53), (1, 54), (1, 55), (1, 56), -- Rotem L系列 - --- 美国系列 -(1, 10), (1, 11), (1, 12), (1, 13), -- Switchblade系列 -(1, 44), (1, 45), (1, 46), (1, 47), (1, 48), -- Phoenix Ghost系列 - --- 中国系列 -(1, 19), (1, 20), (1, 21), (1, 22), (1, 23), -- CH-901/902系列 -(1, 24), (1, 25), (1, 26), (1, 27), (1, 28), -- WS-43/61系列 - --- 波兰和土耳其系列 -(1, 14), (1, 15), (1, 16), (1, 17), (1, 18), -- Warmate系列 -(1, 29), (1, 30), (1, 31), (1, 32), (1, 33), -- Kargu/Alpagu系列 - --- 俄罗斯系列 -(1, 57), (1, 58), (1, 59), (1, 60), -- KUB-BLA系列 - --- 验证集(约20%的数据,12个型号) --- 混合系列 -(2, 7), (2, 8), (2, 9), -- Hero系列高级型号 -(2, 34), (2, 35), (2, 36), (2, 37), (2, 38), -- Shahed系列 -(2, 49), (2, 50), (2, 51), (2, 52); -- ZALA Lancet系列 - --- 添加分类特征编码 -INSERT INTO feature_encoding (feature_type, feature_value, code) VALUES --- 战斗部类型编码 -('warhead_type', '破片杀伤战斗部', 1), -('warhead_type', '破甲战斗部', 2), -('warhead_type', '高爆战斗部', 3), -('warhead_type', '破片杀伤/破甲双用战斗部', 4), -('warhead_type', '模块化战斗部', 5), - --- 发射方式编码 -('launch_mode', '箱式发射', 1), -('launch_mode', '弹射式发射', 2), -('launch_mode', '垂直起降', 3), -('launch_mode', '单兵发射管', 4), -('launch_mode', '箱式发射/弹射式', 5), -('launch_mode', '箱式发射/空中发射', 6), - --- 动力装置编码(按复杂度递增) -('power_system', '电动机', 1), -('power_system', '活塞发动机', 2), - --- 制导系统编码(按复杂度递增) -('guidance_system', 'GPS/INS', 1), -('guidance_system', 'GPS/INS/光电', 2), -('guidance_system', 'GPS/INS/光电/数据链', 3), -('guidance_system', 'GPS/INS/光电/AI识别', 4), -('guidance_system', 'GPS/INS/光电/数据链/AI辅助', 5), -('guidance_system', 'GPS/INS/光电/数据链/AI辅助/红外', 6), -('guidance_system', 'GPS/INS/光电/数据链/AI辅助/卫通', 7); - --- 更新巡飞弹特有参数表,添加新的关键参数和特征工程字段 -UPDATE loitering_munition_params l -JOIN common_params c ON l.equipment_id = c.equipment_id -SET - -- 新增关键参数 - l.payload_weight_kg = l.warhead_weight_kg * 1.2, -- 有效载荷通常比战斗部重量大20% - l.min_combat_radius_km = c.max_range_km * 0.1, -- 最小作战半径约为最大航程的10% - l.engine_power_kw = - CASE - WHEN l.power_system = '电动机' THEN c.weight_kg * 0.15 - WHEN l.power_system = '活塞发动机' THEN c.weight_kg * 0.25 - END, - l.engine_thrust_n = c.weight_kg * 9.8 * 0.3, -- 推力约为重量的30% - l.datalink_range_km = c.max_range_km * 0.8, -- 通信链路距离约为最大航程的80% - l.guidance_accuracy_m = - CASE - WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 1.0 - WHEN INSTR(l.guidance_system, '光电') > 0 THEN 2.0 - ELSE 3.0 - END, - l.min_altitude_m = -- 最小作战高度 - CASE - -- 大型巡飞弹(体型大、重量大) - WHEN equipment_id IN (1, 2, 34, 35, 36, 37, 38) THEN 150 -- HAROP/Harpy系列和 Shahed系列 - - -- 中型巡飞弹 - WHEN equipment_id IN (3, 7, 8, 9, 27, 28) THEN 100 -- Mini Harpy和高端Hero系列, WS-61系列 - - -- 中小型巡飞弹 - WHEN equipment_id IN (6, 11, 13, 15, 17, 22, 23, 24, 25, 26) THEN 80 -- Hero-120, Switchblade 600系列等 - - -- 小型巡飞弹 - WHEN equipment_id IN (4, 5, 10, 12, 14, 16, 18, 19, 20, 21) THEN 50 -- Hero-30/70, Switchblade 300系列等 - - -- 超小型巡飞弹 - WHEN equipment_id IN (29, 30, 31, 32, 33, 53, 54, 55, 56, 57, 58, 59, 60) THEN 30 -- Kargu/Alpagu系列, Rotem系列, KUB-BLA系列 - - -- 其他型号使用默认值 - ELSE 50 - END, - l.max_altitude_m = - CASE - WHEN c.max_range_km > 500 THEN 5000 - WHEN c.max_range_km > 100 THEN 3000 - ELSE 1500 - END, - - -- 特征工程字段 - l.length_width_ratio = c.length_m / c.width_m, - l.weight_range_ratio = c.weight_kg / c.max_range_km, - l.speed_weight_ratio = l.max_speed_ms / c.weight_kg, - l.guidance_system_score = - CASE - WHEN INSTR(l.guidance_system, 'AI') > 0 AND INSTR(l.guidance_system, '卫通') > 0 THEN 10 - WHEN INSTR(l.guidance_system, 'AI') > 0 THEN 8 - WHEN INSTR(l.guidance_system, '数据链') > 0 THEN 6 - WHEN INSTR(l.guidance_system, '光电') > 0 THEN 4 - ELSE 2 - END, - l.warhead_power_score = - CASE - WHEN l.warhead_type = '模块化战斗部' THEN 10 - WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 8 - WHEN l.warhead_type = '高爆战斗部' THEN 7 - WHEN l.warhead_type = '破甲战斗部' THEN 6 - WHEN l.warhead_type = '破片杀伤战斗部' THEN 5 - ELSE 4 - END, - - -- 分类特征编码 - l.warhead_type_code = - CASE - WHEN l.warhead_type = '破片杀伤战斗部' THEN 1 - WHEN l.warhead_type = '破甲战斗部' THEN 2 - WHEN l.warhead_type = '高爆战斗部' THEN 3 - WHEN l.warhead_type = '破片杀伤/破甲双用战斗部' THEN 4 - WHEN l.warhead_type = '模块化战斗部' THEN 5 - ELSE 0 - END, - l.launch_mode_code = - CASE - WHEN l.launch_mode = '箱式发射' THEN 1 - WHEN l.launch_mode = '弹射式发射' THEN 2 - WHEN l.launch_mode = '垂直起降' THEN 3 - WHEN l.launch_mode = '单兵发射管' THEN 4 - WHEN l.launch_mode = '箱式发射/弹射式' THEN 5 - WHEN l.launch_mode = '箱式发射/空中发射' THEN 6 - ELSE 0 - END, - l.power_system_code = - CASE - WHEN l.power_system = '电动机' THEN 1 - WHEN l.power_system = '活塞发动机' THEN 2 - ELSE 0 - END, - l.guidance_system_code = - CASE - WHEN l.guidance_system = 'GPS/INS' THEN 1 - WHEN l.guidance_system = 'GPS/INS/光电' THEN 2 - WHEN l.guidance_system = 'GPS/INS/光电/数据链' THEN 3 - WHEN l.guidance_system = 'GPS/INS/光电/AI识别' THEN 4 - WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助' THEN 5 - WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/红外' THEN 6 - WHEN l.guidance_system = 'GPS/INS/光电/数据链/AI辅助/卫通' THEN 7 - ELSE 0 - END; diff --git a/src/rocket_artillery_data.sql b/src/rocket_artillery_data.sql index e81f8a1..ba84945 100644 --- a/src/rocket_artillery_data.sql +++ b/src/rocket_artillery_data.sql @@ -29,7 +29,7 @@ */ -- 中国系列火箭炮数据 -INSERT INTO equipment (id, name, type, manufacturer) VALUES +INSERT INTO equipments (id, name, type, manufacturer) VALUES (1001, 'PCL-191', '火箭炮', '中国兵器工业集团'), (1002, 'PHL-03', '火箭炮', '中国兵器工业集团'), (1003, 'AR-3', '火箭炮', '中国航天科工'), @@ -39,11 +39,11 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES (1007, 'WS-2', '火箭炮', '中国航天科工'), (1008, 'WS-3', '火箭炮', '中国航天科工'), (1009, 'Type 63', '火箭炮', '中国兵器工业集团'), -(1010, 'BM-21 Grad', '火箭炮', '俄罗斯'), -(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯'), -(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯'), -(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯'), -(1014, 'TOS-1A', '火箭炮', '俄罗斯'), +(1010, 'BM-21 Grad', '火箭炮', '俄罗斯 Rostec'), +(1011, 'BM-27 Uragan', '火箭炮', '俄罗斯 Rostec'), +(1012, 'BM-30 Smerch', '火箭炮', '俄罗斯 Rostec'), +(1013, '9A52-4 Tornado', '火箭炮', '俄罗斯 Rostec'), +(1014, 'TOS-1A', '火箭炮', '俄罗斯 Rostec'), (1015, 'M142 HIMARS', '火箭炮', '美国洛克希德·马丁'), (1016, 'M270 MLRS', '火箭炮', '美国洛克希德·马丁'), (1017, 'M270A1', '火箭炮', '美国洛克希德·马丁'), @@ -62,10 +62,10 @@ INSERT INTO equipment (id, name, type, manufacturer) VALUES (1030, 'ASTROS 2020', '火箭炮', '巴西航空工业'), (1031, 'ASTROS II Mk3', '火箭炮', '巴西航空工业'), (1032, 'ASTROS II Mk6', '火箭炮', '巴西航空工业'), -(1033, 'Pinaka', '火箭炮', '印度DRDO'), -(1034, 'Pinaka Mk-II', '火箭炮', '印度DRDO'), -(1035, 'Pinaka Mk-III', '火箭炮', '印度DRDO'), -(1036, 'Pinaka-ER', '火箭炮', '印度DRDO'), +(1033, 'Pinaka', '火箭炮', '印度 DRDO'), +(1034, 'Pinaka Mk-II', '火箭炮', '印度 DRDO'), +(1035, 'Pinaka Mk-III', '火箭炮', '印度 DRDO'), +(1036, 'Pinaka-ER', '火箭炮', '印度 DRDO'), (1037, 'WR-40 Langusta', '火箭炮', '波兰胡塔斯塔洛瓦'), (1038, 'RM-70', '火箭炮', '波兰胡塔斯塔洛瓦'), (1039, 'BM-21M', '火箭炮', '波兰胡塔斯塔洛瓦'), @@ -485,7 +485,7 @@ INSERT INTO datasets (id, name, description, equipment_type, purpose) VALUES (4, '火箭炮验证集 2024', '包含19个火箭炮型号,用于验证模型性能', '火箭炮', '验证'); -- 训练集(约80%的数据,77个型号) -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES +INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES -- 中国系列(7/9) (3, 1001), (3, 1002), (3, 1003), (3, 1004), (3, 1005), (3, 1006), (3, 1007), @@ -565,7 +565,7 @@ INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES (3, 1094), (3, 1095); -- 验证集(约20%的数据,19个型号) -INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES +INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES -- 中国系列(2/9) (4, 1008), (4, 1009), diff --git a/src/routes.py b/src/routes.py index 04f60b7..ffcf544 100644 --- a/src/routes.py +++ b/src/routes.py @@ -48,6 +48,11 @@ def index(): 'url': '/api/evaluate', 'method': 'POST', 'description': '模型评估' + }, + 'analyze-manufacturers': { + 'url': '/api/analyze-manufacturers', + 'method': 'POST', + 'description': '供应商分析' } } }) @@ -114,193 +119,149 @@ def analyze_features(): with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) - # 获取数据集信息 + # 首先获取数据集的装备类型 cursor.execute(""" - SELECT d.*, - e.type as equipment_type - FROM datasets d - JOIN dataset_equipment de ON d.id = de.dataset_id - JOIN equipment e ON de.equipment_id = e.id - WHERE d.id = %s + SELECT DISTINCT e.type + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id + WHERE de.dataset_id = %s LIMIT 1 """, (dataset_id,)) - dataset = cursor.fetchone() - if not dataset: - logger.warning(f"Dataset {dataset_id} not found") - return jsonify({'error': '数据集不存在'}), 404 + equipment_type = cursor.fetchone()['type'] + logger.info(f"Equipment type: {equipment_type}") - logger.info(f"Dataset info: {dataset}") - - # 创建特征分析实例 - analyzer = FeatureAnalysis() - - # 获取特征列表 - feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type']) - logger.info(f"Feature names: {feature_names}") - - # 获取数据集中的装备数据 - if dataset['equipment_type'] == '火箭炮': + # 根据装备类型选择查询 + if equipment_type == '火箭炮': cursor.execute(""" - SELECT - e.name, - e.id, - cp.length_m, - cp.width_m, - cp.height_m, - cp.weight_kg, - rap.max_range_km, - rap.firing_angle_horizontal, - rap.firing_angle_vertical, - rap.rocket_length_m, - rap.rocket_diameter_mm, - rap.rocket_weight_kg, - rap.rate_of_fire, - rap.combat_weight_kg, - rap.speed_kmh, - rap.min_range_km, - rap.mobility_type, - rap.structure_layout, - rap.engine_model, - rap.power_hp, - rap.travel_range_km, - rap.fire_density, - rap.range_ratio, - rap.mobility_score, - rap.combat_readiness_score, - rap.rocket_power_ratio, - rap.platform_efficiency, - rap.deployment_score, - rap.terrain_adaptability_score, - cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id, + m.tech_level, m.scale_level, m.supply_chain_level, m.country, + cp.length_m, cp.width_m, cp.height_m, cp.weight_kg, + cd.actual_cost, + rap.max_range_km, rap.firing_angle_horizontal, rap.firing_angle_vertical, + rap.rocket_length_m, rap.rocket_diameter_mm, rap.rocket_weight_kg, + rap.rate_of_fire, rap.combat_weight_kg, rap.speed_kmh, + rap.min_range_km, rap.power_hp, rap.travel_range_km, + rap.fire_density, rap.range_ratio, rap.mobility_score, + rap.combat_readiness_score, rap.deployment_score, rap.terrain_adaptability_score, + rap.rocket_power_ratio, rap.platform_efficiency + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id + LEFT JOIN manufacturers m ON e.manufacturer_id = m.id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL - ORDER BY e.id """, (dataset_id,)) else: cursor.execute(""" - SELECT - e.name, - e.id, - cp.length_m, - cp.width_m, - cp.height_m, - cp.weight_kg, - lmp.max_range_km, - lmp.wingspan_m, - lmp.warhead_weight_kg, - lmp.max_speed_ms, - lmp.cruise_speed_kmh, - lmp.endurance_min, - lmp.max_payload_kg, - lmp.ceiling_altitude_m, - lmp.combat_radius_km, - lmp.engine_power_kw, - lmp.engine_thrust_n, - lmp.datalink_range_km, - lmp.guidance_accuracy_m, - lmp.min_altitude_m, - lmp.max_altitude_m, - lmp.length_width_ratio, - lmp.weight_range_ratio, - lmp.speed_weight_ratio, - lmp.guidance_system_score, - lmp.warhead_power_score, - cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + SELECT e.id, e.name, e.type, e.manufacturer, e.manufacturer_id, + m.tech_level, m.scale_level, m.supply_chain_level, m.country, + cp.length_m, cp.width_m, cp.height_m, cp.weight_kg, + cd.actual_cost, + lmp.max_range_km, lmp.wingspan_m, lmp.warhead_weight_kg, + lmp.max_speed_ms, lmp.cruise_speed_kmh, lmp.endurance_min, + lmp.length_width_ratio, lmp.weight_range_ratio, + lmp.speed_weight_ratio, lmp.ceiling_altitude_m, + lmp.guidance_system_score, lmp.warhead_power_score, + lmp.engine_power_kw, lmp.engine_thrust_n, + lmp.min_altitude_m, lmp.max_altitude_m, + lmp.max_payload_kg, lmp.combat_radius_km, + lmp.datalink_range_km, lmp.guidance_accuracy_m + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id + LEFT JOIN manufacturers m ON e.manufacturer_id = m.id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL - ORDER BY e.id """, (dataset_id,)) equipment_data = cursor.fetchall() - logger.info(f"Found {len(equipment_data)} equipment records") - # 提取装备名称列表 - equipment_names = [item['name'] for item in equipment_data] + # 添加数据检查日志 + logger.info(f"Total records found: {len(equipment_data)}") + if equipment_data: + # 检查第一条记录的所有字段 + first_record = equipment_data[0] + logger.info("First record details:") + for key, value in first_record.items(): + logger.info(f"{key}: {value}") + + # 检查所有记录的 max_range_km 字段 + logger.info("Checking max_range_km for all records:") + for item in equipment_data: + logger.info(f"Equipment: {item['name']}") + logger.info(f" max_range_km: {item.get('max_range_km')}") + logger.info(f" type: {item['type']}") + if item['type'] == '火箭炮': + logger.info(f" rocket_artillery_params fields:") + for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']: + logger.info(f" {key}: {item.get(key)}") - # 提取特征数据和目标值 + # 提取特征和目标值 + analyzer = FeatureAnalysis() + feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type']) features = [] targets = [] + for item in equipment_data: + # 计算生产商特征 + manufacturer_features = analyzer.calculate_manufacturer_features({ + 'tech_level': item['tech_level'], + 'scale_level': item['scale_level'], + 'supply_chain_level': item['supply_chain_level'], + 'country': item['country'] + }) + + # 获取装备特征 feature_values = [] - for feature in feature_names: - value = item.get(feature) - feature_values.append(float(value) if value is not None else 0) + for name in feature_names: + if name in manufacturer_features: + value = manufacturer_features[name] + else: + value = item.get(name) + feature_values.append(float(value) if value is not None else 0.0) + features.append(feature_values) targets.append(float(item['actual_cost'])) - # 进行特征分析 - result = analyzer.analyze_features(features, targets, feature_names) + # 执行特征分析 + analysis_result = analyzer.analyze_features(features, targets, feature_names) - # 添加装备名称列表到结果中 - result['equipment_names'] = equipment_names + # 添加装备名称列表 + analysis_result['equipment_names'] = [item['name'] for item in equipment_data] - # 如果是火箭炮,添加额外的分析数据 - if dataset['equipment_type'] == '火箭炮': + # 添加装备特有的分析数据 + if equipment_data[0]['type'] == '火箭炮': rocket_data = { - 'fire_density': [float(item['fire_density']) if item['fire_density'] is not None else 0 for item in equipment_data], - 'range_ratio': [float(item['range_ratio']) if item['range_ratio'] is not None else 0 for item in equipment_data], - 'rate_of_fire': [float(item['rate_of_fire']) if item['rate_of_fire'] is not None else 0 for item in equipment_data], - 'max_range_km': [float(item['max_range_km']) if item['max_range_km'] is not None else 0 for item in equipment_data], - 'rocket_weight_kg': [float(item['rocket_weight_kg']) if item['rocket_weight_kg'] is not None else 0 for item in equipment_data], - 'rocket_diameter_mm': [float(item['rocket_diameter_mm']) if item['rocket_diameter_mm'] is not None else 0 for item in equipment_data], - 'rocket_length_m': [float(item['rocket_length_m']) if item['rocket_length_m'] is not None else 0 for item in equipment_data], - 'mobility_score': [float(item['mobility_score']) if item['mobility_score'] is not None else 0 for item in equipment_data], - 'deployment_score': [float(item['deployment_score']) if item['deployment_score'] is not None else 0 for item in equipment_data], - 'terrain_adaptability_score': [float(item['terrain_adaptability_score']) if item['terrain_adaptability_score'] is not None else 0 for item in equipment_data], - 'combat_readiness_score': [float(item['combat_readiness_score']) if item['combat_readiness_score'] is not None else 0 for item in equipment_data], - 'speed_kmh': [float(item['speed_kmh']) if item['speed_kmh'] is not None else 0 for item in equipment_data], - 'power_hp': [float(item['power_hp']) if item['power_hp'] is not None else 0 for item in equipment_data], - 'travel_range_km': [float(item['travel_range_km']) if item['travel_range_km'] is not None else 0 for item in equipment_data] + 'fire_density': [float(item.get('fire_density', 0)) for item in equipment_data], + 'range_ratio': [float(item.get('range_ratio', 0)) for item in equipment_data], + 'mobility_score': [float(item.get('mobility_score', 0)) for item in equipment_data], + 'combat_readiness_score': [float(item.get('combat_readiness_score', 0)) for item in equipment_data], + 'deployment_score': [float(item.get('deployment_score', 0)) for item in equipment_data], + 'terrain_adaptability_score': [float(item.get('terrain_adaptability_score', 0)) for item in equipment_data] } - result.update(rocket_data) - - # 如果是巡飞弹,添加额外的分析数据 - if dataset['equipment_type'] == '巡飞弹': + analysis_result.update(rocket_data) + else: missile_data = { - 'equipment_names': equipment_names, - # 特征工程参数 - 'length_width_ratio': [float(item['length_width_ratio']) if item.get('length_width_ratio') is not None else 0 for item in equipment_data], - 'weight_range_ratio': [float(item['weight_range_ratio']) if item.get('weight_range_ratio') is not None else 0 for item in equipment_data], - 'speed_weight_ratio': [float(item['speed_weight_ratio']) if item.get('speed_weight_ratio') is not None else 0 for item in equipment_data], - 'guidance_system_score': [float(item['guidance_system_score']) if item.get('guidance_system_score') is not None else 0 for item in equipment_data], - 'warhead_power_score': [float(item['warhead_power_score']) if item.get('warhead_power_score') is not None else 0 for item in equipment_data], - - # 动力系统参数 - 'engine_power_kw': [float(item['engine_power_kw']) if item.get('engine_power_kw') is not None else 0 for item in equipment_data], - 'engine_thrust_n': [float(item['engine_thrust_n']) if item.get('engine_thrust_n') is not None else 0 for item in equipment_data], - - # 作战参数 - 'min_altitude_m': [float(item['min_altitude_m']) if item.get('min_altitude_m') is not None else 0 for item in equipment_data], - 'max_altitude_m': [float(item['max_altitude_m']) if item.get('max_altitude_m') is not None else 0 for item in equipment_data], - 'max_range_km': [float(item['max_range_km']) if item.get('max_range_km') is not None else 0 for item in equipment_data], - 'max_payload_kg': [float(item['max_payload_kg']) if item.get('max_payload_kg') is not None else 0 for item in equipment_data], - 'combat_radius_km': [float(item['combat_radius_km']) if item.get('combat_radius_km') is not None else 0 for item in equipment_data], - 'datalink_range_km': [float(item['datalink_range_km']) if item.get('datalink_range_km') is not None else 0 for item in equipment_data], - 'guidance_accuracy_m': [float(item['guidance_accuracy_m']) if item.get('guidance_accuracy_m') is not None else 0 for item in equipment_data] + 'length_width_ratio': [float(item.get('length_width_ratio', 0)) for item in equipment_data], + 'weight_range_ratio': [float(item.get('weight_range_ratio', 0)) for item in equipment_data], + 'speed_weight_ratio': [float(item.get('speed_weight_ratio', 0)) for item in equipment_data], + 'guidance_system_score': [float(item.get('guidance_system_score', 0)) for item in equipment_data], + 'warhead_power_score': [float(item.get('warhead_power_score', 0)) for item in equipment_data], + 'guidance_accuracy_m': [float(item.get('guidance_accuracy_m', 0)) for item in equipment_data], + 'datalink_range_km': [float(item.get('datalink_range_km', 0)) for item in equipment_data], + 'max_altitude_m': [float(item.get('max_altitude_m', 0)) for item in equipment_data], + 'min_altitude_m': [float(item.get('min_altitude_m', 0)) for item in equipment_data], + 'engine_power_kw': [float(item.get('engine_power_kw', 0)) for item in equipment_data], + 'engine_thrust_n': [float(item.get('engine_thrust_n', 0)) for item in equipment_data] } - - # 验证数据完整性 - for key, value in missile_data.items(): - logger.info(f"{key} data length: {len(value)}") - logger.info(f"{key} sample data: {value[:3]}") - if not any(value): # 检查是否所有值都为0 - logger.warning(f"All values are 0 for {key}") - - # 更新结果 - result.update(missile_data) + analysis_result.update(missile_data) - return jsonify(result) + return jsonify(analysis_result) except Exception as e: logger.error(f"Error analyzing features: {str(e)}") @@ -332,8 +293,8 @@ def train_model(): if equipment_type == '火箭炮': cursor.execute(""" SELECT e.*, cp.*, rap.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id @@ -343,8 +304,8 @@ def train_model(): else: cursor.execute(""" SELECT e.*, cp.*, lmp.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id @@ -360,8 +321,8 @@ def train_model(): if equipment_type == '火箭炮': cursor.execute(""" SELECT e.*, cp.*, rap.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id @@ -371,8 +332,8 @@ def train_model(): else: cursor.execute(""" SELECT e.*, cp.*, lmp.*, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id @@ -500,7 +461,7 @@ def get_equipment_data(): # 获取所有装备数据(使用equipment_id替代id) cursor.execute(""" - SELECT e.id as equipment_id, e.name, e.type, + SELECT e.id as equipment_id, e.name, e.type, e.manufacturer, cp.length_m, cp.width_m, cp.height_m, cp.weight_kg, cd.actual_cost, cd.predicted_cost, CASE @@ -546,7 +507,7 @@ def get_equipment_data(): WHERE equipment_id = e.id ) END as specific_params - FROM equipment e + FROM equipments e LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id ORDER BY e.id @@ -576,7 +537,7 @@ def delete_equipment(id): cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,)) cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,)) cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,)) - cursor.execute("DELETE FROM equipment WHERE id = %s", (id,)) + cursor.execute("DELETE FROM equipments WHERE id = %s", (id,)) db.commit() cursor.close() @@ -737,7 +698,7 @@ def update_equipment(id): # 更新装备基本信息 cursor.execute(""" - UPDATE equipment + UPDATE equipments SET name = %s, manufacturer = %s WHERE id = %s """, (data['name'], data['manufacturer'], equipment_id)) @@ -845,7 +806,7 @@ def get_equipment_details(id): # 获取装备基本信息类型 cursor.execute(""" SELECT e.*, cp.*, cd.actual_cost, cd.predicted_cost - FROM equipment e + FROM equipments e LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE e.id = %s @@ -854,7 +815,7 @@ def get_equipment_details(id): result = cursor.fetchone() if not result: logger.warning(f"Equipment with ID {id} not found") - return jsonify({'error': '装备不存在'}), 404 + return jsonify({'error': '装备不在'}), 404 logger.info(f"Equipment type: {result['type']}") logger.info(f"Found equipment details: {result['name']}") @@ -901,8 +862,8 @@ def get_datasets(): COUNT(de.equipment_id) as equipment_count, GROUP_CONCAT(e.name) as equipment_names FROM datasets d - LEFT JOIN dataset_equipment de ON d.id = de.dataset_id - LEFT JOIN equipment e ON de.equipment_id = e.id + LEFT JOIN dataset_equipments de ON d.id = de.dataset_id + LEFT JOIN equipments e ON de.equipment_id = e.id GROUP BY d.id """) datasets = cursor.fetchall() @@ -932,7 +893,7 @@ def get_dataset(id): SELECT d.*, COUNT(de.equipment_id) as equipment_count FROM datasets d - LEFT JOIN dataset_equipment de ON d.id = de.dataset_id + LEFT JOIN dataset_equipments de ON d.id = de.dataset_id WHERE d.id = %s GROUP BY d.id """, (id,)) @@ -945,14 +906,14 @@ def get_dataset(id): cursor.execute(""" SELECT e.id as equipment_id, e.name, e.type, e.manufacturer, cd.actual_cost - FROM equipment e - JOIN dataset_equipment de ON e.id = de.equipment_id + FROM equipments e + JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id WHERE de.dataset_id = %s """, (id,)) equipment = cursor.fetchall() - # 计算统计信息 + # 算统计信 if equipment: total_cost = sum(item['actual_cost'] or 0 for item in equipment) avg_cost = total_cost / len(equipment) @@ -989,7 +950,7 @@ def create_dataset(): # 直接从 equipment 表查询,不需要 JOIN equipment_ids_str = ','.join(map(str, data['equipment_ids'])) cursor.execute(f""" - SELECT DISTINCT id FROM equipment + SELECT DISTINCT id FROM equipments WHERE id IN ({equipment_ids_str}) AND type = %s """, (data['equipment_type'],)) @@ -1015,7 +976,7 @@ def create_dataset(): if 'equipment_ids' in data and data['equipment_ids']: values = [(dataset_id, equipment_id) for equipment_id in valid_ids] cursor.executemany(""" - INSERT INTO dataset_equipment (dataset_id, equipment_id) + INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES (%s, %s) """, values) logger.info(f"Added {len(values)} equipment associations") @@ -1041,7 +1002,7 @@ def update_dataset(id): if 'equipment_ids' in data: equipment_ids_str = ','.join(map(str, data['equipment_ids'])) cursor.execute(f""" - SELECT id FROM equipment + SELECT id FROM equipments WHERE id IN ({equipment_ids_str}) AND type = %s """, (data['equipment_type'],)) @@ -1064,13 +1025,13 @@ def update_dataset(id): # 3. 更新装备关联 if 'equipment_ids' in data: # 先删除旧的关联 - cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) + cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,)) # 添加新的关联 if valid_ids: # 确保有有效的ID才执行插入 values = [(id, equipment_id) for equipment_id in valid_ids] cursor.executemany(""" - INSERT INTO dataset_equipment (dataset_id, equipment_id) + INSERT INTO dataset_equipments (dataset_id, equipment_id) VALUES (%s, %s) """, values) logger.info(f"Updated {len(values)} equipment associations") @@ -1092,7 +1053,7 @@ def delete_dataset(id): cursor = conn.cursor() # 删除装备关联 - cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) + cursor.execute("DELETE FROM dataset_equipments WHERE dataset_id = %s", (id,)) # 删除数据集 cursor.execute("DELETE FROM datasets WHERE id = %s", (id,)) @@ -1139,7 +1100,7 @@ def get_models(): models = cursor.fetchall() - # 确保数值类型字段是 float + # 确保数值型字段是 float for model in models: if model['r2_score'] is not None: model['r2_score'] = float(model['r2_score']) @@ -1161,7 +1122,7 @@ def get_models(): @api_bp.route('/models//activate', methods=['POST']) def activate_model(id): """ - 激活指定的模型 + 激活定的模型 """ try: with get_db_connection() as conn: @@ -1250,4 +1211,104 @@ def predict_all(): except Exception as e: logger.error(f"Error in prediction: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/analyze-manufacturers', methods=['POST']) +def analyze_manufacturers(): + """分析生产商数据""" + try: + data = request.get_json() + dataset_id = data.get('dataset_id') + + logger.info(f"Starting manufacturer analysis for dataset {dataset_id}") + + if not dataset_id: + logger.warning("No dataset_id provided") + return jsonify({'error': '请选择数据集'}), 400 + + with get_db_connection() as conn: + cursor = conn.cursor(dictionary=True) + + # 获取数据集中的装备和生产商数据 + cursor.execute(""" + SELECT DISTINCT m.*, e.type as equipment_type + FROM manufacturers m + JOIN equipments e ON e.manufacturer_id = m.id + JOIN dataset_equipments de ON e.id = de.equipment_id + WHERE de.dataset_id = %s + """, (dataset_id,)) + + manufacturers = cursor.fetchall() + + if not manufacturers: + return jsonify({'error': '数据集中没有生产商数据'}), 404 + + # 准备分析数据 + manufacturer_names = [] + tech_levels = [] + scale_levels = [] + supply_chain_levels = [] + composite_scores = [] + region_count = {} + manufacturer_scores = [] + + for manu in manufacturers: + manufacturer_names.append(manu['name']) + tech_levels.append(manu['tech_level']) + scale_levels.append(manu['scale_level']) + supply_chain_levels.append(manu['supply_chain_level']) + + # 计算综合得分 + composite_score = ( + manu['tech_level'] * 0.4 + + manu['scale_level'] * 0.3 + + manu['supply_chain_level'] * 0.3 + ) + composite_scores.append(composite_score) + + # 统计地区分布 + region_count[manu['country']] = region_count.get(manu['country'], 0) + 1 + + # 计算区域系数 + region_factors = { + '美国': 1.2, '英国': 1.15, '德国': 1.15, + '法国': 1.15, '以色列': 1.1, '中国': 0.8, + '俄罗斯': 0.85, '韩国': 0.9, '日本': 1.1 + } + region_factor = region_factors.get(manu['country'], 1.0) + + # 添加雷达图数据 + manufacturer_scores.append({ + 'name': manu['name'], + 'value': [ + manu['tech_level'], + manu['scale_level'], + manu['supply_chain_level'], + region_factor, + composite_score + ] + }) + + # 准备地区分布数据 + region_distribution = [ + {'name': country, 'value': count} + for country, count in region_count.items() + ] + + # 返回分析结果 + result = { + 'manufacturer_names': manufacturer_names, + 'manufacturer_tech_levels': tech_levels, + 'manufacturer_scale_levels': scale_levels, + 'manufacturer_supply_chain_levels': supply_chain_levels, + 'manufacturer_composite_scores': composite_scores, + 'region_distribution': region_distribution, + 'manufacturer_scores': manufacturer_scores + } + + return jsonify(result) + + except Exception as e: + logger.error(f"Error analyzing manufacturers: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) return jsonify({'error': str(e)}), 500 \ No newline at end of file diff --git a/src/schema.sql b/src/schema.sql index e5e7cf6..3ff6dde 100644 --- a/src/schema.sql +++ b/src/schema.sql @@ -10,11 +10,12 @@ COLLATE utf8mb4_unicode_ci; USE equipment_cost_db; -- 装备基本信息表 -CREATE TABLE equipment ( +CREATE TABLE equipments ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100), -- 名称 type VARCHAR(50), -- 类型(火箭炮/巡飞弹) manufacturer VARCHAR(100), -- 制造商 + manufacturer_id INT, -- 制造商ID created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; @@ -26,8 +27,7 @@ CREATE TABLE common_params ( width_m FLOAT, -- 宽度(m) height_m FLOAT, -- 高度(m) weight_kg FLOAT, -- 重量(kg) - max_range_km FLOAT, -- 最大射程(km) - FOREIGN KEY (equipment_id) REFERENCES equipment(id) + FOREIGN KEY (equipment_id) REFERENCES equipments(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; -- 火箭炮特有参数表 @@ -61,7 +61,7 @@ CREATE TABLE rocket_artillery_params ( deployment_score INT, -- 部署评分(1-10) terrain_adaptability_score INT, -- 地形适应性评分(1-10) - FOREIGN KEY (equipment_id) REFERENCES equipment(id) + FOREIGN KEY (equipment_id) REFERENCES equipments(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; -- 巡飞弹特有参数表 @@ -103,7 +103,7 @@ CREATE TABLE loitering_munition_params ( power_system_code INT, -- 动力装置编码 guidance_system_code INT, -- 制导系统编码 - FOREIGN KEY (equipment_id) REFERENCES equipment(id) + FOREIGN KEY (equipment_id) REFERENCES equipments(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; -- 分类特征编码表 @@ -122,7 +122,7 @@ CREATE TABLE cost_data ( actual_cost DECIMAL(15,2), -- 实际成本(元) predicted_cost DECIMAL(15,2), -- 预测成本(元) prediction_date TIMESTAMP, -- 预测日期 - FOREIGN KEY (equipment_id) REFERENCES equipment(id) + FOREIGN KEY (equipment_id) REFERENCES equipments(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; -- 特殊参数表 @@ -133,12 +133,12 @@ CREATE TABLE custom_params ( param_value VARCHAR(255), -- 参数值 param_unit VARCHAR(50), -- 参数单位 description TEXT, -- 参数说明 - FOREIGN KEY (equipment_id) REFERENCES equipment(id) + FOREIGN KEY (equipment_id) REFERENCES equipments(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; -- 添加索引 -CREATE INDEX idx_equipment_type ON equipment(type); -CREATE INDEX idx_equipment_name ON equipment(name); +CREATE INDEX idx_equipment_type ON equipments(type); +CREATE INDEX idx_equipment_name ON equipments(name); CREATE INDEX idx_cost_data_equipment ON cost_data(equipment_id); -- 数据集表 @@ -153,12 +153,12 @@ CREATE TABLE datasets ( ); -- 数据集-装备关联表 -CREATE TABLE dataset_equipment ( +CREATE TABLE dataset_equipments ( dataset_id INT NOT NULL, equipment_id INT NOT NULL, PRIMARY KEY (dataset_id, equipment_id), FOREIGN KEY (dataset_id) REFERENCES datasets(id), - FOREIGN KEY (equipment_id) REFERENCES equipment(id) + FOREIGN KEY (equipment_id) REFERENCES equipments(id) ); -- 训练模型表 @@ -175,10 +175,34 @@ CREATE TABLE trained_models ( feature_importance JSON, -- 特征重要性 training_data_size INT, -- 训练数据量 training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 训练时间 - is_active BOOLEAN DEFAULT FALSE, -- 是否为当前激活模型 + is_active BOOLEAN DEFAULT FALSE, -- 是否为当前活模型 created_by VARCHAR(50) -- 创建者 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; -- 添加索引 CREATE INDEX idx_model_equipment_type ON trained_models(equipment_type); -CREATE INDEX idx_model_active ON trained_models(is_active); \ No newline at end of file +CREATE INDEX idx_model_active ON trained_models(is_active); + +-- 生产商表 +CREATE TABLE manufacturers ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, -- 生产商名称 + country VARCHAR(50) NOT NULL, -- 所属国家 + tech_level INT NOT NULL, -- 技术水平评分(1-10) + scale_level INT NOT NULL, -- 规模评分(1-10) + supply_chain_level INT NOT NULL, -- 供应链成熟度评分(1-10) + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY unique_name (name) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- 添加生产商外键 +ALTER TABLE equipments ADD FOREIGN KEY (manufacturer_id) REFERENCES manufacturers(id); + +-- 添加索引 +CREATE INDEX idx_manufacturer_country ON manufacturers(country); +CREATE INDEX idx_manufacturer_tech_level ON manufacturers(tech_level); +CREATE INDEX idx_manufacturer_scale_level ON manufacturers(scale_level); +CREATE INDEX idx_manufacturer_supply_chain_level ON manufacturers(supply_chain_level); +CREATE INDEX idx_equipment_manufacturer ON equipments(manufacturer_id); +