更新预测模块,解决模型字符集问题
This commit is contained in:
parent
dba9f2fcc9
commit
e67da8eaed
@ -359,7 +359,8 @@ const formatDateTime = (value) => {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
hour12: false,
|
||||
timeZone: 'Asia/Shanghai'
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
{{ scope.row.rmse.toFixed(2) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="training_date" label="训练时间">
|
||||
<el-table-column prop="training_date" label="训练时间" width="180">
|
||||
<template #default="scope">
|
||||
{{ formatDateTime(scope.row.training_date) }}
|
||||
</template>
|
||||
@ -211,10 +211,12 @@ const renderImportanceChart = () => {
|
||||
// 格式化模型类型
|
||||
const formatModelType = (type) => {
|
||||
const typeMap = {
|
||||
'pytorch': 'PyTorch',
|
||||
'xgboost': 'XGBoost',
|
||||
'lightgbm': 'LightGBM',
|
||||
'gbdt': 'GBDT',
|
||||
'rf': 'Random Forest'
|
||||
'gbm': 'GBM',
|
||||
'rf': 'Random Forest',
|
||||
'pls': 'PLS回归'
|
||||
}
|
||||
return typeMap[type] || type
|
||||
}
|
||||
@ -230,7 +232,8 @@ const formatDateTime = (value) => {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
hour12: false,
|
||||
timeZone: 'Asia/Shanghai'
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -223,72 +223,46 @@ import { API_BASE_URL } from '@/config'
|
||||
|
||||
const formData = reactive({
|
||||
type: '',
|
||||
length_m: null,
|
||||
width_m: null,
|
||||
height_m: null,
|
||||
weight_kg: null,
|
||||
max_range_km: null,
|
||||
// 火箭炮特有参数
|
||||
firing_angle_horizontal: null,
|
||||
firing_angle_vertical: null,
|
||||
rocket_length_m: null,
|
||||
rocket_diameter_mm: null,
|
||||
rocket_weight_kg: null,
|
||||
rate_of_fire: null,
|
||||
combat_weight_kg: null,
|
||||
speed_kmh: null,
|
||||
min_range_km: null,
|
||||
mobility_type: '',
|
||||
structure_layout: '',
|
||||
engine_model: '',
|
||||
engine_params: '',
|
||||
power_hp: null,
|
||||
travel_range_km: null,
|
||||
// 巡飞弹特有参数 - 补充
|
||||
max_payload_kg: null, // 最大载荷
|
||||
ceiling_altitude_m: null, // 升限
|
||||
combat_radius_km: null, // 作战半径
|
||||
engine_power_kw: null, // 发动机功率
|
||||
engine_thrust_n: null, // 发动机推力
|
||||
datalink_range_km: null, // 通信链路距离
|
||||
guidance_accuracy_m: null, // 制导精度
|
||||
min_altitude_m: null, // 最小作战高度
|
||||
max_altitude_m: null, // 最大作战高度
|
||||
|
||||
// 特征工程参数
|
||||
length_width_ratio: null, // 长宽比
|
||||
weight_range_ratio: null, // 重量/射程比
|
||||
speed_weight_ratio: null, // 速度/重量比
|
||||
guidance_system_score: null, // 制导系统复杂度评分
|
||||
warhead_power_score: null // 战斗部威力评分
|
||||
length_m: 7.35,
|
||||
width_m: 2.4,
|
||||
height_m: 3.1,
|
||||
weight_kg: 13700,
|
||||
max_range_km: 20.4,
|
||||
firing_angle_horizontal: 102,
|
||||
firing_angle_vertical: 55,
|
||||
rocket_length_m: 2.87,
|
||||
rocket_diameter_mm: 122,
|
||||
rocket_weight_kg: 66.6,
|
||||
rate_of_fire: 40,
|
||||
combat_weight_kg: 15000,
|
||||
speed_kmh: 60,
|
||||
min_range_km: 5,
|
||||
mobility_type: '轮式',
|
||||
structure_layout: '6x6轮式底盘',
|
||||
engine_model: 'WD615',
|
||||
engine_params: '6缸直列柴油机',
|
||||
power_hp: 280,
|
||||
travel_range_km: 600,
|
||||
wingspan_m: 2.5,
|
||||
warhead_weight_kg: 20,
|
||||
max_speed_ms: 200,
|
||||
cruise_speed_kmh: 720,
|
||||
endurance_min: 30,
|
||||
warhead_type: '破片杀伤战斗部',
|
||||
launch_mode: '箱式发射',
|
||||
power_system: '电动机',
|
||||
guidance_system: 'GPS/INS/光电',
|
||||
max_payload_kg: 25,
|
||||
ceiling_altitude_m: 5000,
|
||||
combat_radius_km: 100,
|
||||
datalink_range_km: 150,
|
||||
guidance_accuracy_m: 3
|
||||
})
|
||||
|
||||
const predictionResults = ref(null)
|
||||
const mlPrediction = ref(null)
|
||||
const plsPrediction = ref(null)
|
||||
|
||||
const handleTypeChange = () => {
|
||||
// 重置特有参数
|
||||
if (formData.type === '火箭炮') {
|
||||
formData.firing_angle_horizontal = null
|
||||
formData.firing_angle_vertical = null
|
||||
formData.rocket_length_m = null
|
||||
formData.rocket_diameter_mm = null
|
||||
formData.rocket_weight_kg = null
|
||||
formData.rate_of_fire = null
|
||||
} else if (formData.type === '巡飞弹') {
|
||||
formData.wingspan_m = null
|
||||
formData.warhead_weight_kg = null
|
||||
formData.max_speed_ms = null
|
||||
formData.cruise_speed_kmh = null
|
||||
formData.endurance_min = null
|
||||
formData.warhead_type = ''
|
||||
formData.launch_mode = ''
|
||||
formData.power_system = ''
|
||||
formData.guidance_system = ''
|
||||
}
|
||||
}
|
||||
|
||||
const submitForm = async () => {
|
||||
try {
|
||||
// 验证必填字段
|
||||
@ -327,7 +301,7 @@ const submitForm = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取预测结果
|
||||
// 同时调用两个预测接口
|
||||
const [mlResponse, plsResponse] = await Promise.all([
|
||||
axios.post(`${API_BASE_URL}/predict`, formData),
|
||||
axios.post(`${API_BASE_URL}/pls/predict`, formData)
|
||||
@ -396,9 +370,96 @@ const getModelName = (modelType) => {
|
||||
}
|
||||
return modelNames[modelType] || modelType
|
||||
}
|
||||
|
||||
const handleTypeChange = () => {
|
||||
// 清空预测结果
|
||||
predictionResults.value = false
|
||||
mlPrediction.value = null
|
||||
plsPrediction.value = null
|
||||
|
||||
// 重置特有参数
|
||||
if (formData.type === '火箭炮') {
|
||||
// 设置火箭炮的默认值
|
||||
formData.length_m = 7.35
|
||||
formData.width_m = 2.4
|
||||
formData.height_m = 3.1
|
||||
formData.weight_kg = 13700
|
||||
formData.max_range_km = 20.4
|
||||
formData.firing_angle_horizontal = 102
|
||||
formData.firing_angle_vertical = 55
|
||||
formData.rocket_length_m = 2.87
|
||||
formData.rocket_diameter_mm = 122
|
||||
formData.rocket_weight_kg = 66.6
|
||||
formData.rate_of_fire = 40
|
||||
formData.combat_weight_kg = 15000
|
||||
formData.speed_kmh = 60
|
||||
formData.min_range_km = 5
|
||||
formData.mobility_type = '轮式'
|
||||
formData.structure_layout = '6x6轮式底盘'
|
||||
formData.engine_model = 'WD615'
|
||||
formData.engine_params = '6缸直列柴油机'
|
||||
formData.power_hp = 280
|
||||
formData.travel_range_km = 600
|
||||
|
||||
// 清空巡飞弹参数
|
||||
formData.wingspan_m = null
|
||||
formData.warhead_weight_kg = null
|
||||
formData.max_speed_ms = null
|
||||
formData.cruise_speed_kmh = null
|
||||
formData.endurance_min = null
|
||||
formData.warhead_type = ''
|
||||
formData.launch_mode = ''
|
||||
formData.power_system = ''
|
||||
formData.guidance_system = ''
|
||||
formData.max_payload_kg = null
|
||||
formData.ceiling_altitude_m = null
|
||||
formData.combat_radius_km = null
|
||||
formData.datalink_range_km = null
|
||||
formData.guidance_accuracy_m = null
|
||||
|
||||
} else if (formData.type === '巡飞弹') {
|
||||
// 设置巡飞弹的默认值
|
||||
formData.length_m = 2.5
|
||||
formData.width_m = 0.4
|
||||
formData.height_m = 0.4
|
||||
formData.weight_kg = 120
|
||||
formData.max_range_km = 100
|
||||
formData.wingspan_m = 2.5
|
||||
formData.warhead_weight_kg = 20
|
||||
formData.max_speed_ms = 200
|
||||
formData.cruise_speed_kmh = 720
|
||||
formData.endurance_min = 30
|
||||
formData.warhead_type = '破片杀伤战斗部'
|
||||
formData.launch_mode = '箱式发射'
|
||||
formData.power_system = '电动机'
|
||||
formData.guidance_system = 'GPS/INS/光电'
|
||||
formData.max_payload_kg = 25
|
||||
formData.ceiling_altitude_m = 5000
|
||||
formData.combat_radius_km = 100
|
||||
formData.datalink_range_km = 150
|
||||
formData.guidance_accuracy_m = 3
|
||||
|
||||
// 清空火箭炮参数
|
||||
formData.firing_angle_horizontal = null
|
||||
formData.firing_angle_vertical = null
|
||||
formData.rocket_length_m = null
|
||||
formData.rocket_diameter_mm = null
|
||||
formData.rocket_weight_kg = null
|
||||
formData.rate_of_fire = null
|
||||
formData.combat_weight_kg = null
|
||||
formData.speed_kmh = null
|
||||
formData.min_range_km = null
|
||||
formData.mobility_type = ''
|
||||
formData.structure_layout = ''
|
||||
formData.engine_model = ''
|
||||
formData.engine_params = ''
|
||||
formData.power_hp = null
|
||||
formData.travel_range_km = null
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
<style lang="scss" scoped>
|
||||
.predict-page {
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
<template>
|
||||
<div class="training-page">
|
||||
<!-- 上部分:模型训练区域 -->
|
||||
<el-card class="training-card">
|
||||
<template #header>
|
||||
<h2>模型训练</h2>
|
||||
@ -38,11 +39,12 @@
|
||||
|
||||
<el-form-item label="选择模型">
|
||||
<el-checkbox-group v-model="trainingConfig.models">
|
||||
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
|
||||
<el-checkbox value="pytorch" checked>PyTorch</el-checkbox>
|
||||
<el-checkbox value="xgboost" checked>XGBoost</el-checkbox>
|
||||
<el-checkbox value="lightgbm" checked>LightGBM</el-checkbox>
|
||||
<el-checkbox value="gbm" checked>GBM</el-checkbox>
|
||||
<el-checkbox value="rf" checked>Random Forest</el-checkbox>
|
||||
<el-checkbox value="pls" disabled checked>PLS回归</el-checkbox>
|
||||
</el-checkbox-group>
|
||||
</el-form-item>
|
||||
|
||||
@ -130,6 +132,159 @@
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<!-- 下部分:模型简介区域 -->
|
||||
<el-card class="model-intro-card">
|
||||
<template #header>
|
||||
<h2>模型简介</h2>
|
||||
</template>
|
||||
|
||||
<el-collapse>
|
||||
<el-collapse-item name="pytorch">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">PyTorch</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>深度学习框架,可以构建复杂的神经网络结构</li>
|
||||
<li>分别处理装备特征和生产商特征,然后合并进行预测</li>
|
||||
<li>使用批量归一化和Dropout防止过拟合</li>
|
||||
<li>适合处理非线性关系和复杂特征交互</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>强大的特征学习能力</li>
|
||||
<li>可以自动学习特征之间的复杂关系</li>
|
||||
<li>灵活的网络结构设计</li>
|
||||
<li>支持GPU加速训练</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="xgboost">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">XGBoost</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>基于梯度提升树的集成学习算法</li>
|
||||
<li>使用二阶导数进行优化</li>
|
||||
<li>内置正则化机制防止过拟合</li>
|
||||
<li>支持特征重要性评估</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>优秀的预测性能</li>
|
||||
<li>处理缺失值的能力强</li>
|
||||
<li>训练速度快</li>
|
||||
<li>可解释性好</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="lightgbm">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">LightGBM</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>微软开发的轻量级梯度提升框架</li>
|
||||
<li>使用直方图算法优化训练速度</li>
|
||||
<li>支持类别特征的高效处理</li>
|
||||
<li>叶子优先的生长策略</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>训练速度非常快</li>
|
||||
<li>内存占用低</li>
|
||||
<li>支持大规模数据训练</li>
|
||||
<li>准确率高</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="gbm">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">Gradient Boosting (GBM)</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>经典的梯度提升算法</li>
|
||||
<li>逐步减少残差的思想</li>
|
||||
<li>可以使用不同的损失函数</li>
|
||||
<li>支持特征重要性分析</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>稳定的性能表现</li>
|
||||
<li>较好的可解释性</li>
|
||||
<li>对异常值不敏感</li>
|
||||
<li>适合各种回归问题</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="rf">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">Random Forest</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>基于决策树的集成学习方法</li>
|
||||
<li>使用随机采样和特征选择</li>
|
||||
<li>多个决策树投票或平均</li>
|
||||
<li>自带特征重要性评估</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>不易过拟合</li>
|
||||
<li>训练过程可并行化</li>
|
||||
<li>对噪声数据鲁棒</li>
|
||||
<li>较少的参数调整</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
|
||||
<el-collapse-item name="pls">
|
||||
<template #title>
|
||||
<span class="model-title">
|
||||
<el-link type="primary" :underline="false">PLS回归</el-link>
|
||||
</span>
|
||||
</template>
|
||||
<div class="model-intro">
|
||||
<h4>特点:</h4>
|
||||
<ul>
|
||||
<li>偏最小二乘回归</li>
|
||||
<li>同时考虑自变量和因变量的变异</li>
|
||||
<li>处理多重共线性问题</li>
|
||||
<li>降维和回归的结合</li>
|
||||
</ul>
|
||||
<h4>优势:</h4>
|
||||
<ul>
|
||||
<li>适合小样本数据</li>
|
||||
<li>处理变量间相关性强的数据</li>
|
||||
<li>计算效率高</li>
|
||||
<li>结果稳定可靠</li>
|
||||
</ul>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
</el-collapse>
|
||||
</el-card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@ -144,7 +299,7 @@ const trainingConfig = ref({
|
||||
type: '',
|
||||
train_dataset_id: null,
|
||||
validation_dataset_id: null,
|
||||
models: ['xgboost', 'lightgbm', 'gbm', 'rf']
|
||||
models: ['pytorch', 'xgboost', 'lightgbm', 'gbm', 'rf']
|
||||
})
|
||||
|
||||
// 数据集列表
|
||||
@ -241,10 +396,12 @@ const formatNumber = (value) => {
|
||||
// 获取模型中文名称
|
||||
const getModelName = (modelType) => {
|
||||
const modelNames = {
|
||||
'pytorch': 'PyTorch',
|
||||
'xgboost': 'XGBoost',
|
||||
'lightgbm': 'LightGBM',
|
||||
'gbm': 'GBM',
|
||||
'rf': 'Random Forest'
|
||||
'rf': 'Random Forest',
|
||||
'pls': 'PLS回归'
|
||||
}
|
||||
return modelNames[modelType] || modelType
|
||||
}
|
||||
@ -334,8 +491,13 @@ onMounted(() => {
|
||||
<style lang="scss" scoped>
|
||||
.training-page {
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
|
||||
.training-card {
|
||||
width: 100%;
|
||||
|
||||
.training-result {
|
||||
margin-top: 20px;
|
||||
|
||||
@ -366,5 +528,62 @@ onMounted(() => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.model-intro-card {
|
||||
width: 100%;
|
||||
|
||||
.model-title {
|
||||
.el-link {
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
|
||||
&:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
&:active {
|
||||
opacity: 0.6;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.model-intro {
|
||||
padding: 10px;
|
||||
|
||||
h4 {
|
||||
margin: 10px 0;
|
||||
color: #409EFF;
|
||||
font-size: 15px;
|
||||
}
|
||||
|
||||
ul {
|
||||
padding-left: 20px;
|
||||
margin: 5px 0;
|
||||
|
||||
li {
|
||||
line-height: 1.8;
|
||||
color: #606266;
|
||||
font-size: 14px;
|
||||
|
||||
&:hover {
|
||||
color: #409EFF;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.el-collapse-item__header) {
|
||||
padding: 12px 0;
|
||||
font-size: 16px;
|
||||
|
||||
&:hover {
|
||||
background-color: #f5f7fa;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.el-collapse-item__content) {
|
||||
padding: 10px 20px;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@ -26,6 +26,8 @@ dependencies = [
|
||||
"torch==2.5.1",
|
||||
"torchvision==0.20.1",
|
||||
"torchaudio==2.5.1",
|
||||
"xgboost>=2.1.0", # 添加 XGBoost
|
||||
"lightgbm>=4.5.0", # 添加 LightGBM
|
||||
|
||||
# 工具
|
||||
"openpyxl>=3.1.5", # Excel支持
|
||||
|
||||
@ -6,6 +6,8 @@ cryptography>=43.0.0 # MySQL 8.0+ 认证需要
|
||||
mysql-connector-python>=8.0.0 # 添加这行
|
||||
numpy>=1.26.0,<2.0.0
|
||||
pandas>=2.2.0
|
||||
xgboost>=2.1.0
|
||||
lightgbm>=4.5.0
|
||||
|
||||
scikit-learn>=1.5.2
|
||||
|
||||
|
||||
@ -81,50 +81,64 @@ class CostPredictor:
|
||||
self.model = DefaultModel(X.shape[1]).to(self.device)
|
||||
self.equipment_type = '火箭炮'
|
||||
|
||||
def predict(self, data):
|
||||
"""
|
||||
使用训练好的模型进行预测
|
||||
"""
|
||||
def predict(self, data, model_record):
|
||||
"""使用训练好的模型进行预测"""
|
||||
try:
|
||||
logger.info(f"Starting prediction for {data.get('type')}")
|
||||
logger.info(f"Starting prediction for {data.get('type')} using {model_record['model_type']}")
|
||||
equipment_type = data.get('type')
|
||||
|
||||
# 加载已训练的最优模型
|
||||
trainer = ModelTrainer()
|
||||
if not trainer.load_model(equipment_type):
|
||||
raise ValueError(f"No trained model found for {equipment_type}")
|
||||
# 使用ModelTrainer加载模型
|
||||
model_trainer = ModelTrainer()
|
||||
success = model_trainer.load_model(equipment_type, model_record['model_type'])
|
||||
if not success:
|
||||
raise ValueError(f"Failed to load model for {equipment_type}")
|
||||
|
||||
# 从ModelTrainer获取模型和标准化器
|
||||
model = model_trainer.model
|
||||
feature_scaler = model_trainer.feature_scaler
|
||||
target_scaler = model_trainer.target_scaler
|
||||
|
||||
# 准备特征数据
|
||||
features = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
feature_analyzer = FeatureAnalysis()
|
||||
features = feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
X = []
|
||||
for feature in features:
|
||||
value = data.get(feature, 0.0)
|
||||
X.append(float(value))
|
||||
|
||||
# 转换为 tensor
|
||||
X = torch.tensor([X], dtype=torch.float32).to(self.device)
|
||||
# 转换为numpy数组并标准化
|
||||
X = np.array([X])
|
||||
X_scaled = feature_scaler.transform(X)
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
trainer.model.eval() # 设置为评估模式
|
||||
y_pred = trainer.model(X)
|
||||
# 根据模型类型进行预测
|
||||
if isinstance(model, torch.nn.Module):
|
||||
# PyTorch模型预测
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
X_tensor = torch.FloatTensor(X_scaled).to(self.device)
|
||||
y_pred = model(X_tensor)
|
||||
y_pred = y_pred.cpu().numpy()
|
||||
elif model_record['model_type'] == 'pls':
|
||||
# PLS模型预测
|
||||
y_pred = model.predict(X_scaled).reshape(-1, 1)
|
||||
else:
|
||||
# 其他sklearn模型预测
|
||||
y_pred = model.predict(X_scaled).reshape(-1, 1)
|
||||
|
||||
# 转回 numpy
|
||||
y_pred = y_pred.cpu().numpy()
|
||||
# 转换回原始尺度并确保为正数
|
||||
y_pred_original = target_scaler.inverse_transform(y_pred)
|
||||
predicted_cost = abs(float(y_pred_original[0][0])) # 确保预测值为正数
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = self._calculate_confidence_interval(y_pred[0])
|
||||
|
||||
# 获取模型类型
|
||||
model_type = trainer.get_model_type()
|
||||
std = predicted_cost * 0.2 # 使用预测值的20%作为标准差
|
||||
confidence_interval = {
|
||||
'lower': max(predicted_cost - std, predicted_cost * 0.5), # 至少是预测值的50%
|
||||
'upper': predicted_cost + std
|
||||
}
|
||||
|
||||
return {
|
||||
'predicted_cost': float(y_pred[0]),
|
||||
'model_type': model_type,
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
'predicted_cost': predicted_cost,
|
||||
'confidence_interval': confidence_interval
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -30,34 +30,11 @@ class DataPreparation:
|
||||
self.target_scaler = StandardScaler()
|
||||
|
||||
def prepare_training_data(self, equipment_data, equipment_type, batch_size=32):
|
||||
"""
|
||||
准备训练数据
|
||||
"""
|
||||
"""准备训练数据"""
|
||||
try:
|
||||
logger.info(f"Preparing training data for {equipment_type}")
|
||||
logger.info(f"Raw data size: {len(equipment_data)}")
|
||||
|
||||
# 如果输入已经是 numpy 数组,转换为 torch.Tensor
|
||||
if isinstance(equipment_data, np.ndarray):
|
||||
X = equipment_data
|
||||
logger.info(f"Input is numpy array with shape: {X.shape}")
|
||||
|
||||
# 处理无效值
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 转换为 PyTorch 数据集
|
||||
dataset = EquipmentDataset(X)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
return {
|
||||
'dataloader': dataloader,
|
||||
'feature_names': self.feature_analyzer.get_equipment_specific_features(equipment_type),
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler,
|
||||
'raw_shape': X.shape
|
||||
}
|
||||
|
||||
# 从原始数据中提取特征和目标值
|
||||
# 获取特征名称(包含生产商特征)
|
||||
feature_names = self.feature_analyzer.get_equipment_specific_features(equipment_type)
|
||||
features = []
|
||||
targets = []
|
||||
@ -66,18 +43,24 @@ class DataPreparation:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 获取所有生产商数据,用于计算特征
|
||||
cursor.execute("""
|
||||
SELECT * FROM manufacturers
|
||||
""")
|
||||
manufacturers = {row['id']: row for row in cursor.fetchall()}
|
||||
|
||||
for item in equipment_data:
|
||||
# 获取该装备的生产商数据
|
||||
manufacturer_data = self._get_manufacturer_data(item['manufacturer'], cursor)
|
||||
# 获取生产商数据
|
||||
manufacturer = manufacturers.get(item['manufacturer_id'], {})
|
||||
|
||||
# 计算生产商特征
|
||||
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer_data)
|
||||
manufacturer_features = self.feature_analyzer.calculate_manufacturer_features(manufacturer)
|
||||
|
||||
# 合并装备特征和生产商特征
|
||||
feature_values = []
|
||||
for name in feature_names:
|
||||
if name in manufacturer_features:
|
||||
value = manufacturer_features[name]
|
||||
if name.startswith('manufacturer_'):
|
||||
value = manufacturer_features.get(name, 0.0)
|
||||
else:
|
||||
value = item.get(name)
|
||||
feature_values.append(float(value) if value is not None else 0.0)
|
||||
@ -89,7 +72,7 @@ class DataPreparation:
|
||||
X = np.array(features, dtype=float)
|
||||
y = np.array(targets, dtype=float)
|
||||
|
||||
# 记录原始数据范围
|
||||
# 记录数据范围
|
||||
logger.info(f"Raw X range: min={X.min()}, max={X.max()}")
|
||||
logger.info(f"Raw y range: min={y.min()}, max={y.max()}")
|
||||
|
||||
@ -97,7 +80,7 @@ class DataPreparation:
|
||||
X_scaled = self.feature_scaler.fit_transform(X)
|
||||
y_scaled = self.target_scaler.fit_transform(y.reshape(-1, 1)).ravel()
|
||||
|
||||
# 创建 PyTorch 数据集和数据加载器
|
||||
# 创建数据集和数据加载器
|
||||
dataset = EquipmentDataset(X_scaled, y_scaled)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
@ -106,7 +89,9 @@ class DataPreparation:
|
||||
'feature_names': feature_names,
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler,
|
||||
'raw_shape': X.shape
|
||||
'raw_shape': X.shape,
|
||||
'X': X_scaled,
|
||||
'y': y_scaled
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -165,7 +150,6 @@ class DataPreparation:
|
||||
# 提取目标值(成本)并验证范围
|
||||
try:
|
||||
cost = float(item['actual_cost'])
|
||||
logger.info(f"Raw cost value: {cost}")
|
||||
if cost > 0: # 只使用正数成本值
|
||||
targets.append(cost)
|
||||
else:
|
||||
|
||||
@ -4,6 +4,7 @@ import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
@ -12,26 +13,104 @@ from src.feature_analysis import FeatureAnalysis
|
||||
from src.database import get_db_connection
|
||||
from src.data_preparation import DataPreparation, EquipmentDataset
|
||||
from .logger import setup_logger
|
||||
import math
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
class CostPredictionModel(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
def __init__(self, input_size, equipment_type):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(input_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1)
|
||||
)
|
||||
self.equipment_type = equipment_type
|
||||
|
||||
if equipment_type == '火箭炮':
|
||||
# 火箭炮使用更简单和稳定的网络结构
|
||||
self.net = nn.Sequential(
|
||||
# 第一层:特征映射
|
||||
nn.Linear(input_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(32),
|
||||
|
||||
# 第二层:特征提取
|
||||
nn.Linear(32, 16),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(16),
|
||||
|
||||
# 第三层:特征整合
|
||||
nn.Linear(16, 8),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(8),
|
||||
|
||||
# 输出层
|
||||
nn.Linear(8, 1)
|
||||
)
|
||||
|
||||
# 使用正交初始化
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight, gain=0.5)
|
||||
torch.nn.init.constant_(m.bias, 0.0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
torch.nn.init.constant_(m.weight, 0.5)
|
||||
torch.nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
self.net.apply(init_weights)
|
||||
|
||||
else: # 巡飞弹保持原有结构
|
||||
# 生产商特征网络 - 更简单的结构
|
||||
self.manufacturer_net = nn.Sequential(
|
||||
nn.Linear(5, 4),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(4),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# 巡飞弹特征网络 - 较深的结构
|
||||
self.equipment_net = nn.Sequential(
|
||||
nn.Linear(input_size - 5, 64),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(64, 32),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(32, 16),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(16),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# 合并网络 - 较复杂的结构
|
||||
self.combined_net = nn.Sequential(
|
||||
nn.Linear(20, 32), # 4 + 16 = 20
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(32, 16),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(16),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(16, 8),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.BatchNorm1d(8),
|
||||
nn.Linear(8, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
if self.equipment_type == '火箭炮':
|
||||
return self.net(x)
|
||||
else:
|
||||
# 分离特征
|
||||
manufacturer_features = x[:, -5:]
|
||||
equipment_features = x[:, :-5]
|
||||
|
||||
# 特征处理
|
||||
manu_out = self.manufacturer_net(manufacturer_features)
|
||||
equip_out = self.equipment_net(equipment_features)
|
||||
|
||||
# 特征融合
|
||||
combined = torch.cat([equip_out, manu_out], dim=1)
|
||||
return self.combined_net(combined)
|
||||
|
||||
class ModelTrainer:
|
||||
def __init__(self):
|
||||
@ -42,24 +121,86 @@ class ModelTrainer:
|
||||
self.equipment_type = None
|
||||
self.feature_analyzer = FeatureAnalysis()
|
||||
|
||||
def train_model(self, dataloader, epochs=100, learning_rate=0.001):
|
||||
def train_model(self, dataloader, epochs=100, learning_rate=0.001, equipment_type=None):
|
||||
"""训练模型"""
|
||||
try:
|
||||
# 获取输入特征维度
|
||||
sample_features, _ = next(iter(dataloader))
|
||||
input_size = sample_features.shape[1]
|
||||
|
||||
# 创建模型
|
||||
self.model = CostPredictionModel(input_size).to(self.device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||
# 设置确定性
|
||||
torch.manual_seed(42)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(42)
|
||||
np.random.seed(42)
|
||||
|
||||
self.model = CostPredictionModel(input_size, equipment_type).to(self.device)
|
||||
|
||||
if equipment_type == '火箭炮':
|
||||
# 火箭炮使用更保守和稳定的训练设置
|
||||
criterion = nn.SmoothL1Loss(beta=0.1) # 使用Huber损失,beta值较小
|
||||
learning_rate = 0.0003 # 更小的学习率
|
||||
weight_decay = 0.001 # 适中的权重衰减
|
||||
|
||||
# 使用AdamW优化器,更小的beta值
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
betas=(0.8, 0.9), # 更小的动量值
|
||||
eps=1e-8
|
||||
)
|
||||
|
||||
# 使用带预热的学习率调度
|
||||
num_steps = len(dataloader) * epochs
|
||||
warmup_steps = num_steps // 10 # 10%的预热步数
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return 0.5 * (1.0 + math.cos(
|
||||
math.pi * (current_step - warmup_steps) / float(max(1, num_steps - warmup_steps))
|
||||
))
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
else: # 巡飞弹保持原有设置
|
||||
# 巡飞弹使用更激进的训练设置
|
||||
criterion = nn.MSELoss()
|
||||
learning_rate = 0.001 # 较大的学习率
|
||||
weight_decay = 0.001 # 较小的权重衰减
|
||||
patience = 20 # 较短的耐心值
|
||||
|
||||
# 使用Adam优化器
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
betas=(0.9, 0.999)
|
||||
)
|
||||
|
||||
# 使用余弦退火学习率调度
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=epochs,
|
||||
eta_min=learning_rate * 0.01
|
||||
)
|
||||
|
||||
# 训练循环
|
||||
best_loss = float('inf')
|
||||
patience = 30
|
||||
patience_counter = 0
|
||||
best_model_state = None
|
||||
moving_avg_loss = None
|
||||
alpha = 0.9 # 移动平均系数
|
||||
|
||||
for epoch in range(epochs):
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
batch_count = 0
|
||||
|
||||
for batch_features, batch_targets in dataloader:
|
||||
# 移动数据到设备
|
||||
batch_features = batch_features.to(self.device)
|
||||
batch_targets = batch_targets.to(self.device)
|
||||
|
||||
@ -68,16 +209,57 @@ class ModelTrainer:
|
||||
loss = criterion(outputs, batch_targets.view(-1, 1))
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
optimizer.zero_grad(set_to_none=True) # 更高效的梯度清零
|
||||
loss.backward()
|
||||
|
||||
# 梯度裁剪
|
||||
if equipment_type == '火箭炮':
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
max_norm=0.1
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
if equipment_type == '火箭炮':
|
||||
scheduler.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
batch_count += 1
|
||||
|
||||
avg_loss = total_loss / batch_count
|
||||
|
||||
# 使用移动平均计算损失
|
||||
if moving_avg_loss is None:
|
||||
moving_avg_loss = avg_loss
|
||||
else:
|
||||
moving_avg_loss = alpha * moving_avg_loss + (1 - alpha) * avg_loss
|
||||
|
||||
# 早停检查使用移动平均损失
|
||||
if moving_avg_loss < best_loss:
|
||||
best_loss = moving_avg_loss
|
||||
patience_counter = 0
|
||||
best_model_state = {
|
||||
'state_dict': self.model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict() if equipment_type == '火箭炮' else None
|
||||
}
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
# 记录训练进度
|
||||
if (epoch + 1) % 10 == 0:
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
|
||||
logger.info(f'Epoch [{epoch+1}/{epochs}], Loss: {moving_avg_loss:.4f}, '
|
||||
f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
|
||||
|
||||
if patience_counter >= patience:
|
||||
logger.info(f"Early stopping triggered at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# 恢复最佳模型
|
||||
if best_model_state is not None:
|
||||
self.model.load_state_dict(best_model_state['state_dict'])
|
||||
optimizer.load_state_dict(best_model_state['optimizer'])
|
||||
if equipment_type == '火箭炮' and best_model_state['scheduler']:
|
||||
scheduler.load_state_dict(best_model_state['scheduler'])
|
||||
|
||||
return True
|
||||
|
||||
@ -85,62 +267,78 @@ class ModelTrainer:
|
||||
logger.error(f"Error in model training: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_model(self, equipment_type):
|
||||
def save_model(self, equipment_type, metrics=None):
|
||||
"""保存模型"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 转换装备类型为英文
|
||||
equipment_type_en = 'rocket' if equipment_type == '火箭炮' else 'missile'
|
||||
|
||||
# 保存模型
|
||||
model_path = f'{model_dir}/{equipment_type}_{timestamp}.pth'
|
||||
model_path = f'{model_dir}/{equipment_type_en}_{timestamp}.pth'
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'input_size': self.model.layers[0].in_features
|
||||
'input_size': self.model.equipment_net[0].in_features + 5,
|
||||
'manufacturer_net_state': self.model.manufacturer_net.state_dict(),
|
||||
'equipment_net_state': self.model.equipment_net.state_dict(),
|
||||
'combined_net_state': self.model.combined_net.state_dict()
|
||||
}, model_path)
|
||||
|
||||
# 保存标准化器
|
||||
scaler_path = f'{model_dir}/{equipment_type}_{timestamp}_scaler.pth'
|
||||
scaler_path = f'{model_dir}/{equipment_type_en}_{timestamp}_scaler.pth'
|
||||
torch.save({
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}, scaler_path)
|
||||
|
||||
# 获取评估指标
|
||||
r2 = metrics['validation']['r2'] if metrics and metrics.get('validation') else metrics['train']['r2']
|
||||
mae = metrics['validation']['mae'] if metrics and metrics.get('validation') else metrics['train']['mae']
|
||||
rmse = metrics['validation']['rmse'] if metrics and metrics.get('validation') else metrics['train']['rmse']
|
||||
|
||||
# 更新数据库
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 将所有模型设置为非激活
|
||||
# 将所有同类型模型设置为非激活, 除了 PLS 模型
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s
|
||||
""", (equipment_type,))
|
||||
WHERE equipment_type = %s AND model_type != %s
|
||||
""", (equipment_type, 'pls'))
|
||||
|
||||
# 保存新模型记录
|
||||
cursor.execute("""
|
||||
INSERT INTO trained_models (
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
scaler_path, training_date, is_active, created_by
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s)
|
||||
scaler_path, training_date, is_active, created_by,
|
||||
r2_score, mae, rmse
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
|
||||
""", (
|
||||
f"{equipment_type}_{timestamp}",
|
||||
'pytorch',
|
||||
equipment_type,
|
||||
model_path,
|
||||
scaler_path,
|
||||
'system'
|
||||
'system',
|
||||
r2,
|
||||
mae,
|
||||
rmse
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
logger.info(f"Model saved successfully: {model_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {str(e)}")
|
||||
return False
|
||||
|
||||
def load_model(self, equipment_type):
|
||||
def load_model(self, equipment_type, model_type):
|
||||
"""加载模型"""
|
||||
try:
|
||||
# 从数据库获取最新的激活模型
|
||||
@ -148,25 +346,39 @@ class ModelTrainer:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s AND is_active = TRUE
|
||||
WHERE equipment_type = %s AND model_type = %s AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
""", (equipment_type, model_type))
|
||||
model_record = cursor.fetchone()
|
||||
|
||||
if not model_record:
|
||||
return False
|
||||
raise ValueError(f"No trained model found for {equipment_type}")
|
||||
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_record['model_path'])
|
||||
input_size = checkpoint['input_size']
|
||||
self.model = CostPredictionModel(input_size).to(self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# 加载标准化器
|
||||
scalers = torch.load(model_record['scaler_path'])
|
||||
self.feature_scaler = scalers['feature_scaler']
|
||||
self.target_scaler = scalers['target_scaler']
|
||||
if model_record['model_type'] == 'pytorch':
|
||||
# 加载PyTorch模型
|
||||
checkpoint = torch.load(model_record['model_path'], encoding='latin1')
|
||||
input_size = checkpoint['input_size']
|
||||
|
||||
# 创建新模型实例
|
||||
self.model = CostPredictionModel(input_size, equipment_type).to(self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# 加载标准化器
|
||||
scalers = torch.load(model_record['scaler_path'], encoding='latin1')
|
||||
self.feature_scaler = scalers['feature_scaler']
|
||||
self.target_scaler = scalers['target_scaler']
|
||||
else:
|
||||
# 加载sklearn模型
|
||||
from joblib import load
|
||||
with open(model_record['model_path'], 'rb') as f:
|
||||
self.model = load(f)
|
||||
with open(model_record['scaler_path'], 'rb') as f:
|
||||
scalers = load(f)
|
||||
self.feature_scaler = scalers['feature_scaler']
|
||||
self.target_scaler = scalers['target_scaler']
|
||||
|
||||
logger.info(f"Model loaded successfully from {model_record['model_path']}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@ -178,12 +390,356 @@ class ModelTrainer:
|
||||
try:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# 转换为tensor并移动到正确的设备
|
||||
features_tensor = torch.FloatTensor(features).to(self.device)
|
||||
# 进行预测
|
||||
predictions = self.model(features_tensor)
|
||||
# 移回CPU并转换为numpy数组
|
||||
predictions = self.model(features_tensor) # 直接返回预测值
|
||||
return predictions.cpu().numpy()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
def fit_model(self, X_train, y_train, models, X_val=None, y_val=None, equipment_type=None):
|
||||
"""训练模型并返回评估结果"""
|
||||
try:
|
||||
logger.info(f"Starting model training for {equipment_type}")
|
||||
logger.info(f"Selected models: {models}")
|
||||
logger.info(f"Training data shape: {X_train.shape}")
|
||||
|
||||
all_metrics = {}
|
||||
best_model = None
|
||||
best_score = float('-inf')
|
||||
best_model_type = None # 添加变量记录最佳模型类型
|
||||
|
||||
# 训练所有选择的模型
|
||||
for model_type in models:
|
||||
logger.info(f"Training {model_type} model...")
|
||||
|
||||
if model_type == 'pls':
|
||||
# PLS模型单独处理,不参与最优模型评选
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
|
||||
# 使用较少的组件数来避免过拟合
|
||||
n_components = min(3, X_train.shape[1] // 5)
|
||||
model = PLSRegression(
|
||||
n_components=n_components,
|
||||
scale=True,
|
||||
max_iter=500,
|
||||
tol=1e-6
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# 评估PLS模型
|
||||
y_train_pred = model.predict(X_train).ravel()
|
||||
if X_val is not None:
|
||||
y_val_pred = model.predict(X_val).ravel()
|
||||
|
||||
# 将预测值转换回原始尺度
|
||||
y_train_pred_original = self.target_scaler.inverse_transform(y_train_pred.reshape(-1, 1)).ravel()
|
||||
y_train_original = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
|
||||
|
||||
train_metrics = {
|
||||
'r2': float(r2_score(y_train_original, y_train_pred_original)),
|
||||
'mae': float(mean_absolute_error(y_train_original, y_train_pred_original)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_train_original, y_train_pred_original)))
|
||||
}
|
||||
|
||||
val_metrics = None
|
||||
if X_val is not None:
|
||||
y_val_pred_original = self.target_scaler.inverse_transform(y_val_pred.reshape(-1, 1)).ravel()
|
||||
y_val_original = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
|
||||
|
||||
val_metrics = {
|
||||
'r2': float(r2_score(y_val_original, y_val_pred_original)),
|
||||
'mae': float(mean_absolute_error(y_val_original, y_val_pred_original)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_val_original, y_val_pred_original)))
|
||||
}
|
||||
|
||||
all_metrics[model_type] = {
|
||||
'train': train_metrics,
|
||||
'validation': val_metrics
|
||||
}
|
||||
|
||||
# 保存PLS模型,但不参与最优模型评选
|
||||
if equipment_type:
|
||||
self._save_sklearn_model(equipment_type, model_type, model, all_metrics[model_type])
|
||||
|
||||
continue # 跳过后续的最优模型评选
|
||||
|
||||
elif model_type == 'xgboost':
|
||||
import xgboost as xgb
|
||||
model = xgb.XGBRegressor(
|
||||
n_estimators=50,
|
||||
learning_rate=0.03,
|
||||
max_depth=3,
|
||||
min_child_weight=5,
|
||||
subsample=0.6,
|
||||
colsample_bytree=0.6,
|
||||
reg_alpha=0.5,
|
||||
reg_lambda=2.0,
|
||||
gamma=1,
|
||||
random_state=42
|
||||
)
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
elif model_type == 'lightgbm':
|
||||
import lightgbm as lgb
|
||||
model = lgb.LGBMRegressor(
|
||||
n_estimators=50,
|
||||
learning_rate=0.03,
|
||||
max_depth=3,
|
||||
num_leaves=8,
|
||||
subsample=0.6,
|
||||
colsample_bytree=0.6,
|
||||
reg_alpha=0.5,
|
||||
reg_lambda=2.0,
|
||||
min_child_samples=10,
|
||||
min_split_gain=1.0,
|
||||
random_state=42
|
||||
)
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
elif model_type == 'gbm':
|
||||
from sklearn.ensemble import GradientBoostingRegressor
|
||||
model = GradientBoostingRegressor(
|
||||
n_estimators=50,
|
||||
learning_rate=0.03,
|
||||
max_depth=3,
|
||||
min_samples_split=10,
|
||||
min_samples_leaf=5,
|
||||
subsample=0.6,
|
||||
min_impurity_decrease=0.01,
|
||||
random_state=42
|
||||
)
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
elif model_type == 'rf':
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
model = RandomForestRegressor(
|
||||
n_estimators=100,
|
||||
max_depth=4,
|
||||
min_samples_split=5,
|
||||
min_samples_leaf=3,
|
||||
max_features='sqrt',
|
||||
bootstrap=True,
|
||||
random_state=42
|
||||
)
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
elif model_type == 'pytorch':
|
||||
# 训练PyTorch模型
|
||||
train_dataset = EquipmentDataset(X_train, y_train)
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
training_success = self.train_model(train_loader)
|
||||
if not training_success:
|
||||
continue
|
||||
|
||||
# 评估模型性能
|
||||
if model_type == 'pytorch':
|
||||
with torch.no_grad():
|
||||
X_train_tensor = torch.FloatTensor(X_train).to(self.device)
|
||||
y_train_pred = self.model(X_train_tensor).cpu().numpy() # 直接获取输出
|
||||
|
||||
if X_val is not None:
|
||||
X_val_tensor = torch.FloatTensor(X_val).to(self.device)
|
||||
y_val_pred = self.model(X_val_tensor).cpu().numpy() # 直接获取输出
|
||||
else:
|
||||
# 使用训练好的模型进行预测
|
||||
y_train_pred = model.predict(X_train)
|
||||
if X_val is not None:
|
||||
if model_type == 'pls':
|
||||
y_val_pred = model.predict(X_val).ravel()
|
||||
|
||||
# 记录PLS的一些额外信息
|
||||
if hasattr(model, 'score'):
|
||||
train_r2 = model.score(X_train, y_train)
|
||||
val_r2 = model.score(X_val, y_val)
|
||||
logger.info(f"PLS built-in R² - Train: {train_r2:.4f}, Val: {val_r2:.4f}")
|
||||
|
||||
# 记录每个组件解释的方差比例
|
||||
if hasattr(model, 'explained_variance_ratio_'):
|
||||
logger.info("PLS explained variance ratios: " +
|
||||
", ".join([f"{v:.4f}" for v in model.explained_variance_ratio_]))
|
||||
else:
|
||||
y_val_pred = model.predict(X_val)
|
||||
|
||||
# 将测值转换回始尺度
|
||||
y_train_pred_original = self.target_scaler.inverse_transform(y_train_pred.reshape(-1, 1)).ravel()
|
||||
y_train_original = self.target_scaler.inverse_transform(y_train.reshape(-1, 1)).ravel()
|
||||
|
||||
train_metrics = {
|
||||
'r2': float(r2_score(y_train_original, y_train_pred_original)),
|
||||
'mae': float(mean_absolute_error(y_train_original, y_train_pred_original)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_train_original, y_train_pred_original)))
|
||||
}
|
||||
|
||||
val_metrics = None
|
||||
if X_val is not None and y_val is not None:
|
||||
y_val_pred_original = self.target_scaler.inverse_transform(y_val_pred.reshape(-1, 1)).ravel()
|
||||
y_val_original = self.target_scaler.inverse_transform(y_val.reshape(-1, 1)).ravel()
|
||||
|
||||
val_metrics = {
|
||||
'r2': float(r2_score(y_val_original, y_val_pred_original)),
|
||||
'mae': float(mean_absolute_error(y_val_original, y_val_pred_original)),
|
||||
'rmse': float(np.sqrt(mean_squared_error(y_val_original, y_val_pred_original)))
|
||||
}
|
||||
|
||||
all_metrics[model_type] = {
|
||||
'train': train_metrics,
|
||||
'validation': val_metrics
|
||||
}
|
||||
|
||||
# 更新最佳模型(不包括PLS)
|
||||
current_score = val_metrics['r2'] if val_metrics else train_metrics['r2']
|
||||
if model_type != 'pls' and current_score > best_score:
|
||||
best_score = current_score
|
||||
best_model = {
|
||||
'type': model_type,
|
||||
'r2': current_score,
|
||||
'mae': val_metrics['mae'] if val_metrics else train_metrics['mae'],
|
||||
'rmse': val_metrics['rmse'] if val_metrics else train_metrics['rmse']
|
||||
}
|
||||
best_model_type = model_type # 记录最佳模型类型
|
||||
|
||||
# 保存最佳模型实例(但不立即写入数据库)
|
||||
if model_type == 'pytorch':
|
||||
self.best_pytorch_model = self.model.state_dict() # 保存模型状态
|
||||
self.best_pytorch_metrics = all_metrics[model_type] # 保存指标
|
||||
else:
|
||||
self.best_model = model
|
||||
self.best_model_metrics = all_metrics[model_type]
|
||||
|
||||
# 单独保存PLS模型(不参与最佳模型评选)
|
||||
if model_type == 'pls' and equipment_type:
|
||||
self._save_sklearn_model(equipment_type, model_type, model, all_metrics[model_type])
|
||||
|
||||
# 在所有模型训练完成后,只保存最佳模型
|
||||
if best_model_type and equipment_type:
|
||||
if best_model_type == 'pytorch':
|
||||
# 恢复最佳PyTorch模型状态并保存
|
||||
self.model.load_state_dict(self.best_pytorch_model)
|
||||
self.save_model(equipment_type, self.best_pytorch_metrics)
|
||||
else:
|
||||
# 保存最佳sklearn模型
|
||||
self._save_sklearn_model(equipment_type, best_model_type, self.best_model, self.best_model_metrics)
|
||||
|
||||
return {
|
||||
'metrics': all_metrics,
|
||||
'feature_importance': None,
|
||||
'best_model': best_model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in model fitting: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
raise
|
||||
|
||||
def _calculate_feature_importance(self, X):
|
||||
"""计算特征重要性"""
|
||||
try:
|
||||
if self.model is None:
|
||||
return None
|
||||
|
||||
self.model.eval()
|
||||
feature_importance = np.zeros(X.shape[1])
|
||||
|
||||
# 使用特征扰动计算重要性
|
||||
with torch.no_grad():
|
||||
X_tensor = torch.FloatTensor(X).to(self.device)
|
||||
baseline_pred = self.model(X_tensor).cpu().numpy() # 直接获取预测值
|
||||
|
||||
for i in range(X.shape[1]):
|
||||
# 创建动后的特征
|
||||
X_perturbed = X.copy()
|
||||
X_perturbed[:, i] = np.random.permutation(X_perturbed[:, i])
|
||||
|
||||
# 预测并计算影响
|
||||
X_perturbed_tensor = torch.FloatTensor(X_perturbed).to(self.device)
|
||||
perturbed_pred = self.model(X_perturbed_tensor).cpu().numpy() # 直接获取预测值
|
||||
|
||||
# 特征重要性为预变化的平均绝对值
|
||||
feature_importance[i] = np.mean(np.abs(baseline_pred - perturbed_pred))
|
||||
|
||||
# 归一化特征重要性
|
||||
feature_importance = feature_importance / np.sum(feature_importance)
|
||||
|
||||
return feature_importance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating feature importance: {str(e)}")
|
||||
return None
|
||||
|
||||
def _save_sklearn_model(self, equipment_type, model_type, model, metrics=None):
|
||||
"""保存sklearn类型的模型"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_dir = 'models'
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 转换装备类型为英文
|
||||
equipment_type_en = 'rocket' if equipment_type == '火箭炮' else 'missile'
|
||||
|
||||
# 保存模型
|
||||
model_path = f'{model_dir}/{equipment_type_en}_{model_type}_{timestamp}.joblib'
|
||||
from joblib import dump
|
||||
dump(model, model_path)
|
||||
|
||||
# 保存标准化器
|
||||
scaler_path = f'{model_dir}/{equipment_type_en}_{model_type}_{timestamp}_scaler.joblib'
|
||||
dump({
|
||||
'feature_scaler': self.feature_scaler,
|
||||
'target_scaler': self.target_scaler
|
||||
}, scaler_path)
|
||||
|
||||
# 获取评估指标
|
||||
r2 = metrics['validation']['r2'] if metrics and metrics.get('validation') else metrics['train']['r2']
|
||||
mae = metrics['validation']['mae'] if metrics and metrics.get('validation') else metrics['train']['mae']
|
||||
rmse = metrics['validation']['rmse'] if metrics and metrics.get('validation') else metrics['train']['rmse']
|
||||
|
||||
# 更新数据库
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 将同类型的其他模型设置为非激活
|
||||
if model_type != 'pls':
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type != %s
|
||||
""", (equipment_type, 'pls'))
|
||||
else:
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s AND model_type = %s
|
||||
""", (equipment_type, 'pls'))
|
||||
|
||||
# 保存新模型记录
|
||||
cursor.execute("""
|
||||
INSERT INTO trained_models (
|
||||
model_name, model_type, equipment_type, model_path,
|
||||
scaler_path, training_date, is_active, created_by,
|
||||
r2_score, mae, rmse
|
||||
) VALUES (%s, %s, %s, %s, %s, NOW(), TRUE, %s, %s, %s, %s)
|
||||
""", (
|
||||
f"{equipment_type}_{model_type}_{timestamp}",
|
||||
model_type,
|
||||
equipment_type,
|
||||
model_path,
|
||||
scaler_path,
|
||||
'system',
|
||||
r2,
|
||||
mae,
|
||||
rmse
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"Model {model_type} saved successfully: {model_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving sklearn model: {str(e)}")
|
||||
return False
|
||||
231
src/routes.py
231
src/routes.py
@ -12,6 +12,7 @@ import os
|
||||
from .data_preparation import DataPreparation
|
||||
from .model_trainer import ModelTrainer
|
||||
from .logger import setup_logger
|
||||
import torch
|
||||
|
||||
# 创建蓝图
|
||||
api_bp = Blueprint('api', __name__)
|
||||
@ -58,49 +59,45 @@ def index():
|
||||
})
|
||||
|
||||
@api_bp.route('/predict', methods=['POST'])
|
||||
def predict_cost():
|
||||
"""
|
||||
成本预测接口
|
||||
"""
|
||||
def predict():
|
||||
"""使用最优机器学习模型进行预测"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
logger.info(f"Received prediction request for equipment type: {data.get('type')}")
|
||||
equipment_type = data.get('type')
|
||||
logger.info(f"Received prediction request for equipment type: {equipment_type}")
|
||||
|
||||
# 验证装备类型
|
||||
if 'type' not in data:
|
||||
return jsonify({'error': 'Equipment type is required'}), 400
|
||||
|
||||
# 预测成本
|
||||
predictor = CostPredictor()
|
||||
result = predictor.predict(data)
|
||||
|
||||
# 获取当前使用的模型信息
|
||||
# 获取最新的激活模型(非PLS模型)
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("""
|
||||
SELECT model_type, model_name, r2_score, mae, rmse
|
||||
FROM trained_models
|
||||
WHERE equipment_type = %s AND model_type != 'pls' AND is_active = TRUE
|
||||
LIMIT 1
|
||||
""", (data['type'],))
|
||||
model_info = cursor.fetchone()
|
||||
|
||||
# 在结果中添加模型信息
|
||||
result.update({
|
||||
'model_info': {
|
||||
'type': model_info['model_type'],
|
||||
'name': model_info['model_name'],
|
||||
'r2_score': float(model_info['r2_score']),
|
||||
'mae': float(model_info['mae']),
|
||||
'rmse': float(model_info['rmse'])
|
||||
}
|
||||
})
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s
|
||||
AND model_type != 'pls' # 明确排除PLS模型
|
||||
AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
logger.info(f"Prediction completed: {result}")
|
||||
return jsonify(result)
|
||||
if not model:
|
||||
return jsonify({'error': '未找到可用的模型'}), 404
|
||||
|
||||
# 使用普通预测方法
|
||||
predictor = CostPredictor()
|
||||
prediction = predictor.predict(data, model) # 使用普通predict方法
|
||||
|
||||
# 返回预测结果
|
||||
return jsonify({
|
||||
'model_info': {
|
||||
'type': model['model_type'], # 使用数据库中的模型类型
|
||||
'name': model['model_name']
|
||||
},
|
||||
'predicted_cost': prediction['predicted_cost'],
|
||||
'confidence_interval': prediction['confidence_interval']
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@api_bp.route('/analyze-features', methods=['POST'])
|
||||
@ -180,27 +177,7 @@ def analyze_features():
|
||||
""", (dataset_id,))
|
||||
|
||||
equipment_data = cursor.fetchall()
|
||||
|
||||
# 添加数据检查日志
|
||||
logger.info(f"Total records found: {len(equipment_data)}")
|
||||
if equipment_data:
|
||||
# 检查第一条记录的所有字段
|
||||
first_record = equipment_data[0]
|
||||
logger.info("First record details:")
|
||||
for key, value in first_record.items():
|
||||
logger.info(f"{key}: {value}")
|
||||
|
||||
# 检查所有记录的 max_range_km 字段
|
||||
logger.info("Checking max_range_km for all records:")
|
||||
for item in equipment_data:
|
||||
logger.info(f"Equipment: {item['name']}")
|
||||
logger.info(f" max_range_km: {item.get('max_range_km')}")
|
||||
logger.info(f" type: {item['type']}")
|
||||
if item['type'] == '火箭炮':
|
||||
logger.info(f" rocket_artillery_params fields:")
|
||||
for key in ['firing_angle_horizontal', 'rocket_length_m', 'rate_of_fire']:
|
||||
logger.info(f" {key}: {item.get(key)}")
|
||||
|
||||
|
||||
# 提取特征和目标值
|
||||
analyzer = FeatureAnalysis()
|
||||
feature_names = analyzer.get_equipment_specific_features(equipment_data[0]['type'])
|
||||
@ -270,9 +247,7 @@ def analyze_features():
|
||||
|
||||
@api_bp.route('/train', methods=['POST'])
|
||||
def train_model():
|
||||
"""
|
||||
训练模型
|
||||
"""
|
||||
"""训练模型"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
logger.info(f"Starting model training for {data.get('type')}")
|
||||
@ -289,54 +264,65 @@ def train_model():
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 获取训练集数据
|
||||
# 获取训练集数据(包含生产商信息)
|
||||
if equipment_type == '火箭炮':
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||||
SELECT e.*, cp.*, rap.*, cd.actual_cost,
|
||||
m.tech_level, m.scale_level, m.supply_chain_level,
|
||||
m.id as manufacturer_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (train_dataset_id,))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||||
SELECT e.*, cp.*, lmp.*, cd.actual_cost,
|
||||
m.tech_level, m.scale_level, m.supply_chain_level,
|
||||
m.id as manufacturer_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (train_dataset_id,))
|
||||
|
||||
train_data = cursor.fetchall()
|
||||
|
||||
# 获取验证集据(如果有)
|
||||
validation_data = None
|
||||
# 获取验证集数据
|
||||
if validation_dataset_id:
|
||||
if equipment_type == '火箭炮':
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, rap.*, cd.actual_cost
|
||||
SELECT e.*, cp.*, rap.*, cd.actual_cost,
|
||||
m.tech_level, m.scale_level, m.supply_chain_level,
|
||||
m.id as manufacturer_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN rocket_artillery_params rap ON e.id = rap.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (validation_dataset_id,))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT e.*, cp.*, lmp.*, cd.actual_cost
|
||||
SELECT e.*, cp.*, lmp.*, cd.actual_cost,
|
||||
m.tech_level, m.scale_level, m.supply_chain_level,
|
||||
m.id as manufacturer_id
|
||||
FROM equipments e
|
||||
JOIN dataset_equipments de ON e.id = de.equipment_id
|
||||
LEFT JOIN common_params cp ON e.id = cp.equipment_id
|
||||
LEFT JOIN loitering_munition_params lmp ON e.id = lmp.equipment_id
|
||||
LEFT JOIN cost_data cd ON e.id = cd.equipment_id
|
||||
LEFT JOIN manufacturers m ON e.manufacturer_id = m.id
|
||||
WHERE de.dataset_id = %s
|
||||
AND cd.actual_cost IS NOT NULL
|
||||
""", (validation_dataset_id,))
|
||||
@ -351,7 +337,7 @@ def train_model():
|
||||
# 准备训练数据
|
||||
train_prepared = data_processor.prepare_training_data(train_data, equipment_type)
|
||||
|
||||
# 准备验数据(如果有)
|
||||
# 准备验证数据(如果有)
|
||||
validation_prepared = None
|
||||
if validation_data:
|
||||
validation_prepared = data_processor.prepare_validation_data(
|
||||
@ -369,14 +355,14 @@ def train_model():
|
||||
model_trainer.feature_scaler = train_prepared['feature_scaler']
|
||||
model_trainer.target_scaler = train_prepared['target_scaler']
|
||||
|
||||
# 执行训练,传入 equipment_type 参数
|
||||
# 执行训练,传入equipment_type参数
|
||||
training_result = model_trainer.fit_model(
|
||||
train_prepared['X'],
|
||||
train_prepared['y'],
|
||||
models,
|
||||
validation_prepared['X'] if validation_prepared else None,
|
||||
validation_prepared['y'] if validation_prepared else None,
|
||||
equipment_type=equipment_type
|
||||
equipment_type=equipment_type # 添加这个参数
|
||||
)
|
||||
|
||||
return jsonify(training_result)
|
||||
@ -588,65 +574,43 @@ def get_db_connection():
|
||||
|
||||
@api_bp.route('/pls/predict', methods=['POST'])
|
||||
def pls_predict():
|
||||
"""
|
||||
PLS回归预测接口
|
||||
"""
|
||||
"""使用PLS模型进行预测"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
logger.info(f"Received PLS prediction request for equipment type: {data.get('type')}")
|
||||
equipment_type = data.get('type')
|
||||
|
||||
# 验证装备类型
|
||||
if 'type' not in data:
|
||||
return jsonify({'error': 'Equipment type is required'}), 400
|
||||
|
||||
# 使用 ModelTrainer 中的 PLS 模型行预测
|
||||
trainer = ModelTrainer()
|
||||
if not trainer.load_model(data['type'], model_type='pls'): # 指定加载 PLS 模型
|
||||
return jsonify({'error': '未找到可用的模型'}), 404
|
||||
|
||||
# 准备特征数据
|
||||
feature_analyzer = FeatureAnalysis()
|
||||
features = feature_analyzer.get_equipment_specific_features(data['type'])
|
||||
X = np.array([[data.get(feature) for feature in features]])
|
||||
|
||||
# 预测
|
||||
result = trainer.predict(X)
|
||||
|
||||
# 计算置信区间
|
||||
confidence_interval = trainer._calculate_confidence_interval(result[0])
|
||||
|
||||
# 获取模型信息
|
||||
# 获取最新的PLS模型
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("""
|
||||
SELECT model_type, model_name, r2_score, mae, rmse
|
||||
FROM trained_models
|
||||
WHERE equipment_type = %s AND model_type = 'pls' AND is_active = TRUE
|
||||
LIMIT 1
|
||||
""", (data['type'],))
|
||||
model_info = cursor.fetchone()
|
||||
|
||||
# 确保返回的数据可以序列化为JSON
|
||||
response = {
|
||||
'predicted_cost': float(result[0]),
|
||||
'model_info': {
|
||||
'type': model_info['model_type'],
|
||||
'name': model_info['model_name'],
|
||||
'r2_score': model_info['r2_score'],
|
||||
'mae': model_info['mae'],
|
||||
'rmse': model_info['rmse']
|
||||
},
|
||||
'confidence_interval': {
|
||||
'lower': float(confidence_interval[0]),
|
||||
'upper': float(confidence_interval[1])
|
||||
}
|
||||
}
|
||||
SELECT * FROM trained_models
|
||||
WHERE equipment_type = %s
|
||||
AND model_type = 'pls' # 只选择PLS模型
|
||||
AND is_active = TRUE
|
||||
ORDER BY training_date DESC LIMIT 1
|
||||
""", (equipment_type,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
logger.info(f"PLS prediction completed: {response}")
|
||||
return jsonify(response)
|
||||
if not model:
|
||||
return jsonify({'error': '未找到可用的PLS模型'}), 404
|
||||
|
||||
# 使用普通predict方法,传入模型信息
|
||||
predictor = CostPredictor()
|
||||
prediction = predictor.predict(data, model) # 传入model参数
|
||||
|
||||
# 返回预测结果
|
||||
return jsonify({
|
||||
'model_info': {
|
||||
'type': 'pls',
|
||||
'name': model['model_name']
|
||||
},
|
||||
'predicted_cost': prediction['predicted_cost'],
|
||||
'confidence_interval': prediction['confidence_interval']
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in PLS prediction: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@api_bp.route('/data/import', methods=['POST'])
|
||||
@ -852,7 +816,7 @@ def get_equipment_details(id):
|
||||
@api_bp.route('/datasets', methods=['GET'])
|
||||
def get_datasets():
|
||||
"""
|
||||
获取数集列表
|
||||
获取数据集列表
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
@ -868,13 +832,20 @@ def get_datasets():
|
||||
""")
|
||||
datasets = cursor.fetchall()
|
||||
|
||||
# 理装备名称列表
|
||||
# 整理装备名称列表
|
||||
for dataset in datasets:
|
||||
if dataset['equipment_names']:
|
||||
dataset['equipment_names'] = dataset['equipment_names'].split(',')
|
||||
else:
|
||||
dataset['equipment_names'] = []
|
||||
|
||||
# 格式化时间和数值字段
|
||||
for dataset in datasets:
|
||||
if dataset['created_at']:
|
||||
dataset['created_at'] = dataset['created_at'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
if dataset['updated_at']:
|
||||
dataset['updated_at'] = dataset['updated_at'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
return jsonify(datasets)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting datasets: {str(e)}")
|
||||
@ -930,6 +901,7 @@ def get_dataset(id):
|
||||
}
|
||||
|
||||
dataset['equipment'] = equipment
|
||||
|
||||
return jsonify(dataset)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dataset: {str(e)}")
|
||||
@ -1067,7 +1039,7 @@ def delete_dataset(id):
|
||||
@api_bp.route('/models/<equipment_type>/latest', methods=['GET'])
|
||||
def get_latest_model(equipment_type):
|
||||
"""
|
||||
获取最新训练的型信息
|
||||
获取最新训练的模型信息
|
||||
"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
@ -1087,9 +1059,7 @@ def get_latest_model(equipment_type):
|
||||
|
||||
@api_bp.route('/models', methods=['GET'])
|
||||
def get_models():
|
||||
"""
|
||||
获取模型列表
|
||||
"""
|
||||
"""获取模型列表"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
@ -1100,8 +1070,11 @@ def get_models():
|
||||
|
||||
models = cursor.fetchall()
|
||||
|
||||
# 确保数值型字段是 float
|
||||
# 格式化时间和数值字段
|
||||
for model in models:
|
||||
# 将数据库中的datetime转换为ISO格式字符串
|
||||
if model['training_date']:
|
||||
model['training_date'] = model['training_date'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
if model['r2_score'] is not None:
|
||||
model['r2_score'] = float(model['r2_score'])
|
||||
if model['mae'] is not None:
|
||||
@ -1121,16 +1094,14 @@ def get_models():
|
||||
|
||||
@api_bp.route('/models/<int:id>/activate', methods=['POST'])
|
||||
def activate_model(id):
|
||||
"""
|
||||
激活定的模型
|
||||
"""
|
||||
"""激活指定的模型"""
|
||||
try:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor = conn.cursor(dictionary=True) # 使用字典游标
|
||||
|
||||
# 获取模型信息
|
||||
cursor.execute("""
|
||||
SELECT equipment_type FROM trained_models
|
||||
SELECT equipment_type, model_type FROM trained_models
|
||||
WHERE id = %s
|
||||
""", (id,))
|
||||
model = cursor.fetchone()
|
||||
@ -1142,8 +1113,8 @@ def activate_model(id):
|
||||
cursor.execute("""
|
||||
UPDATE trained_models
|
||||
SET is_active = FALSE
|
||||
WHERE equipment_type = %s
|
||||
""", (model[0],))
|
||||
WHERE equipment_type = %s AND model_type = %s
|
||||
""", (model['equipment_type'], model['model_type']))
|
||||
|
||||
# 激活指定模型
|
||||
cursor.execute("""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user