EG/plugins/user/network_sync/prediction/prediction_manager.py
2025-12-12 16:16:15 +08:00

791 lines
29 KiB
Python

"""
预测管理模块
负责处理网络延迟下的对象状态预测
"""
import time
from typing import Dict, Any, List, Optional
import threading
import copy
class PredictionManager:
"""
预测管理器
负责处理网络延迟下的对象状态预测
"""
def __init__(self, plugin):
"""
初始化预测管理器
Args:
plugin: 网络同步插件实例
"""
self.plugin = plugin
self.enabled = False
self.initialized = False
# 预测配置
self.prediction_config = {
'enable_prediction': True,
'enable_interpolation': True,
'enable_extrapolation': True,
'max_prediction_time': 0.5, # 最大预测时间(秒)
'prediction_smoothing': 0.1, # 预测平滑系数
'correction_speed': 5.0, # 状态校正速度
'velocity_damping': 0.95, # 速度阻尼
'position_threshold': 0.01, # 位置校正阈值
'rotation_threshold': 0.1, # 旋转校正阈值
'enable_snap_correction': True, # 启用快照校正
'snap_threshold': 1.0 # 快照阈值
}
# 预测对象管理
self.predicted_objects = {} # 预测对象状态
self.object_history = {} # 对象历史状态
self.object_velocities = {} # 对象速度状态
# 预测状态
self.prediction_state = {
'last_prediction_time': 0.0,
'predicted_objects_count': 0,
'corrections_applied': 0,
'snap_corrections': 0
}
# 预测统计
self.prediction_stats = {
'predictions_made': 0,
'corrections_applied': 0,
'snap_corrections': 0,
'prediction_errors': 0,
'average_error': 0.0,
'max_error': 0.0
}
# 线程锁
self.prediction_lock = threading.RLock()
# 回调函数
self.prediction_callbacks = {
'prediction_updated': [],
'correction_applied': [],
'snap_correction': [],
'prediction_error': []
}
# 时间戳记录
self.last_prediction_update = 0.0
self.last_correction_apply = 0.0
print("✓ 预测管理器已创建")
def initialize(self) -> bool:
"""
初始化预测管理器
Returns:
是否初始化成功
"""
try:
self.initialized = True
print("✓ 预测管理器初始化完成")
return True
except Exception as e:
print(f"✗ 预测管理器初始化失败: {e}")
import traceback
traceback.print_exc()
return False
def enable(self) -> bool:
"""
启用预测管理器
Returns:
是否启用成功
"""
try:
if not self.initialized:
print("✗ 预测管理器未初始化")
return False
self.enabled = True
print("✓ 预测管理器已启用")
return True
except Exception as e:
print(f"✗ 预测管理器启用失败: {e}")
import traceback
traceback.print_exc()
return False
def disable(self):
"""禁用预测管理器"""
try:
self.enabled = False
print("✓ 预测管理器已禁用")
except Exception as e:
print(f"✗ 预测管理器禁用失败: {e}")
import traceback
traceback.print_exc()
def finalize(self):
"""清理预测管理器资源"""
try:
self.disable()
self.predicted_objects.clear()
self.object_history.clear()
self.object_velocities.clear()
self.prediction_callbacks.clear()
self.initialized = False
print("✓ 预测管理器资源已清理")
except Exception as e:
print(f"✗ 预测管理器资源清理失败: {e}")
import traceback
traceback.print_exc()
def update(self, dt: float):
"""
更新预测管理器状态
Args:
dt: 时间增量
"""
try:
if not self.enabled:
return
current_time = time.time()
self.last_prediction_update = current_time
# 更新预测
if self.prediction_config['enable_prediction']:
self._update_predictions(current_time, dt)
except Exception as e:
print(f"✗ 预测管理器更新失败: {e}")
import traceback
traceback.print_exc()
def _update_predictions(self, current_time: float, dt: float):
"""
更新预测状态
Args:
current_time: 当前时间
dt: 时间增量
"""
try:
with self.prediction_lock:
for obj_id, pred_data in self.predicted_objects.items():
# 获取服务器时间
server_time = self.plugin.clock_sync.get_server_time() if self.plugin.clock_sync else current_time
# 计算预测时间
prediction_time = server_time + self.plugin.network_manager.get_connection_state().get('latency', 0.0) \
if self.plugin.network_manager else server_time
# 限制最大预测时间
max_prediction_time = server_time + self.prediction_config['max_prediction_time']
prediction_time = min(prediction_time, max_prediction_time)
# 执行预测
if self.prediction_config['enable_extrapolation']:
self._extrapolate_object(obj_id, pred_data, prediction_time, current_time, dt)
# 更新统计信息
self.prediction_stats['predictions_made'] += 1
# 更新预测对象计数
self.prediction_state['predicted_objects_count'] = len(self.predicted_objects)
self.prediction_state['last_prediction_time'] = current_time
except Exception as e:
print(f"✗ 预测更新失败: {e}")
self.prediction_stats['prediction_errors'] += 1
def _extrapolate_object(self, obj_id: str, pred_data: Dict[str, Any],
prediction_time: float, current_time: float, dt: float):
"""
外推对象状态
Args:
obj_id: 对象ID
pred_data: 预测数据
prediction_time: 预测时间
current_time: 当前时间
dt: 时间增量
"""
try:
# 获取对象当前状态
if obj_id not in self.object_history or not self.object_history[obj_id]:
return
current_state = self.object_history[obj_id][-1]
current_timestamp = current_state.get('timestamp', current_time)
# 计算时间差
time_diff = prediction_time - current_timestamp
# 如果时间差太大,不进行外推
if abs(time_diff) > self.prediction_config['max_prediction_time']:
return
# 获取对象速度
velocity = self.object_velocities.get(obj_id, {
'position': [0.0, 0.0, 0.0],
'rotation': [0.0, 0.0, 0.0],
'scale': [0.0, 0.0, 0.0]
})
# 外推位置
if 'position' in current_state:
current_pos = current_state['position']
velocity_pos = velocity['position']
predicted_pos = [
current_pos[i] + velocity_pos[i] * time_diff
for i in range(min(len(current_pos), len(velocity_pos)))
]
pred_data['position'] = predicted_pos
# 外推旋转
if 'rotation' in current_state:
current_rot = current_state['rotation']
velocity_rot = velocity['rotation']
predicted_rot = [
current_rot[i] + velocity_rot[i] * time_diff
for i in range(min(len(current_rot), len(velocity_rot)))
]
pred_data['rotation'] = predicted_rot
# 外推缩放
if 'scale' in current_state:
current_scale = current_state['scale']
velocity_scale = velocity['scale']
predicted_scale = [
current_scale[i] + velocity_scale[i] * time_diff
for i in range(min(len(current_scale), len(velocity_scale)))
]
pred_data['scale'] = predicted_scale
# 应用阻尼
self._apply_velocity_damping(obj_id, dt)
# 触发预测更新回调
self._trigger_prediction_callback('prediction_updated', {
'object_id': obj_id,
'predicted_state': pred_data.copy(),
'prediction_time': prediction_time,
'time_diff': time_diff
})
except Exception as e:
print(f"✗ 对象外推失败: {e}")
self.prediction_stats['prediction_errors'] += 1
def _apply_velocity_damping(self, obj_id: str, dt: float):
"""
应用速度阻尼
Args:
obj_id: 对象ID
dt: 时间增量
"""
try:
if obj_id in self.object_velocities:
damping = self.prediction_config['velocity_damping']
velocity = self.object_velocities[obj_id]
for key in ['position', 'rotation', 'scale']:
if key in velocity:
velocity[key] = [
v * (damping ** dt)
for v in velocity[key]
]
except Exception as e:
print(f"✗ 速度阻尼应用失败: {e}")
def register_object(self, obj_id: str, initial_state: Dict[str, Any]) -> bool:
"""
注册预测对象
Args:
obj_id: 对象ID
initial_state: 初始状态
Returns:
是否注册成功
"""
try:
with self.prediction_lock:
# 初始化预测数据
self.predicted_objects[obj_id] = initial_state.copy()
# 初始化历史记录
self.object_history[obj_id] = [{
'state': initial_state.copy(),
'timestamp': time.time()
}]
# 初始化速度数据
self.object_velocities[obj_id] = {
'position': [0.0, 0.0, 0.0],
'rotation': [0.0, 0.0, 0.0],
'scale': [0.0, 0.0, 0.0]
}
print(f"✓ 预测对象已注册: {obj_id}")
return True
except Exception as e:
print(f"✗ 预测对象注册失败: {e}")
return False
def unregister_object(self, obj_id: str) -> bool:
"""
注销预测对象
Args:
obj_id: 对象ID
Returns:
是否注销成功
"""
try:
with self.prediction_lock:
if obj_id in self.predicted_objects:
del self.predicted_objects[obj_id]
if obj_id in self.object_history:
del self.object_history[obj_id]
if obj_id in self.object_velocities:
del self.object_velocities[obj_id]
print(f"✓ 预测对象已注销: {obj_id}")
return True
except Exception as e:
print(f"✗ 预测对象注销失败: {e}")
return False
def update_object_state(self, obj_id: str, new_state: Dict[str, Any], timestamp: float = None):
"""
更新对象状态
Args:
obj_id: 对象ID
new_state: 新状态
timestamp: 时间戳
"""
try:
if timestamp is None:
timestamp = time.time()
with self.prediction_lock:
# 添加到历史记录
if obj_id not in self.object_history:
self.object_history[obj_id] = []
self.object_history[obj_id].append({
'state': new_state.copy(),
'timestamp': timestamp
})
# 保持历史记录大小
if len(self.object_history[obj_id]) > 20: # 保持最近20个状态
self.object_history[obj_id].pop(0)
# 计算速度(如果可能)
self._calculate_object_velocity(obj_id, new_state, timestamp)
# 更新预测对象状态
if obj_id in self.predicted_objects:
self.predicted_objects[obj_id].update(new_state)
except Exception as e:
print(f"✗ 对象状态更新失败: {e}")
def _calculate_object_velocity(self, obj_id: str, new_state: Dict[str, Any], timestamp: float):
"""
计算对象速度
Args:
obj_id: 对象ID
new_state: 新状态
timestamp: 时间戳
"""
try:
# 需要至少两个历史状态来计算速度
if obj_id not in self.object_history or len(self.object_history[obj_id]) < 2:
return
history = self.object_history[obj_id]
current_entry = history[-1]
previous_entry = history[-2]
current_state = current_entry['state']
previous_state = previous_entry['state']
current_time = current_entry['timestamp']
previous_time = previous_entry['timestamp']
time_diff = current_time - previous_time
if time_diff <= 0:
return
# 计算位置速度
if 'position' in new_state and 'position' in previous_state:
current_pos = current_state['position']
previous_pos = previous_state['position']
velocity_pos = [
(current_pos[i] - previous_pos[i]) / time_diff
for i in range(min(len(current_pos), len(previous_pos)))
]
self.object_velocities[obj_id]['position'] = velocity_pos
# 计算旋转速度
if 'rotation' in new_state and 'rotation' in previous_state:
current_rot = current_state['rotation']
previous_rot = previous_state['rotation']
velocity_rot = [
(current_rot[i] - previous_rot[i]) / time_diff
for i in range(min(len(current_rot), len(previous_rot)))
]
self.object_velocities[obj_id]['rotation'] = velocity_rot
# 计算缩放速度
if 'scale' in new_state and 'scale' in previous_state:
current_scale = current_state['scale']
previous_scale = previous_state['scale']
velocity_scale = [
(current_scale[i] - previous_scale[i]) / time_diff
for i in range(min(len(current_scale), len(previous_scale)))
]
self.object_velocities[obj_id]['scale'] = velocity_scale
except Exception as e:
print(f"✗ 对象速度计算失败: {e}")
def apply_correction(self, obj_id: str, server_state: Dict[str, Any], server_time: float):
"""
应用服务器校正
Args:
obj_id: 对象ID
server_state: 服务器状态
server_time: 服务器时间
"""
try:
if not self.prediction_config['enable_prediction']:
return
with self.prediction_lock:
# 检查对象是否存在
if obj_id not in self.predicted_objects:
return
predicted_state = self.predicted_objects[obj_id]
# 计算位置差异
position_error = 0.0
if 'position' in server_state and 'position' in predicted_state:
pos1 = server_state['position']
pos2 = predicted_state['position']
position_error = self._calculate_distance(pos1, pos2)
# 计算旋转差异
rotation_error = 0.0
if 'rotation' in server_state and 'rotation' in predicted_state:
rot1 = server_state['rotation']
rot2 = predicted_state['rotation']
rotation_error = self._calculate_rotation_difference(rot1, rot2)
# 更新平均误差统计
total_error = position_error + rotation_error
if self.prediction_stats['predictions_made'] > 0:
self.prediction_stats['average_error'] = (
(self.prediction_stats['average_error'] * (self.prediction_stats['predictions_made'] - 1) + total_error) /
self.prediction_stats['predictions_made']
)
else:
self.prediction_stats['average_error'] = total_error
self.prediction_stats['max_error'] = max(self.prediction_stats['max_error'], total_error)
# 检查是否需要快照校正
if (self.prediction_config['enable_snap_correction'] and
total_error > self.prediction_config['snap_threshold']):
# 快照校正
self.predicted_objects[obj_id].update(server_state)
self.prediction_stats['snap_corrections'] += 1
self.prediction_state['snap_corrections'] += 1
# 触发快照校正回调
self._trigger_prediction_callback('snap_correction', {
'object_id': obj_id,
'server_state': server_state,
'predicted_state': predicted_state,
'error': total_error
})
else:
# 平滑校正
self._smooth_correct_object(obj_id, server_state, server_time)
self.prediction_stats['corrections_applied'] += 1
self.prediction_state['corrections_applied'] += 1
# 触发校正应用回调
self._trigger_prediction_callback('correction_applied', {
'object_id': obj_id,
'server_state': server_state,
'predicted_state': predicted_state,
'error': total_error
})
# 更新对象历史
self.update_object_state(obj_id, server_state, server_time)
except Exception as e:
print(f"✗ 校正应用失败: {e}")
self.prediction_stats['prediction_errors'] += 1
self._trigger_prediction_callback('prediction_error', {
'object_id': obj_id,
'error': str(e)
})
def _smooth_correct_object(self, obj_id: str, server_state: Dict[str, Any], server_time: float):
"""
平滑校正对象
Args:
obj_id: 对象ID
server_state: 服务器状态
server_time: 服务器时间
"""
try:
if obj_id not in self.predicted_objects:
return
predicted_state = self.predicted_objects[obj_id]
correction_speed = self.prediction_config['correction_speed']
# 平滑校正位置
if 'position' in server_state and 'position' in predicted_state:
server_pos = server_state['position']
predicted_pos = predicted_state['position']
corrected_pos = [
predicted_pos[i] + (server_pos[i] - predicted_pos[i]) * correction_speed * 0.016 # 假设60FPS
for i in range(min(len(server_pos), len(predicted_pos)))
]
self.predicted_objects[obj_id]['position'] = corrected_pos
# 平滑校正旋转
if 'rotation' in server_state and 'rotation' in predicted_state:
server_rot = server_state['rotation']
predicted_rot = predicted_state['rotation']
corrected_rot = [
predicted_rot[i] + (server_rot[i] - predicted_rot[i]) * correction_speed * 0.016
for i in range(min(len(server_rot), len(predicted_rot)))
]
self.predicted_objects[obj_id]['rotation'] = corrected_rot
# 平滑校正缩放
if 'scale' in server_state and 'scale' in predicted_state:
server_scale = server_state['scale']
predicted_scale = predicted_state['scale']
corrected_scale = [
predicted_scale[i] + (server_scale[i] - predicted_scale[i]) * correction_speed * 0.016
for i in range(min(len(server_scale), len(predicted_scale)))
]
self.predicted_objects[obj_id]['scale'] = corrected_scale
except Exception as e:
print(f"✗ 平滑校正失败: {e}")
def _calculate_distance(self, pos1: List[float], pos2: List[float]) -> float:
"""
计算两点间距离
Args:
pos1: 位置1
pos2: 位置2
Returns:
距离
"""
try:
if len(pos1) >= 3 and len(pos2) >= 3:
dx = pos1[0] - pos2[0]
dy = pos1[1] - pos2[1]
dz = pos1[2] - pos2[2]
return (dx*dx + dy*dy + dz*dz) ** 0.5
elif len(pos1) >= 2 and len(pos2) >= 2:
dx = pos1[0] - pos2[0]
dy = pos1[1] - pos2[1]
return (dx*dx + dy*dy) ** 0.5
else:
return 0.0
except Exception as e:
print(f"✗ 距离计算失败: {e}")
return 0.0
def _calculate_rotation_difference(self, rot1: List[float], rot2: List[float]) -> float:
"""
计算旋转差异
Args:
rot1: 旋转1
rot2: 旋转2
Returns:
旋转差异
"""
try:
# 简单实现:计算欧拉角差异
diff = 0.0
for i in range(min(len(rot1), len(rot2))):
diff += abs(rot1[i] - rot2[i])
return diff
except Exception as e:
print(f"✗ 旋转差异计算失败: {e}")
return 0.0
def get_predicted_state(self, obj_id: str) -> Optional[Dict[str, Any]]:
"""
获取预测状态
Args:
obj_id: 对象ID
Returns:
预测状态或None
"""
try:
with self.prediction_lock:
return self.predicted_objects.get(obj_id, {}).copy()
except Exception as e:
print(f"✗ 预测状态获取失败: {e}")
return None
def get_prediction_stats(self) -> Dict[str, Any]:
"""
获取预测统计信息
Returns:
预测统计字典
"""
return self.prediction_stats.copy()
def reset_prediction_stats(self):
"""重置预测统计信息"""
try:
self.prediction_stats = {
'predictions_made': 0,
'corrections_applied': 0,
'snap_corrections': 0,
'prediction_errors': 0,
'average_error': 0.0,
'max_error': 0.0
}
print("✓ 预测统计信息已重置")
except Exception as e:
print(f"✗ 预测统计信息重置失败: {e}")
def set_prediction_config(self, config: Dict[str, Any]) -> bool:
"""
设置预测配置
Args:
config: 预测配置字典
Returns:
是否设置成功
"""
try:
self.prediction_config.update(config)
print(f"✓ 预测配置已更新: {self.prediction_config}")
return True
except Exception as e:
print(f"✗ 预测配置设置失败: {e}")
return False
def get_prediction_config(self) -> Dict[str, Any]:
"""
获取预测配置
Returns:
预测配置字典
"""
return self.prediction_config.copy()
def _trigger_prediction_callback(self, callback_type: str, data: Dict[str, Any]):
"""
触发预测回调
Args:
callback_type: 回调类型
data: 回调数据
"""
try:
if callback_type in self.prediction_callbacks:
for callback in self.prediction_callbacks[callback_type]:
try:
callback(data)
except Exception as e:
print(f"✗ 预测回调执行失败: {e}")
except Exception as e:
print(f"✗ 预测回调触发失败: {e}")
def register_prediction_callback(self, callback_type: str, callback: callable):
"""
注册预测回调
Args:
callback_type: 回调类型
callback: 回调函数
"""
try:
if callback_type in self.prediction_callbacks:
self.prediction_callbacks[callback_type].append(callback)
print(f"✓ 预测回调已注册: {callback_type}")
else:
print(f"✗ 无效的回调类型: {callback_type}")
except Exception as e:
print(f"✗ 预测回调注册失败: {e}")
def unregister_prediction_callback(self, callback_type: str, callback: callable):
"""
注销预测回调
Args:
callback_type: 回调类型
callback: 回调函数
"""
try:
if callback_type in self.prediction_callbacks:
if callback in self.prediction_callbacks[callback_type]:
self.prediction_callbacks[callback_type].remove(callback)
print(f"✓ 预测回调已注销: {callback_type}")
except Exception as e:
print(f"✗ 预测回调注销失败: {e}")