diff --git a/frontend/src/views/DatasetPage.vue b/frontend/src/views/DatasetPage.vue index 60e94c1..1d6ba7d 100644 --- a/frontend/src/views/DatasetPage.vue +++ b/frontend/src/views/DatasetPage.vue @@ -359,7 +359,8 @@ const formatDateTime = (value) => { hour: '2-digit', minute: '2-digit', second: '2-digit', - hour12: false + hour12: false, + timeZone: 'Asia/Shanghai' }) } diff --git a/frontend/src/views/ModelPage.vue b/frontend/src/views/ModelPage.vue index ba7910e..98551af 100644 --- a/frontend/src/views/ModelPage.vue +++ b/frontend/src/views/ModelPage.vue @@ -31,7 +31,7 @@ {{ scope.row.rmse.toFixed(2) }} - + @@ -211,10 +211,12 @@ const renderImportanceChart = () => { // 格式化模型类型 const formatModelType = (type) => { const typeMap = { + 'pytorch': 'PyTorch', 'xgboost': 'XGBoost', 'lightgbm': 'LightGBM', - 'gbdt': 'GBDT', - 'rf': 'Random Forest' + 'gbm': 'GBM', + 'rf': 'Random Forest', + 'pls': 'PLS回归' } return typeMap[type] || type } @@ -230,7 +232,8 @@ const formatDateTime = (value) => { hour: '2-digit', minute: '2-digit', second: '2-digit', - hour12: false + hour12: false, + timeZone: 'Asia/Shanghai' }) } diff --git a/frontend/src/views/PredictPage.vue b/frontend/src/views/PredictPage.vue index 631139b..6cd4f38 100644 --- a/frontend/src/views/PredictPage.vue +++ b/frontend/src/views/PredictPage.vue @@ -223,72 +223,46 @@ import { API_BASE_URL } from '@/config' const formData = reactive({ type: '', - length_m: null, - width_m: null, - height_m: null, - weight_kg: null, - max_range_km: null, - // 火箭炮特有参数 - firing_angle_horizontal: null, - firing_angle_vertical: null, - rocket_length_m: null, - rocket_diameter_mm: null, - rocket_weight_kg: null, - rate_of_fire: null, - combat_weight_kg: null, - speed_kmh: null, - min_range_km: null, - mobility_type: '', - structure_layout: '', - engine_model: '', - engine_params: '', - power_hp: null, - travel_range_km: null, - // 巡飞弹特有参数 - 补充 - max_payload_kg: null, // 最大载荷 - ceiling_altitude_m: null, // 升限 - combat_radius_km: null, // 作战半径 - engine_power_kw: null, // 发动机功率 - engine_thrust_n: null, // 发动机推力 - datalink_range_km: null, // 通信链路距离 - guidance_accuracy_m: null, // 制导精度 - min_altitude_m: null, // 最小作战高度 - max_altitude_m: null, // 最大作战高度 - - // 特征工程参数 - length_width_ratio: null, // 长宽比 - weight_range_ratio: null, // 重量/射程比 - speed_weight_ratio: null, // 速度/重量比 - guidance_system_score: null, // 制导系统复杂度评分 - warhead_power_score: null // 战斗部威力评分 + length_m: 7.35, + width_m: 2.4, + height_m: 3.1, + weight_kg: 13700, + max_range_km: 20.4, + firing_angle_horizontal: 102, + firing_angle_vertical: 55, + rocket_length_m: 2.87, + rocket_diameter_mm: 122, + rocket_weight_kg: 66.6, + rate_of_fire: 40, + combat_weight_kg: 15000, + speed_kmh: 60, + min_range_km: 5, + mobility_type: '轮式', + structure_layout: '6x6轮式底盘', + engine_model: 'WD615', + engine_params: '6缸直列柴油机', + power_hp: 280, + travel_range_km: 600, + wingspan_m: 2.5, + warhead_weight_kg: 20, + max_speed_ms: 200, + cruise_speed_kmh: 720, + endurance_min: 30, + warhead_type: '破片杀伤战斗部', + launch_mode: '箱式发射', + power_system: '电动机', + guidance_system: 'GPS/INS/光电', + max_payload_kg: 25, + ceiling_altitude_m: 5000, + combat_radius_km: 100, + datalink_range_km: 150, + guidance_accuracy_m: 3 }) const predictionResults = ref(null) const mlPrediction = ref(null) const plsPrediction = ref(null) -const handleTypeChange = () => { - // 重置特有参数 - if (formData.type === '火箭炮') { - formData.firing_angle_horizontal = null - formData.firing_angle_vertical = null - formData.rocket_length_m = null - formData.rocket_diameter_mm = null - formData.rocket_weight_kg = null - formData.rate_of_fire = null - } else if (formData.type === '巡飞弹') { - formData.wingspan_m = null - formData.warhead_weight_kg = null - formData.max_speed_ms = null - formData.cruise_speed_kmh = null - formData.endurance_min = null - formData.warhead_type = '' - formData.launch_mode = '' - formData.power_system = '' - formData.guidance_system = '' - } -} - const submitForm = async () => { try { // 验证必填字段 @@ -327,7 +301,7 @@ const submitForm = async () => { } } - // 获取预测结果 + // 同时调用两个预测接口 const [mlResponse, plsResponse] = await Promise.all([ axios.post(`${API_BASE_URL}/predict`, formData), axios.post(`${API_BASE_URL}/pls/predict`, formData) @@ -396,9 +370,96 @@ const getModelName = (modelType) => { } return modelNames[modelType] || modelType } + +const handleTypeChange = () => { + // 清空预测结果 + predictionResults.value = false + mlPrediction.value = null + plsPrediction.value = null + + // 重置特有参数 + if (formData.type === '火箭炮') { + // 设置火箭炮的默认值 + formData.length_m = 7.35 + formData.width_m = 2.4 + formData.height_m = 3.1 + formData.weight_kg = 13700 + formData.max_range_km = 20.4 + formData.firing_angle_horizontal = 102 + formData.firing_angle_vertical = 55 + formData.rocket_length_m = 2.87 + formData.rocket_diameter_mm = 122 + formData.rocket_weight_kg = 66.6 + formData.rate_of_fire = 40 + formData.combat_weight_kg = 15000 + formData.speed_kmh = 60 + formData.min_range_km = 5 + formData.mobility_type = '轮式' + formData.structure_layout = '6x6轮式底盘' + formData.engine_model = 'WD615' + formData.engine_params = '6缸直列柴油机' + formData.power_hp = 280 + formData.travel_range_km = 600 + + // 清空巡飞弹参数 + formData.wingspan_m = null + formData.warhead_weight_kg = null + formData.max_speed_ms = null + formData.cruise_speed_kmh = null + formData.endurance_min = null + formData.warhead_type = '' + formData.launch_mode = '' + formData.power_system = '' + formData.guidance_system = '' + formData.max_payload_kg = null + formData.ceiling_altitude_m = null + formData.combat_radius_km = null + formData.datalink_range_km = null + formData.guidance_accuracy_m = null + + } else if (formData.type === '巡飞弹') { + // 设置巡飞弹的默认值 + formData.length_m = 2.5 + formData.width_m = 0.4 + formData.height_m = 0.4 + formData.weight_kg = 120 + formData.max_range_km = 100 + formData.wingspan_m = 2.5 + formData.warhead_weight_kg = 20 + formData.max_speed_ms = 200 + formData.cruise_speed_kmh = 720 + formData.endurance_min = 30 + formData.warhead_type = '破片杀伤战斗部' + formData.launch_mode = '箱式发射' + formData.power_system = '电动机' + formData.guidance_system = 'GPS/INS/光电' + formData.max_payload_kg = 25 + formData.ceiling_altitude_m = 5000 + formData.combat_radius_km = 100 + formData.datalink_range_km = 150 + formData.guidance_accuracy_m = 3 + + // 清空火箭炮参数 + formData.firing_angle_horizontal = null + formData.firing_angle_vertical = null + formData.rocket_length_m = null + formData.rocket_diameter_mm = null + formData.rocket_weight_kg = null + formData.rate_of_fire = null + formData.combat_weight_kg = null + formData.speed_kmh = null + formData.min_range_km = null + formData.mobility_type = '' + formData.structure_layout = '' + formData.engine_model = '' + formData.engine_params = '' + formData.power_hp = null + formData.travel_range_km = null + } +} - \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7e6e325..7ca60a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "torch==2.5.1", "torchvision==0.20.1", "torchaudio==2.5.1", + "xgboost>=2.1.0", # 添加 XGBoost + "lightgbm>=4.5.0", # 添加 LightGBM # 工具 "openpyxl>=3.1.5", # Excel支持 diff --git a/requirements.txt b/requirements.txt index f168f07..1a55f53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,8 @@ cryptography>=43.0.0 # MySQL 8.0+ 认证需要 mysql-connector-python>=8.0.0 # 添加这行 numpy>=1.26.0,<2.0.0 pandas>=2.2.0 +xgboost>=2.1.0 +lightgbm>=4.5.0 scikit-learn>=1.5.2 diff --git a/src/cost_prediction.py b/src/cost_prediction.py index c4f00fc..07eb378 100644 --- a/src/cost_prediction.py +++ b/src/cost_prediction.py @@ -81,50 +81,64 @@ class CostPredictor: self.model = DefaultModel(X.shape[1]).to(self.device) self.equipment_type = '火箭炮' - def predict(self, data): - """ - 使用训练好的模型进行预测 - """ + def predict(self, data, model_record): + """使用训练好的模型进行预测""" try: - logger.info(f"Starting prediction for {data.get('type')}") + logger.info(f"Starting prediction for {data.get('type')} using {model_record['model_type']}") equipment_type = data.get('type') - # 加载已训练的最优模型 - trainer = ModelTrainer() - if not trainer.load_model(equipment_type): - raise ValueError(f"No trained model found for {equipment_type}") + # 使用ModelTrainer加载模型 + model_trainer = ModelTrainer() + success = model_trainer.load_model(equipment_type, model_record['model_type']) + if not success: + raise ValueError(f"Failed to load model for {equipment_type}") + + # 从ModelTrainer获取模型和标准化器 + model = model_trainer.model + feature_scaler = model_trainer.feature_scaler + target_scaler = model_trainer.target_scaler # 准备特征数据 - features = self.feature_analyzer.get_equipment_specific_features(equipment_type) + feature_analyzer = FeatureAnalysis() + features = feature_analyzer.get_equipment_specific_features(equipment_type) X = [] for feature in features: value = data.get(feature, 0.0) X.append(float(value)) - # 转换为 tensor - X = torch.tensor([X], dtype=torch.float32).to(self.device) + # 转换为numpy数组并标准化 + X = np.array([X]) + X_scaled = feature_scaler.transform(X) - # 预测 - with torch.no_grad(): - trainer.model.eval() # 设置为评估模式 - y_pred = trainer.model(X) + # 根据模型类型进行预测 + if isinstance(model, torch.nn.Module): + # PyTorch模型预测 + model.eval() + with torch.no_grad(): + X_tensor = torch.FloatTensor(X_scaled).to(self.device) + y_pred = model(X_tensor) + y_pred = y_pred.cpu().numpy() + elif model_record['model_type'] == 'pls': + # PLS模型预测 + y_pred = model.predict(X_scaled).reshape(-1, 1) + else: + # 其他sklearn模型预测 + y_pred = model.predict(X_scaled).reshape(-1, 1) - # 转回 numpy - y_pred = y_pred.cpu().numpy() + # 转换回原始尺度并确保为正数 + y_pred_original = target_scaler.inverse_transform(y_pred) + predicted_cost = abs(float(y_pred_original[0][0])) # 确保预测值为正数 # 计算置信区间 - confidence_interval = self._calculate_confidence_interval(y_pred[0]) - - # 获取模型类型 - model_type = trainer.get_model_type() + std = predicted_cost * 0.2 # 使用预测值的20%作为标准差 + confidence_interval = { + 'lower': max(predicted_cost - std, predicted_cost * 0.5), # 至少是预测值的50% + 'upper': predicted_cost + std + } return { - 'predicted_cost': float(y_pred[0]), - 'model_type': model_type, - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } + 'predicted_cost': predicted_cost, + 'confidence_interval': confidence_interval } except Exception as e: diff --git a/src/data_preparation.py b/src/data_preparation.py index 8b83330..b0f6d8e 100644 --- a/src/data_preparation.py +++ b/src/data_preparation.py @@ -30,34 +30,11 @@ class DataPreparation: self.target_scaler = StandardScaler() def prepare_training_data(self, equipment_data, equipment_type, batch_size=32): - """ - 准备训练数据 - """ + """准备训练数据""" try: logger.info(f"Preparing training data for {equipment_type}") - logger.info(f"Raw data size: {len(equipment_data)}") - # 如果输入已经是 numpy 数组,转换为 torch.Tensor - if isinstance(equipment_data, np.ndarray): - X = equipment_data - logger.info(f"Input is numpy array with shape: {X.shape}") - - # 处理无效值 - X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) - - # 转换为 PyTorch 数据集 - dataset = EquipmentDataset(X) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) - - return { - 'dataloader': dataloader, - 'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type), - 'feature_scaler': self.feature_scaler, - 'target_scaler': self.target_scaler, - 'raw_shape': X.shape - } - - # 从原始数据中提取特征和目标值 + # 获取特征名称(包含生产商特征) feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type) features = [] targets = [] @@ -66,18 +43,24 @@ class DataPreparation: with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) + # 获取所有生产商数据,用于计算特征 + cursor.execute(""" + SELECT * FROM manufacturers + """) + manufacturers = {row['id']: row for row in cursor.fetchall()} + for item in equipment_data: - # 获取该装备的生产商数据 - manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor) + # 获取生产商数据 + manufacturer = manufacturers.get(item['manufacturer_id'], {}) # 计算生产商特征 - manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data) + manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer) # 合并装备特征和生产商特征 feature_values = [] for name in feature_names: - if name in manufacturer_features: - value = manufacturer_features[name] + if name.startswith('manufacturer_'): + value = manufacturer_features.get(name, 0.0) else: value = item.get(name) feature_values.append(float(value) if value is not None else 0.0) @@ -89,7 +72,7 @@ class DataPreparation: X = np.array(features, dtype=float) y = np.array(targets, dtype=float) - # 记录原始数据范围 + # 记录数据范围 logger.info(f"Raw X range: min={X.min()}, max={X.max()}") logger.info(f"Raw y range: min={y.min()}, max={y.max()}") @@ -97,7 +80,7 @@ class DataPreparation: X_scaled = self.feature_scaler.fit_transform(X) y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel() - # 创建 PyTorch 数据集和数据加载器 + # 创建数据集和数据加载器 dataset = EquipmentDataset(X_scaled, y_scaled) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) @@ -106,7 +89,9 @@ class DataPreparation: 'feature_names': feature_names, 'feature_scaler': self.feature_scaler, 'target_scaler': self.target_scaler, - 'raw_shape': X.shape + 'raw_shape': X.shape, + 'X': X_scaled, + 'y': y_scaled } except Exception as e: @@ -165,7 +150,6 @@ class DataPreparation: # 提取目标值(成本)并验证范围 try: cost = float(item['actual_cost']) - logger.info(f"Raw cost value: {cost}") if cost > 0: # 只使用正数成本值 targets.append(cost) else: diff --git a/src/model_trainer.py b/src/model_trainer.py index 9da7c39..aac18de 100644 --- a/src/model_trainer.py +++ b/src/model_trainer.py @@ -4,6 +4,7 @@ import torch.nn as nn from torch.utils.data import DataLoader from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split +from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error import logging import os from datetime import datetime @@ -12,26 +13,104 @@ from src.feature_analysis import FeatureAnalysis from src.database import get_db_connection from src.data_preparation import DataPreparation, EquipmentDataset from .logger import setup_logger +import math logger = setup_logger(__name__) class CostPredictionModel(nn.Module): - def __init__(self, input_size): + def __init__(self, input_size, equipment_type): super().__init__() - self.layers = nn.Sequential( - nn.Linear(input_size, 128), - nn.ReLU(), - nn.Dropout(0.3), - nn.Linear(128, 64), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(64, 32), - nn.ReLU(), - nn.Linear(32, 1) - ) + self.equipment_type = equipment_type + if equipment_type == '火箭炮': + # 火箭炮使用更简单和稳定的网络结构 + self.net = nn.Sequential( + # 第一层:特征映射 + nn.Linear(input_size, 32), + nn.ReLU(), + nn.BatchNorm1d(32), + + # 第二层:特征提取 + nn.Linear(32, 16), + nn.ReLU(), + nn.BatchNorm1d(16), + + # 第三层:特征整合 + nn.Linear(16, 8), + nn.ReLU(), + nn.BatchNorm1d(8), + + # 输出层 + nn.Linear(8, 1) + ) + + # 使用正交初始化 + def init_weights(m): + if isinstance(m, nn.Linear): + torch.nn.init.orthogonal_(m.weight, gain=0.5) + torch.nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm1d): + torch.nn.init.constant_(m.weight, 0.5) + torch.nn.init.constant_(m.bias, 0.0) + + self.net.apply(init_weights) + + else: # 巡飞弹保持原有结构 + # 生产商特征网络 - 更简单的结构 + self.manufacturer_net = nn.Sequential( + nn.Linear(5, 4), + nn.ReLU(), + nn.BatchNorm1d(4), + nn.Dropout(0.2) + ) + + # 巡飞弹特征网络 - 较深的结构 + self.equipment_net = nn.Sequential( + nn.Linear(input_size - 5, 64), + nn.LeakyReLU(0.1), + nn.BatchNorm1d(64), + nn.Dropout(0.2), + nn.Linear(64, 32), + nn.LeakyReLU(0.1), + nn.BatchNorm1d(32), + nn.Dropout(0.2), + nn.Linear(32, 16), + nn.LeakyReLU(0.1), + nn.BatchNorm1d(16), + nn.Dropout(0.2) + ) + + # 合并网络 - 较复杂的结构 + self.combined_net = nn.Sequential( + nn.Linear(20, 32), # 4 + 16 = 20 + nn.LeakyReLU(0.1), + nn.BatchNorm1d(32), + nn.Dropout(0.2), + nn.Linear(32, 16), + nn.LeakyReLU(0.1), + nn.BatchNorm1d(16), + nn.Dropout(0.2), + nn.Linear(16, 8), + nn.LeakyReLU(0.1), + nn.BatchNorm1d(8), + nn.Linear(8, 1) + ) + def forward(self, x): - return self.layers(x) + if self.equipment_type == '火箭炮': + return self.net(x) + else: + # 分离特征 + manufacturer_features = x[:, -5:] + equipment_features = x[:, :-5] + + # 特征处理 + manu_out = self.manufacturer_net(manufacturer_features) + equip_out = self.equipment_net(equipment_features) + + # 特征融合 + combined = torch.cat([equip_out, manu_out], dim=1) + return self.combined_net(combined) class ModelTrainer: def __init__(self): @@ -42,24 +121,86 @@ class ModelTrainer: self.equipment_type = None self.feature_analyzer = FeatureAnalysis() - def train_model(self, dataloader, epochs=100, learning_rate=0.001): + def train_model(self, dataloader, epochs=100, learning_rate=0.001, equipment_type=None): """训练模型""" try: - # 获取输入特征维度 sample_features, _ = next(iter(dataloader)) input_size = sample_features.shape[1] - # 创建模型 - self.model = CostPredictionModel(input_size).to(self.device) - criterion = nn.MSELoss() - optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) + # 设置确定性 + torch.manual_seed(42) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + np.random.seed(42) + + self.model = CostPredictionModel(input_size, equipment_type).to(self.device) + + if equipment_type == '火箭炮': + # 火箭炮使用更保守和稳定的训练设置 + criterion = nn.SmoothL1Loss(beta=0.1) # 使用Huber损失,beta值较小 + learning_rate = 0.0003 # 更小的学习率 + weight_decay = 0.001 # 适中的权重衰减 + + # 使用AdamW优化器,更小的beta值 + optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.8, 0.9), # 更小的动量值 + eps=1e-8 + ) + + # 使用带预热的学习率调度 + num_steps = len(dataloader) * epochs + warmup_steps = num_steps // 10 # 10%的预热步数 + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return 0.5 * (1.0 + math.cos( + math.pi * (current_step - warmup_steps) / float(max(1, num_steps - warmup_steps)) + )) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + else: # 巡飞弹保持原有设置 + # 巡飞弹使用更激进的训练设置 + criterion = nn.MSELoss() + learning_rate = 0.001 # 较大的学习率 + weight_decay = 0.001 # 较小的权重衰减 + patience = 20 # 较短的耐心值 + + # 使用Adam优化器 + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.9, 0.999) + ) + + # 使用余弦退火学习率调度 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=epochs, + eta_min=learning_rate * 0.01 + ) # 训练循环 + best_loss = float('inf') + patience = 30 + patience_counter = 0 + best_model_state = None + moving_avg_loss = None + alpha = 0.9 # 移动平均系数 + for epoch in range(epochs): self.model.train() total_loss = 0 + batch_count = 0 + for batch_features, batch_targets in dataloader: - # 移动数据到设备 batch_features = batch_features.to(self.device) batch_targets = batch_targets.to(self.device) @@ -68,16 +209,57 @@ class ModelTrainer: loss = criterion(outputs, batch_targets.view(-1, 1)) # 反向传播 - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) # 更高效的梯度清零 loss.backward() + + # 梯度裁剪 + if equipment_type == '火箭炮': + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + max_norm=0.1 + ) + optimizer.step() + if equipment_type == '火箭炮': + scheduler.step() total_loss += loss.item() + batch_count += 1 + + avg_loss = total_loss / batch_count + + # 使用移动平均计算损失 + if moving_avg_loss is None: + moving_avg_loss = avg_loss + else: + moving_avg_loss = alpha * moving_avg_loss + (1 - alpha) * avg_loss + + # 早停检查使用移动平均损失 + if moving_avg_loss < best_loss: + best_loss = moving_avg_loss + patience_counter = 0 + best_model_state = { + 'state_dict': self.model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict() if equipment_type == '火箭炮' else None + } + else: + patience_counter += 1 - # 记录训练进度 if (epoch + 1) % 10 == 0: - avg_loss = total_loss / len(dataloader) - logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}') + logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {moving_avg_loss:.4f}, ' + f'LR: {optimizer.param_groups[0]["lr"]:.6f}') + + if patience_counter >= patience: + logger.info(f"Early stopping triggered at epoch {epoch+1}") + break + + # 恢复最佳模型 + if best_model_state is not None: + self.model.load_state_dict(best_model_state['state_dict']) + optimizer.load_state_dict(best_model_state['optimizer']) + if equipment_type == '火箭炮' and best_model_state['scheduler']: + scheduler.load_state_dict(best_model_state['scheduler']) return True @@ -85,62 +267,78 @@ class ModelTrainer: logger.error(f"Error in model training: {str(e)}") raise - def save_model(self, equipment_type): + def save_model(self, equipment_type, metrics=None): """保存模型""" try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_dir = 'models' os.makedirs(model_dir, exist_ok=True) + # 转换装备类型为英文 + equipment_type_en = 'rocket' if equipment_type == '火箭炮' else 'missile' + # 保存模型 - model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth' + model_path = f'{model_dir}/{equipment_type_en}_{timestamp}.pth' torch.save({ 'model_state_dict': self.model.state_dict(), - 'input_size': self.model.layers[0].in_features + 'input_size': self.model.equipment_net[0].in_features + 5, + 'manufacturer_net_state': self.model.manufacturer_net.state_dict(), + 'equipment_net_state': self.model.equipment_net.state_dict(), + 'combined_net_state': self.model.combined_net.state_dict() }, model_path) # 保存标准化器 - scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth' + scaler_path = f'{model_dir}/{equipment_type_en}_{timestamp}_scaler.pth' torch.save({ 'feature_scaler': self.feature_scaler, 'target_scaler': self.target_scaler }, scaler_path) + # 获取评估指标 + r2 = metrics['validation']['r2'] if metrics and metrics.get('validation') else metrics['train']['r2'] + mae = metrics['validation']['mae'] if metrics and metrics.get('validation') else metrics['train']['mae'] + rmse = metrics['validation']['rmse'] if metrics and metrics.get('validation') else metrics['train']['rmse'] + # 更新数据库 with get_db_connection() as conn: cursor = conn.cursor() - # 将所有模型设置为非激活 + # 将所有同类型模型设置为非激活, 除了 PLS 模型 cursor.execute(""" UPDATE trained_models SET is_active = FALSE - WHERE equipment_type = %s - """, (equipment_type,)) + WHERE equipment_type = %s AND model_type != %s + """, (equipment_type, 'pls')) # 保存新模型记录 cursor.execute(""" INSERT INTO trained_models ( model_name, model_type, equipment_type, model_path, - scaler_path, training_date, is_active, created_by - ) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s) + scaler_path, training_date, is_active, created_by, + r2_score, mae, rmse + ) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s) """, ( f"{equipment_type}_{timestamp}", 'pytorch', equipment_type, model_path, scaler_path, - 'system' + 'system', + r2, + mae, + rmse )) conn.commit() - + + logger.info(f"Model saved successfully: {model_path}") return True except Exception as e: logger.error(f"Error saving model: {str(e)}") return False - def load_model(self, equipment_type): + def load_model(self, equipment_type, model_type): """加载模型""" try: # 从数据库获取最新的激活模型 @@ -148,25 +346,39 @@ class ModelTrainer: cursor = conn.cursor(dictionary=True) cursor.execute(""" SELECT * FROM trained_models - WHERE equipment_type = %s AND is_active = TRUE + WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE ORDER BY training_date DESC LIMIT 1 - """, (equipment_type,)) + """, (equipment_type, model_type)) model_record = cursor.fetchone() if not model_record: - return False + raise ValueError(f"No trained model found for {equipment_type}") # 加载模型 - checkpoint = torch.load(model_record['model_path']) - input_size = checkpoint['input_size'] - self.model = CostPredictionModel(input_size).to(self.device) - self.model.load_state_dict(checkpoint['model_state_dict']) - - # 加载标准化器 - scalers = torch.load(model_record['scaler_path']) - self.feature_scaler = scalers['feature_scaler'] - self.target_scaler = scalers['target_scaler'] + if model_record['model_type'] == 'pytorch': + # 加载PyTorch模型 + checkpoint = torch.load(model_record['model_path'], encoding='latin1') + input_size = checkpoint['input_size'] + + # 创建新模型实例 + self.model = CostPredictionModel(input_size, equipment_type).to(self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + + # 加载标准化器 + scalers = torch.load(model_record['scaler_path'], encoding='latin1') + self.feature_scaler = scalers['feature_scaler'] + self.target_scaler = scalers['target_scaler'] + else: + # 加载sklearn模型 + from joblib import load + with open(model_record['model_path'], 'rb') as f: + self.model = load(f) + with open(model_record['scaler_path'], 'rb') as f: + scalers = load(f) + self.feature_scaler = scalers['feature_scaler'] + self.target_scaler = scalers['target_scaler'] + logger.info(f"Model loaded successfully from {model_record['model_path']}") return True except Exception as e: @@ -178,12 +390,356 @@ class ModelTrainer: try: self.model.eval() with torch.no_grad(): - # 转换为tensor并移动到正确的设备 features_tensor = torch.FloatTensor(features).to(self.device) - # 进行预测 - predictions = self.model(features_tensor) - # 移回CPU并转换为numpy数组 + predictions = self.model(features_tensor) # 直接返回预测值 return predictions.cpu().numpy() except Exception as e: logger.error(f"Error in prediction: {str(e)}") - raise \ No newline at end of file + raise + + def fit_model(self, X_train, y_train, models, X_val=None, y_val=None, equipment_type=None): + """训练模型并返回评估结果""" + try: + logger.info(f"Starting model training for {equipment_type}") + logger.info(f"Selected models: {models}") + logger.info(f"Training data shape: {X_train.shape}") + + all_metrics = {} + best_model = None + best_score = float('-inf') + best_model_type = None # 添加变量记录最佳模型类型 + + # 训练所有选择的模型 + for model_type in models: + logger.info(f"Training {model_type} model...") + + if model_type == 'pls': + # PLS模型单独处理,不参与最优模型评选 + from sklearn.cross_decomposition import PLSRegression + + # 使用较少的组件数来避免过拟合 + n_components = min(3, X_train.shape[1] // 5) + model = PLSRegression( + n_components=n_components, + scale=True, + max_iter=500, + tol=1e-6 + ) + model.fit(X_train, y_train) + + # 评估PLS模型 + y_train_pred = model.predict(X_train).ravel() + if X_val is not None: + y_val_pred = model.predict(X_val).ravel() + + # 将预测值转换回原始尺度 + y_train_pred_original = self.target_scaler.inverse_transform(y_train_pred.reshape(-1, 1)).ravel() + y_train_original = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel() + + train_metrics = { + 'r2': float(r2_score(y_train_original, y_train_pred_original)), + 'mae': float(mean_absolute_error(y_train_original, y_train_pred_original)), + 'rmse': float(np.sqrt(mean_squared_error(y_train_original, y_train_pred_original))) + } + + val_metrics = None + if X_val is not None: + y_val_pred_original = self.target_scaler.inverse_transform(y_val_pred.reshape(-1, 1)).ravel() + y_val_original = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel() + + val_metrics = { + 'r2': float(r2_score(y_val_original, y_val_pred_original)), + 'mae': float(mean_absolute_error(y_val_original, y_val_pred_original)), + 'rmse': float(np.sqrt(mean_squared_error(y_val_original, y_val_pred_original))) + } + + all_metrics[model_type] = { + 'train': train_metrics, + 'validation': val_metrics + } + + # 保存PLS模型,但不参与最优模型评选 + if equipment_type: + self._save_sklearn_model(equipment_type, model_type, model, all_metrics[model_type]) + + continue # 跳过后续的最优模型评选 + + elif model_type == 'xgboost': + import xgboost as xgb + model = xgb.XGBRegressor( + n_estimators=50, + learning_rate=0.03, + max_depth=3, + min_child_weight=5, + subsample=0.6, + colsample_bytree=0.6, + reg_alpha=0.5, + reg_lambda=2.0, + gamma=1, + random_state=42 + ) + # 训练模型 + model.fit(X_train, y_train) + + elif model_type == 'lightgbm': + import lightgbm as lgb + model = lgb.LGBMRegressor( + n_estimators=50, + learning_rate=0.03, + max_depth=3, + num_leaves=8, + subsample=0.6, + colsample_bytree=0.6, + reg_alpha=0.5, + reg_lambda=2.0, + min_child_samples=10, + min_split_gain=1.0, + random_state=42 + ) + # 训练模型 + model.fit(X_train, y_train) + + elif model_type == 'gbm': + from sklearn.ensemble import GradientBoostingRegressor + model = GradientBoostingRegressor( + n_estimators=50, + learning_rate=0.03, + max_depth=3, + min_samples_split=10, + min_samples_leaf=5, + subsample=0.6, + min_impurity_decrease=0.01, + random_state=42 + ) + # 训练模型 + model.fit(X_train, y_train) + + elif model_type == 'rf': + from sklearn.ensemble import RandomForestRegressor + model = RandomForestRegressor( + n_estimators=100, + max_depth=4, + min_samples_split=5, + min_samples_leaf=3, + max_features='sqrt', + bootstrap=True, + random_state=42 + ) + # 训练模型 + model.fit(X_train, y_train) + + elif model_type == 'pytorch': + # 训练PyTorch模型 + train_dataset = EquipmentDataset(X_train, y_train) + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + training_success = self.train_model(train_loader) + if not training_success: + continue + + # 评估模型性能 + if model_type == 'pytorch': + with torch.no_grad(): + X_train_tensor = torch.FloatTensor(X_train).to(self.device) + y_train_pred = self.model(X_train_tensor).cpu().numpy() # 直接获取输出 + + if X_val is not None: + X_val_tensor = torch.FloatTensor(X_val).to(self.device) + y_val_pred = self.model(X_val_tensor).cpu().numpy() # 直接获取输出 + else: + # 使用训练好的模型进行预测 + y_train_pred = model.predict(X_train) + if X_val is not None: + if model_type == 'pls': + y_val_pred = model.predict(X_val).ravel() + + # 记录PLS的一些额外信息 + if hasattr(model, 'score'): + train_r2 = model.score(X_train, y_train) + val_r2 = model.score(X_val, y_val) + logger.info(f"PLS built-in R² - Train: {train_r2:.4f}, Val: {val_r2:.4f}") + + # 记录每个组件解释的方差比例 + if hasattr(model, 'explained_variance_ratio_'): + logger.info("PLS explained variance ratios: " + + ", ".join([f"{v:.4f}" for v in model.explained_variance_ratio_])) + else: + y_val_pred = model.predict(X_val) + + # 将测值转换回始尺度 + y_train_pred_original = self.target_scaler.inverse_transform(y_train_pred.reshape(-1, 1)).ravel() + y_train_original = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel() + + train_metrics = { + 'r2': float(r2_score(y_train_original, y_train_pred_original)), + 'mae': float(mean_absolute_error(y_train_original, y_train_pred_original)), + 'rmse': float(np.sqrt(mean_squared_error(y_train_original, y_train_pred_original))) + } + + val_metrics = None + if X_val is not None and y_val is not None: + y_val_pred_original = self.target_scaler.inverse_transform(y_val_pred.reshape(-1, 1)).ravel() + y_val_original = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel() + + val_metrics = { + 'r2': float(r2_score(y_val_original, y_val_pred_original)), + 'mae': float(mean_absolute_error(y_val_original, y_val_pred_original)), + 'rmse': float(np.sqrt(mean_squared_error(y_val_original, y_val_pred_original))) + } + + all_metrics[model_type] = { + 'train': train_metrics, + 'validation': val_metrics + } + + # 更新最佳模型(不包括PLS) + current_score = val_metrics['r2'] if val_metrics else train_metrics['r2'] + if model_type != 'pls' and current_score > best_score: + best_score = current_score + best_model = { + 'type': model_type, + 'r2': current_score, + 'mae': val_metrics['mae'] if val_metrics else train_metrics['mae'], + 'rmse': val_metrics['rmse'] if val_metrics else train_metrics['rmse'] + } + best_model_type = model_type # 记录最佳模型类型 + + # 保存最佳模型实例(但不立即写入数据库) + if model_type == 'pytorch': + self.best_pytorch_model = self.model.state_dict() # 保存模型状态 + self.best_pytorch_metrics = all_metrics[model_type] # 保存指标 + else: + self.best_model = model + self.best_model_metrics = all_metrics[model_type] + + # 单独保存PLS模型(不参与最佳模型评选) + if model_type == 'pls' and equipment_type: + self._save_sklearn_model(equipment_type, model_type, model, all_metrics[model_type]) + + # 在所有模型训练完成后,只保存最佳模型 + if best_model_type and equipment_type: + if best_model_type == 'pytorch': + # 恢复最佳PyTorch模型状态并保存 + self.model.load_state_dict(self.best_pytorch_model) + self.save_model(equipment_type, self.best_pytorch_metrics) + else: + # 保存最佳sklearn模型 + self._save_sklearn_model(equipment_type, best_model_type, self.best_model, self.best_model_metrics) + + return { + 'metrics': all_metrics, + 'feature_importance': None, + 'best_model': best_model + } + + except Exception as e: + logger.error(f"Error in model fitting: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + raise + + def _calculate_feature_importance(self, X): + """计算特征重要性""" + try: + if self.model is None: + return None + + self.model.eval() + feature_importance = np.zeros(X.shape[1]) + + # 使用特征扰动计算重要性 + with torch.no_grad(): + X_tensor = torch.FloatTensor(X).to(self.device) + baseline_pred = self.model(X_tensor).cpu().numpy() # 直接获取预测值 + + for i in range(X.shape[1]): + # 创建动后的特征 + X_perturbed = X.copy() + X_perturbed[:, i] = np.random.permutation(X_perturbed[:, i]) + + # 预测并计算影响 + X_perturbed_tensor = torch.FloatTensor(X_perturbed).to(self.device) + perturbed_pred = self.model(X_perturbed_tensor).cpu().numpy() # 直接获取预测值 + + # 特征重要性为预变化的平均绝对值 + feature_importance[i] = np.mean(np.abs(baseline_pred - perturbed_pred)) + + # 归一化特征重要性 + feature_importance = feature_importance / np.sum(feature_importance) + + return feature_importance + + except Exception as e: + logger.error(f"Error calculating feature importance: {str(e)}") + return None + + def _save_sklearn_model(self, equipment_type, model_type, model, metrics=None): + """保存sklearn类型的模型""" + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_dir = 'models' + os.makedirs(model_dir, exist_ok=True) + + # 转换装备类型为英文 + equipment_type_en = 'rocket' if equipment_type == '火箭炮' else 'missile' + + # 保存模型 + model_path = f'{model_dir}/{equipment_type_en}_{model_type}_{timestamp}.joblib' + from joblib import dump + dump(model, model_path) + + # 保存标准化器 + scaler_path = f'{model_dir}/{equipment_type_en}_{model_type}_{timestamp}_scaler.joblib' + dump({ + 'feature_scaler': self.feature_scaler, + 'target_scaler': self.target_scaler + }, scaler_path) + + # 获取评估指标 + r2 = metrics['validation']['r2'] if metrics and metrics.get('validation') else metrics['train']['r2'] + mae = metrics['validation']['mae'] if metrics and metrics.get('validation') else metrics['train']['mae'] + rmse = metrics['validation']['rmse'] if metrics and metrics.get('validation') else metrics['train']['rmse'] + + # 更新数据库 + with get_db_connection() as conn: + cursor = conn.cursor() + + # 将同类型的其他模型设置为非激活 + if model_type != 'pls': + cursor.execute(""" + UPDATE trained_models + SET is_active = FALSE + WHERE equipment_type = %s AND model_type != %s + """, (equipment_type, 'pls')) + else: + cursor.execute(""" + UPDATE trained_models + SET is_active = FALSE + WHERE equipment_type = %s AND model_type = %s + """, (equipment_type, 'pls')) + + # 保存新模型记录 + cursor.execute(""" + INSERT INTO trained_models ( + model_name, model_type, equipment_type, model_path, + scaler_path, training_date, is_active, created_by, + r2_score, mae, rmse + ) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s) + """, ( + f"{equipment_type}_{model_type}_{timestamp}", + model_type, + equipment_type, + model_path, + scaler_path, + 'system', + r2, + mae, + rmse + )) + + conn.commit() + + logger.info(f"Model {model_type} saved successfully: {model_path}") + return True + + except Exception as e: + logger.error(f"Error saving sklearn model: {str(e)}") + return False \ No newline at end of file diff --git a/src/routes.py b/src/routes.py index ffcf544..fbdaa84 100644 --- a/src/routes.py +++ b/src/routes.py @@ -12,6 +12,7 @@ import os from .data_preparation import DataPreparation from .model_trainer import ModelTrainer from .logger import setup_logger +import torch # 创建蓝图 api_bp = Blueprint('api', __name__) @@ -58,49 +59,45 @@ def index(): }) @api_bp.route('/predict', methods=['POST']) -def predict_cost(): - """ - 成本预测接口 - """ +def predict(): + """使用最优机器学习模型进行预测""" try: data = request.get_json() - logger.info(f"Received prediction request for equipment type: {data.get('type')}") + equipment_type = data.get('type') + logger.info(f"Received prediction request for equipment type: {equipment_type}") - # 验证装备类型 - if 'type' not in data: - return jsonify({'error': 'Equipment type is required'}), 400 - - # 预测成本 - predictor = CostPredictor() - result = predictor.predict(data) - - # 获取当前使用的模型信息 + # 获取最新的激活模型(非PLS模型) with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) cursor.execute(""" - SELECT model_type, model_name, r2_score, mae, rmse - FROM trained_models - WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE - LIMIT 1 - """, (data['type'],)) - model_info = cursor.fetchone() - - # 在结果中添加模型信息 - result.update({ - 'model_info': { - 'type': model_info['model_type'], - 'name': model_info['model_name'], - 'r2_score': float(model_info['r2_score']), - 'mae': float(model_info['mae']), - 'rmse': float(model_info['rmse']) - } - }) + SELECT * FROM trained_models + WHERE equipment_type = %s + AND model_type != 'pls' # 明确排除PLS模型 + AND is_active = TRUE + ORDER BY training_date DESC LIMIT 1 + """, (equipment_type,)) + model = cursor.fetchone() - logger.info(f"Prediction completed: {result}") - return jsonify(result) + if not model: + return jsonify({'error': '未找到可用的模型'}), 404 + + # 使用普通预测方法 + predictor = CostPredictor() + prediction = predictor.predict(data, model) # 使用普通predict方法 + + # 返回预测结果 + return jsonify({ + 'model_info': { + 'type': model['model_type'], # 使用数据库中的模型类型 + 'name': model['model_name'] + }, + 'predicted_cost': prediction['predicted_cost'], + 'confidence_interval': prediction['confidence_interval'] + }) except Exception as e: logger.error(f"Error in prediction: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) return jsonify({'error': str(e)}), 500 @api_bp.route('/analyze-features', methods=['POST']) @@ -180,27 +177,7 @@ def analyze_features(): """, (dataset_id,)) equipment_data = cursor.fetchall() - - # 添加数据检查日志 - logger.info(f"Total records found: {len(equipment_data)}") - if equipment_data: - # 检查第一条记录的所有字段 - first_record = equipment_data[0] - logger.info("First record details:") - for key, value in first_record.items(): - logger.info(f"{key}: {value}") - - # 检查所有记录的 max_range_km 字段 - logger.info("Checking max_range_km for all records:") - for item in equipment_data: - logger.info(f"Equipment: {item['name']}") - logger.info(f" max_range_km: {item.get('max_range_km')}") - logger.info(f" type: {item['type']}") - if item['type'] == '火箭炮': - logger.info(f" rocket_artillery_params fields:") - for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']: - logger.info(f" {key}: {item.get(key)}") - + # 提取特征和目标值 analyzer = FeatureAnalysis() feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type']) @@ -270,9 +247,7 @@ def analyze_features(): @api_bp.route('/train', methods=['POST']) def train_model(): - """ - 训练模型 - """ + """训练模型""" try: data = request.get_json() logger.info(f"Starting model training for {data.get('type')}") @@ -289,54 +264,65 @@ def train_model(): with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) - # 获取训练集数据 + # 获取训练集数据(包含生产商信息) if equipment_type == '火箭炮': cursor.execute(""" - SELECT e.*, cp.*, rap.*, cd.actual_cost + SELECT e.*, cp.*, rap.*, cd.actual_cost, + m.tech_level, m.scale_level, m.supply_chain_level, + m.id as manufacturer_id FROM equipments e JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id + LEFT JOIN manufacturers m ON e.manufacturer_id = m.id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL """, (train_dataset_id,)) else: cursor.execute(""" - SELECT e.*, cp.*, lmp.*, cd.actual_cost + SELECT e.*, cp.*, lmp.*, cd.actual_cost, + m.tech_level, m.scale_level, m.supply_chain_level, + m.id as manufacturer_id FROM equipments e JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id + LEFT JOIN manufacturers m ON e.manufacturer_id = m.id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL """, (train_dataset_id,)) train_data = cursor.fetchall() - # 获取验证集据(如果有) - validation_data = None + # 获取验证集数据 if validation_dataset_id: if equipment_type == '火箭炮': cursor.execute(""" - SELECT e.*, cp.*, rap.*, cd.actual_cost + SELECT e.*, cp.*, rap.*, cd.actual_cost, + m.tech_level, m.scale_level, m.supply_chain_level, + m.id as manufacturer_id FROM equipments e JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id + LEFT JOIN manufacturers m ON e.manufacturer_id = m.id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL """, (validation_dataset_id,)) else: cursor.execute(""" - SELECT e.*, cp.*, lmp.*, cd.actual_cost + SELECT e.*, cp.*, lmp.*, cd.actual_cost, + m.tech_level, m.scale_level, m.supply_chain_level, + m.id as manufacturer_id FROM equipments e JOIN dataset_equipments de ON e.id = de.equipment_id LEFT JOIN common_params cp ON e.id = cp.equipment_id LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id + LEFT JOIN manufacturers m ON e.manufacturer_id = m.id WHERE de.dataset_id = %s AND cd.actual_cost IS NOT NULL """, (validation_dataset_id,)) @@ -351,7 +337,7 @@ def train_model(): # 准备训练数据 train_prepared = data_processor.prepare_training_data(train_data, equipment_type) - # 准备验数据(如果有) + # 准备验证数据(如果有) validation_prepared = None if validation_data: validation_prepared = data_processor.prepare_validation_data( @@ -369,14 +355,14 @@ def train_model(): model_trainer.feature_scaler = train_prepared['feature_scaler'] model_trainer.target_scaler = train_prepared['target_scaler'] - # 执行训练,传入 equipment_type 参数 + # 执行训练,传入equipment_type参数 training_result = model_trainer.fit_model( train_prepared['X'], train_prepared['y'], models, validation_prepared['X'] if validation_prepared else None, validation_prepared['y'] if validation_prepared else None, - equipment_type=equipment_type + equipment_type=equipment_type # 添加这个参数 ) return jsonify(training_result) @@ -588,65 +574,43 @@ def get_db_connection(): @api_bp.route('/pls/predict', methods=['POST']) def pls_predict(): - """ - PLS回归预测接口 - """ + """使用PLS模型进行预测""" try: data = request.get_json() - logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}") + equipment_type = data.get('type') - # 验证装备类型 - if 'type' not in data: - return jsonify({'error': 'Equipment type is required'}), 400 - - # 使用 ModelTrainer 中的 PLS 模型行预测 - trainer = ModelTrainer() - if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型 - return jsonify({'error': '未找到可用的模型'}), 404 - - # 准备特征数据 - feature_analyzer = FeatureAnalysis() - features = feature_analyzer.get_equipment_specific_features(data['type']) - X = np.array([[data.get(feature) for feature in features]]) - - # 预测 - result = trainer.predict(X) - - # 计算置信区间 - confidence_interval = trainer._calculate_confidence_interval(result[0]) - - # 获取模型信息 + # 获取最新的PLS模型 with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) cursor.execute(""" - SELECT model_type, model_name, r2_score, mae, rmse - FROM trained_models - WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE - LIMIT 1 - """, (data['type'],)) - model_info = cursor.fetchone() - - # 确保返回的数据可以序列化为JSON - response = { - 'predicted_cost': float(result[0]), - 'model_info': { - 'type': model_info['model_type'], - 'name': model_info['model_name'], - 'r2_score': model_info['r2_score'], - 'mae': model_info['mae'], - 'rmse': model_info['rmse'] - }, - 'confidence_interval': { - 'lower': float(confidence_interval[0]), - 'upper': float(confidence_interval[1]) - } - } + SELECT * FROM trained_models + WHERE equipment_type = %s + AND model_type = 'pls' # 只选择PLS模型 + AND is_active = TRUE + ORDER BY training_date DESC LIMIT 1 + """, (equipment_type,)) + model = cursor.fetchone() - logger.info(f"PLS prediction completed: {response}") - return jsonify(response) + if not model: + return jsonify({'error': '未找到可用的PLS模型'}), 404 + + # 使用普通predict方法,传入模型信息 + predictor = CostPredictor() + prediction = predictor.predict(data, model) # 传入model参数 + + # 返回预测结果 + return jsonify({ + 'model_info': { + 'type': 'pls', + 'name': model['model_name'] + }, + 'predicted_cost': prediction['predicted_cost'], + 'confidence_interval': prediction['confidence_interval'] + }) except Exception as e: logger.error(f"Error in PLS prediction: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) return jsonify({'error': str(e)}), 500 @api_bp.route('/data/import', methods=['POST']) @@ -852,7 +816,7 @@ def get_equipment_details(id): @api_bp.route('/datasets', methods=['GET']) def get_datasets(): """ - 获取数集列表 + 获取数据集列表 """ try: with get_db_connection() as conn: @@ -868,13 +832,20 @@ def get_datasets(): """) datasets = cursor.fetchall() - # 理装备名称列表 + # 整理装备名称列表 for dataset in datasets: if dataset['equipment_names']: dataset['equipment_names'] = dataset['equipment_names'].split(',') else: dataset['equipment_names'] = [] + # 格式化时间和数值字段 + for dataset in datasets: + if dataset['created_at']: + dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S') + if dataset['updated_at']: + dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S') + return jsonify(datasets) except Exception as e: logger.error(f"Error getting datasets: {str(e)}") @@ -930,6 +901,7 @@ def get_dataset(id): } dataset['equipment'] = equipment + return jsonify(dataset) except Exception as e: logger.error(f"Error getting dataset: {str(e)}") @@ -1067,7 +1039,7 @@ def delete_dataset(id): @api_bp.route('/models//latest', methods=['GET']) def get_latest_model(equipment_type): """ - 获取最新训练的型信息 + 获取最新训练的模型信息 """ try: with get_db_connection() as conn: @@ -1087,9 +1059,7 @@ def get_latest_model(equipment_type): @api_bp.route('/models', methods=['GET']) def get_models(): - """ - 获取模型列表 - """ + """获取模型列表""" try: with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) @@ -1100,8 +1070,11 @@ def get_models(): models = cursor.fetchall() - # 确保数值型字段是 float + # 格式化时间和数值字段 for model in models: + # 将数据库中的datetime转换为ISO格式字符串 + if model['training_date']: + model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S') if model['r2_score'] is not None: model['r2_score'] = float(model['r2_score']) if model['mae'] is not None: @@ -1121,16 +1094,14 @@ def get_models(): @api_bp.route('/models//activate', methods=['POST']) def activate_model(id): - """ - 激活定的模型 - """ + """激活指定的模型""" try: with get_db_connection() as conn: - cursor = conn.cursor() + cursor = conn.cursor(dictionary=True) # 使用字典游标 # 获取模型信息 cursor.execute(""" - SELECT equipment_type FROM trained_models + SELECT equipment_type, model_type FROM trained_models WHERE id = %s """, (id,)) model = cursor.fetchone() @@ -1142,8 +1113,8 @@ def activate_model(id): cursor.execute(""" UPDATE trained_models SET is_active = FALSE - WHERE equipment_type = %s - """, (model[0],)) + WHERE equipment_type = %s AND model_type = %s + """, (model['equipment_type'], model['model_type'])) # 激活指定模型 cursor.execute("""