更新预测模块,解决模型字符集问题

This commit is contained in:
Tian jianyong 2024-11-26 22:54:15 +08:00
parent dba9f2fcc9
commit e67da8eaed
10 changed files with 1128 additions and 315 deletions

View File

@ -359,7 +359,8 @@ const formatDateTime = (value) => {
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
hour12: false,
timeZone: 'Asia/Shanghai'
})
}

View File

@ -31,7 +31,7 @@
{{ scope.row.rmse.toFixed(2) }}
</template>
</el-table-column>
<el-table-column prop="training_date" label="训练时间">
<el-table-column prop="training_date" label="训练时间" width="180">
<template #default="scope">
{{ formatDateTime(scope.row.training_date) }}
</template>
@ -211,10 +211,12 @@ const renderImportanceChart = () => {
//
const formatModelType = (type) => {
const typeMap = {
'pytorch': 'PyTorch',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbdt': 'GBDT',
'rf': 'Random Forest'
'gbm': 'GBM',
'rf': 'Random Forest',
'pls': 'PLS回归'
}
return typeMap[type] || type
}
@ -230,7 +232,8 @@ const formatDateTime = (value) => {
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
hour12: false,
timeZone: 'Asia/Shanghai'
})
}

View File

@ -223,72 +223,46 @@ import { API_BASE_URL } from '@/config'
const formData = reactive({
type: '',
length_m: null,
width_m: null,
height_m: null,
weight_kg: null,
max_range_km: null,
//
firing_angle_horizontal: null,
firing_angle_vertical: null,
rocket_length_m: null,
rocket_diameter_mm: null,
rocket_weight_kg: null,
rate_of_fire: null,
combat_weight_kg: null,
speed_kmh: null,
min_range_km: null,
mobility_type: '',
structure_layout: '',
engine_model: '',
engine_params: '',
power_hp: null,
travel_range_km: null,
// -
max_payload_kg: null, //
ceiling_altitude_m: null, //
combat_radius_km: null, //
engine_power_kw: null, //
engine_thrust_n: null, //
datalink_range_km: null, //
guidance_accuracy_m: null, //
min_altitude_m: null, //
max_altitude_m: null, //
//
length_width_ratio: null, //
weight_range_ratio: null, // /
speed_weight_ratio: null, // /
guidance_system_score: null, //
warhead_power_score: null //
length_m: 7.35,
width_m: 2.4,
height_m: 3.1,
weight_kg: 13700,
max_range_km: 20.4,
firing_angle_horizontal: 102,
firing_angle_vertical: 55,
rocket_length_m: 2.87,
rocket_diameter_mm: 122,
rocket_weight_kg: 66.6,
rate_of_fire: 40,
combat_weight_kg: 15000,
speed_kmh: 60,
min_range_km: 5,
mobility_type: '轮式',
structure_layout: '6x6轮式底盘',
engine_model: 'WD615',
engine_params: '6缸直列柴油机',
power_hp: 280,
travel_range_km: 600,
wingspan_m: 2.5,
warhead_weight_kg: 20,
max_speed_ms: 200,
cruise_speed_kmh: 720,
endurance_min: 30,
warhead_type: '破片杀伤战斗部',
launch_mode: '箱式发射',
power_system: '电动机',
guidance_system: 'GPS/INS/光电',
max_payload_kg: 25,
ceiling_altitude_m: 5000,
combat_radius_km: 100,
datalink_range_km: 150,
guidance_accuracy_m: 3
})
const predictionResults = ref(null)
const mlPrediction = ref(null)
const plsPrediction = ref(null)
const handleTypeChange = () => {
//
if (formData.type === '火箭炮') {
formData.firing_angle_horizontal = null
formData.firing_angle_vertical = null
formData.rocket_length_m = null
formData.rocket_diameter_mm = null
formData.rocket_weight_kg = null
formData.rate_of_fire = null
} else if (formData.type === '巡飞弹') {
formData.wingspan_m = null
formData.warhead_weight_kg = null
formData.max_speed_ms = null
formData.cruise_speed_kmh = null
formData.endurance_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.power_system = ''
formData.guidance_system = ''
}
}
const submitForm = async () => {
try {
//
@ -327,7 +301,7 @@ const submitForm = async () => {
}
}
//
//
const [mlResponse, plsResponse] = await Promise.all([
axios.post(`${API_BASE_URL}/predict`, formData),
axios.post(`${API_BASE_URL}/pls/predict`, formData)
@ -396,9 +370,96 @@ const getModelName = (modelType) => {
}
return modelNames[modelType] || modelType
}
const handleTypeChange = () => {
//
predictionResults.value = false
mlPrediction.value = null
plsPrediction.value = null
//
if (formData.type === '火箭炮') {
//
formData.length_m = 7.35
formData.width_m = 2.4
formData.height_m = 3.1
formData.weight_kg = 13700
formData.max_range_km = 20.4
formData.firing_angle_horizontal = 102
formData.firing_angle_vertical = 55
formData.rocket_length_m = 2.87
formData.rocket_diameter_mm = 122
formData.rocket_weight_kg = 66.6
formData.rate_of_fire = 40
formData.combat_weight_kg = 15000
formData.speed_kmh = 60
formData.min_range_km = 5
formData.mobility_type = '轮式'
formData.structure_layout = '6x6轮式底盘'
formData.engine_model = 'WD615'
formData.engine_params = '6缸直列柴油机'
formData.power_hp = 280
formData.travel_range_km = 600
//
formData.wingspan_m = null
formData.warhead_weight_kg = null
formData.max_speed_ms = null
formData.cruise_speed_kmh = null
formData.endurance_min = null
formData.warhead_type = ''
formData.launch_mode = ''
formData.power_system = ''
formData.guidance_system = ''
formData.max_payload_kg = null
formData.ceiling_altitude_m = null
formData.combat_radius_km = null
formData.datalink_range_km = null
formData.guidance_accuracy_m = null
} else if (formData.type === '巡飞弹') {
//
formData.length_m = 2.5
formData.width_m = 0.4
formData.height_m = 0.4
formData.weight_kg = 120
formData.max_range_km = 100
formData.wingspan_m = 2.5
formData.warhead_weight_kg = 20
formData.max_speed_ms = 200
formData.cruise_speed_kmh = 720
formData.endurance_min = 30
formData.warhead_type = '破片杀伤战斗部'
formData.launch_mode = '箱式发射'
formData.power_system = '电动机'
formData.guidance_system = 'GPS/INS/光电'
formData.max_payload_kg = 25
formData.ceiling_altitude_m = 5000
formData.combat_radius_km = 100
formData.datalink_range_km = 150
formData.guidance_accuracy_m = 3
//
formData.firing_angle_horizontal = null
formData.firing_angle_vertical = null
formData.rocket_length_m = null
formData.rocket_diameter_mm = null
formData.rocket_weight_kg = null
formData.rate_of_fire = null
formData.combat_weight_kg = null
formData.speed_kmh = null
formData.min_range_km = null
formData.mobility_type = ''
formData.structure_layout = ''
formData.engine_model = ''
formData.engine_params = ''
formData.power_hp = null
formData.travel_range_km = null
}
}
</script>
<style scoped>
<style lang="scss" scoped>
.predict-page {
padding: 20px;
}

View File

@ -1,5 +1,6 @@
<template>
<div class="training-page">
<!-- 上部分模型训练区域 -->
<el-card class="training-card">
<template #header>
<h2>模型训练</h2>
@ -38,11 +39,12 @@
<el-form-item label="选择模型">
<el-checkbox-group v-model="trainingConfig.models">
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
<el-checkbox value="pytorch" checked>PyTorch</el-checkbox>
<el-checkbox value="xgboost" checked>XGBoost</el-checkbox>
<el-checkbox value="lightgbm" checked>LightGBM</el-checkbox>
<el-checkbox value="gbm" checked>GBM</el-checkbox>
<el-checkbox value="rf" checked>Random Forest</el-checkbox>
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
</el-checkbox-group>
</el-form-item>
@ -130,6 +132,159 @@
</div>
</div>
</el-card>
<!-- 下部分模型简介区域 -->
<el-card class="model-intro-card">
<template #header>
<h2>模型简介</h2>
</template>
<el-collapse>
<el-collapse-item name="pytorch">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">PyTorch</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>深度学习框架可以构建复杂的神经网络结构</li>
<li>分别处理装备特征和生产商特征然后合并进行预测</li>
<li>使用批量归一化和Dropout防止过拟合</li>
<li>适合处理非线性关系和复杂特征交互</li>
</ul>
<h4>优势</h4>
<ul>
<li>强大的特征学习能力</li>
<li>可以自动学习特征之间的复杂关系</li>
<li>灵活的网络结构设计</li>
<li>支持GPU加速训练</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="xgboost">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">XGBoost</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>基于梯度提升树的集成学习算法</li>
<li>使用二阶导数进行优化</li>
<li>内置正则化机制防止过拟合</li>
<li>支持特征重要性评估</li>
</ul>
<h4>优势</h4>
<ul>
<li>优秀的预测性能</li>
<li>处理缺失值的能力强</li>
<li>训练速度快</li>
<li>可解释性好</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="lightgbm">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">LightGBM</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>微软开发的轻量级梯度提升框架</li>
<li>使用直方图算法优化训练速度</li>
<li>支持类别特征的高效处理</li>
<li>叶子优先的生长策略</li>
</ul>
<h4>优势</h4>
<ul>
<li>训练速度非常快</li>
<li>内存占用低</li>
<li>支持大规模数据训练</li>
<li>准确率高</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="gbm">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">Gradient Boosting (GBM)</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>经典的梯度提升算法</li>
<li>逐步减少残差的思想</li>
<li>可以使用不同的损失函数</li>
<li>支持特征重要性分析</li>
</ul>
<h4>优势</h4>
<ul>
<li>稳定的性能表现</li>
<li>较好的可解释性</li>
<li>对异常值不敏感</li>
<li>适合各种回归问题</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="rf">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">Random Forest</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>基于决策树的集成学习方法</li>
<li>使用随机采样和特征选择</li>
<li>多个决策树投票或平均</li>
<li>自带特征重要性评估</li>
</ul>
<h4>优势</h4>
<ul>
<li>不易过拟合</li>
<li>训练过程可并行化</li>
<li>对噪声数据鲁棒</li>
<li>较少的参数调整</li>
</ul>
</div>
</el-collapse-item>
<el-collapse-item name="pls">
<template #title>
<span class="model-title">
<el-link type="primary" :underline="false">PLS回归</el-link>
</span>
</template>
<div class="model-intro">
<h4>特点</h4>
<ul>
<li>偏最小二乘回归</li>
<li>同时考虑自变量和因变量的变异</li>
<li>处理多重共线性问题</li>
<li>降维和回归的结合</li>
</ul>
<h4>优势</h4>
<ul>
<li>适合小样本数据</li>
<li>处理变量间相关性强的数据</li>
<li>计算效率高</li>
<li>结果稳定可靠</li>
</ul>
</div>
</el-collapse-item>
</el-collapse>
</el-card>
</div>
</template>
@ -144,7 +299,7 @@ const trainingConfig = ref({
type: '',
train_dataset_id: null,
validation_dataset_id: null,
models: ['xgboost', 'lightgbm', 'gbm', 'rf']
models: ['pytorch', 'xgboost', 'lightgbm', 'gbm', 'rf']
})
//
@ -241,10 +396,12 @@ const formatNumber = (value) => {
//
const getModelName = (modelType) => {
const modelNames = {
'pytorch': 'PyTorch',
'xgboost': 'XGBoost',
'lightgbm': 'LightGBM',
'gbm': 'GBM',
'rf': 'Random Forest'
'rf': 'Random Forest',
'pls': 'PLS回归'
}
return modelNames[modelType] || modelType
}
@ -334,8 +491,13 @@ onMounted(() => {
<style lang="scss" scoped>
.training-page {
padding: 20px;
display: flex;
flex-direction: column;
gap: 20px;
.training-card {
width: 100%;
.training-result {
margin-top: 20px;
@ -366,5 +528,62 @@ onMounted(() => {
}
}
}
.model-intro-card {
width: 100%;
.model-title {
.el-link {
font-size: 16px;
font-weight: 500;
&:hover {
opacity: 0.8;
}
&:active {
opacity: 0.6;
}
}
}
.model-intro {
padding: 10px;
h4 {
margin: 10px 0;
color: #409EFF;
font-size: 15px;
}
ul {
padding-left: 20px;
margin: 5px 0;
li {
line-height: 1.8;
color: #606266;
font-size: 14px;
&:hover {
color: #409EFF;
}
}
}
}
:deep(.el-collapse-item__header) {
padding: 12px 0;
font-size: 16px;
&:hover {
background-color: #f5f7fa;
}
}
:deep(.el-collapse-item__content) {
padding: 10px 20px;
}
}
}
</style>

View File

@ -26,6 +26,8 @@ dependencies = [
"torch==2.5.1",
"torchvision==0.20.1",
"torchaudio==2.5.1",
"xgboost>=2.1.0", # 添加 XGBoost
"lightgbm>=4.5.0", # 添加 LightGBM
# 工具
"openpyxl>=3.1.5", # Excel支持

View File

@ -6,6 +6,8 @@ cryptography>=43.0.0 # MySQL 8.0+ 认证需要
mysql-connector-python>=8.0.0 # 添加这行
numpy>=1.26.0,<2.0.0
pandas>=2.2.0
xgboost>=2.1.0
lightgbm>=4.5.0
scikit-learn>=1.5.2

View File

@ -81,50 +81,64 @@ class CostPredictor:
self.model = DefaultModel(X.shape[1]).to(self.device)
self.equipment_type = '火箭炮'
def predict(self, data):
"""
使用训练好的模型进行预测
"""
def predict(self, data, model_record):
"""使用训练好的模型进行预测"""
try:
logger.info(f"Starting prediction for {data.get('type')}")
logger.info(f"Starting prediction for {data.get('type')} using {model_record['model_type']}")
equipment_type = data.get('type')
# 加载已训练的最优模型
trainer = ModelTrainer()
if not trainer.load_model(equipment_type):
raise ValueError(f"No trained model found for {equipment_type}")
# 使用ModelTrainer加载模型
model_trainer = ModelTrainer()
success = model_trainer.load_model(equipment_type, model_record['model_type'])
if not success:
raise ValueError(f"Failed to load model for {equipment_type}")
# 从ModelTrainer获取模型和标准化器
model = model_trainer.model
feature_scaler = model_trainer.feature_scaler
target_scaler = model_trainer.target_scaler
# 准备特征数据
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
feature_analyzer = FeatureAnalysis()
features = feature_analyzer.get_equipment_specific_features(equipment_type)
X = []
for feature in features:
value = data.get(feature, 0.0)
X.append(float(value))
# 转换为 tensor
X = torch.tensor([X], dtype=torch.float32).to(self.device)
# 转换为numpy数组并标准化
X = np.array([X])
X_scaled = feature_scaler.transform(X)
# 预测
with torch.no_grad():
trainer.model.eval() # 设置为评估模式
y_pred = trainer.model(X)
# 根据模型类型进行预测
if isinstance(model, torch.nn.Module):
# PyTorch模型预测
model.eval()
with torch.no_grad():
X_tensor = torch.FloatTensor(X_scaled).to(self.device)
y_pred = model(X_tensor)
y_pred = y_pred.cpu().numpy()
elif model_record['model_type'] == 'pls':
# PLS模型预测
y_pred = model.predict(X_scaled).reshape(-1, 1)
else:
# 其他sklearn模型预测
y_pred = model.predict(X_scaled).reshape(-1, 1)
# 转回 numpy
y_pred = y_pred.cpu().numpy()
# 转换回原始尺度并确保为正数
y_pred_original = target_scaler.inverse_transform(y_pred)
predicted_cost = abs(float(y_pred_original[0][0])) # 确保预测值为正数
# 计算置信区间
confidence_interval = self._calculate_confidence_interval(y_pred[0])
# 获取模型类型
model_type = trainer.get_model_type()
std = predicted_cost * 0.2 # 使用预测值的20%作为标准差
confidence_interval = {
'lower': max(predicted_cost - std, predicted_cost * 0.5), # 至少是预测值的50%
'upper': predicted_cost + std
}
return {
'predicted_cost': float(y_pred[0]),
'model_type': model_type,
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
'predicted_cost': predicted_cost,
'confidence_interval': confidence_interval
}
except Exception as e:

View File

@ -30,34 +30,11 @@ class DataPreparation:
self.target_scaler = StandardScaler()
def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
"""
准备训练数据
"""
"""准备训练数据"""
try:
logger.info(f"Preparing training data for {equipment_type}")
logger.info(f"Raw data size: {len(equipment_data)}")
# 如果输入已经是 numpy 数组,转换为 torch.Tensor
if isinstance(equipment_data, np.ndarray):
X = equipment_data
logger.info(f"Input is numpy array with shape: {X.shape}")
# 处理无效值
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 转换为 PyTorch 数据集
dataset = EquipmentDataset(X)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return {
'dataloader': dataloader,
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler,
'raw_shape': X.shape
}
# 从原始数据中提取特征和目标值
# 获取特征名称(包含生产商特征)
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
features = []
targets = []
@ -66,18 +43,24 @@ class DataPreparation:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取所有生产商数据,用于计算特征
cursor.execute("""
SELECT * FROM manufacturers
""")
manufacturers = {row['id']: row for row in cursor.fetchall()}
for item in equipment_data:
# 获取该装备的生产商数据
manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor)
# 获取生产商数据
manufacturer = manufacturers.get(item['manufacturer_id'], {})
# 计算生产商特征
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data)
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer)
# 合并装备特征和生产商特征
feature_values = []
for name in feature_names:
if name in manufacturer_features:
value = manufacturer_features[name]
if name.startswith('manufacturer_'):
value = manufacturer_features.get(name, 0.0)
else:
value = item.get(name)
feature_values.append(float(value) if value is not None else 0.0)
@ -89,7 +72,7 @@ class DataPreparation:
X = np.array(features, dtype=float)
y = np.array(targets, dtype=float)
# 记录原始数据范围
# 记录数据范围
logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
@ -97,7 +80,7 @@ class DataPreparation:
X_scaled = self.feature_scaler.fit_transform(X)
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
# 创建 PyTorch 数据集和数据加载器
# 创建数据集和数据加载器
dataset = EquipmentDataset(X_scaled, y_scaled)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
@ -106,7 +89,9 @@ class DataPreparation:
'feature_names': feature_names,
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler,
'raw_shape': X.shape
'raw_shape': X.shape,
'X': X_scaled,
'y': y_scaled
}
except Exception as e:
@ -165,7 +150,6 @@ class DataPreparation:
# 提取目标值(成本)并验证范围
try:
cost = float(item['actual_cost'])
logger.info(f"Raw cost value: {cost}")
if cost > 0: # 只使用正数成本值
targets.append(cost)
else:

View File

@ -4,6 +4,7 @@ import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import logging
import os
from datetime import datetime
@ -12,26 +13,104 @@ from src.feature_analysis import FeatureAnalysis
from src.database import get_db_connection
from src.data_preparation import DataPreparation, EquipmentDataset
from .logger import setup_logger
import math
logger = setup_logger(__name__)
class CostPredictionModel(nn.Module):
def __init__(self, input_size):
def __init__(self, input_size, equipment_type):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
self.equipment_type = equipment_type
if equipment_type == '火箭炮':
# 火箭炮使用更简单和稳定的网络结构
self.net = nn.Sequential(
# 第一层:特征映射
nn.Linear(input_size, 32),
nn.ReLU(),
nn.BatchNorm1d(32),
# 第二层:特征提取
nn.Linear(32, 16),
nn.ReLU(),
nn.BatchNorm1d(16),
# 第三层:特征整合
nn.Linear(16, 8),
nn.ReLU(),
nn.BatchNorm1d(8),
# 输出层
nn.Linear(8, 1)
)
# 使用正交初始化
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=0.5)
torch.nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm1d):
torch.nn.init.constant_(m.weight, 0.5)
torch.nn.init.constant_(m.bias, 0.0)
self.net.apply(init_weights)
else: # 巡飞弹保持原有结构
# 生产商特征网络 - 更简单的结构
self.manufacturer_net = nn.Sequential(
nn.Linear(5, 4),
nn.ReLU(),
nn.BatchNorm1d(4),
nn.Dropout(0.2)
)
# 巡飞弹特征网络 - 较深的结构
self.equipment_net = nn.Sequential(
nn.Linear(input_size - 5, 64),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(64),
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(32),
nn.Dropout(0.2),
nn.Linear(32, 16),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(16),
nn.Dropout(0.2)
)
# 合并网络 - 较复杂的结构
self.combined_net = nn.Sequential(
nn.Linear(20, 32), # 4 + 16 = 20
nn.LeakyReLU(0.1),
nn.BatchNorm1d(32),
nn.Dropout(0.2),
nn.Linear(32, 16),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(16),
nn.Dropout(0.2),
nn.Linear(16, 8),
nn.LeakyReLU(0.1),
nn.BatchNorm1d(8),
nn.Linear(8, 1)
)
def forward(self, x):
return self.layers(x)
if self.equipment_type == '火箭炮':
return self.net(x)
else:
# 分离特征
manufacturer_features = x[:, -5:]
equipment_features = x[:, :-5]
# 特征处理
manu_out = self.manufacturer_net(manufacturer_features)
equip_out = self.equipment_net(equipment_features)
# 特征融合
combined = torch.cat([equip_out, manu_out], dim=1)
return self.combined_net(combined)
class ModelTrainer:
def __init__(self):
@ -42,24 +121,86 @@ class ModelTrainer:
self.equipment_type = None
self.feature_analyzer = FeatureAnalysis()
def train_model(self, dataloader, epochs=100, learning_rate=0.001):
def train_model(self, dataloader, epochs=100, learning_rate=0.001, equipment_type=None):
"""训练模型"""
try:
# 获取输入特征维度
sample_features, _ = next(iter(dataloader))
input_size = sample_features.shape[1]
# 创建模型
self.model = CostPredictionModel(input_size).to(self.device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
# 设置确定性
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
np.random.seed(42)
self.model = CostPredictionModel(input_size, equipment_type).to(self.device)
if equipment_type == '火箭炮':
# 火箭炮使用更保守和稳定的训练设置
criterion = nn.SmoothL1Loss(beta=0.1) # 使用Huber损失beta值较小
learning_rate = 0.0003 # 更小的学习率
weight_decay = 0.001 # 适中的权重衰减
# 使用AdamW优化器更小的beta值
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.8, 0.9), # 更小的动量值
eps=1e-8
)
# 使用带预热的学习率调度
num_steps = len(dataloader) * epochs
warmup_steps = num_steps // 10 # 10%的预热步数
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 0.5 * (1.0 + math.cos(
math.pi * (current_step - warmup_steps) / float(max(1, num_steps - warmup_steps))
))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
else: # 巡飞弹保持原有设置
# 巡飞弹使用更激进的训练设置
criterion = nn.MSELoss()
learning_rate = 0.001 # 较大的学习率
weight_decay = 0.001 # 较小的权重衰减
patience = 20 # 较短的耐心值
# 使用Adam优化器
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
# 使用余弦退火学习率调度
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=epochs,
eta_min=learning_rate * 0.01
)
# 训练循环
best_loss = float('inf')
patience = 30
patience_counter = 0
best_model_state = None
moving_avg_loss = None
alpha = 0.9 # 移动平均系数
for epoch in range(epochs):
self.model.train()
total_loss = 0
batch_count = 0
for batch_features, batch_targets in dataloader:
# 移动数据到设备
batch_features = batch_features.to(self.device)
batch_targets = batch_targets.to(self.device)
@ -68,16 +209,57 @@ class ModelTrainer:
loss = criterion(outputs, batch_targets.view(-1, 1))
# 反向传播
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=True) # 更高效的梯度清零
loss.backward()
# 梯度裁剪
if equipment_type == '火箭炮':
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=0.1
)
optimizer.step()
if equipment_type == '火箭炮':
scheduler.step()
total_loss += loss.item()
batch_count += 1
avg_loss = total_loss / batch_count
# 使用移动平均计算损失
if moving_avg_loss is None:
moving_avg_loss = avg_loss
else:
moving_avg_loss = alpha * moving_avg_loss + (1 - alpha) * avg_loss
# 早停检查使用移动平均损失
if moving_avg_loss < best_loss:
best_loss = moving_avg_loss
patience_counter = 0
best_model_state = {
'state_dict': self.model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict() if equipment_type == '火箭炮' else None
}
else:
patience_counter += 1
# 记录训练进度
if (epoch + 1) % 10 == 0:
avg_loss = total_loss / len(dataloader)
logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {moving_avg_loss:.4f}, '
f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
if patience_counter >= patience:
logger.info(f"Early stopping triggered at epoch {epoch+1}")
break
# 恢复最佳模型
if best_model_state is not None:
self.model.load_state_dict(best_model_state['state_dict'])
optimizer.load_state_dict(best_model_state['optimizer'])
if equipment_type == '火箭炮' and best_model_state['scheduler']:
scheduler.load_state_dict(best_model_state['scheduler'])
return True
@ -85,62 +267,78 @@ class ModelTrainer:
logger.error(f"Error in model training: {str(e)}")
raise
def save_model(self, equipment_type):
def save_model(self, equipment_type, metrics=None):
"""保存模型"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 转换装备类型为英文
equipment_type_en = 'rocket' if equipment_type == '火箭炮' else 'missile'
# 保存模型
model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth'
model_path = f'{model_dir}/{equipment_type_en}_{timestamp}.pth'
torch.save({
'model_state_dict': self.model.state_dict(),
'input_size': self.model.layers[0].in_features
'input_size': self.model.equipment_net[0].in_features + 5,
'manufacturer_net_state': self.model.manufacturer_net.state_dict(),
'equipment_net_state': self.model.equipment_net.state_dict(),
'combined_net_state': self.model.combined_net.state_dict()
}, model_path)
# 保存标准化器
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth'
scaler_path = f'{model_dir}/{equipment_type_en}_{timestamp}_scaler.pth'
torch.save({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
# 获取评估指标
r2 = metrics['validation']['r2'] if metrics and metrics.get('validation') else metrics['train']['r2']
mae = metrics['validation']['mae'] if metrics and metrics.get('validation') else metrics['train']['mae']
rmse = metrics['validation']['rmse'] if metrics and metrics.get('validation') else metrics['train']['rmse']
# 更新数据库
with get_db_connection() as conn:
cursor = conn.cursor()
# 将所有模型设置为非激活
# 将所有同类型模型设置为非激活, 除了 PLS 模型
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s
""", (equipment_type,))
WHERE equipment_type = %s AND model_type != %s
""", (equipment_type, 'pls'))
# 保存新模型记录
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s)
scaler_path, training_date, is_active, created_by,
r2_score, mae, rmse
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
""", (
f"{equipment_type}_{timestamp}",
'pytorch',
equipment_type,
model_path,
scaler_path,
'system'
'system',
r2,
mae,
rmse
))
conn.commit()
logger.info(f"Model saved successfully: {model_path}")
return True
except Exception as e:
logger.error(f"Error saving model: {str(e)}")
return False
def load_model(self, equipment_type):
def load_model(self, equipment_type, model_type):
"""加载模型"""
try:
# 从数据库获取最新的激活模型
@ -148,25 +346,39 @@ class ModelTrainer:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
""", (equipment_type, model_type))
model_record = cursor.fetchone()
if not model_record:
return False
raise ValueError(f"No trained model found for {equipment_type}")
# 加载模型
checkpoint = torch.load(model_record['model_path'])
input_size = checkpoint['input_size']
self.model = CostPredictionModel(input_size).to(self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
# 加载标准化器
scalers = torch.load(model_record['scaler_path'])
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
if model_record['model_type'] == 'pytorch':
# 加载PyTorch模型
checkpoint = torch.load(model_record['model_path'], encoding='latin1')
input_size = checkpoint['input_size']
# 创建新模型实例
self.model = CostPredictionModel(input_size, equipment_type).to(self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
# 加载标准化器
scalers = torch.load(model_record['scaler_path'], encoding='latin1')
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
else:
# 加载sklearn模型
from joblib import load
with open(model_record['model_path'], 'rb') as f:
self.model = load(f)
with open(model_record['scaler_path'], 'rb') as f:
scalers = load(f)
self.feature_scaler = scalers['feature_scaler']
self.target_scaler = scalers['target_scaler']
logger.info(f"Model loaded successfully from {model_record['model_path']}")
return True
except Exception as e:
@ -178,12 +390,356 @@ class ModelTrainer:
try:
self.model.eval()
with torch.no_grad():
# 转换为tensor并移动到正确的设备
features_tensor = torch.FloatTensor(features).to(self.device)
# 进行预测
predictions = self.model(features_tensor)
# 移回CPU并转换为numpy数组
predictions = self.model(features_tensor) # 直接返回预测值
return predictions.cpu().numpy()
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise
raise
def fit_model(self, X_train, y_train, models, X_val=None, y_val=None, equipment_type=None):
"""训练模型并返回评估结果"""
try:
logger.info(f"Starting model training for {equipment_type}")
logger.info(f"Selected models: {models}")
logger.info(f"Training data shape: {X_train.shape}")
all_metrics = {}
best_model = None
best_score = float('-inf')
best_model_type = None # 添加变量记录最佳模型类型
# 训练所有选择的模型
for model_type in models:
logger.info(f"Training {model_type} model...")
if model_type == 'pls':
# PLS模型单独处理不参与最优模型评选
from sklearn.cross_decomposition import PLSRegression
# 使用较少的组件数来避免过拟合
n_components = min(3, X_train.shape[1] // 5)
model = PLSRegression(
n_components=n_components,
scale=True,
max_iter=500,
tol=1e-6
)
model.fit(X_train, y_train)
# 评估PLS模型
y_train_pred = model.predict(X_train).ravel()
if X_val is not None:
y_val_pred = model.predict(X_val).ravel()
# 将预测值转换回原始尺度
y_train_pred_original = self.target_scaler.inverse_transform(y_train_pred.reshape(-1, 1)).ravel()
y_train_original = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
train_metrics = {
'r2': float(r2_score(y_train_original, y_train_pred_original)),
'mae': float(mean_absolute_error(y_train_original, y_train_pred_original)),
'rmse': float(np.sqrt(mean_squared_error(y_train_original, y_train_pred_original)))
}
val_metrics = None
if X_val is not None:
y_val_pred_original = self.target_scaler.inverse_transform(y_val_pred.reshape(-1, 1)).ravel()
y_val_original = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
val_metrics = {
'r2': float(r2_score(y_val_original, y_val_pred_original)),
'mae': float(mean_absolute_error(y_val_original, y_val_pred_original)),
'rmse': float(np.sqrt(mean_squared_error(y_val_original, y_val_pred_original)))
}
all_metrics[model_type] = {
'train': train_metrics,
'validation': val_metrics
}
# 保存PLS模型但不参与最优模型评选
if equipment_type:
self._save_sklearn_model(equipment_type, model_type, model, all_metrics[model_type])
continue # 跳过后续的最优模型评选
elif model_type == 'xgboost':
import xgboost as xgb
model = xgb.XGBRegressor(
n_estimators=50,
learning_rate=0.03,
max_depth=3,
min_child_weight=5,
subsample=0.6,
colsample_bytree=0.6,
reg_alpha=0.5,
reg_lambda=2.0,
gamma=1,
random_state=42
)
# 训练模型
model.fit(X_train, y_train)
elif model_type == 'lightgbm':
import lightgbm as lgb
model = lgb.LGBMRegressor(
n_estimators=50,
learning_rate=0.03,
max_depth=3,
num_leaves=8,
subsample=0.6,
colsample_bytree=0.6,
reg_alpha=0.5,
reg_lambda=2.0,
min_child_samples=10,
min_split_gain=1.0,
random_state=42
)
# 训练模型
model.fit(X_train, y_train)
elif model_type == 'gbm':
from sklearn.ensemble import GradientBoostingRegressor
model = GradientBoostingRegressor(
n_estimators=50,
learning_rate=0.03,
max_depth=3,
min_samples_split=10,
min_samples_leaf=5,
subsample=0.6,
min_impurity_decrease=0.01,
random_state=42
)
# 训练模型
model.fit(X_train, y_train)
elif model_type == 'rf':
from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor(
n_estimators=100,
max_depth=4,
min_samples_split=5,
min_samples_leaf=3,
max_features='sqrt',
bootstrap=True,
random_state=42
)
# 训练模型
model.fit(X_train, y_train)
elif model_type == 'pytorch':
# 训练PyTorch模型
train_dataset = EquipmentDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
training_success = self.train_model(train_loader)
if not training_success:
continue
# 评估模型性能
if model_type == 'pytorch':
with torch.no_grad():
X_train_tensor = torch.FloatTensor(X_train).to(self.device)
y_train_pred = self.model(X_train_tensor).cpu().numpy() # 直接获取输出
if X_val is not None:
X_val_tensor = torch.FloatTensor(X_val).to(self.device)
y_val_pred = self.model(X_val_tensor).cpu().numpy() # 直接获取输出
else:
# 使用训练好的模型进行预测
y_train_pred = model.predict(X_train)
if X_val is not None:
if model_type == 'pls':
y_val_pred = model.predict(X_val).ravel()
# 记录PLS的一些额外信息
if hasattr(model, 'score'):
train_r2 = model.score(X_train, y_train)
val_r2 = model.score(X_val, y_val)
logger.info(f"PLS built-in R² - Train: {train_r2:.4f}, Val: {val_r2:.4f}")
# 记录每个组件解释的方差比例
if hasattr(model, 'explained_variance_ratio_'):
logger.info("PLS explained variance ratios: " +
", ".join([f"{v:.4f}" for v in model.explained_variance_ratio_]))
else:
y_val_pred = model.predict(X_val)
# 将测值转换回始尺度
y_train_pred_original = self.target_scaler.inverse_transform(y_train_pred.reshape(-1, 1)).ravel()
y_train_original = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
train_metrics = {
'r2': float(r2_score(y_train_original, y_train_pred_original)),
'mae': float(mean_absolute_error(y_train_original, y_train_pred_original)),
'rmse': float(np.sqrt(mean_squared_error(y_train_original, y_train_pred_original)))
}
val_metrics = None
if X_val is not None and y_val is not None:
y_val_pred_original = self.target_scaler.inverse_transform(y_val_pred.reshape(-1, 1)).ravel()
y_val_original = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
val_metrics = {
'r2': float(r2_score(y_val_original, y_val_pred_original)),
'mae': float(mean_absolute_error(y_val_original, y_val_pred_original)),
'rmse': float(np.sqrt(mean_squared_error(y_val_original, y_val_pred_original)))
}
all_metrics[model_type] = {
'train': train_metrics,
'validation': val_metrics
}
# 更新最佳模型不包括PLS
current_score = val_metrics['r2'] if val_metrics else train_metrics['r2']
if model_type != 'pls' and current_score > best_score:
best_score = current_score
best_model = {
'type': model_type,
'r2': current_score,
'mae': val_metrics['mae'] if val_metrics else train_metrics['mae'],
'rmse': val_metrics['rmse'] if val_metrics else train_metrics['rmse']
}
best_model_type = model_type # 记录最佳模型类型
# 保存最佳模型实例(但不立即写入数据库)
if model_type == 'pytorch':
self.best_pytorch_model = self.model.state_dict() # 保存模型状态
self.best_pytorch_metrics = all_metrics[model_type] # 保存指标
else:
self.best_model = model
self.best_model_metrics = all_metrics[model_type]
# 单独保存PLS模型不参与最佳模型评选
if model_type == 'pls' and equipment_type:
self._save_sklearn_model(equipment_type, model_type, model, all_metrics[model_type])
# 在所有模型训练完成后,只保存最佳模型
if best_model_type and equipment_type:
if best_model_type == 'pytorch':
# 恢复最佳PyTorch模型状态并保存
self.model.load_state_dict(self.best_pytorch_model)
self.save_model(equipment_type, self.best_pytorch_metrics)
else:
# 保存最佳sklearn模型
self._save_sklearn_model(equipment_type, best_model_type, self.best_model, self.best_model_metrics)
return {
'metrics': all_metrics,
'feature_importance': None,
'best_model': best_model
}
except Exception as e:
logger.error(f"Error in model fitting: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
raise
def _calculate_feature_importance(self, X):
"""计算特征重要性"""
try:
if self.model is None:
return None
self.model.eval()
feature_importance = np.zeros(X.shape[1])
# 使用特征扰动计算重要性
with torch.no_grad():
X_tensor = torch.FloatTensor(X).to(self.device)
baseline_pred = self.model(X_tensor).cpu().numpy() # 直接获取预测值
for i in range(X.shape[1]):
# 创建动后的特征
X_perturbed = X.copy()
X_perturbed[:, i] = np.random.permutation(X_perturbed[:, i])
# 预测并计算影响
X_perturbed_tensor = torch.FloatTensor(X_perturbed).to(self.device)
perturbed_pred = self.model(X_perturbed_tensor).cpu().numpy() # 直接获取预测值
# 特征重要性为预变化的平均绝对值
feature_importance[i] = np.mean(np.abs(baseline_pred - perturbed_pred))
# 归一化特征重要性
feature_importance = feature_importance / np.sum(feature_importance)
return feature_importance
except Exception as e:
logger.error(f"Error calculating feature importance: {str(e)}")
return None
def _save_sklearn_model(self, equipment_type, model_type, model, metrics=None):
"""保存sklearn类型的模型"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
# 转换装备类型为英文
equipment_type_en = 'rocket' if equipment_type == '火箭炮' else 'missile'
# 保存模型
model_path = f'{model_dir}/{equipment_type_en}_{model_type}_{timestamp}.joblib'
from joblib import dump
dump(model, model_path)
# 保存标准化器
scaler_path = f'{model_dir}/{equipment_type_en}_{model_type}_{timestamp}_scaler.joblib'
dump({
'feature_scaler': self.feature_scaler,
'target_scaler': self.target_scaler
}, scaler_path)
# 获取评估指标
r2 = metrics['validation']['r2'] if metrics and metrics.get('validation') else metrics['train']['r2']
mae = metrics['validation']['mae'] if metrics and metrics.get('validation') else metrics['train']['mae']
rmse = metrics['validation']['rmse'] if metrics and metrics.get('validation') else metrics['train']['rmse']
# 更新数据库
with get_db_connection() as conn:
cursor = conn.cursor()
# 将同类型的其他模型设置为非激活
if model_type != 'pls':
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type != %s
""", (equipment_type, 'pls'))
else:
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s AND model_type = %s
""", (equipment_type, 'pls'))
# 保存新模型记录
cursor.execute("""
INSERT INTO trained_models (
model_name, model_type, equipment_type, model_path,
scaler_path, training_date, is_active, created_by,
r2_score, mae, rmse
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
""", (
f"{equipment_type}_{model_type}_{timestamp}",
model_type,
equipment_type,
model_path,
scaler_path,
'system',
r2,
mae,
rmse
))
conn.commit()
logger.info(f"Model {model_type} saved successfully: {model_path}")
return True
except Exception as e:
logger.error(f"Error saving sklearn model: {str(e)}")
return False

