396 lines
16 KiB
Python
396 lines
16 KiB
Python
"""
|
||
行为树控制节点
|
||
包括选择节点、序列节点、并行节点等
|
||
提供多种控制流逻辑实现
|
||
"""
|
||
|
||
import time
|
||
import random
|
||
from typing import List, Optional
|
||
from ..core.behavior_tree import BTNode, NodeStatus, CompositeNode, BTNodeConfig
|
||
|
||
class SelectorNode(CompositeNode):
|
||
"""
|
||
选择节点(OR节点)
|
||
从左到右执行子节点,直到有一个成功为止
|
||
如果所有子节点都失败,则返回失败
|
||
支持优先级排序和中断机制
|
||
"""
|
||
|
||
def __init__(self, config: BTNodeConfig = None):
|
||
if config is None:
|
||
config = BTNodeConfig(name="Selector")
|
||
super().__init__(config)
|
||
self.current_child_index = 0
|
||
self.interruptible_children: List[int] = [] # 可中断的子节点索引
|
||
|
||
def add_interruptible_child(self, index: int) -> None:
|
||
"""添加可中断的子节点索引"""
|
||
if index not in self.interruptible_children:
|
||
self.interruptible_children.append(index)
|
||
|
||
def _execute_child_logic(self, blackboard) -> NodeStatus:
|
||
"""执行选择节点逻辑"""
|
||
# 从当前子节点开始执行
|
||
for i in range(self.current_child_index, len(self.children)):
|
||
child = self.children[i]
|
||
|
||
# 检查是否可以中断当前运行的子节点
|
||
if (self.current_child_index != i and
|
||
self.current_child_index in self.interruptible_children and
|
||
hasattr(child, 'config') and child.config.can_be_interrupted):
|
||
# 中断之前的子节点
|
||
if hasattr(self.children[self.current_child_index], 'on_abort'):
|
||
self.children[self.current_child_index].on_abort(blackboard)
|
||
|
||
status = child.execute(blackboard)
|
||
|
||
if status == NodeStatus.RUNNING:
|
||
# 记录当前正在运行的子节点索引
|
||
self.current_child_index = i
|
||
self.memory.set('running_child_index', i)
|
||
return NodeStatus.RUNNING
|
||
elif status == NodeStatus.SUCCESS:
|
||
# 重置当前子节点索引
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
return NodeStatus.SUCCESS
|
||
# 如果是FAILURE,继续执行下一个子节点
|
||
|
||
# 所有子节点都失败
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
return NodeStatus.FAILURE
|
||
|
||
def reset(self) -> None:
|
||
"""重置节点状态"""
|
||
super().reset()
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
|
||
class SequenceNode(CompositeNode):
|
||
"""
|
||
序列节点(AND节点)
|
||
从左到右执行子节点,直到有一个失败为止
|
||
如果所有子节点都成功,则返回成功
|
||
支持失败重试和条件检查
|
||
"""
|
||
|
||
def __init__(self, config: BTNodeConfig = None, max_retries: int = 0):
|
||
if config is None:
|
||
config = BTNodeConfig(name="Sequence")
|
||
super().__init__(config)
|
||
self.current_child_index = 0
|
||
self.max_retries = max_retries # 最大重试次数
|
||
self.retry_count = 0 # 当前重试次数
|
||
self.failed_child_index = -1 # 最后失败的子节点索引
|
||
|
||
def _execute_child_logic(self, blackboard) -> NodeStatus:
|
||
"""执行序列节点逻辑"""
|
||
# 检查是否需要重置(所有子节点都成功后再次执行)
|
||
if self.current_child_index >= len(self.children):
|
||
self.current_child_index = 0
|
||
self.retry_count = 0
|
||
|
||
# 从当前子节点开始执行
|
||
for i in range(self.current_child_index, len(self.children)):
|
||
child = self.children[i]
|
||
status = child.execute(blackboard)
|
||
|
||
if status == NodeStatus.RUNNING:
|
||
# 记录当前正在运行的子节点索引
|
||
self.current_child_index = i
|
||
self.memory.set('running_child_index', i)
|
||
return NodeStatus.RUNNING
|
||
elif status == NodeStatus.FAILURE:
|
||
# 记录失败的子节点
|
||
self.failed_child_index = i
|
||
|
||
# 检查是否可以重试
|
||
if self.retry_count < self.max_retries:
|
||
self.retry_count += 1
|
||
# 重置到失败的子节点重新开始
|
||
self.current_child_index = i
|
||
return NodeStatus.RUNNING
|
||
else:
|
||
# 重置状态并返回失败
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
return NodeStatus.FAILURE
|
||
# 如果是SUCCESS,继续执行下一个子节点
|
||
|
||
# 所有子节点都成功
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
return NodeStatus.SUCCESS
|
||
|
||
def reset(self) -> None:
|
||
"""重置节点状态"""
|
||
super().reset()
|
||
self.current_child_index = 0
|
||
self.retry_count = 0
|
||
self.failed_child_index = -1
|
||
self.memory.set('running_child_index', -1)
|
||
|
||
class ParallelNode(CompositeNode):
|
||
"""
|
||
并行节点
|
||
同时执行所有子节点
|
||
根据成功和失败的阈值决定返回状态
|
||
支持多种并行模式
|
||
"""
|
||
|
||
class ParallelMode:
|
||
"""并行模式枚举"""
|
||
SEQUENCE = "sequence" # 顺序启动
|
||
ALL = "all" # 同时启动所有
|
||
RACE = "race" # 竞速模式,第一个完成的决定结果
|
||
|
||
def __init__(self, config: BTNodeConfig = None,
|
||
success_threshold: int = 1,
|
||
failure_threshold: int = 1,
|
||
mode: str = ParallelMode.ALL):
|
||
if config is None:
|
||
config = BTNodeConfig(name="Parallel")
|
||
super().__init__(config)
|
||
self.success_threshold = success_threshold # 成功阈值
|
||
self.failure_threshold = failure_threshold # 失败阈值
|
||
self.mode = mode # 并行模式
|
||
self.child_results: List[NodeStatus] = [] # 子节点执行结果
|
||
self.child_start_times: List[float] = [] # 子节点开始时间
|
||
self.completed_count = 0 # 已完成的子节点数量
|
||
|
||
def _execute_child_logic(self, blackboard) -> NodeStatus:
|
||
"""执行并行节点逻辑"""
|
||
success_count = 0
|
||
failure_count = 0
|
||
running_count = 0
|
||
|
||
current_time = time.time()
|
||
|
||
# 初始化结果列表
|
||
while len(self.child_results) < len(self.children):
|
||
self.child_results.append(NodeStatus.RUNNING)
|
||
self.child_start_times.append(0.0)
|
||
|
||
# 根据模式执行子节点
|
||
if self.mode == self.ParallelMode.SEQUENCE:
|
||
self._execute_sequence_mode(blackboard, current_time)
|
||
elif self.mode == self.ParallelMode.RACE:
|
||
self._execute_race_mode(blackboard, current_time)
|
||
else: # ALL模式
|
||
self._execute_all_mode(blackboard, current_time)
|
||
|
||
# 统计结果
|
||
for status in self.child_results:
|
||
if status == NodeStatus.SUCCESS:
|
||
success_count += 1
|
||
elif status == NodeStatus.FAILURE:
|
||
failure_count += 1
|
||
elif status == NodeStatus.RUNNING:
|
||
running_count += 1
|
||
|
||
# 检查是否达到阈值
|
||
if success_count >= self.success_threshold:
|
||
# 重置状态
|
||
self._reset_state()
|
||
return NodeStatus.SUCCESS
|
||
elif failure_count >= self.failure_threshold:
|
||
# 重置状态
|
||
self._reset_state()
|
||
return NodeStatus.FAILURE
|
||
elif running_count == 0:
|
||
# 所有子节点都已完成但未达到阈值
|
||
self._reset_state()
|
||
return NodeStatus.FAILURE
|
||
else:
|
||
return NodeStatus.RUNNING
|
||
|
||
def _execute_all_mode(self, blackboard, current_time: float) -> None:
|
||
"""执行ALL模式"""
|
||
for i, child in enumerate(self.children):
|
||
# 如果该子节点还没有结果或者之前是RUNNING状态,则执行它
|
||
if (i >= len(self.child_results) or
|
||
self.child_results[i] == NodeStatus.RUNNING):
|
||
# 记录开始时间
|
||
if self.child_start_times[i] == 0.0:
|
||
self.child_start_times[i] = current_time
|
||
|
||
status = child.execute(blackboard)
|
||
# 确保child_results列表足够长
|
||
while len(self.child_results) <= i:
|
||
self.child_results.append(NodeStatus.RUNNING)
|
||
self.child_start_times.append(0.0)
|
||
self.child_results[i] = status
|
||
|
||
def _execute_sequence_mode(self, blackboard, current_time: float) -> None:
|
||
"""执行SEQUENCE模式"""
|
||
# 只执行未完成的子节点,按顺序启动
|
||
for i, child in enumerate(self.children):
|
||
if i >= self.completed_count:
|
||
# 启动下一个子节点
|
||
if (i >= len(self.child_results) or
|
||
self.child_results[i] == NodeStatus.RUNNING):
|
||
# 记录开始时间
|
||
if self.child_start_times[i] == 0.0:
|
||
self.child_start_times[i] = current_time
|
||
|
||
status = child.execute(blackboard)
|
||
# 确保child_results列表足够长
|
||
while len(self.child_results) <= i:
|
||
self.child_results.append(NodeStatus.RUNNING)
|
||
self.child_start_times.append(0.0)
|
||
self.child_results[i] = status
|
||
|
||
# 如果子节点已完成,增加完成计数
|
||
if status != NodeStatus.RUNNING:
|
||
self.completed_count = i + 1
|
||
|
||
def _execute_race_mode(self, blackboard, current_time: float) -> None:
|
||
"""执行RACE模式"""
|
||
# 执行所有未完成的子节点,第一个完成的决定结果
|
||
for i, child in enumerate(self.children):
|
||
if (i >= len(self.child_results) or
|
||
self.child_results[i] == NodeStatus.RUNNING):
|
||
# 记录开始时间
|
||
if self.child_start_times[i] == 0.0:
|
||
self.child_start_times[i] = current_time
|
||
|
||
status = child.execute(blackboard)
|
||
# 确保child_results列表足够长
|
||
while len(self.child_results) <= i:
|
||
self.child_results.append(NodeStatus.RUNNING)
|
||
self.child_start_times.append(0.0)
|
||
self.child_results[i] = status
|
||
|
||
def _reset_state(self) -> None:
|
||
"""重置节点状态"""
|
||
self.child_results = []
|
||
self.child_start_times = []
|
||
self.completed_count = 0
|
||
|
||
def reset(self) -> None:
|
||
"""重置节点状态"""
|
||
super().reset()
|
||
self._reset_state()
|
||
|
||
class RandomSelectorNode(SelectorNode):
|
||
"""
|
||
随机选择节点
|
||
随机选择一个子节点执行,而不是按顺序
|
||
"""
|
||
|
||
def __init__(self, config: BTNodeConfig = None,
|
||
weighted: bool = False):
|
||
if config is None:
|
||
config = BTNodeConfig(name="RandomSelector")
|
||
super().__init__(config)
|
||
self.weighted = weighted # 是否使用权重
|
||
self.weights: List[float] = [] # 子节点权重
|
||
self.selected_index = -1 # 当前选择的索引
|
||
|
||
def add_weighted_child(self, child: BTNode, weight: float = 1.0) -> None:
|
||
"""添加带权重的子节点"""
|
||
self.add_child(child)
|
||
self.weights.append(weight)
|
||
|
||
def _execute_child_logic(self, blackboard) -> NodeStatus:
|
||
"""执行随机选择节点逻辑"""
|
||
if not self.children:
|
||
return NodeStatus.FAILURE
|
||
|
||
# 如果还没有选择子节点,随机选择一个
|
||
if self.selected_index == -1:
|
||
if self.weighted and self.weights:
|
||
# 使用权重随机选择
|
||
total_weight = sum(self.weights)
|
||
if total_weight > 0:
|
||
random_value = random.uniform(0, total_weight)
|
||
cumulative_weight = 0
|
||
for i, weight in enumerate(self.weights):
|
||
cumulative_weight += weight
|
||
if random_value <= cumulative_weight:
|
||
self.selected_index = i
|
||
break
|
||
# 确保选择了有效索引
|
||
if self.selected_index == -1:
|
||
self.selected_index = len(self.children) - 1
|
||
else:
|
||
# 权重总和为0,均匀随机选择
|
||
self.selected_index = random.randint(0, len(self.children) - 1)
|
||
else:
|
||
# 均匀随机选择
|
||
self.selected_index = random.randint(0, len(self.children) - 1)
|
||
|
||
# 执行选中的子节点
|
||
selected_child = self.children[self.selected_index]
|
||
status = selected_child.execute(blackboard)
|
||
|
||
if status != NodeStatus.RUNNING:
|
||
# 子节点完成,重置选择
|
||
self.selected_index = -1
|
||
|
||
return status
|
||
|
||
def reset(self) -> None:
|
||
"""重置节点状态"""
|
||
super().reset()
|
||
self.selected_index = -1
|
||
|
||
class PrioritySelectorNode(SelectorNode):
|
||
"""
|
||
优先级选择节点
|
||
根据优先级选择子节点执行,优先级高的先执行
|
||
"""
|
||
|
||
def __init__(self, config: BTNodeConfig = None):
|
||
if config is None:
|
||
config = BTNodeConfig(name="PrioritySelector")
|
||
super().__init__(config)
|
||
self.priorities: List[int] = [] # 子节点优先级
|
||
self.sorted_indices: List[int] = [] # 按优先级排序的索引
|
||
|
||
def add_prioritized_child(self, child: BTNode, priority: int = 0) -> None:
|
||
"""添加带优先级的子节点"""
|
||
self.add_child(child)
|
||
self.priorities.append(priority)
|
||
self._sort_children()
|
||
|
||
def _sort_children(self) -> None:
|
||
"""根据优先级对子节点进行排序"""
|
||
# 创建索引列表并按优先级排序(优先级高的在前)
|
||
self.sorted_indices = list(range(len(self.children)))
|
||
self.sorted_indices.sort(key=lambda i: self.priorities[i], reverse=True)
|
||
|
||
def _execute_child_logic(self, blackboard) -> NodeStatus:
|
||
"""执行优先级选择节点逻辑"""
|
||
if not self.children:
|
||
return NodeStatus.FAILURE
|
||
|
||
# 按优先级顺序执行子节点
|
||
for i in self.sorted_indices:
|
||
child = self.children[i]
|
||
status = child.execute(blackboard)
|
||
|
||
if status == NodeStatus.RUNNING:
|
||
# 记录当前正在运行的子节点索引
|
||
self.current_child_index = i
|
||
self.memory.set('running_child_index', i)
|
||
return NodeStatus.RUNNING
|
||
elif status == NodeStatus.SUCCESS:
|
||
# 重置当前子节点索引
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
return NodeStatus.SUCCESS
|
||
# 如果是FAILURE,继续执行下一个优先级的子节点
|
||
|
||
# 所有子节点都失败
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1)
|
||
return NodeStatus.FAILURE
|
||
|
||
def reset(self) -> None:
|
||
"""重置节点状态"""
|
||
super().reset()
|
||
self.current_child_index = 0
|
||
self.memory.set('running_child_index', -1) |