791 lines
29 KiB
Python
791 lines
29 KiB
Python
"""
|
|
预测管理模块
|
|
负责处理网络延迟下的对象状态预测
|
|
"""
|
|
|
|
import time
|
|
from typing import Dict, Any, List, Optional
|
|
import threading
|
|
import copy
|
|
|
|
class PredictionManager:
|
|
"""
|
|
预测管理器
|
|
负责处理网络延迟下的对象状态预测
|
|
"""
|
|
|
|
def __init__(self, plugin):
|
|
"""
|
|
初始化预测管理器
|
|
|
|
Args:
|
|
plugin: 网络同步插件实例
|
|
"""
|
|
self.plugin = plugin
|
|
self.enabled = False
|
|
self.initialized = False
|
|
|
|
# 预测配置
|
|
self.prediction_config = {
|
|
'enable_prediction': True,
|
|
'enable_interpolation': True,
|
|
'enable_extrapolation': True,
|
|
'max_prediction_time': 0.5, # 最大预测时间(秒)
|
|
'prediction_smoothing': 0.1, # 预测平滑系数
|
|
'correction_speed': 5.0, # 状态校正速度
|
|
'velocity_damping': 0.95, # 速度阻尼
|
|
'position_threshold': 0.01, # 位置校正阈值
|
|
'rotation_threshold': 0.1, # 旋转校正阈值
|
|
'enable_snap_correction': True, # 启用快照校正
|
|
'snap_threshold': 1.0 # 快照阈值
|
|
}
|
|
|
|
# 预测对象管理
|
|
self.predicted_objects = {} # 预测对象状态
|
|
self.object_history = {} # 对象历史状态
|
|
self.object_velocities = {} # 对象速度状态
|
|
|
|
# 预测状态
|
|
self.prediction_state = {
|
|
'last_prediction_time': 0.0,
|
|
'predicted_objects_count': 0,
|
|
'corrections_applied': 0,
|
|
'snap_corrections': 0
|
|
}
|
|
|
|
# 预测统计
|
|
self.prediction_stats = {
|
|
'predictions_made': 0,
|
|
'corrections_applied': 0,
|
|
'snap_corrections': 0,
|
|
'prediction_errors': 0,
|
|
'average_error': 0.0,
|
|
'max_error': 0.0
|
|
}
|
|
|
|
# 线程锁
|
|
self.prediction_lock = threading.RLock()
|
|
|
|
# 回调函数
|
|
self.prediction_callbacks = {
|
|
'prediction_updated': [],
|
|
'correction_applied': [],
|
|
'snap_correction': [],
|
|
'prediction_error': []
|
|
}
|
|
|
|
# 时间戳记录
|
|
self.last_prediction_update = 0.0
|
|
self.last_correction_apply = 0.0
|
|
|
|
print("✓ 预测管理器已创建")
|
|
|
|
def initialize(self) -> bool:
|
|
"""
|
|
初始化预测管理器
|
|
|
|
Returns:
|
|
是否初始化成功
|
|
"""
|
|
try:
|
|
self.initialized = True
|
|
print("✓ 预测管理器初始化完成")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测管理器初始化失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def enable(self) -> bool:
|
|
"""
|
|
启用预测管理器
|
|
|
|
Returns:
|
|
是否启用成功
|
|
"""
|
|
try:
|
|
if not self.initialized:
|
|
print("✗ 预测管理器未初始化")
|
|
return False
|
|
|
|
self.enabled = True
|
|
print("✓ 预测管理器已启用")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测管理器启用失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def disable(self):
|
|
"""禁用预测管理器"""
|
|
try:
|
|
self.enabled = False
|
|
print("✓ 预测管理器已禁用")
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测管理器禁用失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def finalize(self):
|
|
"""清理预测管理器资源"""
|
|
try:
|
|
self.disable()
|
|
self.predicted_objects.clear()
|
|
self.object_history.clear()
|
|
self.object_velocities.clear()
|
|
self.prediction_callbacks.clear()
|
|
self.initialized = False
|
|
print("✓ 预测管理器资源已清理")
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测管理器资源清理失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def update(self, dt: float):
|
|
"""
|
|
更新预测管理器状态
|
|
|
|
Args:
|
|
dt: 时间增量
|
|
"""
|
|
try:
|
|
if not self.enabled:
|
|
return
|
|
|
|
current_time = time.time()
|
|
self.last_prediction_update = current_time
|
|
|
|
# 更新预测
|
|
if self.prediction_config['enable_prediction']:
|
|
self._update_predictions(current_time, dt)
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测管理器更新失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def _update_predictions(self, current_time: float, dt: float):
|
|
"""
|
|
更新预测状态
|
|
|
|
Args:
|
|
current_time: 当前时间
|
|
dt: 时间增量
|
|
"""
|
|
try:
|
|
with self.prediction_lock:
|
|
for obj_id, pred_data in self.predicted_objects.items():
|
|
# 获取服务器时间
|
|
server_time = self.plugin.clock_sync.get_server_time() if self.plugin.clock_sync else current_time
|
|
|
|
# 计算预测时间
|
|
prediction_time = server_time + self.plugin.network_manager.get_connection_state().get('latency', 0.0) \
|
|
if self.plugin.network_manager else server_time
|
|
|
|
# 限制最大预测时间
|
|
max_prediction_time = server_time + self.prediction_config['max_prediction_time']
|
|
prediction_time = min(prediction_time, max_prediction_time)
|
|
|
|
# 执行预测
|
|
if self.prediction_config['enable_extrapolation']:
|
|
self._extrapolate_object(obj_id, pred_data, prediction_time, current_time, dt)
|
|
|
|
# 更新统计信息
|
|
self.prediction_stats['predictions_made'] += 1
|
|
|
|
# 更新预测对象计数
|
|
self.prediction_state['predicted_objects_count'] = len(self.predicted_objects)
|
|
self.prediction_state['last_prediction_time'] = current_time
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测更新失败: {e}")
|
|
self.prediction_stats['prediction_errors'] += 1
|
|
|
|
def _extrapolate_object(self, obj_id: str, pred_data: Dict[str, Any],
|
|
prediction_time: float, current_time: float, dt: float):
|
|
"""
|
|
外推对象状态
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
pred_data: 预测数据
|
|
prediction_time: 预测时间
|
|
current_time: 当前时间
|
|
dt: 时间增量
|
|
"""
|
|
try:
|
|
# 获取对象当前状态
|
|
if obj_id not in self.object_history or not self.object_history[obj_id]:
|
|
return
|
|
|
|
current_state = self.object_history[obj_id][-1]
|
|
current_timestamp = current_state.get('timestamp', current_time)
|
|
|
|
# 计算时间差
|
|
time_diff = prediction_time - current_timestamp
|
|
|
|
# 如果时间差太大,不进行外推
|
|
if abs(time_diff) > self.prediction_config['max_prediction_time']:
|
|
return
|
|
|
|
# 获取对象速度
|
|
velocity = self.object_velocities.get(obj_id, {
|
|
'position': [0.0, 0.0, 0.0],
|
|
'rotation': [0.0, 0.0, 0.0],
|
|
'scale': [0.0, 0.0, 0.0]
|
|
})
|
|
|
|
# 外推位置
|
|
if 'position' in current_state:
|
|
current_pos = current_state['position']
|
|
velocity_pos = velocity['position']
|
|
|
|
predicted_pos = [
|
|
current_pos[i] + velocity_pos[i] * time_diff
|
|
for i in range(min(len(current_pos), len(velocity_pos)))
|
|
]
|
|
|
|
pred_data['position'] = predicted_pos
|
|
|
|
# 外推旋转
|
|
if 'rotation' in current_state:
|
|
current_rot = current_state['rotation']
|
|
velocity_rot = velocity['rotation']
|
|
|
|
predicted_rot = [
|
|
current_rot[i] + velocity_rot[i] * time_diff
|
|
for i in range(min(len(current_rot), len(velocity_rot)))
|
|
]
|
|
|
|
pred_data['rotation'] = predicted_rot
|
|
|
|
# 外推缩放
|
|
if 'scale' in current_state:
|
|
current_scale = current_state['scale']
|
|
velocity_scale = velocity['scale']
|
|
|
|
predicted_scale = [
|
|
current_scale[i] + velocity_scale[i] * time_diff
|
|
for i in range(min(len(current_scale), len(velocity_scale)))
|
|
]
|
|
|
|
pred_data['scale'] = predicted_scale
|
|
|
|
# 应用阻尼
|
|
self._apply_velocity_damping(obj_id, dt)
|
|
|
|
# 触发预测更新回调
|
|
self._trigger_prediction_callback('prediction_updated', {
|
|
'object_id': obj_id,
|
|
'predicted_state': pred_data.copy(),
|
|
'prediction_time': prediction_time,
|
|
'time_diff': time_diff
|
|
})
|
|
|
|
except Exception as e:
|
|
print(f"✗ 对象外推失败: {e}")
|
|
self.prediction_stats['prediction_errors'] += 1
|
|
|
|
def _apply_velocity_damping(self, obj_id: str, dt: float):
|
|
"""
|
|
应用速度阻尼
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
dt: 时间增量
|
|
"""
|
|
try:
|
|
if obj_id in self.object_velocities:
|
|
damping = self.prediction_config['velocity_damping']
|
|
velocity = self.object_velocities[obj_id]
|
|
|
|
for key in ['position', 'rotation', 'scale']:
|
|
if key in velocity:
|
|
velocity[key] = [
|
|
v * (damping ** dt)
|
|
for v in velocity[key]
|
|
]
|
|
|
|
except Exception as e:
|
|
print(f"✗ 速度阻尼应用失败: {e}")
|
|
|
|
def register_object(self, obj_id: str, initial_state: Dict[str, Any]) -> bool:
|
|
"""
|
|
注册预测对象
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
initial_state: 初始状态
|
|
|
|
Returns:
|
|
是否注册成功
|
|
"""
|
|
try:
|
|
with self.prediction_lock:
|
|
# 初始化预测数据
|
|
self.predicted_objects[obj_id] = initial_state.copy()
|
|
|
|
# 初始化历史记录
|
|
self.object_history[obj_id] = [{
|
|
'state': initial_state.copy(),
|
|
'timestamp': time.time()
|
|
}]
|
|
|
|
# 初始化速度数据
|
|
self.object_velocities[obj_id] = {
|
|
'position': [0.0, 0.0, 0.0],
|
|
'rotation': [0.0, 0.0, 0.0],
|
|
'scale': [0.0, 0.0, 0.0]
|
|
}
|
|
|
|
print(f"✓ 预测对象已注册: {obj_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测对象注册失败: {e}")
|
|
return False
|
|
|
|
def unregister_object(self, obj_id: str) -> bool:
|
|
"""
|
|
注销预测对象
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
|
|
Returns:
|
|
是否注销成功
|
|
"""
|
|
try:
|
|
with self.prediction_lock:
|
|
if obj_id in self.predicted_objects:
|
|
del self.predicted_objects[obj_id]
|
|
|
|
if obj_id in self.object_history:
|
|
del self.object_history[obj_id]
|
|
|
|
if obj_id in self.object_velocities:
|
|
del self.object_velocities[obj_id]
|
|
|
|
print(f"✓ 预测对象已注销: {obj_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 预测对象注销失败: {e}")
|
|
return False
|
|
|
|
def update_object_state(self, obj_id: str, new_state: Dict[str, Any], timestamp: float = None):
|
|
"""
|
|
更新对象状态
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
new_state: 新状态
|
|
timestamp: 时间戳
|
|
"""
|
|
try:
|
|
if timestamp is None:
|
|
timestamp = time.time()
|
|
|
|
with self.prediction_lock:
|
|
# 添加到历史记录
|
|
if obj_id not in self.object_history:
|
|
self.object_history[obj_id] = []
|
|
|
|
self.object_history[obj_id].append({
|
|
'state': new_state.copy(),
|
|
'timestamp': timestamp
|
|
})
|
|
|
|
# 保持历史记录大小
|
|
if len(self.object_history[obj_id]) > 20: # 保持最近20个状态
|
|
self.object_history[obj_id].pop(0)
|
|
|
|
# 计算速度(如果可能)
|
|
self._calculate_object_velocity(obj_id, new_state, timestamp)
|
|
|
|
# 更新预测对象状态
|
|
if obj_id in self.predicted_objects:
|
|
self.predicted_objects[obj_id].update(new_state)
|
|
|
|
except Exception as e:
|
|
print(f"✗ 对象状态更新失败: {e}")
|
|
|
|
def _calculate_object_velocity(self, obj_id: str, new_state: Dict[str, Any], timestamp: float):
|
|
"""
|
|
计算对象速度
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
new_state: 新状态
|
|
timestamp: 时间戳
|
|
"""
|
|
try:
|
|
# 需要至少两个历史状态来计算速度
|
|
if obj_id not in self.object_history or len(self.object_history[obj_id]) < 2:
|
|
return
|
|
|
|
history = self.object_history[obj_id]
|
|
current_entry = history[-1]
|
|
previous_entry = history[-2]
|
|
|
|
current_state = current_entry['state']
|
|
previous_state = previous_entry['state']
|
|
current_time = current_entry['timestamp']
|
|
previous_time = previous_entry['timestamp']
|
|
|
|
time_diff = current_time - previous_time
|
|
if time_diff <= 0:
|
|
return
|
|
|
|
# 计算位置速度
|
|
if 'position' in new_state and 'position' in previous_state:
|
|
current_pos = current_state['position']
|
|
previous_pos = previous_state['position']
|
|
|
|
velocity_pos = [
|
|
(current_pos[i] - previous_pos[i]) / time_diff
|
|
for i in range(min(len(current_pos), len(previous_pos)))
|
|
]
|
|
|
|
self.object_velocities[obj_id]['position'] = velocity_pos
|
|
|
|
# 计算旋转速度
|
|
if 'rotation' in new_state and 'rotation' in previous_state:
|
|
current_rot = current_state['rotation']
|
|
previous_rot = previous_state['rotation']
|
|
|
|
velocity_rot = [
|
|
(current_rot[i] - previous_rot[i]) / time_diff
|
|
for i in range(min(len(current_rot), len(previous_rot)))
|
|
]
|
|
|
|
self.object_velocities[obj_id]['rotation'] = velocity_rot
|
|
|
|
# 计算缩放速度
|
|
if 'scale' in new_state and 'scale' in previous_state:
|
|
current_scale = current_state['scale']
|
|
previous_scale = previous_state['scale']
|
|
|
|
velocity_scale = [
|
|
(current_scale[i] - previous_scale[i]) / time_diff
|
|
for i in range(min(len(current_scale), len(previous_scale)))
|
|
]
|
|
|
|
self.object_velocities[obj_id]['scale'] = velocity_scale
|
|
|
|
except Exception as e:
|
|
print(f"✗ 对象速度计算失败: {e}")
|
|
|
|
def apply_correction(self, obj_id: str, server_state: Dict[str, Any], server_time: float):
|
|
"""
|
|
应用服务器校正
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
server_state: 服务器状态
|
|
server_time: 服务器时间
|
|
"""
|
|
try:
|
|
if not self.prediction_config['enable_prediction']:
|
|
return
|
|
|
|
with self.prediction_lock:
|
|
# 检查对象是否存在
|
|
if obj_id not in self.predicted_objects:
|
|
return
|
|
|
|
predicted_state = self.predicted_objects[obj_id]
|
|
|
|
# 计算位置差异
|
|
position_error = 0.0
|
|
if 'position' in server_state and 'position' in predicted_state:
|
|
pos1 = server_state['position']
|
|
pos2 = predicted_state['position']
|
|
position_error = self._calculate_distance(pos1, pos2)
|
|
|
|
# 计算旋转差异
|
|
rotation_error = 0.0
|
|
if 'rotation' in server_state and 'rotation' in predicted_state:
|
|
rot1 = server_state['rotation']
|
|
rot2 = predicted_state['rotation']
|
|
rotation_error = self._calculate_rotation_difference(rot1, rot2)
|
|
|
|
# 更新平均误差统计
|
|
total_error = position_error + rotation_error
|
|
if self.prediction_stats['predictions_made'] > 0:
|
|
self.prediction_stats['average_error'] = (
|
|
(self.prediction_stats['average_error'] * (self.prediction_stats['predictions_made'] - 1) + total_error) /
|
|
self.prediction_stats['predictions_made']
|
|
)
|
|
else:
|
|
self.prediction_stats['average_error'] = total_error
|
|
|
|
self.prediction_stats['max_error'] = max(self.prediction_stats['max_error'], total_error)
|
|
|
|
# 检查是否需要快照校正
|
|
if (self.prediction_config['enable_snap_correction'] and
|
|
total_error > self.prediction_config['snap_threshold']):
|
|
# 快照校正
|
|
self.predicted_objects[obj_id].update(server_state)
|
|
self.prediction_stats['snap_corrections'] += 1
|
|
self.prediction_state['snap_corrections'] += 1
|
|
|
|
# 触发快照校正回调
|
|
self._trigger_prediction_callback('snap_correction', {
|
|
'object_id': obj_id,
|
|
'server_state': server_state,
|
|
'predicted_state': predicted_state,
|
|
'error': total_error
|
|
})
|
|
else:
|
|
# 平滑校正
|
|
self._smooth_correct_object(obj_id, server_state, server_time)
|
|
self.prediction_stats['corrections_applied'] += 1
|
|
self.prediction_state['corrections_applied'] += 1
|
|
|
|
# 触发校正应用回调
|
|
self._trigger_prediction_callback('correction_applied', {
|
|
'object_id': obj_id,
|
|
'server_state': server_state,
|
|
'predicted_state': predicted_state,
|
|
'error': total_error
|
|
})
|
|
|
|
# 更新对象历史
|
|
self.update_object_state(obj_id, server_state, server_time)
|
|
|
|
except Exception as e:
|
|
print(f"✗ 校正应用失败: {e}")
|
|
self.prediction_stats['prediction_errors'] += 1
|
|
self._trigger_prediction_callback('prediction_error', {
|
|
'object_id': obj_id,
|
|
'error': str(e)
|
|
})
|
|
|
|
def _smooth_correct_object(self, obj_id: str, server_state: Dict[str, Any], server_time: float):
|
|
"""
|
|
平滑校正对象
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
server_state: 服务器状态
|
|
server_time: 服务器时间
|
|
"""
|
|
try:
|
|
if obj_id not in self.predicted_objects:
|
|
return
|
|
|
|
predicted_state = self.predicted_objects[obj_id]
|
|
correction_speed = self.prediction_config['correction_speed']
|
|
|
|
# 平滑校正位置
|
|
if 'position' in server_state and 'position' in predicted_state:
|
|
server_pos = server_state['position']
|
|
predicted_pos = predicted_state['position']
|
|
|
|
corrected_pos = [
|
|
predicted_pos[i] + (server_pos[i] - predicted_pos[i]) * correction_speed * 0.016 # 假设60FPS
|
|
for i in range(min(len(server_pos), len(predicted_pos)))
|
|
]
|
|
|
|
self.predicted_objects[obj_id]['position'] = corrected_pos
|
|
|
|
# 平滑校正旋转
|
|
if 'rotation' in server_state and 'rotation' in predicted_state:
|
|
server_rot = server_state['rotation']
|
|
predicted_rot = predicted_state['rotation']
|
|
|
|
corrected_rot = [
|
|
predicted_rot[i] + (server_rot[i] - predicted_rot[i]) * correction_speed * 0.016
|
|
for i in range(min(len(server_rot), len(predicted_rot)))
|
|
]
|
|
|
|
self.predicted_objects[obj_id]['rotation'] = corrected_rot
|
|
|
|
# 平滑校正缩放
|
|
if 'scale' in server_state and 'scale' in predicted_state:
|
|
server_scale = server_state['scale']
|
|
predicted_scale = predicted_state['scale']
|
|
|
|
corrected_scale = [
|
|
predicted_scale[i] + (server_scale[i] - predicted_scale[i]) * correction_speed * 0.016
|
|
for i in range(min(len(server_scale), len(predicted_scale)))
|
|
]
|
|
|
|
self.predicted_objects[obj_id]['scale'] = corrected_scale
|
|
|
|
except Exception as e:
|
|
print(f"✗ 平滑校正失败: {e}")
|
|
|
|
def _calculate_distance(self, pos1: List[float], pos2: List[float]) -> float:
|
|
"""
|
|
计算两点间距离
|
|
|
|
Args:
|
|
pos1: 位置1
|
|
pos2: 位置2
|
|
|
|
Returns:
|
|
距离
|
|
"""
|
|
try:
|
|
if len(pos1) >= 3 and len(pos2) >= 3:
|
|
dx = pos1[0] - pos2[0]
|
|
dy = pos1[1] - pos2[1]
|
|
dz = pos1[2] - pos2[2]
|
|
return (dx*dx + dy*dy + dz*dz) ** 0.5
|
|
elif len(pos1) >= 2 and len(pos2) >= 2:
|
|
dx = pos1[0] - pos2[0]
|
|
dy = pos1[1] - pos2[1]
|
|
return (dx*dx + dy*dy) ** 0.5
|
|
else:
|
|
return 0.0
|
|
except Exception as e:
|
|
print(f"✗ 距离计算失败: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_rotation_difference(self, rot1: List[float], rot2: List[float]) -> float:
|
|
"""
|
|
计算旋转差异
|
|
|
|
Args:
|
|
rot1: 旋转1
|
|
rot2: 旋转2
|
|
|
|
Returns:
|
|
旋转差异
|
|
"""
|
|
try:
|
|
# 简单实现:计算欧拉角差异
|
|
diff = 0.0
|
|
for i in range(min(len(rot1), len(rot2))):
|
|
diff += abs(rot1[i] - rot2[i])
|
|
return diff
|
|
except Exception as e:
|
|
print(f"✗ 旋转差异计算失败: {e}")
|
|
return 0.0
|
|
|
|
def get_predicted_state(self, obj_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取预测状态
|
|
|
|
Args:
|
|
obj_id: 对象ID
|
|
|
|
Returns:
|
|
预测状态或None
|
|
"""
|
|
try:
|
|
with self.prediction_lock:
|
|
return self.predicted_objects.get(obj_id, {}).copy()
|
|
except Exception as e:
|
|
print(f"✗ 预测状态获取失败: {e}")
|
|
return None
|
|
|
|
def get_prediction_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
获取预测统计信息
|
|
|
|
Returns:
|
|
预测统计字典
|
|
"""
|
|
return self.prediction_stats.copy()
|
|
|
|
def reset_prediction_stats(self):
|
|
"""重置预测统计信息"""
|
|
try:
|
|
self.prediction_stats = {
|
|
'predictions_made': 0,
|
|
'corrections_applied': 0,
|
|
'snap_corrections': 0,
|
|
'prediction_errors': 0,
|
|
'average_error': 0.0,
|
|
'max_error': 0.0
|
|
}
|
|
print("✓ 预测统计信息已重置")
|
|
except Exception as e:
|
|
print(f"✗ 预测统计信息重置失败: {e}")
|
|
|
|
def set_prediction_config(self, config: Dict[str, Any]) -> bool:
|
|
"""
|
|
设置预测配置
|
|
|
|
Args:
|
|
config: 预测配置字典
|
|
|
|
Returns:
|
|
是否设置成功
|
|
"""
|
|
try:
|
|
self.prediction_config.update(config)
|
|
print(f"✓ 预测配置已更新: {self.prediction_config}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ 预测配置设置失败: {e}")
|
|
return False
|
|
|
|
def get_prediction_config(self) -> Dict[str, Any]:
|
|
"""
|
|
获取预测配置
|
|
|
|
Returns:
|
|
预测配置字典
|
|
"""
|
|
return self.prediction_config.copy()
|
|
|
|
def _trigger_prediction_callback(self, callback_type: str, data: Dict[str, Any]):
|
|
"""
|
|
触发预测回调
|
|
|
|
Args:
|
|
callback_type: 回调类型
|
|
data: 回调数据
|
|
"""
|
|
try:
|
|
if callback_type in self.prediction_callbacks:
|
|
for callback in self.prediction_callbacks[callback_type]:
|
|
try:
|
|
callback(data)
|
|
except Exception as e:
|
|
print(f"✗ 预测回调执行失败: {e}")
|
|
except Exception as e:
|
|
print(f"✗ 预测回调触发失败: {e}")
|
|
|
|
def register_prediction_callback(self, callback_type: str, callback: callable):
|
|
"""
|
|
注册预测回调
|
|
|
|
Args:
|
|
callback_type: 回调类型
|
|
callback: 回调函数
|
|
"""
|
|
try:
|
|
if callback_type in self.prediction_callbacks:
|
|
self.prediction_callbacks[callback_type].append(callback)
|
|
print(f"✓ 预测回调已注册: {callback_type}")
|
|
else:
|
|
print(f"✗ 无效的回调类型: {callback_type}")
|
|
except Exception as e:
|
|
print(f"✗ 预测回调注册失败: {e}")
|
|
|
|
def unregister_prediction_callback(self, callback_type: str, callback: callable):
|
|
"""
|
|
注销预测回调
|
|
|
|
Args:
|
|
callback_type: 回调类型
|
|
callback: 回调函数
|
|
"""
|
|
try:
|
|
if callback_type in self.prediction_callbacks:
|
|
if callback in self.prediction_callbacks[callback_type]:
|
|
self.prediction_callbacks[callback_type].remove(callback)
|
|
print(f"✓ 预测回调已注销: {callback_type}")
|
|
except Exception as e:
|
|
print(f"✗ 预测回调注销失败: {e}") |