EG/plugins/user/behavior_tree/nodes/control_nodes.py
2025-12-12 16:16:15 +08:00

396 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
行为树控制节点
包括选择节点、序列节点、并行节点等
提供多种控制流逻辑实现
"""
import time
import random
from typing import List, Optional
from ..core.behavior_tree import BTNode, NodeStatus, CompositeNode, BTNodeConfig
class SelectorNode(CompositeNode):
"""
选择节点OR节点
从左到右执行子节点,直到有一个成功为止
如果所有子节点都失败,则返回失败
支持优先级排序和中断机制
"""
def __init__(self, config: BTNodeConfig = None):
if config is None:
config = BTNodeConfig(name="Selector")
super().__init__(config)
self.current_child_index = 0
self.interruptible_children: List[int] = [] # 可中断的子节点索引
def add_interruptible_child(self, index: int) -> None:
"""添加可中断的子节点索引"""
if index not in self.interruptible_children:
self.interruptible_children.append(index)
def _execute_child_logic(self, blackboard) -> NodeStatus:
"""执行选择节点逻辑"""
# 从当前子节点开始执行
for i in range(self.current_child_index, len(self.children)):
child = self.children[i]
# 检查是否可以中断当前运行的子节点
if (self.current_child_index != i and
self.current_child_index in self.interruptible_children and
hasattr(child, 'config') and child.config.can_be_interrupted):
# 中断之前的子节点
if hasattr(self.children[self.current_child_index], 'on_abort'):
self.children[self.current_child_index].on_abort(blackboard)
status = child.execute(blackboard)
if status == NodeStatus.RUNNING:
# 记录当前正在运行的子节点索引
self.current_child_index = i
self.memory.set('running_child_index', i)
return NodeStatus.RUNNING
elif status == NodeStatus.SUCCESS:
# 重置当前子节点索引
self.current_child_index = 0
self.memory.set('running_child_index', -1)
return NodeStatus.SUCCESS
# 如果是FAILURE继续执行下一个子节点
# 所有子节点都失败
self.current_child_index = 0
self.memory.set('running_child_index', -1)
return NodeStatus.FAILURE
def reset(self) -> None:
"""重置节点状态"""
super().reset()
self.current_child_index = 0
self.memory.set('running_child_index', -1)
class SequenceNode(CompositeNode):
"""
序列节点AND节点
从左到右执行子节点,直到有一个失败为止
如果所有子节点都成功,则返回成功
支持失败重试和条件检查
"""
def __init__(self, config: BTNodeConfig = None, max_retries: int = 0):
if config is None:
config = BTNodeConfig(name="Sequence")
super().__init__(config)
self.current_child_index = 0
self.max_retries = max_retries # 最大重试次数
self.retry_count = 0 # 当前重试次数
self.failed_child_index = -1 # 最后失败的子节点索引
def _execute_child_logic(self, blackboard) -> NodeStatus:
"""执行序列节点逻辑"""
# 检查是否需要重置(所有子节点都成功后再次执行)
if self.current_child_index >= len(self.children):
self.current_child_index = 0
self.retry_count = 0
# 从当前子节点开始执行
for i in range(self.current_child_index, len(self.children)):
child = self.children[i]
status = child.execute(blackboard)
if status == NodeStatus.RUNNING:
# 记录当前正在运行的子节点索引
self.current_child_index = i
self.memory.set('running_child_index', i)
return NodeStatus.RUNNING
elif status == NodeStatus.FAILURE:
# 记录失败的子节点
self.failed_child_index = i
# 检查是否可以重试
if self.retry_count < self.max_retries:
self.retry_count += 1
# 重置到失败的子节点重新开始
self.current_child_index = i
return NodeStatus.RUNNING
else:
# 重置状态并返回失败
self.current_child_index = 0
self.memory.set('running_child_index', -1)
return NodeStatus.FAILURE
# 如果是SUCCESS继续执行下一个子节点
# 所有子节点都成功
self.current_child_index = 0
self.memory.set('running_child_index', -1)
return NodeStatus.SUCCESS
def reset(self) -> None:
"""重置节点状态"""
super().reset()
self.current_child_index = 0
self.retry_count = 0
self.failed_child_index = -1
self.memory.set('running_child_index', -1)
class ParallelNode(CompositeNode):
"""
并行节点
同时执行所有子节点
根据成功和失败的阈值决定返回状态
支持多种并行模式
"""
class ParallelMode:
"""并行模式枚举"""
SEQUENCE = "sequence" # 顺序启动
ALL = "all" # 同时启动所有
RACE = "race" # 竞速模式,第一个完成的决定结果
def __init__(self, config: BTNodeConfig = None,
success_threshold: int = 1,
failure_threshold: int = 1,
mode: str = ParallelMode.ALL):
if config is None:
config = BTNodeConfig(name="Parallel")
super().__init__(config)
self.success_threshold = success_threshold # 成功阈值
self.failure_threshold = failure_threshold # 失败阈值
self.mode = mode # 并行模式
self.child_results: List[NodeStatus] = [] # 子节点执行结果
self.child_start_times: List[float] = [] # 子节点开始时间
self.completed_count = 0 # 已完成的子节点数量
def _execute_child_logic(self, blackboard) -> NodeStatus:
"""执行并行节点逻辑"""
success_count = 0
failure_count = 0
running_count = 0
current_time = time.time()
# 初始化结果列表
while len(self.child_results) < len(self.children):
self.child_results.append(NodeStatus.RUNNING)
self.child_start_times.append(0.0)
# 根据模式执行子节点
if self.mode == self.ParallelMode.SEQUENCE:
self._execute_sequence_mode(blackboard, current_time)
elif self.mode == self.ParallelMode.RACE:
self._execute_race_mode(blackboard, current_time)
else: # ALL模式
self._execute_all_mode(blackboard, current_time)
# 统计结果
for status in self.child_results:
if status == NodeStatus.SUCCESS:
success_count += 1
elif status == NodeStatus.FAILURE:
failure_count += 1
elif status == NodeStatus.RUNNING:
running_count += 1
# 检查是否达到阈值
if success_count >= self.success_threshold:
# 重置状态
self._reset_state()
return NodeStatus.SUCCESS
elif failure_count >= self.failure_threshold:
# 重置状态
self._reset_state()
return NodeStatus.FAILURE
elif running_count == 0:
# 所有子节点都已完成但未达到阈值
self._reset_state()
return NodeStatus.FAILURE
else:
return NodeStatus.RUNNING
def _execute_all_mode(self, blackboard, current_time: float) -> None:
"""执行ALL模式"""
for i, child in enumerate(self.children):
# 如果该子节点还没有结果或者之前是RUNNING状态则执行它
if (i >= len(self.child_results) or
self.child_results[i] == NodeStatus.RUNNING):
# 记录开始时间
if self.child_start_times[i] == 0.0:
self.child_start_times[i] = current_time
status = child.execute(blackboard)
# 确保child_results列表足够长
while len(self.child_results) <= i:
self.child_results.append(NodeStatus.RUNNING)
self.child_start_times.append(0.0)
self.child_results[i] = status
def _execute_sequence_mode(self, blackboard, current_time: float) -> None:
"""执行SEQUENCE模式"""
# 只执行未完成的子节点,按顺序启动
for i, child in enumerate(self.children):
if i >= self.completed_count:
# 启动下一个子节点
if (i >= len(self.child_results) or
self.child_results[i] == NodeStatus.RUNNING):
# 记录开始时间
if self.child_start_times[i] == 0.0:
self.child_start_times[i] = current_time
status = child.execute(blackboard)
# 确保child_results列表足够长
while len(self.child_results) <= i:
self.child_results.append(NodeStatus.RUNNING)
self.child_start_times.append(0.0)
self.child_results[i] = status
# 如果子节点已完成,增加完成计数
if status != NodeStatus.RUNNING:
self.completed_count = i + 1
def _execute_race_mode(self, blackboard, current_time: float) -> None:
"""执行RACE模式"""
# 执行所有未完成的子节点,第一个完成的决定结果
for i, child in enumerate(self.children):
if (i >= len(self.child_results) or
self.child_results[i] == NodeStatus.RUNNING):
# 记录开始时间
if self.child_start_times[i] == 0.0:
self.child_start_times[i] = current_time
status = child.execute(blackboard)
# 确保child_results列表足够长
while len(self.child_results) <= i:
self.child_results.append(NodeStatus.RUNNING)
self.child_start_times.append(0.0)
self.child_results[i] = status
def _reset_state(self) -> None:
"""重置节点状态"""
self.child_results = []
self.child_start_times = []
self.completed_count = 0
def reset(self) -> None:
"""重置节点状态"""
super().reset()
self._reset_state()
class RandomSelectorNode(SelectorNode):
"""
随机选择节点
随机选择一个子节点执行,而不是按顺序
"""
def __init__(self, config: BTNodeConfig = None,
weighted: bool = False):
if config is None:
config = BTNodeConfig(name="RandomSelector")
super().__init__(config)
self.weighted = weighted # 是否使用权重
self.weights: List[float] = [] # 子节点权重
self.selected_index = -1 # 当前选择的索引
def add_weighted_child(self, child: BTNode, weight: float = 1.0) -> None:
"""添加带权重的子节点"""
self.add_child(child)
self.weights.append(weight)
def _execute_child_logic(self, blackboard) -> NodeStatus:
"""执行随机选择节点逻辑"""
if not self.children:
return NodeStatus.FAILURE
# 如果还没有选择子节点,随机选择一个
if self.selected_index == -1:
if self.weighted and self.weights:
# 使用权重随机选择
total_weight = sum(self.weights)
if total_weight > 0:
random_value = random.uniform(0, total_weight)
cumulative_weight = 0
for i, weight in enumerate(self.weights):
cumulative_weight += weight
if random_value <= cumulative_weight:
self.selected_index = i
break
# 确保选择了有效索引
if self.selected_index == -1:
self.selected_index = len(self.children) - 1
else:
# 权重总和为0均匀随机选择
self.selected_index = random.randint(0, len(self.children) - 1)
else:
# 均匀随机选择
self.selected_index = random.randint(0, len(self.children) - 1)
# 执行选中的子节点
selected_child = self.children[self.selected_index]
status = selected_child.execute(blackboard)
if status != NodeStatus.RUNNING:
# 子节点完成,重置选择
self.selected_index = -1
return status
def reset(self) -> None:
"""重置节点状态"""
super().reset()
self.selected_index = -1
class PrioritySelectorNode(SelectorNode):
"""
优先级选择节点
根据优先级选择子节点执行,优先级高的先执行
"""
def __init__(self, config: BTNodeConfig = None):
if config is None:
config = BTNodeConfig(name="PrioritySelector")
super().__init__(config)
self.priorities: List[int] = [] # 子节点优先级
self.sorted_indices: List[int] = [] # 按优先级排序的索引
def add_prioritized_child(self, child: BTNode, priority: int = 0) -> None:
"""添加带优先级的子节点"""
self.add_child(child)
self.priorities.append(priority)
self._sort_children()
def _sort_children(self) -> None:
"""根据优先级对子节点进行排序"""
# 创建索引列表并按优先级排序(优先级高的在前)
self.sorted_indices = list(range(len(self.children)))
self.sorted_indices.sort(key=lambda i: self.priorities[i], reverse=True)
def _execute_child_logic(self, blackboard) -> NodeStatus:
"""执行优先级选择节点逻辑"""
if not self.children:
return NodeStatus.FAILURE
# 按优先级顺序执行子节点
for i in self.sorted_indices:
child = self.children[i]
status = child.execute(blackboard)
if status == NodeStatus.RUNNING:
# 记录当前正在运行的子节点索引
self.current_child_index = i
self.memory.set('running_child_index', i)
return NodeStatus.RUNNING
elif status == NodeStatus.SUCCESS:
# 重置当前子节点索引
self.current_child_index = 0
self.memory.set('running_child_index', -1)
return NodeStatus.SUCCESS
# 如果是FAILURE继续执行下一个优先级的子节点
# 所有子节点都失败
self.current_child_index = 0
self.memory.set('running_child_index', -1)
return NodeStatus.FAILURE
def reset(self) -> None:
"""重置节点状态"""
super().reset()
self.current_child_index = 0
self.memory.set('running_child_index', -1)