diff --git a/frontend/src/views/DatasetPage.vue b/frontend/src/views/DatasetPage.vue index e832f4e..60e94c1 100644 --- a/frontend/src/views/DatasetPage.vue +++ b/frontend/src/views/DatasetPage.vue @@ -91,6 +91,8 @@ + + @@ -100,6 +102,7 @@ +
+ 请先选择装备类型 +
- + - + @@ -99,12 +99,12 @@ {{ formatNumber(scope.row.validation.r2) }} - + - + diff --git a/src/manufacturer_data.sql b/src/manufacturer_data.sql new file mode 100644 index 0000000..01cf09b --- /dev/null +++ b/src/manufacturer_data.sql @@ -0,0 +1,78 @@ +-- 插入供应商数据 +INSERT INTO manufacturers ( + name, -- 供应商名称 + country, -- 所属国家 + tech_level, -- 技术水平评分(1-10) + scale_level, -- 规模评分(1-10) + supply_chain_level -- 供应链成熟度评分(1-10) +) VALUES +-- 美国供应商 +('美国洛克希德·马丁', '美国', 10, 10, 10), -- 全球最大军工企业 +('美国 AeroVironment', '美国', 9, 8, 9), -- 无人机和导弹领域领先 +('美国 Raytheon', '美国', 9, 9, 9), -- 导弹技术领先 +('美国 AEVEX', '美国', 8, 7, 8), -- 新兴军工企业 +('美国 AREA-I', '美国', 8, 7, 8), -- 专注无人机系统 +('美国 Northrop Grumman', '美国', 9, 9, 9), -- 大型军工企业 + +-- 欧洲供应商 +('英国 BAE Systems', '英国', 8, 9, 9), -- 欧洲最大军工企业 +('英国 MBDA', '英国', 8, 8, 8), -- 导弹系统专家 +('德国 KMW', '德国', 9, 8, 9), -- 陆军装备主要供应商 +('德国 MBDA', '德国', 8, 8, 8), -- 导弹系统制造商 +('德国 Rheinmetall', '德国', 8, 8, 8), -- 综合军工企业 +('法国 Nexter', '法国', 8, 8, 8), -- 陆军装备制造商 +('法国 MBDA', '法国', 8, 8, 8), -- 导弹系统制造商 +('法国 Safran', '法国', 8, 8, 8), -- 航空航天企业 +('意大利 Leonardo', '意大利', 7, 7, 7), -- 综合军工企业 +('意大利 OTO Melara', '意大利', 7, 7, 7), -- 火炮系统制造商 + +-- 以色列供应商 +('以色列军事工业', '以色列', 9, 7, 7), -- 技术先进 +('以色列 IAI', '以色列', 9, 7, 7), -- 航空航天领先 +('以色列 UVision', '以色列', 8, 6, 7), -- 无人机专家 + +-- 中国供应商 +('中国兵器工业集团', '中国', 8, 9, 8), -- 陆军装备制造商 +('中国航天科工', '中国', 8, 9, 8), -- 导弹制造商 + +-- 亚洲供应商 +('韩国韩华防务', '韩国', 7, 7, 7), -- 韩国主要军工企业 +('日本防卫装备厂', '日本', 7, 7, 7), -- 日本主要军工企业 + +-- 俄罗斯供应商 +('俄罗斯', '俄罗斯', 7, 8, 6), -- 技术成熟但供应链受限 +('俄罗斯 ZALA', '俄罗斯', 7, 6, 6), -- 无人机制造商 +('俄罗斯 UZGA', '俄罗斯', 7, 6, 6), -- 航空设备制造商 + +-- 其他欧洲供应商 +('波兰 WB Electronics', '波兰', 6, 6, 6), -- 电子系统制造商 +('波兰 WB Group', '波兰', 6, 6, 6), -- 军工集团 +('波兰胡塔斯塔洛瓦', '波兰', 6, 6, 6), -- 装备制造商 +('瑞典 UMS Skeldar', '瑞典', 7, 6, 7), -- 无人机系统 +('瑞典 Saab', '瑞典', 7, 7, 7), -- 综合军工企业 +('捷克 RETIA', '捷克', 6, 5, 6), -- 电子系统制造商 +('斯洛伐克 ZTS', '斯洛伐克', 5, 5, 5), -- 装备制造商 +('捷克 Excalibur Army', '捷克', 6, 5, 6), -- 陆军装备制造商 +('克罗地亚 RH ALAN', '克罗地亚', 5, 4, 5), -- 军工企业 +('塞尔维亚 Yugoimport', '塞尔维亚', 5, 4, 5), -- 军工出口企业 +('芬兰 Patria', '芬兰', 7, 6, 7), -- 装甲车辆制造商 +('奥地利 Hirtenberger', '奥地利', 7, 6, 7), -- 火炮系统制造商 + +-- 其他供应商 +('土耳其洛克特桑', '土耳其', 6, 6, 6), -- 新兴军工企业 +('土耳其 STM', '土耳其', 6, 6, 6), -- 防务技术公司 +('巴西航空工业', '巴西', 6, 6, 5), -- 南美最大军工企业 +('印度DRDO', '印度', 5, 5, 5), -- 国防研究机构 +('伊朗国防工业', '伊朗', 4, 4, 4), -- 受制裁影响 +('埃及 AOI', '埃及', 4, 4, 4), -- 军工企业 +('罗马尼亚 ROMARM', '罗马尼亚', 5, 4, 5), -- 国营军工企业 +('乌克兰尤日马什', '乌克兰', 6, 5, 5), -- 航天企业 +('白俄罗斯国家军工委员会', '白俄罗斯', 5, 5, 5), -- 国家军工管理机构 +('阿联酋国际金龙', '阿联酋', 6, 6, 6), -- 新兴军工企业 +('新加坡ST工程', '新加坡', 7, 6, 7); -- 技术领先的军工企业 + +-- 更新装备表中的供应商ID +UPDATE equipment e +SET manufacturer_id = m.id +FROM manufacturers m +WHERE e.manufacturer = m.name; \ No newline at end of file diff --git a/src/routes.py b/src/routes.py index 61ff782..04f60b7 100644 --- a/src/routes.py +++ b/src/routes.py @@ -149,7 +149,7 @@ def analyze_features(): cp.width_m, cp.height_m, cp.weight_kg, - cp.max_range_km, + rap.max_range_km, rap.firing_angle_horizontal, rap.firing_angle_vertical, rap.rocket_length_m, @@ -191,7 +191,7 @@ def analyze_features(): cp.width_m, cp.height_m, cp.weight_kg, - cp.max_range_km, + lmp.max_range_km, lmp.wingspan_m, lmp.warhead_weight_kg, lmp.max_speed_ms, @@ -390,7 +390,7 @@ def train_model(): # 准备训练数据 train_prepared = data_processor.prepare_training_data(train_data, equipment_type) - # 准备验证数据(如果有) + # 准备验数据(如果有) validation_prepared = None if validation_data: validation_prepared = data_processor.prepare_validation_data( @@ -498,9 +498,11 @@ def get_equipment_data(): with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) - # 获取所有装备数据 + # 获取所有装备数据(使用equipment_id替代id) cursor.execute(""" - SELECT e.*, cp.*, cd.actual_cost, cd.predicted_cost, + SELECT e.id as equipment_id, e.name, e.type, + cp.length_m, cp.width_m, cp.height_m, cp.weight_kg, + cd.actual_cost, cd.predicted_cost, CASE WHEN e.type = '火箭炮' THEN ( SELECT CONCAT( @@ -636,7 +638,7 @@ def pls_predict(): if 'type' not in data: return jsonify({'error': 'Equipment type is required'}), 400 - # 使用 ModelTrainer 中的 PLS 模型进行预测 + # 使用 ModelTrainer 中的 PLS 模型行预测 trainer = ModelTrainer() if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型 return jsonify({'error': '未找到可用的模型'}), 404 @@ -720,66 +722,100 @@ def import_data(): @api_bp.route('/data/', methods=['PUT']) def update_equipment(id): - """ - 更新装备数据 - """ + """更新装备数据""" try: data = request.get_json() - logger.info(f"Updating equipment ID: {id}") - logger.info(f"Update data: {data}") + logger.info(f"Updating equipment with data: {data}") with get_db_connection() as conn: cursor = conn.cursor() - # 更新基本信息 + # 使用 equipment_id 而不是 id + equipment_id = data.get('equipment_id') + if not equipment_id: + raise ValueError("Missing equipment_id") + + # 更新装备基本信息 cursor.execute(""" UPDATE equipment SET name = %s, manufacturer = %s WHERE id = %s - """, (data['name'], data['manufacturer'], id)) - logger.info("Basic info updated") + """, (data['name'], data['manufacturer'], equipment_id)) # 更新通用参数 cursor.execute(""" UPDATE common_params - SET length_m = %s, width_m = %s, height_m = %s, - weight_kg = %s, max_range_km = %s + SET length_m = %s, width_m = %s, height_m = %s, weight_kg = %s WHERE equipment_id = %s - """, ( - data['length_m'], data['width_m'], data['height_m'], - data['weight_kg'], data['max_range_km'], id - )) - logger.info("Common params updated") + """, (data['length_m'], data['width_m'], data['height_m'], data['weight_kg'], equipment_id)) - # 根据备类型更新特有参数 + # 根据装备类型更新特有参数 if data['type'] == '火箭炮': cursor.execute(""" UPDATE rocket_artillery_params - SET firing_angle_horizontal = %s, firing_angle_vertical = %s, - rocket_length_m = %s, rocket_diameter_mm = %s, - rocket_weight_kg = %s, rate_of_fire = %s + SET firing_angle_horizontal = %s, + firing_angle_vertical = %s, + rocket_length_m = %s, + rocket_diameter_mm = %s, + rocket_weight_kg = %s, + rate_of_fire = %s, + combat_weight_kg = %s, + speed_kmh = %s, + min_range_km = %s, + max_range_km = %s, + mobility_type = %s, + structure_layout = %s, + engine_model = %s, + engine_params = %s, + power_hp = %s, + travel_range_km = %s WHERE equipment_id = %s """, ( - data['firing_angle_horizontal'], data['firing_angle_vertical'], - data['rocket_length_m'], data['rocket_diameter_mm'], - data['rocket_weight_kg'], data['rate_of_fire'], id + data['firing_angle_horizontal'], + data['firing_angle_vertical'], + data['rocket_length_m'], + data['rocket_diameter_mm'], + data['rocket_weight_kg'], + data['rate_of_fire'], + data['combat_weight_kg'], + data['speed_kmh'], + data['min_range_km'], + data['max_range_km'], + data['mobility_type'], + data['structure_layout'], + data['engine_model'], + data['engine_params'], + data['power_hp'], + data['travel_range_km'], + equipment_id )) - logger.info("Rocket artillery params updated") else: cursor.execute(""" UPDATE loitering_munition_params - SET max_speed_ms = %s, cruise_speed_kmh = %s, - flight_time_min = %s, warhead_type = %s, - launch_mode = %s, folded_length_mm = %s, - folded_width_mm = %s, folded_height_mm = %s + SET wingspan_m = %s, + warhead_weight_kg = %s, + max_speed_ms = %s, + cruise_speed_kmh = %s, + endurance_min = %s, + max_range_km = %s, + warhead_type = %s, + launch_mode = %s, + power_system = %s, + guidance_system = %s WHERE equipment_id = %s """, ( - data['max_speed_ms'], data['cruise_speed_kmh'], - data['flight_time_min'], data['warhead_type'], - data['launch_mode'], data['folded_length_mm'], - data['folded_width_mm'], data['folded_height_mm'], id + data['wingspan_m'], + data['warhead_weight_kg'], + data['max_speed_ms'], + data['cruise_speed_kmh'], + data['endurance_min'], + data['max_range_km'], + data['warhead_type'], + data['launch_mode'], + data['power_system'], + data['guidance_system'], + equipment_id )) - logger.info("Missile params updated") # 更新成本数据 if 'actual_cost' in data: @@ -787,23 +823,9 @@ def update_equipment(id): UPDATE cost_data SET actual_cost = %s WHERE equipment_id = %s - """, (data['actual_cost'], id)) - logger.info("Cost data updated") - - # 更新特殊参数 - if 'custom_params' in data and data['custom_params']: - logger.info(f"Updating custom params: {data['custom_params']}") - for param in data['custom_params']: - cursor.execute(""" - UPDATE custom_params - SET param_value = %s - WHERE id = %s AND equipment_id = %s - """, (param['param_value'], param['id'], id)) - logger.info("Custom params updated") + """, (data['actual_cost'], equipment_id)) conn.commit() - logger.info("All updates committed successfully") - return jsonify({'success': True}) except Exception as e: @@ -820,7 +842,7 @@ def get_equipment_details(id): with get_db_connection() as conn: cursor = conn.cursor(dictionary=True) - # 获取装备基本信息和类型 + # 获取装备基本信息类型 cursor.execute(""" SELECT e.*, cp.*, cd.actual_cost, cd.predicted_cost FROM equipment e @@ -919,9 +941,10 @@ def get_dataset(id): if not dataset: return jsonify({'error': 'Dataset not found'}), 404 - # 获取数据集中的装备 + # 获取数据集中的装备 - 修改查询,确保返回正确的ID cursor.execute(""" - SELECT e.*, cd.actual_cost + SELECT e.id as equipment_id, e.name, e.type, e.manufacturer, + cd.actual_cost FROM equipment e JOIN dataset_equipment de ON e.id = de.equipment_id LEFT JOIN cost_data cd ON e.id = cd.equipment_id @@ -953,66 +976,108 @@ def get_dataset(id): @api_bp.route('/datasets', methods=['POST']) def create_dataset(): - """ - 建数据集 - """ + """创建数据集""" try: data = request.get_json() + logger.info(f"Creating dataset with data: {data}") + with get_db_connection() as conn: cursor = conn.cursor() - # 创建数据集 + # 1. 验证装备ID是否存在 + if 'equipment_ids' in data and data['equipment_ids']: + # 直接从 equipment 表查询,不需要 JOIN + equipment_ids_str = ','.join(map(str, data['equipment_ids'])) + cursor.execute(f""" + SELECT DISTINCT id FROM equipment + WHERE id IN ({equipment_ids_str}) AND type = %s + """, (data['equipment_type'],)) + + valid_ids = [row[0] for row in cursor.fetchall()] + logger.info(f"Valid equipment IDs: {valid_ids}") + + # 如果有无效的ID,返回错误 + invalid_ids = set(data['equipment_ids']) - set(valid_ids) + if invalid_ids: + logger.error(f"Invalid equipment IDs: {invalid_ids}") + return jsonify({'error': f'无效的装备ID: {invalid_ids}'}), 400 + + # 2. 创建数据集 cursor.execute(""" INSERT INTO datasets (name, description, equipment_type, purpose) VALUES (%s, %s, %s, %s) """, (data['name'], data['description'], data['equipment_type'], data['purpose'])) dataset_id = cursor.lastrowid + logger.info(f"Created dataset with ID: {dataset_id}") - # 添加装备关联 + # 3. 添加装备关联 if 'equipment_ids' in data and data['equipment_ids']: - values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']] + values = [(dataset_id, equipment_id) for equipment_id in valid_ids] cursor.executemany(""" INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES (%s, %s) """, values) + logger.info(f"Added {len(values)} equipment associations") conn.commit() return jsonify({'id': dataset_id, 'message': '数据集创建成功'}) + except Exception as e: logger.error(f"Error creating dataset: {str(e)}") return jsonify({'error': str(e)}), 500 @api_bp.route('/datasets/', methods=['PUT']) def update_dataset(id): - """ - 更新数据集 - """ + """更新数据集""" try: data = request.get_json() + logger.info(f"Updating dataset {id} with data: {data}") + with get_db_connection() as conn: cursor = conn.cursor() - # 更新数据集基本信息 + # 1. 验证装备ID是否存在 + if 'equipment_ids' in data: + equipment_ids_str = ','.join(map(str, data['equipment_ids'])) + cursor.execute(f""" + SELECT id FROM equipment + WHERE id IN ({equipment_ids_str}) AND type = %s + """, (data['equipment_type'],)) + + valid_ids = [row[0] for row in cursor.fetchall()] + logger.info(f"Valid equipment IDs: {valid_ids}") + + # 如果有无效的ID,返回错误 + invalid_ids = set(data['equipment_ids']) - set(valid_ids) + if invalid_ids: + logger.error(f"Invalid equipment IDs: {invalid_ids}") + return jsonify({'error': f'无效的装备ID: {invalid_ids}'}), 400 + + # 2. 更新数据集基本信息 cursor.execute(""" UPDATE datasets SET name = %s, description = %s, equipment_type = %s, purpose = %s WHERE id = %s """, (data['name'], data['description'], data['equipment_type'], data['purpose'], id)) - # 删除旧的装备关联 - cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) - - # 加新的装备关��� + # 3. 更新装备关联 if 'equipment_ids' in data: - for equipment_id in data['equipment_ids']: - cursor.execute(""" + # 先删除旧的关联 + cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,)) + + # 添加新的关联 + if valid_ids: # 确保有有效的ID才执行插入 + values = [(id, equipment_id) for equipment_id in valid_ids] + cursor.executemany(""" INSERT INTO dataset_equipment (dataset_id, equipment_id) VALUES (%s, %s) - """, (id, equipment_id)) + """, values) + logger.info(f"Updated {len(values)} equipment associations") conn.commit() return jsonify({'success': True}) + except Exception as e: logger.error(f"Error updating dataset: {str(e)}") return jsonify({'error': str(e)}), 500 @@ -1153,7 +1218,7 @@ def delete_model(id): if not model: return jsonify({'error': 'Model not found'}), 404 - # 删除模型文件 + # 删除模型件 if os.path.exists(model[0]): os.remove(model[0]) if os.path.exists(model[1]): @@ -1172,7 +1237,7 @@ def delete_model(id): @api_bp.route('/predict/all', methods=['POST']) def predict_all(): """ - 获取所有机器学习模型的预测结果 + 获取所有机学习模型的预测结果 """ try: data = request.get_json()