CostPrediction/src/routes.py

1254 lines
47 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Blueprint, request, jsonify, send_file
from .cost_prediction import CostPredictor
from .feature_analysis import FeatureAnalysis
import pandas as pd
from datetime import datetime
import numpy as np
import mysql.connector
from sklearn.metrics import mean_absolute_error
from .create_template import create_excel_template
import json
import os
from .data_preparation import DataPreparation
from .model_trainer import ModelTrainer
from .logger import setup_logger
# 创建蓝图
api_bp = Blueprint('api', __name__)
# 获取logger
logger = setup_logger(__name__)
@api_bp.route('/', methods=['GET'])
def index():
"""
API根路由
返回API版本信息和可用端点列表
"""
return jsonify({
'name': '装备成本估算系统 API',
'version': '1.0.0',
'endpoints': {
'predict': {
'url': '/api/predict',
'method': 'POST',
'description': '成本预测'
},
'analyze-features': {
'url': '/api/analyze-features',
'method': 'POST',
'description': '特征分析'
},
'train': {
'url': '/api/train',
'method': 'POST',
'description': '模型训练'
},
'evaluate': {
'url': '/api/evaluate',
'method': 'POST',
'description': '模型评估'
}
}
})
@api_bp.route('/predict', methods=['POST'])
def predict_cost():
"""
成本预测接口
"""
try:
data = request.get_json()
logger.info(f"Received prediction request for equipment type: {data.get('type')}")
# 验证装备类型
if 'type' not in data:
return jsonify({'error': 'Equipment type is required'}), 400
# 预测成本
predictor = CostPredictor()
result = predictor.predict(data)
# 获取当前使用的模型信息
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE
LIMIT 1
""", (data['type'],))
model_info = cursor.fetchone()
# 在结果中添加模型信息
result.update({
'model_info': {
'type': model_info['model_type'],
'name': model_info['model_name'],
'r2_score': float(model_info['r2_score']),
'mae': float(model_info['mae']),
'rmse': float(model_info['rmse'])
}
})
logger.info(f"Prediction completed: {result}")
return jsonify(result)
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/analyze-features', methods=['POST'])
def analyze_features():
"""
基于数据集进行特征分析
"""
try:
data = request.get_json()
dataset_id = data.get('dataset_id')
logger.info(f"Starting feature analysis for dataset {dataset_id}")
if not dataset_id:
logger.warning("No dataset_id provided")
return jsonify({'error': '请选择数据集'}), 400
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取数据集信息
cursor.execute("""
SELECT d.*,
e.type as equipment_type
FROM datasets d
JOIN dataset_equipment de ON d.id = de.dataset_id
JOIN equipment e ON de.equipment_id = e.id
WHERE d.id = %s
LIMIT 1
""", (dataset_id,))
dataset = cursor.fetchone()
if not dataset:
logger.warning(f"Dataset {dataset_id} not found")
return jsonify({'error': '数据集不存在'}), 404
logger.info(f"Dataset info: {dataset}")
# 创建特征分析实例
from src.feature_analysis import FeatureAnalysis
analyzer = FeatureAnalysis()
# 获取特征列表
feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
logger.info(f"Feature names: {feature_names}")
# 获取数据集中的装备数据
if dataset['equipment_type'] == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (dataset_id,))
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (dataset_id,))
equipment_data = cursor.fetchall()
logger.info(f"Found {len(equipment_data)} equipment records")
if not equipment_data:
logger.warning("No valid equipment data found in dataset")
return jsonify({'error': '数据集没有有效的成本数据'}), 400
# 统计每个特征的缺失率
missing_rates = {}
for name in feature_names:
missing_count = sum(1 for item in equipment_data if item.get(name) is None)
missing_rate = missing_count / len(equipment_data)
missing_rates[name] = missing_rate
logger.info(f"Feature {name} missing rate: {missing_rate:.2%}")
# 过滤掉缺失率过高的特征
valid_features = [name for name in feature_names if missing_rates[name] < 0.7]
logger.info(f"Valid features after filtering: {valid_features}")
if len(valid_features) < 3: # 至少需要3个特征
return jsonify({'error': '有效特征数量不足'}), 400
# 计算每个特征的均值
feature_means = {}
for name in valid_features:
values = [float(item[name]) for item in equipment_data if item.get(name) is not None]
feature_means[name] = sum(values) / len(values) if values else 0
logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}")
# 准备特征和目标值
features = []
target = []
# 提取特征和目标值,使用均值填充缺失值
for item in equipment_data:
feature_values = []
for name in valid_features:
value = item.get(name)
try:
# 确保数值类型转换正确
feature_values.append(float(value) if value is not None else feature_means[name])
except (ValueError, TypeError) as e:
logger.error(f"Error converting value for feature {name}: {value}")
logger.error(f"Error details: {str(e)}")
return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400
features.append(feature_values)
# 确保成本值是值类型
try:
target.append(float(item['actual_cost']))
except (ValueError, TypeError) as e:
logger.error(f"Error converting actual_cost: {item['actual_cost']}")
logger.error(f"Error details: {str(e)}")
return jsonify({'error': '成本值无法换为数值'}), 400
logger.info(f"Prepared {len(features)} feature vectors")
logger.info(f"First feature vector: {features[0] if features else None}")
logger.info(f"First target value: {target[0] if target else None}")
# 调用特征分析方法
result = analyzer.analyze_features(features, target, valid_features)
logger.info("Analysis completed successfully")
return jsonify(result)
except Exception as e:
logger.error(f"Error analyzing features: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
@api_bp.route('/train', methods=['POST'])
def train_model():
"""
训练模型
"""
try:
data = request.get_json()
logger.info(f"Starting model training for {data.get('type')}")
equipment_type = data.get('type')
train_dataset_id = data.get('train_dataset_id')
validation_dataset_id = data.get('validation_dataset_id')
models = data.get('models', [])
logger.info(f"Training dataset: {train_dataset_id}")
logger.info(f"Validation dataset: {validation_dataset_id}")
logger.info(f"Selected models: {models}")
# 获取训练数据
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取训练集数据
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,))
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (train_dataset_id,))
train_data = cursor.fetchall()
# 获取验证集数据(如果有)
validation_data = None
if validation_dataset_id:
if equipment_type == '火箭炮':
cursor.execute("""
SELECT e.*, cp.*, rap.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,))
else:
cursor.execute("""
SELECT e.*, cp.*, lmp.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
AND cd.actual_cost IS NOT NULL
""", (validation_dataset_id,))
validation_data = cursor.fetchall()
if not train_data:
return jsonify({'error': '训练数据集为空'}), 400
# 1. 准备数据
data_processor = DataPreparation()
# 准备训练数据
train_prepared = data_processor.prepare_training_data(train_data, equipment_type)
# 准备验证数据(如果有)
validation_prepared = None
if validation_data:
validation_prepared = data_processor.prepare_validation_data(
validation_data,
equipment_type,
train_prepared['feature_names'],
{
'feature_scaler': train_prepared['feature_scaler'],
'target_scaler': train_prepared['target_scaler']
}
)
# 2. 训练模型
model_trainer = ModelTrainer()
model_trainer.feature_scaler = train_prepared['feature_scaler']
model_trainer.target_scaler = train_prepared['target_scaler']
# 执行训练,传入 equipment_type 参数
training_result = model_trainer.fit_model(
train_prepared['X'],
train_prepared['y'],
models,
validation_prepared['X'] if validation_prepared else None,
validation_prepared['y'] if validation_prepared else None,
equipment_type=equipment_type
)
return jsonify(training_result)
except Exception as e:
logger.error(f"Error in model training: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
@api_bp.route('/evaluate', methods=['POST'])
def evaluate_model():
"""
模型评估接口
"""
try:
data = request.get_json()
logger.info("Received model evaluation request")
if 'test_data' not in data:
return jsonify({'error': 'Test data is required'}), 400
predictor = CostPredictor()
evaluation_result = predictor.evaluate(
data['test_data']['actual'],
data['test_data']['predicted']
)
logger.info("Model evaluation completed")
return jsonify(evaluation_result)
except Exception as e:
logger.error(f"Error in model evaluation: {str(e)}")
return jsonify({'error': str(e)}), 500
def get_required_params(equipment_type):
"""
根据装备类型获取必要参数
"""
common_params = [
'length_m',
'width_m',
'height_m',
'weight_kg',
'max_range_km'
]
if equipment_type == '火箭炮':
return common_params + [
'firing_angle_horizontal',
'firing_angle_vertical',
'rocket_length_m',
'rocket_diameter_mm',
'rocket_weight_kg'
]
elif equipment_type == '巡飞弹':
return common_params + [
'max_speed_kmh',
'cruise_speed_kmh',
'flight_time_min',
'folded_length_mm',
'folded_width_mm',
'folded_height_mm'
]
return common_params
@api_bp.errorhandler(404)
def not_found(error):
return jsonify({'error': 'Not found'}), 404
@api_bp.errorhandler(500)
def internal_error(error):
logger.error(f"Internal server error: {str(error)}")
return jsonify({'error': 'Internal server error'}), 500
@api_bp.route('/data', methods=['GET'])
def get_equipment_data():
"""
获取装备数据
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute('SET SESSION group_concat_max_len = 1000000')
# 先测试特殊参数查询
cursor.execute("""
SELECT equipment_id, param_name, param_value, param_unit
FROM custom_params
WHERE param_name IS NOT NULL
AND param_value IS NOT NULL
LIMIT 5
""")
test_params = cursor.fetchall()
logger.info(f"Test custom params: {test_params}")
# 获取火箭炮数据
logger.info("Fetching rocket artillery data...")
cursor.execute("""
SELECT
e.id,
e.name,
e.type,
e.manufacturer,
e.created_at,
cp.length_m,
cp.width_m,
cp.height_m,
cp.weight_kg,
cp.max_range_km,
rap.firing_angle_horizontal,
rap.firing_angle_vertical,
rap.rocket_length_m,
rap.rocket_diameter_mm,
rap.rocket_weight_kg,
rap.rate_of_fire,
rap.combat_weight_kg,
rap.speed_kmh,
rap.min_range_km,
rap.mobility_type,
rap.structure_layout,
rap.engine_model,
rap.engine_params,
rap.power_hp,
rap.travel_range_km,
cd.actual_cost,
(
SELECT COALESCE(
JSON_ARRAYAGG(
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
),
'[]'
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.type = '火箭炮'
""")
rocket_artillery = cursor.fetchall()
logger.info(f"Found {len(rocket_artillery)} rocket artillery records")
if rocket_artillery:
logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}")
logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}")
# 获取巡飞弹数据
logger.info("Fetching missile data...")
cursor.execute("""
SELECT
e.id,
e.name,
e.type,
e.manufacturer,
e.created_at,
cp.length_m,
cp.width_m,
cp.height_m,
cp.weight_kg,
cp.max_range_km,
lmp.wingspan_m,
lmp.warhead_weight_kg,
lmp.max_speed_ms,
lmp.cruise_speed_kmh,
lmp.flight_time_min,
lmp.warhead_type,
lmp.launch_mode,
lmp.folded_length_mm,
lmp.folded_width_mm,
lmp.folded_height_mm,
lmp.power_system,
lmp.guidance_system,
cd.actual_cost,
(
SELECT COALESCE(
JSON_ARRAYAGG(
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
),
'[]'
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.type = '巡飞弹'
""")
loitering_munition = cursor.fetchall()
logger.info(f"Found {len(loitering_munition)} missile records")
if loitering_munition:
logger.info(f"First missile: {loitering_munition[0]['name']}")
logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}")
# 处理 custom_params保为 NULL
for item in rocket_artillery + loitering_munition:
if item['custom_params'] is None:
item['custom_params'] = []
logger.debug(f"Set empty custom_params for equipment {item['id']}")
else:
logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params")
logger.info("Data fetching completed")
return jsonify({
'rocket_artillery': rocket_artillery,
'loitering_munition': loitering_munition
})
except Exception as e:
logger.error(f"Error getting equipment data: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
@api_bp.route('/data/<int:id>', methods=['DELETE'])
def delete_equipment(id):
"""
删除装备数据
"""
try:
db = get_db_connection()
cursor = db.cursor()
# 删除相关数据
cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
db.commit()
cursor.close()
db.close()
return jsonify({'status': 'success'})
except Exception as e:
logger.error(f"Error deleting equipment: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/data/template', methods=['GET'])
def download_template():
"""
下载数据模板
"""
try:
# 创建模板文件
from .create_template import create_excel_template
template_path = create_excel_template()
# 检查文件是否存
if not os.path.exists(template_path):
raise FileNotFoundError("模板文件不存在")
# 返回文件
return send_file(
template_path,
as_attachment=True,
download_name='equipment_data_template.xlsx',
mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
)
except Exception as e:
logger.error(f"Error creating template: {str(e)}")
return jsonify({'error': str(e)}), 500
def get_db_connection():
"""
获取数据库连接
"""
return mysql.connector.connect(
host="localhost",
user="root",
password="123456",
database="equipment_cost_db"
)
@api_bp.route('/pls/predict', methods=['POST'])
def pls_predict():
"""
PLS回归预测接口
"""
try:
data = request.get_json()
logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}")
# 验证装备类型
if 'type' not in data:
return jsonify({'error': 'Equipment type is required'}), 400
# 使用 ModelTrainer 中的 PLS 模型进行预测
trainer = ModelTrainer()
if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型
return jsonify({'error': '未找到可用的模型'}), 404
# 准备特征数据
feature_analyzer = FeatureAnalysis()
features = feature_analyzer.get_equipment_specific_features(data['type'])
X = np.array([[data.get(feature) for feature in features]])
# 预测
result = trainer.predict(X)
# 计算置信区间
confidence_interval = trainer._calculate_confidence_interval(result[0])
# 获取模型信息
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT model_type, model_name, r2_score, mae, rmse
FROM trained_models
WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE
LIMIT 1
""", (data['type'],))
model_info = cursor.fetchone()
# 确保返回的数据可以序列化为JSON
response = {
'predicted_cost': float(result[0]),
'model_info': {
'type': model_info['model_type'],
'name': model_info['model_name'],
'r2_score': model_info['r2_score'],
'mae': model_info['mae'],
'rmse': model_info['rmse']
},
'confidence_interval': {
'lower': float(confidence_interval[0]),
'upper': float(confidence_interval[1])
}
}
logger.info(f"PLS prediction completed: {response}")
return jsonify(response)
except Exception as e:
logger.error(f"Error in PLS prediction: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/data/import', methods=['POST'])
def import_data():
"""
导入数据接口
"""
try:
if 'file' not in request.files:
return jsonify({'error': '没有上传文件'}), 400
file = request.files['file']
if not file.filename.endswith(('.xls', '.xlsx')):
return jsonify({'error': '请上传Excel文件'}), 400
# 保存上的文件
upload_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, file.filename)
file.save(file_path)
# 导入数据
from .import_data import import_training_data
import_training_data(file_path)
return jsonify({
'success': True,
'message': '数据导入成功'
})
except Exception as e:
logger.error(f"Error importing data: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/data/<int:id>', methods=['PUT'])
def update_equipment(id):
"""
更新装备数据
"""
try:
data = request.get_json()
logger.info(f"Updating equipment ID: {id}")
logger.info(f"Update data: {data}")
with get_db_connection() as conn:
cursor = conn.cursor()
# 更新基本信息
cursor.execute("""
UPDATE equipment
SET name = %s, manufacturer = %s
WHERE id = %s
""", (data['name'], data['manufacturer'], id))
logger.info("Basic info updated")
# 更新通用参数
cursor.execute("""
UPDATE common_params
SET length_m = %s, width_m = %s, height_m = %s,
weight_kg = %s, max_range_km = %s
WHERE equipment_id = %s
""", (
data['length_m'], data['width_m'], data['height_m'],
data['weight_kg'], data['max_range_km'], id
))
logger.info("Common params updated")
# 根据备类型更新特有参数
if data['type'] == '火箭炮':
cursor.execute("""
UPDATE rocket_artillery_params
SET firing_angle_horizontal = %s, firing_angle_vertical = %s,
rocket_length_m = %s, rocket_diameter_mm = %s,
rocket_weight_kg = %s, rate_of_fire = %s
WHERE equipment_id = %s
""", (
data['firing_angle_horizontal'], data['firing_angle_vertical'],
data['rocket_length_m'], data['rocket_diameter_mm'],
data['rocket_weight_kg'], data['rate_of_fire'], id
))
logger.info("Rocket artillery params updated")
else:
cursor.execute("""
UPDATE loitering_munition_params
SET max_speed_ms = %s, cruise_speed_kmh = %s,
flight_time_min = %s, warhead_type = %s,
launch_mode = %s, folded_length_mm = %s,
folded_width_mm = %s, folded_height_mm = %s
WHERE equipment_id = %s
""", (
data['max_speed_ms'], data['cruise_speed_kmh'],
data['flight_time_min'], data['warhead_type'],
data['launch_mode'], data['folded_length_mm'],
data['folded_width_mm'], data['folded_height_mm'], id
))
logger.info("Missile params updated")
# 更新成本数据
if 'actual_cost' in data:
cursor.execute("""
UPDATE cost_data
SET actual_cost = %s
WHERE equipment_id = %s
""", (data['actual_cost'], id))
logger.info("Cost data updated")
# 更新特殊参数
if 'custom_params' in data and data['custom_params']:
logger.info(f"Updating custom params: {data['custom_params']}")
for param in data['custom_params']:
cursor.execute("""
UPDATE custom_params
SET param_value = %s
WHERE id = %s AND equipment_id = %s
""", (param['param_value'], param['id'], id))
logger.info("Custom params updated")
conn.commit()
logger.info("All updates committed successfully")
return jsonify({'success': True})
except Exception as e:
logger.error(f"Error updating equipment: {str(e)}")
logger.error("Detailed traceback:", exc_info=True)
return jsonify({'error': str(e)}), 500
@api_bp.route('/data/details/<int:id>', methods=['GET'])
def get_equipment_details(id):
"""
获取装备详数据
"""
try:
logger.info(f"Getting details for equipment ID: {id}")
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 先获取装备类型
cursor.execute("SELECT type FROM equipment WHERE id = %s", (id,))
equipment = cursor.fetchone()
if not equipment:
logger.warning(f"Equipment not found: {id}")
return jsonify({'error': 'Equipment not found'}), 404
equipment_type = equipment['type']
logger.info(f"Equipment type: {equipment_type}")
# 根据装备类型选择查询
if equipment_type == '火箭炮':
query = """
SELECT
e.*,
cp.*,
rap.*,
cd.actual_cost,
cd.prediction_date as cost_estimate_date,
cd.predicted_cost,
(
SELECT JSON_ARRAYAGG(
CASE
WHEN csp.id IS NOT NULL THEN
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
END
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.id = %s
"""
else:
query = """
SELECT
e.*,
cp.*,
lmp.*,
cd.actual_cost,
cd.prediction_date as cost_estimate_date,
cd.predicted_cost,
(
SELECT JSON_ARRAYAGG(
CASE
WHEN csp.id IS NOT NULL THEN
JSON_OBJECT(
'id', csp.id,
'param_name', csp.param_name,
'param_value', csp.param_value,
'param_unit', csp.param_unit,
'description', csp.description
)
END
)
FROM custom_params csp
WHERE csp.equipment_id = e.id
AND csp.param_name IS NOT NULL
AND csp.param_value IS NOT NULL
) as custom_params
FROM equipment e
LEFT JOIN common_params cp ON e.id = cp.equipment_id
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE e.id = %s
"""
cursor.execute(query, (id,))
result = cursor.fetchone()
if result:
logger.info(f"Found equipment details: {result['name']}")
logger.info(f"Custom params: {result.get('custom_params')}")
return jsonify(result)
except Exception as e:
logger.error(f"Error getting equipment details: {str(e)}")
return jsonify({'error': str(e)}), 500
# 添加数据集相关的路由
@api_bp.route('/datasets', methods=['GET'])
def get_datasets():
"""
获取数据集列表
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT d.*,
COUNT(de.equipment_id) as equipment_count,
GROUP_CONCAT(e.name) as equipment_names
FROM datasets d
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
LEFT JOIN equipment e ON de.equipment_id = e.id
GROUP BY d.id
""")
datasets = cursor.fetchall()
# 理装备名称列表
for dataset in datasets:
if dataset['equipment_names']:
dataset['equipment_names'] = dataset['equipment_names'].split(',')
else:
dataset['equipment_names'] = []
return jsonify(datasets)
except Exception as e:
logger.error(f"Error getting datasets: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/datasets/<int:id>', methods=['GET'])
def get_dataset(id):
"""
获取数据集详情
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
# 获取数据集基本信息
cursor.execute("""
SELECT d.*,
COUNT(de.equipment_id) as equipment_count
FROM datasets d
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
WHERE d.id = %s
GROUP BY d.id
""", (id,))
dataset = cursor.fetchone()
if not dataset:
return jsonify({'error': 'Dataset not found'}), 404
# 获取数据集中的装备
cursor.execute("""
SELECT e.*, cd.actual_cost
FROM equipment e
JOIN dataset_equipment de ON e.id = de.equipment_id
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
WHERE de.dataset_id = %s
""", (id,))
equipment = cursor.fetchall()
# 计算统计信息
if equipment:
total_cost = sum(item['actual_cost'] or 0 for item in equipment)
avg_cost = total_cost / len(equipment)
dataset['statistics'] = {
'equipment_count': len(equipment),
'total_cost': total_cost,
'average_cost': avg_cost
}
else:
dataset['statistics'] = {
'equipment_count': 0,
'total_cost': 0,
'average_cost': 0
}
dataset['equipment'] = equipment
return jsonify(dataset)
except Exception as e:
logger.error(f"Error getting dataset: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/datasets', methods=['POST'])
def create_dataset():
"""
建数据集
"""
try:
data = request.get_json()
with get_db_connection() as conn:
cursor = conn.cursor()
# 创建数据集
cursor.execute("""
INSERT INTO datasets (name, description, equipment_type, purpose)
VALUES (%s, %s, %s, %s)
""", (data['name'], data['description'], data['equipment_type'], data['purpose']))
dataset_id = cursor.lastrowid
# 添加装备关联
if 'equipment_ids' in data and data['equipment_ids']:
values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']]
cursor.executemany("""
INSERT INTO dataset_equipment (dataset_id, equipment_id)
VALUES (%s, %s)
""", values)
conn.commit()
return jsonify({'id': dataset_id, 'message': '数据集创建成功'})
except Exception as e:
logger.error(f"Error creating dataset: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/datasets/<int:id>', methods=['PUT'])
def update_dataset(id):
"""
更新数据集
"""
try:
data = request.get_json()
with get_db_connection() as conn:
cursor = conn.cursor()
# 更新数据集基本信息
cursor.execute("""
UPDATE datasets
SET name = %s, description = %s, equipment_type = %s, purpose = %s
WHERE id = %s
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
# 删除旧的装备关联
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
# 加新的装备关联
if 'equipment_ids' in data:
for equipment_id in data['equipment_ids']:
cursor.execute("""
INSERT INTO dataset_equipment (dataset_id, equipment_id)
VALUES (%s, %s)
""", (id, equipment_id))
conn.commit()
return jsonify({'success': True})
except Exception as e:
logger.error(f"Error updating dataset: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/datasets/<int:id>', methods=['DELETE'])
def delete_dataset(id):
"""
删除数据集
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
# 删除装备关联
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
# 删除数据集
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
conn.commit()
return jsonify({'success': True})
except Exception as e:
logger.error(f"Error deleting dataset: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/models/<equipment_type>/latest', methods=['GET'])
def get_latest_model(equipment_type):
"""
获取最新训练的型信息
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT * FROM trained_models
WHERE equipment_type = %s AND is_active = TRUE
ORDER BY training_date DESC LIMIT 1
""", (equipment_type,))
model = cursor.fetchone()
return jsonify(model)
except Exception as e:
logger.error(f"Error getting latest model: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/models', methods=['GET'])
def get_models():
"""
获取模型列表
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT * FROM trained_models
ORDER BY training_date DESC
""")
models = cursor.fetchall()
# 确保数值类型字段是 float
for model in models:
if model['r2_score'] is not None:
model['r2_score'] = float(model['r2_score'])
if model['mae'] is not None:
model['mae'] = float(model['mae'])
if model['rmse'] is not None:
model['rmse'] = float(model['rmse'])
# 解析特征重要性
if model['feature_importance']:
model['feature_importance'] = json.loads(model['feature_importance'])
return jsonify(models)
except Exception as e:
logger.error(f"Error getting models: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/models/<int:id>/activate', methods=['POST'])
def activate_model(id):
"""
激活指定的模型
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
# 获取模型信息
cursor.execute("""
SELECT equipment_type FROM trained_models
WHERE id = %s
""", (id,))
model = cursor.fetchone()
if not model:
return jsonify({'error': 'Model not found'}), 404
# 将同类型的其他模型设置为非激活
cursor.execute("""
UPDATE trained_models
SET is_active = FALSE
WHERE equipment_type = %s
""", (model[0],))
# 激活指定模型
cursor.execute("""
UPDATE trained_models
SET is_active = TRUE
WHERE id = %s
""", (id,))
conn.commit()
return jsonify({'success': True})
except Exception as e:
logger.error(f"Error activating model: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/models/<int:id>', methods=['DELETE'])
def delete_model(id):
"""
删除指定的模型
"""
try:
with get_db_connection() as conn:
cursor = conn.cursor()
# 获取模型文件路径
cursor.execute("""
SELECT model_path, scaler_path
FROM trained_models
WHERE id = %s
""", (id,))
model = cursor.fetchone()
if not model:
return jsonify({'error': 'Model not found'}), 404
# 删除模型文件
if os.path.exists(model[0]):
os.remove(model[0])
if os.path.exists(model[1]):
os.remove(model[1])
# 删除数据库记录
cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
conn.commit()
return jsonify({'success': True})
except Exception as e:
logger.error(f"Error deleting model: {str(e)}")
return jsonify({'error': str(e)}), 500
@api_bp.route('/predict/all', methods=['POST'])
def predict_all():
"""
获取所有机器学习模型的预测结果
"""
try:
data = request.get_json()
logger.info(f"Received prediction request for all models, equipment type: {data.get('type')}")
predictor = CostPredictor()
results = predictor.predict_all(data)
return jsonify(results)
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
return jsonify({'error': str(e)}), 500