1254 lines
47 KiB
Python
1254 lines
47 KiB
Python
from flask import Blueprint, request, jsonify, send_file
|
||
from .cost_prediction import CostPredictor
|
||
from .feature_analysis import FeatureAnalysis
|
||
import pandas as pd
|
||
from datetime import datetime
|
||
import numpy as np
|
||
import mysql.connector
|
||
from sklearn.metrics import mean_absolute_error
|
||
from .create_template import create_excel_template
|
||
import json
|
||
import os
|
||
from .data_preparation import DataPreparation
|
||
from .model_trainer import ModelTrainer
|
||
from .logger import setup_logger
|
||
|
||
# 创建蓝图
|
||
api_bp = Blueprint('api', __name__)
|
||
|
||
# 获取logger
|
||
logger = setup_logger(__name__)
|
||
|
||
@api_bp.route('/', methods=['GET'])
|
||
def index():
|
||
"""
|
||
API根路由
|
||
返回API版本信息和可用端点列表
|
||
"""
|
||
return jsonify({
|
||
'name': '装备成本估算系统 API',
|
||
'version': '1.0.0',
|
||
'endpoints': {
|
||
'predict': {
|
||
'url': '/api/predict',
|
||
'method': 'POST',
|
||
'description': '成本预测'
|
||
},
|
||
'analyze-features': {
|
||
'url': '/api/analyze-features',
|
||
'method': 'POST',
|
||
'description': '特征分析'
|
||
},
|
||
'train': {
|
||
'url': '/api/train',
|
||
'method': 'POST',
|
||
'description': '模型训练'
|
||
},
|
||
'evaluate': {
|
||
'url': '/api/evaluate',
|
||
'method': 'POST',
|
||
'description': '模型评估'
|
||
}
|
||
}
|
||
})
|
||
|
||
@api_bp.route('/predict', methods=['POST'])
|
||
def predict_cost():
|
||
"""
|
||
成本预测接口
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
logger.info(f"Received prediction request for equipment type: {data.get('type')}")
|
||
|
||
# 验证装备类型
|
||
if 'type' not in data:
|
||
return jsonify({'error': 'Equipment type is required'}), 400
|
||
|
||
# 预测成本
|
||
predictor = CostPredictor()
|
||
result = predictor.predict(data)
|
||
|
||
# 获取当前使用的模型信息
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("""
|
||
SELECT model_type, model_name, r2_score, mae, rmse
|
||
FROM trained_models
|
||
WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE
|
||
LIMIT 1
|
||
""", (data['type'],))
|
||
model_info = cursor.fetchone()
|
||
|
||
# 在结果中添加模型信息
|
||
result.update({
|
||
'model_info': {
|
||
'type': model_info['model_type'],
|
||
'name': model_info['model_name'],
|
||
'r2_score': float(model_info['r2_score']),
|
||
'mae': float(model_info['mae']),
|
||
'rmse': float(model_info['rmse'])
|
||
}
|
||
})
|
||
|
||
logger.info(f"Prediction completed: {result}")
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in prediction: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/analyze-features', methods=['POST'])
|
||
def analyze_features():
|
||
"""
|
||
基于数据集进行特征分析
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
dataset_id = data.get('dataset_id')
|
||
|
||
logger.info(f"Starting feature analysis for dataset {dataset_id}")
|
||
|
||
if not dataset_id:
|
||
logger.warning("No dataset_id provided")
|
||
return jsonify({'error': '请选择数据集'}), 400
|
||
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 获取数据集信息
|
||
cursor.execute("""
|
||
SELECT d.*,
|
||
e.type as equipment_type
|
||
FROM datasets d
|
||
JOIN dataset_equipment de ON d.id = de.dataset_id
|
||
JOIN equipment e ON de.equipment_id = e.id
|
||
WHERE d.id = %s
|
||
LIMIT 1
|
||
""", (dataset_id,))
|
||
dataset = cursor.fetchone()
|
||
|
||
if not dataset:
|
||
logger.warning(f"Dataset {dataset_id} not found")
|
||
return jsonify({'error': '数据集不存在'}), 404
|
||
|
||
logger.info(f"Dataset info: {dataset}")
|
||
|
||
# 创建特征分析实例
|
||
from src.feature_analysis import FeatureAnalysis
|
||
analyzer = FeatureAnalysis()
|
||
|
||
# 获取特征列表
|
||
feature_names = analyzer.get_equipment_specific_features(dataset['equipment_type'])
|
||
logger.info(f"Feature names: {feature_names}")
|
||
|
||
# 获取数据集中的装备数据
|
||
if dataset['equipment_type'] == '火箭炮':
|
||
cursor.execute("""
|
||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
AND cd.actual_cost IS NOT NULL
|
||
""", (dataset_id,))
|
||
else:
|
||
cursor.execute("""
|
||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
AND cd.actual_cost IS NOT NULL
|
||
""", (dataset_id,))
|
||
|
||
equipment_data = cursor.fetchall()
|
||
logger.info(f"Found {len(equipment_data)} equipment records")
|
||
|
||
if not equipment_data:
|
||
logger.warning("No valid equipment data found in dataset")
|
||
return jsonify({'error': '数据集没有有效的成本数据'}), 400
|
||
|
||
# 统计每个特征的缺失率
|
||
missing_rates = {}
|
||
for name in feature_names:
|
||
missing_count = sum(1 for item in equipment_data if item.get(name) is None)
|
||
missing_rate = missing_count / len(equipment_data)
|
||
missing_rates[name] = missing_rate
|
||
logger.info(f"Feature {name} missing rate: {missing_rate:.2%}")
|
||
|
||
# 过滤掉缺失率过高的特征
|
||
valid_features = [name for name in feature_names if missing_rates[name] < 0.7]
|
||
logger.info(f"Valid features after filtering: {valid_features}")
|
||
|
||
if len(valid_features) < 3: # 至少需要3个特征
|
||
return jsonify({'error': '有效特征数量不足'}), 400
|
||
|
||
# 计算每个特征的均值
|
||
feature_means = {}
|
||
for name in valid_features:
|
||
values = [float(item[name]) for item in equipment_data if item.get(name) is not None]
|
||
feature_means[name] = sum(values) / len(values) if values else 0
|
||
logger.info(f"Feature {name} mean value: {feature_means[name]:.2f}")
|
||
|
||
# 准备特征和目标值
|
||
features = []
|
||
target = []
|
||
|
||
# 提取特征和目标值,使用均值填充缺失值
|
||
for item in equipment_data:
|
||
feature_values = []
|
||
for name in valid_features:
|
||
value = item.get(name)
|
||
try:
|
||
# 确保数值类型转换正确
|
||
feature_values.append(float(value) if value is not None else feature_means[name])
|
||
except (ValueError, TypeError) as e:
|
||
logger.error(f"Error converting value for feature {name}: {value}")
|
||
logger.error(f"Error details: {str(e)}")
|
||
return jsonify({'error': f'特征 {name} 的值 {value} 无法转换为数值'}), 400
|
||
features.append(feature_values)
|
||
|
||
# 确保成本值是值类型
|
||
try:
|
||
target.append(float(item['actual_cost']))
|
||
except (ValueError, TypeError) as e:
|
||
logger.error(f"Error converting actual_cost: {item['actual_cost']}")
|
||
logger.error(f"Error details: {str(e)}")
|
||
return jsonify({'error': '成本值无法换为数值'}), 400
|
||
|
||
logger.info(f"Prepared {len(features)} feature vectors")
|
||
logger.info(f"First feature vector: {features[0] if features else None}")
|
||
logger.info(f"First target value: {target[0] if target else None}")
|
||
|
||
# 调用特征分析方法
|
||
result = analyzer.analyze_features(features, target, valid_features)
|
||
logger.info("Analysis completed successfully")
|
||
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error analyzing features: {str(e)}")
|
||
logger.error("Detailed traceback:", exc_info=True)
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/train', methods=['POST'])
|
||
def train_model():
|
||
"""
|
||
训练模型
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
logger.info(f"Starting model training for {data.get('type')}")
|
||
equipment_type = data.get('type')
|
||
train_dataset_id = data.get('train_dataset_id')
|
||
validation_dataset_id = data.get('validation_dataset_id')
|
||
models = data.get('models', [])
|
||
|
||
logger.info(f"Training dataset: {train_dataset_id}")
|
||
logger.info(f"Validation dataset: {validation_dataset_id}")
|
||
logger.info(f"Selected models: {models}")
|
||
|
||
# 获取训练数据
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 获取训练集数据
|
||
if equipment_type == '火箭炮':
|
||
cursor.execute("""
|
||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
AND cd.actual_cost IS NOT NULL
|
||
""", (train_dataset_id,))
|
||
else:
|
||
cursor.execute("""
|
||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
AND cd.actual_cost IS NOT NULL
|
||
""", (train_dataset_id,))
|
||
|
||
train_data = cursor.fetchall()
|
||
|
||
# 获取验证集数据(如果有)
|
||
validation_data = None
|
||
if validation_dataset_id:
|
||
if equipment_type == '火箭炮':
|
||
cursor.execute("""
|
||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
AND cd.actual_cost IS NOT NULL
|
||
""", (validation_dataset_id,))
|
||
else:
|
||
cursor.execute("""
|
||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
AND cd.actual_cost IS NOT NULL
|
||
""", (validation_dataset_id,))
|
||
validation_data = cursor.fetchall()
|
||
|
||
if not train_data:
|
||
return jsonify({'error': '训练数据集为空'}), 400
|
||
|
||
# 1. 准备数据
|
||
data_processor = DataPreparation()
|
||
|
||
# 准备训练数据
|
||
train_prepared = data_processor.prepare_training_data(train_data, equipment_type)
|
||
|
||
# 准备验证数据(如果有)
|
||
validation_prepared = None
|
||
if validation_data:
|
||
validation_prepared = data_processor.prepare_validation_data(
|
||
validation_data,
|
||
equipment_type,
|
||
train_prepared['feature_names'],
|
||
{
|
||
'feature_scaler': train_prepared['feature_scaler'],
|
||
'target_scaler': train_prepared['target_scaler']
|
||
}
|
||
)
|
||
|
||
# 2. 训练模型
|
||
model_trainer = ModelTrainer()
|
||
model_trainer.feature_scaler = train_prepared['feature_scaler']
|
||
model_trainer.target_scaler = train_prepared['target_scaler']
|
||
|
||
# 执行训练,传入 equipment_type 参数
|
||
training_result = model_trainer.fit_model(
|
||
train_prepared['X'],
|
||
train_prepared['y'],
|
||
models,
|
||
validation_prepared['X'] if validation_prepared else None,
|
||
validation_prepared['y'] if validation_prepared else None,
|
||
equipment_type=equipment_type
|
||
)
|
||
|
||
return jsonify(training_result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in model training: {str(e)}")
|
||
logger.error("Detailed traceback:", exc_info=True)
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/evaluate', methods=['POST'])
|
||
def evaluate_model():
|
||
"""
|
||
模型评估接口
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
logger.info("Received model evaluation request")
|
||
|
||
if 'test_data' not in data:
|
||
return jsonify({'error': 'Test data is required'}), 400
|
||
|
||
predictor = CostPredictor()
|
||
evaluation_result = predictor.evaluate(
|
||
data['test_data']['actual'],
|
||
data['test_data']['predicted']
|
||
)
|
||
|
||
logger.info("Model evaluation completed")
|
||
return jsonify(evaluation_result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in model evaluation: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
def get_required_params(equipment_type):
|
||
"""
|
||
根据装备类型获取必要参数
|
||
"""
|
||
common_params = [
|
||
'length_m',
|
||
'width_m',
|
||
'height_m',
|
||
'weight_kg',
|
||
'max_range_km'
|
||
]
|
||
|
||
if equipment_type == '火箭炮':
|
||
return common_params + [
|
||
'firing_angle_horizontal',
|
||
'firing_angle_vertical',
|
||
'rocket_length_m',
|
||
'rocket_diameter_mm',
|
||
'rocket_weight_kg'
|
||
]
|
||
elif equipment_type == '巡飞弹':
|
||
return common_params + [
|
||
'max_speed_kmh',
|
||
'cruise_speed_kmh',
|
||
'flight_time_min',
|
||
'folded_length_mm',
|
||
'folded_width_mm',
|
||
'folded_height_mm'
|
||
]
|
||
|
||
return common_params
|
||
|
||
@api_bp.errorhandler(404)
|
||
def not_found(error):
|
||
return jsonify({'error': 'Not found'}), 404
|
||
|
||
@api_bp.errorhandler(500)
|
||
def internal_error(error):
|
||
logger.error(f"Internal server error: {str(error)}")
|
||
return jsonify({'error': 'Internal server error'}), 500
|
||
|
||
@api_bp.route('/data', methods=['GET'])
|
||
def get_equipment_data():
|
||
"""
|
||
获取装备数据
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute('SET SESSION group_concat_max_len = 1000000')
|
||
|
||
# 先测试特殊参数查询
|
||
cursor.execute("""
|
||
SELECT equipment_id, param_name, param_value, param_unit
|
||
FROM custom_params
|
||
WHERE param_name IS NOT NULL
|
||
AND param_value IS NOT NULL
|
||
LIMIT 5
|
||
""")
|
||
test_params = cursor.fetchall()
|
||
logger.info(f"Test custom params: {test_params}")
|
||
|
||
# 获取火箭炮数据
|
||
logger.info("Fetching rocket artillery data...")
|
||
cursor.execute("""
|
||
SELECT
|
||
e.id,
|
||
e.name,
|
||
e.type,
|
||
e.manufacturer,
|
||
e.created_at,
|
||
cp.length_m,
|
||
cp.width_m,
|
||
cp.height_m,
|
||
cp.weight_kg,
|
||
cp.max_range_km,
|
||
rap.firing_angle_horizontal,
|
||
rap.firing_angle_vertical,
|
||
rap.rocket_length_m,
|
||
rap.rocket_diameter_mm,
|
||
rap.rocket_weight_kg,
|
||
rap.rate_of_fire,
|
||
rap.combat_weight_kg,
|
||
rap.speed_kmh,
|
||
rap.min_range_km,
|
||
rap.mobility_type,
|
||
rap.structure_layout,
|
||
rap.engine_model,
|
||
rap.engine_params,
|
||
rap.power_hp,
|
||
rap.travel_range_km,
|
||
cd.actual_cost,
|
||
(
|
||
SELECT COALESCE(
|
||
JSON_ARRAYAGG(
|
||
JSON_OBJECT(
|
||
'id', csp.id,
|
||
'param_name', csp.param_name,
|
||
'param_value', csp.param_value,
|
||
'param_unit', csp.param_unit,
|
||
'description', csp.description
|
||
)
|
||
),
|
||
'[]'
|
||
)
|
||
FROM custom_params csp
|
||
WHERE csp.equipment_id = e.id
|
||
AND csp.param_name IS NOT NULL
|
||
AND csp.param_value IS NOT NULL
|
||
) as custom_params
|
||
FROM equipment e
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE e.type = '火箭炮'
|
||
""")
|
||
rocket_artillery = cursor.fetchall()
|
||
logger.info(f"Found {len(rocket_artillery)} rocket artillery records")
|
||
if rocket_artillery:
|
||
logger.info(f"First rocket artillery: {rocket_artillery[0]['name']}")
|
||
logger.info(f"First rocket custom_params: {rocket_artillery[0].get('custom_params')}")
|
||
|
||
# 获取巡飞弹数据
|
||
logger.info("Fetching missile data...")
|
||
cursor.execute("""
|
||
SELECT
|
||
e.id,
|
||
e.name,
|
||
e.type,
|
||
e.manufacturer,
|
||
e.created_at,
|
||
cp.length_m,
|
||
cp.width_m,
|
||
cp.height_m,
|
||
cp.weight_kg,
|
||
cp.max_range_km,
|
||
lmp.wingspan_m,
|
||
lmp.warhead_weight_kg,
|
||
lmp.max_speed_ms,
|
||
lmp.cruise_speed_kmh,
|
||
lmp.flight_time_min,
|
||
lmp.warhead_type,
|
||
lmp.launch_mode,
|
||
lmp.folded_length_mm,
|
||
lmp.folded_width_mm,
|
||
lmp.folded_height_mm,
|
||
lmp.power_system,
|
||
lmp.guidance_system,
|
||
cd.actual_cost,
|
||
(
|
||
SELECT COALESCE(
|
||
JSON_ARRAYAGG(
|
||
JSON_OBJECT(
|
||
'id', csp.id,
|
||
'param_name', csp.param_name,
|
||
'param_value', csp.param_value,
|
||
'param_unit', csp.param_unit,
|
||
'description', csp.description
|
||
)
|
||
),
|
||
'[]'
|
||
)
|
||
FROM custom_params csp
|
||
WHERE csp.equipment_id = e.id
|
||
AND csp.param_name IS NOT NULL
|
||
AND csp.param_value IS NOT NULL
|
||
) as custom_params
|
||
FROM equipment e
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE e.type = '巡飞弹'
|
||
""")
|
||
loitering_munition = cursor.fetchall()
|
||
logger.info(f"Found {len(loitering_munition)} missile records")
|
||
if loitering_munition:
|
||
logger.info(f"First missile: {loitering_munition[0]['name']}")
|
||
logger.info(f"First missile custom_params: {loitering_munition[0].get('custom_params')}")
|
||
|
||
# 处理 custom_params,保为 NULL
|
||
for item in rocket_artillery + loitering_munition:
|
||
if item['custom_params'] is None:
|
||
item['custom_params'] = []
|
||
logger.debug(f"Set empty custom_params for equipment {item['id']}")
|
||
else:
|
||
logger.debug(f"Equipment {item['id']} has {len(item['custom_params'])} custom params")
|
||
|
||
logger.info("Data fetching completed")
|
||
return jsonify({
|
||
'rocket_artillery': rocket_artillery,
|
||
'loitering_munition': loitering_munition
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting equipment data: {str(e)}")
|
||
logger.error("Detailed traceback:", exc_info=True)
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/data/<int:id>', methods=['DELETE'])
|
||
def delete_equipment(id):
|
||
"""
|
||
删除装备数据
|
||
"""
|
||
try:
|
||
db = get_db_connection()
|
||
cursor = db.cursor()
|
||
|
||
# 删除相关数据
|
||
cursor.execute("DELETE FROM cost_data WHERE equipment_id = %s", (id,))
|
||
cursor.execute("DELETE FROM rocket_artillery_params WHERE equipment_id = %s", (id,))
|
||
cursor.execute("DELETE FROM loitering_munition_params WHERE equipment_id = %s", (id,))
|
||
cursor.execute("DELETE FROM common_params WHERE equipment_id = %s", (id,))
|
||
cursor.execute("DELETE FROM equipment WHERE id = %s", (id,))
|
||
|
||
db.commit()
|
||
cursor.close()
|
||
db.close()
|
||
|
||
return jsonify({'status': 'success'})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error deleting equipment: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/data/template', methods=['GET'])
|
||
def download_template():
|
||
"""
|
||
下载数据模板
|
||
"""
|
||
try:
|
||
# 创建模板文件
|
||
from .create_template import create_excel_template
|
||
template_path = create_excel_template()
|
||
|
||
# 检查文件是否存
|
||
if not os.path.exists(template_path):
|
||
raise FileNotFoundError("模板文件不存在")
|
||
|
||
# 返回文件
|
||
return send_file(
|
||
template_path,
|
||
as_attachment=True,
|
||
download_name='equipment_data_template.xlsx',
|
||
mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error creating template: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
def get_db_connection():
|
||
"""
|
||
获取数据库连接
|
||
"""
|
||
return mysql.connector.connect(
|
||
host="localhost",
|
||
user="root",
|
||
password="123456",
|
||
database="equipment_cost_db"
|
||
)
|
||
|
||
@api_bp.route('/pls/predict', methods=['POST'])
|
||
def pls_predict():
|
||
"""
|
||
PLS回归预测接口
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}")
|
||
|
||
# 验证装备类型
|
||
if 'type' not in data:
|
||
return jsonify({'error': 'Equipment type is required'}), 400
|
||
|
||
# 使用 ModelTrainer 中的 PLS 模型进行预测
|
||
trainer = ModelTrainer()
|
||
if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型
|
||
return jsonify({'error': '未找到可用的模型'}), 404
|
||
|
||
# 准备特征数据
|
||
feature_analyzer = FeatureAnalysis()
|
||
features = feature_analyzer.get_equipment_specific_features(data['type'])
|
||
X = np.array([[data.get(feature) for feature in features]])
|
||
|
||
# 预测
|
||
result = trainer.predict(X)
|
||
|
||
# 计算置信区间
|
||
confidence_interval = trainer._calculate_confidence_interval(result[0])
|
||
|
||
# 获取模型信息
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("""
|
||
SELECT model_type, model_name, r2_score, mae, rmse
|
||
FROM trained_models
|
||
WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE
|
||
LIMIT 1
|
||
""", (data['type'],))
|
||
model_info = cursor.fetchone()
|
||
|
||
# 确保返回的数据可以序列化为JSON
|
||
response = {
|
||
'predicted_cost': float(result[0]),
|
||
'model_info': {
|
||
'type': model_info['model_type'],
|
||
'name': model_info['model_name'],
|
||
'r2_score': model_info['r2_score'],
|
||
'mae': model_info['mae'],
|
||
'rmse': model_info['rmse']
|
||
},
|
||
'confidence_interval': {
|
||
'lower': float(confidence_interval[0]),
|
||
'upper': float(confidence_interval[1])
|
||
}
|
||
}
|
||
|
||
logger.info(f"PLS prediction completed: {response}")
|
||
return jsonify(response)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in PLS prediction: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/data/import', methods=['POST'])
|
||
def import_data():
|
||
"""
|
||
导入数据接口
|
||
"""
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({'error': '没有上传文件'}), 400
|
||
|
||
file = request.files['file']
|
||
if not file.filename.endswith(('.xls', '.xlsx')):
|
||
return jsonify({'error': '请上传Excel文件'}), 400
|
||
|
||
# 保存上的文件
|
||
upload_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
file_path = os.path.join(upload_dir, file.filename)
|
||
file.save(file_path)
|
||
|
||
# 导入数据
|
||
from .import_data import import_training_data
|
||
import_training_data(file_path)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '数据导入成功'
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error importing data: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/data/<int:id>', methods=['PUT'])
|
||
def update_equipment(id):
|
||
"""
|
||
更新装备数据
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
logger.info(f"Updating equipment ID: {id}")
|
||
logger.info(f"Update data: {data}")
|
||
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 更新基本信息
|
||
cursor.execute("""
|
||
UPDATE equipment
|
||
SET name = %s, manufacturer = %s
|
||
WHERE id = %s
|
||
""", (data['name'], data['manufacturer'], id))
|
||
logger.info("Basic info updated")
|
||
|
||
# 更新通用参数
|
||
cursor.execute("""
|
||
UPDATE common_params
|
||
SET length_m = %s, width_m = %s, height_m = %s,
|
||
weight_kg = %s, max_range_km = %s
|
||
WHERE equipment_id = %s
|
||
""", (
|
||
data['length_m'], data['width_m'], data['height_m'],
|
||
data['weight_kg'], data['max_range_km'], id
|
||
))
|
||
logger.info("Common params updated")
|
||
|
||
# 根据备类型更新特有参数
|
||
if data['type'] == '火箭炮':
|
||
cursor.execute("""
|
||
UPDATE rocket_artillery_params
|
||
SET firing_angle_horizontal = %s, firing_angle_vertical = %s,
|
||
rocket_length_m = %s, rocket_diameter_mm = %s,
|
||
rocket_weight_kg = %s, rate_of_fire = %s
|
||
WHERE equipment_id = %s
|
||
""", (
|
||
data['firing_angle_horizontal'], data['firing_angle_vertical'],
|
||
data['rocket_length_m'], data['rocket_diameter_mm'],
|
||
data['rocket_weight_kg'], data['rate_of_fire'], id
|
||
))
|
||
logger.info("Rocket artillery params updated")
|
||
else:
|
||
cursor.execute("""
|
||
UPDATE loitering_munition_params
|
||
SET max_speed_ms = %s, cruise_speed_kmh = %s,
|
||
flight_time_min = %s, warhead_type = %s,
|
||
launch_mode = %s, folded_length_mm = %s,
|
||
folded_width_mm = %s, folded_height_mm = %s
|
||
WHERE equipment_id = %s
|
||
""", (
|
||
data['max_speed_ms'], data['cruise_speed_kmh'],
|
||
data['flight_time_min'], data['warhead_type'],
|
||
data['launch_mode'], data['folded_length_mm'],
|
||
data['folded_width_mm'], data['folded_height_mm'], id
|
||
))
|
||
logger.info("Missile params updated")
|
||
|
||
# 更新成本数据
|
||
if 'actual_cost' in data:
|
||
cursor.execute("""
|
||
UPDATE cost_data
|
||
SET actual_cost = %s
|
||
WHERE equipment_id = %s
|
||
""", (data['actual_cost'], id))
|
||
logger.info("Cost data updated")
|
||
|
||
# 更新特殊参数
|
||
if 'custom_params' in data and data['custom_params']:
|
||
logger.info(f"Updating custom params: {data['custom_params']}")
|
||
for param in data['custom_params']:
|
||
cursor.execute("""
|
||
UPDATE custom_params
|
||
SET param_value = %s
|
||
WHERE id = %s AND equipment_id = %s
|
||
""", (param['param_value'], param['id'], id))
|
||
logger.info("Custom params updated")
|
||
|
||
conn.commit()
|
||
logger.info("All updates committed successfully")
|
||
|
||
return jsonify({'success': True})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error updating equipment: {str(e)}")
|
||
logger.error("Detailed traceback:", exc_info=True)
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/data/details/<int:id>', methods=['GET'])
|
||
def get_equipment_details(id):
|
||
"""
|
||
获取装备详数据
|
||
"""
|
||
try:
|
||
logger.info(f"Getting details for equipment ID: {id}")
|
||
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 先获取装备类型
|
||
cursor.execute("SELECT type FROM equipment WHERE id = %s", (id,))
|
||
equipment = cursor.fetchone()
|
||
|
||
if not equipment:
|
||
logger.warning(f"Equipment not found: {id}")
|
||
return jsonify({'error': 'Equipment not found'}), 404
|
||
|
||
equipment_type = equipment['type']
|
||
logger.info(f"Equipment type: {equipment_type}")
|
||
|
||
# 根据装备类型选择查询
|
||
if equipment_type == '火箭炮':
|
||
query = """
|
||
SELECT
|
||
e.*,
|
||
cp.*,
|
||
rap.*,
|
||
cd.actual_cost,
|
||
cd.prediction_date as cost_estimate_date,
|
||
cd.predicted_cost,
|
||
(
|
||
SELECT JSON_ARRAYAGG(
|
||
CASE
|
||
WHEN csp.id IS NOT NULL THEN
|
||
JSON_OBJECT(
|
||
'id', csp.id,
|
||
'param_name', csp.param_name,
|
||
'param_value', csp.param_value,
|
||
'param_unit', csp.param_unit,
|
||
'description', csp.description
|
||
)
|
||
END
|
||
)
|
||
FROM custom_params csp
|
||
WHERE csp.equipment_id = e.id
|
||
AND csp.param_name IS NOT NULL
|
||
AND csp.param_value IS NOT NULL
|
||
) as custom_params
|
||
FROM equipment e
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE e.id = %s
|
||
"""
|
||
else:
|
||
query = """
|
||
SELECT
|
||
e.*,
|
||
cp.*,
|
||
lmp.*,
|
||
cd.actual_cost,
|
||
cd.prediction_date as cost_estimate_date,
|
||
cd.predicted_cost,
|
||
(
|
||
SELECT JSON_ARRAYAGG(
|
||
CASE
|
||
WHEN csp.id IS NOT NULL THEN
|
||
JSON_OBJECT(
|
||
'id', csp.id,
|
||
'param_name', csp.param_name,
|
||
'param_value', csp.param_value,
|
||
'param_unit', csp.param_unit,
|
||
'description', csp.description
|
||
)
|
||
END
|
||
)
|
||
FROM custom_params csp
|
||
WHERE csp.equipment_id = e.id
|
||
AND csp.param_name IS NOT NULL
|
||
AND csp.param_value IS NOT NULL
|
||
) as custom_params
|
||
FROM equipment e
|
||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE e.id = %s
|
||
"""
|
||
|
||
cursor.execute(query, (id,))
|
||
result = cursor.fetchone()
|
||
|
||
if result:
|
||
logger.info(f"Found equipment details: {result['name']}")
|
||
logger.info(f"Custom params: {result.get('custom_params')}")
|
||
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting equipment details: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
# 添加数据集相关的路由
|
||
@api_bp.route('/datasets', methods=['GET'])
|
||
def get_datasets():
|
||
"""
|
||
获取数据集列表
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("""
|
||
SELECT d.*,
|
||
COUNT(de.equipment_id) as equipment_count,
|
||
GROUP_CONCAT(e.name) as equipment_names
|
||
FROM datasets d
|
||
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
|
||
LEFT JOIN equipment e ON de.equipment_id = e.id
|
||
GROUP BY d.id
|
||
""")
|
||
datasets = cursor.fetchall()
|
||
|
||
# 理装备名称列表
|
||
for dataset in datasets:
|
||
if dataset['equipment_names']:
|
||
dataset['equipment_names'] = dataset['equipment_names'].split(',')
|
||
else:
|
||
dataset['equipment_names'] = []
|
||
|
||
return jsonify(datasets)
|
||
except Exception as e:
|
||
logger.error(f"Error getting datasets: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/datasets/<int:id>', methods=['GET'])
|
||
def get_dataset(id):
|
||
"""
|
||
获取数据集详情
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
# 获取数据集基本信息
|
||
cursor.execute("""
|
||
SELECT d.*,
|
||
COUNT(de.equipment_id) as equipment_count
|
||
FROM datasets d
|
||
LEFT JOIN dataset_equipment de ON d.id = de.dataset_id
|
||
WHERE d.id = %s
|
||
GROUP BY d.id
|
||
""", (id,))
|
||
dataset = cursor.fetchone()
|
||
|
||
if not dataset:
|
||
return jsonify({'error': 'Dataset not found'}), 404
|
||
|
||
# 获取数据集中的装备
|
||
cursor.execute("""
|
||
SELECT e.*, cd.actual_cost
|
||
FROM equipment e
|
||
JOIN dataset_equipment de ON e.id = de.equipment_id
|
||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||
WHERE de.dataset_id = %s
|
||
""", (id,))
|
||
equipment = cursor.fetchall()
|
||
|
||
# 计算统计信息
|
||
if equipment:
|
||
total_cost = sum(item['actual_cost'] or 0 for item in equipment)
|
||
avg_cost = total_cost / len(equipment)
|
||
dataset['statistics'] = {
|
||
'equipment_count': len(equipment),
|
||
'total_cost': total_cost,
|
||
'average_cost': avg_cost
|
||
}
|
||
else:
|
||
dataset['statistics'] = {
|
||
'equipment_count': 0,
|
||
'total_cost': 0,
|
||
'average_cost': 0
|
||
}
|
||
|
||
dataset['equipment'] = equipment
|
||
return jsonify(dataset)
|
||
except Exception as e:
|
||
logger.error(f"Error getting dataset: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/datasets', methods=['POST'])
|
||
def create_dataset():
|
||
"""
|
||
建数据集
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 创建数据集
|
||
cursor.execute("""
|
||
INSERT INTO datasets (name, description, equipment_type, purpose)
|
||
VALUES (%s, %s, %s, %s)
|
||
""", (data['name'], data['description'], data['equipment_type'], data['purpose']))
|
||
|
||
dataset_id = cursor.lastrowid
|
||
|
||
# 添加装备关联
|
||
if 'equipment_ids' in data and data['equipment_ids']:
|
||
values = [(dataset_id, equipment_id) for equipment_id in data['equipment_ids']]
|
||
cursor.executemany("""
|
||
INSERT INTO dataset_equipment (dataset_id, equipment_id)
|
||
VALUES (%s, %s)
|
||
""", values)
|
||
|
||
conn.commit()
|
||
return jsonify({'id': dataset_id, 'message': '数据集创建成功'})
|
||
except Exception as e:
|
||
logger.error(f"Error creating dataset: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/datasets/<int:id>', methods=['PUT'])
|
||
def update_dataset(id):
|
||
"""
|
||
更新数据集
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 更新数据集基本信息
|
||
cursor.execute("""
|
||
UPDATE datasets
|
||
SET name = %s, description = %s, equipment_type = %s, purpose = %s
|
||
WHERE id = %s
|
||
""", (data['name'], data['description'], data['equipment_type'], data['purpose'], id))
|
||
|
||
# 删除旧的装备关联
|
||
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
|
||
|
||
# 加新的装备关联
|
||
if 'equipment_ids' in data:
|
||
for equipment_id in data['equipment_ids']:
|
||
cursor.execute("""
|
||
INSERT INTO dataset_equipment (dataset_id, equipment_id)
|
||
VALUES (%s, %s)
|
||
""", (id, equipment_id))
|
||
|
||
conn.commit()
|
||
return jsonify({'success': True})
|
||
except Exception as e:
|
||
logger.error(f"Error updating dataset: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/datasets/<int:id>', methods=['DELETE'])
|
||
def delete_dataset(id):
|
||
"""
|
||
删除数据集
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 删除装备关联
|
||
cursor.execute("DELETE FROM dataset_equipment WHERE dataset_id = %s", (id,))
|
||
|
||
# 删除数据集
|
||
cursor.execute("DELETE FROM datasets WHERE id = %s", (id,))
|
||
|
||
conn.commit()
|
||
return jsonify({'success': True})
|
||
except Exception as e:
|
||
logger.error(f"Error deleting dataset: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/models/<equipment_type>/latest', methods=['GET'])
|
||
def get_latest_model(equipment_type):
|
||
"""
|
||
获取最新训练的型信息
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("""
|
||
SELECT * FROM trained_models
|
||
WHERE equipment_type = %s AND is_active = TRUE
|
||
ORDER BY training_date DESC LIMIT 1
|
||
""", (equipment_type,))
|
||
|
||
model = cursor.fetchone()
|
||
return jsonify(model)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting latest model: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/models', methods=['GET'])
|
||
def get_models():
|
||
"""
|
||
获取模型列表
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("""
|
||
SELECT * FROM trained_models
|
||
ORDER BY training_date DESC
|
||
""")
|
||
|
||
models = cursor.fetchall()
|
||
|
||
# 确保数值类型字段是 float
|
||
for model in models:
|
||
if model['r2_score'] is not None:
|
||
model['r2_score'] = float(model['r2_score'])
|
||
if model['mae'] is not None:
|
||
model['mae'] = float(model['mae'])
|
||
if model['rmse'] is not None:
|
||
model['rmse'] = float(model['rmse'])
|
||
|
||
# 解析特征重要性
|
||
if model['feature_importance']:
|
||
model['feature_importance'] = json.loads(model['feature_importance'])
|
||
|
||
return jsonify(models)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting models: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/models/<int:id>/activate', methods=['POST'])
|
||
def activate_model(id):
|
||
"""
|
||
激活指定的模型
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 获取模型信息
|
||
cursor.execute("""
|
||
SELECT equipment_type FROM trained_models
|
||
WHERE id = %s
|
||
""", (id,))
|
||
model = cursor.fetchone()
|
||
|
||
if not model:
|
||
return jsonify({'error': 'Model not found'}), 404
|
||
|
||
# 将同类型的其他模型设置为非激活
|
||
cursor.execute("""
|
||
UPDATE trained_models
|
||
SET is_active = FALSE
|
||
WHERE equipment_type = %s
|
||
""", (model[0],))
|
||
|
||
# 激活指定模型
|
||
cursor.execute("""
|
||
UPDATE trained_models
|
||
SET is_active = TRUE
|
||
WHERE id = %s
|
||
""", (id,))
|
||
|
||
conn.commit()
|
||
return jsonify({'success': True})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error activating model: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/models/<int:id>', methods=['DELETE'])
|
||
def delete_model(id):
|
||
"""
|
||
删除指定的模型
|
||
"""
|
||
try:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 获取模型文件路径
|
||
cursor.execute("""
|
||
SELECT model_path, scaler_path
|
||
FROM trained_models
|
||
WHERE id = %s
|
||
""", (id,))
|
||
model = cursor.fetchone()
|
||
|
||
if not model:
|
||
return jsonify({'error': 'Model not found'}), 404
|
||
|
||
# 删除模型文件
|
||
if os.path.exists(model[0]):
|
||
os.remove(model[0])
|
||
if os.path.exists(model[1]):
|
||
os.remove(model[1])
|
||
|
||
# 删除数据库记录
|
||
cursor.execute("DELETE FROM trained_models WHERE id = %s", (id,))
|
||
conn.commit()
|
||
|
||
return jsonify({'success': True})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error deleting model: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
@api_bp.route('/predict/all', methods=['POST'])
|
||
def predict_all():
|
||
"""
|
||
获取所有机器学习模型的预测结果
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
logger.info(f"Received prediction request for all models, equipment type: {data.get('type')}")
|
||
|
||
predictor = CostPredictor()
|
||
results = predictor.predict_all(data)
|
||
|
||
return jsonify(results)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in prediction: {str(e)}")
|
||
return jsonify({'error': str(e)}), 500 |