View File

@ -12,6 +12,7 @@ import os
from .data_preparation import DataPreparation
from .model_trainer import ModelTrainer
from .logger import setup_logger
import torch
# 创建蓝图
api_bp = Blueprint('api', __name__)
@ -58,49 +59,45 @@ def index():
})
@api_bp.route('/predict', methods=['POST'])
def predict_cost():
"""
成本预测接口
"""
def predict():
"""使用最优机器学习模型进行预测"""
try:
data = request.get_json()
logger.info(f"Received prediction request for equipment type: {data.get('type')}")
equipment_type = data.get('type')
logger.info(f"Received prediction request for equipment type: {equipment_type}")
# 验证装备类型
if 'type' not in data:
return jsonify({'error': 'Equipment type is required'}), 400
# 预测成本
predictor = CostPredictor()
result = predictor.predict(data)
# 获取当前使用的模型信息
# 获取最新的激活模型非PLS模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE
LIMIT 1
""", (data['type'],))
model_info = cursor.fetchone()
# 在结果中添加模型信息
result.update({
'model_info': {
'type': model_info['model_type'],
'name': model_info['model_name'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
}
})
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type != 'pls' # 明确排除PLS模型
AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
model = cursor.fetchone()
logger.info(f"Prediction completed: {result}")
return jsonify(result)
if not model:
return jsonify({'error': '未找到可用的模型'}), 404
# 使用普通预测方法
predictor = CostPredictor()
prediction = predictor.predict(data, model) # 使用普通predict方法
# 返回预测结果
return jsonify({
'model_info': {
'type': model['model_type'], # 使用数据库中的模型类型
'name': model['model_name']
},
'predicted_cost': prediction['predicted_cost'],
'confidence_interval': prediction['confidence_interval']
})
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
@api_bp.route('/analyze-features', methods=['POST'])
@ -180,27 +177,7 @@ def analyze_features():
""", (dataset_id,))
equipment_data = cursor.fetchall()
# 添加数据检查日志
logger.info(f"Total records found: {len(equipment_data)}")
if equipment_data:
# 检查第一条记录的所有字段
first_record = equipment_data[0]
logger.info("First record details:")
for key, value in first_record.items():
logger.info(f"{key}: {value}")
# 检查所有记录的 max_range_km 字段
logger.info("Checking max_range_km for all records:")
for item in equipment_data:
logger.info(f"Equipment: {item['name']}")
logger.info(f" max_range_km: {item.get('max_range_km')}")
logger.info(f" type: {item['type']}")
if item['type'] == '火箭炮':
logger.info(f" rocket_artillery_params fields:")
for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']:
logger.info(f" {key}: {item.get(key)}")
# 提取特征和目标值
analyzer = FeatureAnalysis()
feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type'])
@ -270,9 +247,7 @@ def analyze_features():
@api_bp.route('/train', methods=['POST'])
def train_model():
"""
训练模型
"""
"""训练模型"""
try:
data = request.get_json()
logger.info(f"Starting model training for {data.get('type')}")
@ -289,54 +264,65 @@ def train_model():
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取训练集数据
# 获取训练集数据(包含生产商信息)
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
SELECT e.*, cp.*, rap.*, cd.actual_cost,
m.tech_level, m.scale_level, m.supply_chain_level,
m.id as manufacturer_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,))
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
SELECT e.*, cp.*, lmp.*, cd.actual_cost,
m.tech_level, m.scale_level, m.supply_chain_level,
m.id as manufacturer_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,))
train_data = cursor.fetchall()
# 获取验证集据(如果有)
validation_data = None
# 获取验证集数据
if validation_dataset_id:
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
SELECT e.*, cp.*, rap.*, cd.actual_cost,
m.tech_level, m.scale_level, m.supply_chain_level,
m.id as manufacturer_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,))
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
SELECT e.*, cp.*, lmp.*, cd.actual_cost,
m.tech_level, m.scale_level, m.supply_chain_level,
m.id as manufacturer_id
FROM equipments e
JOIN dataset_equipments de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,))
@ -351,7 +337,7 @@ def train_model():
# 准备训练数据
train_prepared = data_processor.prepare_training_data(train_data, equipment_type)
# 准备验数据(如果有)
# 准备验数据(如果有)
validation_prepared = None
if validation_data:
validation_prepared = data_processor.prepare_validation_data(
@ -369,14 +355,14 @@ def train_model():
model_trainer.feature_scaler = train_prepared['feature_scaler']
model_trainer.target_scaler = train_prepared['target_scaler']
# 执行训练,传入 equipment_type 参数
# 执行训练,传入equipment_type参数
training_result = model_trainer.fit_model(
train_prepared['X'],
train_prepared['y'],
models,
validation_prepared['X'] if validation_prepared else None,
validation_prepared['y'] if validation_prepared else None,
equipment_type=equipment_type
equipment_type=equipment_type # 添加这个参数
)
return jsonify(training_result)
@ -588,65 +574,43 @@ def get_db_connection():
@api_bp.route('/pls/predict', methods=['POST'])
def pls_predict():
"""
PLS回归预测接口
"""
"""使用PLS模型进行预测"""
try:
data = request.get_json()
logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}")
equipment_type = data.get('type')
# 验证装备类型
if 'type' not in data:
return jsonify({'error': 'Equipment type is required'}), 400
# 使用 ModelTrainer 中的 PLS 模型行预测
trainer = ModelTrainer()
if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型
return jsonify({'error': '未找到可用的模型'}), 404
# 准备特征数据
feature_analyzer = FeatureAnalysis()
features = feature_analyzer.get_equipment_specific_features(data['type'])
X = np.array([[data.get(feature) for feature in features]])
# 预测
result = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(result[0])
# 获取模型信息
# 获取最新的PLS模型
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE
LIMIT 1
""", (data['type'],))
model_info = cursor.fetchone()
# 确保返回的数据可以序列化为JSON
response = {
'predicted_cost': float(result[0]),
'model_info': {
'type': model_info['model_type'],
'name': model_info['model_name'],
'r2_score': model_info['r2_score'],
'mae': model_info['mae'],
'rmse': model_info['rmse']
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
SELECT * FROM trained_models
WHERE equipment_type = %s
AND model_type = 'pls' # 只选择PLS模型
AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
model = cursor.fetchone()
logger.info(f"PLS prediction completed: {response}")
return jsonify(response)
if not model:
return jsonify({'error': '未找到可用的PLS模型'}), 404
# 使用普通predict方法传入模型信息
predictor = CostPredictor()
prediction = predictor.predict(data, model) # 传入model参数
# 返回预测结果
return jsonify({
'model_info': {
'type': 'pls',
'name': model['model_name']
},
'predicted_cost': prediction['predicted_cost'],
'confidence_interval': prediction['confidence_interval']
})
except Exception as e:
logger.error(f"Error in PLS prediction: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
@api_bp.route('/data/import', methods=['POST'])
@ -852,7 +816,7 @@ def get_equipment_details(id):
@api_bp.route('/datasets', methods=['GET'])
def get_datasets():
"""
获取数集列表
获取数集列表
"""
try:
with get_db_connection() as conn:
@ -868,13 +832,20 @@ def get_datasets():
""")
datasets = cursor.fetchall()
# 理装备名称列表
# 理装备名称列表
for dataset in datasets:
if dataset['equipment_names']:
dataset['equipment_names'] = dataset['equipment_names'].split(',')
else:
dataset['equipment_names'] = []
# 格式化时间和数值字段
for dataset in datasets:
if dataset['created_at']:
dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S')
if dataset['updated_at']:
dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S')
return jsonify(datasets)
except Exception as e:
logger.error(f"Error getting datasets: {str(e)}")
@ -930,6 +901,7 @@ def get_dataset(id):
}
dataset['equipment'] = equipment
return jsonify(dataset)
except Exception as e:
logger.error(f"Error getting dataset: {str(e)}")
@ -1067,7 +1039,7 @@ def delete_dataset(id):
@api_bp.route('/models/<equipment_type>/latest', methods=['GET'])
def get_latest_model(equipment_type):
"""
获取最新训练的型信息
获取最新训练的型信息
"""
try:
with get_db_connection() as conn:
@ -1087,9 +1059,7 @@ def get_latest_model(equipment_type):
@api_bp.route('/models', methods=['GET'])
def get_models():
"""
获取模型列表
"""
"""获取模型列表"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
@ -1100,8 +1070,11 @@ def get_models():
models = cursor.fetchall()
# 确保数值型字段是 float
# 格式化时间和数值字段
for model in models:
# 将数据库中的datetime转换为ISO格式字符串
if model['training_date']:
model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S')
if model['r2_score'] is not None:
model['r2_score'] = float(model['r2_score'])
if model['mae'] is not None:
@ -1121,16 +1094,14 @@ def get_models():
@api_bp.route('/models/<int:id>/activate', methods=['POST'])
def activate_model(id):
"""
激活定的模型
"""
"""激活指定的模型"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor = conn.cursor(dictionary=True) # 使用字典游标
# 获取模型信息
cursor.execute("""
SELECT equipment_type FROM trained_models
SELECT equipment_type, model_type FROM trained_models
WHERE id = %s
""", (id,))
model = cursor.fetchone()
@ -1142,8 +1113,8 @@ def activate_model(id):
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s
""", (model[0],))
WHERE equipment_type = %s AND model_type = %s
""", (model['equipment_type'], model['model_type']))
# 激活指定模型
cursor.execute("""