EG/plugins/user/swarm_intelligence/advanced_boids.py
2025-12-12 16:16:15 +08:00

908 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
高级Boids算法实现
这是一个经过优化和扩展的Boids算法实现包含空间分区优化、并行计算和高级行为规则
"""
from panda3d.core import Vec3, Point3
import random
import math
from typing import List, Dict, Tuple, Optional
import concurrent.futures
import threading
from multiprocessing import Pool, cpu_count
import numpy as np
class SpatialPartition:
"""
空间分区类,用于优化邻居查找算法
将3D空间划分为网格每个网格格子存储位于其中的群体成员
"""
def __init__(self, bounds: Dict[str, float], cell_size: float):
"""
初始化空间分区
:param bounds: 边界信息 {'min_x': float, 'max_x': float, 'min_y': float, 'max_y': float, 'min_z': float, 'max_z': float}
:param cell_size: 网格格子大小
"""
self.bounds = bounds
self.cell_size = cell_size
# 计算网格尺寸
self.grid_width = int((bounds['max_x'] - bounds['min_x']) / cell_size) + 1
self.grid_height = int((bounds['max_y'] - bounds['min_y']) / cell_size) + 1
self.grid_depth = int((bounds['max_z'] - bounds['min_z']) / cell_size) + 1
# 创建网格
self.grid = {}
# 邻居格子偏移量
self.neighbor_offsets = [
(-1, -1, -1), (-1, -1, 0), (-1, -1, 1),
(-1, 0, -1), (-1, 0, 0), (-1, 0, 1),
(-1, 1, -1), (-1, 1, 0), (-1, 1, 1),
(0, -1, -1), (0, -1, 0), (0, -1, 1),
(0, 0, -1), (0, 0, 0), (0, 0, 1),
(0, 1, -1), (0, 1, 0), (0, 1, 1),
(1, -1, -1), (1, -1, 0), (1, -1, 1),
(1, 0, -1), (1, 0, 0), (1, 0, 1),
(1, 1, -1), (1, 1, 0), (1, 1, 1)
]
def get_grid_coords(self, position: Vec3) -> Tuple[int, int, int]:
"""
获取位置对应的网格坐标
:param position: 3D位置
:return: (x, y, z) 网格坐标
"""
grid_x = int((position.x - self.bounds['min_x']) / self.cell_size)
grid_y = int((position.y - self.bounds['min_y']) / self.cell_size)
grid_z = int((position.z - self.bounds['min_z']) / self.cell_size)
# 确保坐标在有效范围内
grid_x = max(0, min(grid_x, self.grid_width - 1))
grid_y = max(0, min(grid_y, self.grid_height - 1))
grid_z = max(0, min(grid_z, self.grid_depth - 1))
return (grid_x, grid_y, grid_z)
def add_member(self, member: Dict, position: Vec3):
"""
将成员添加到网格中
:param member: 成员对象
:param position: 成员位置
"""
grid_coords = self.get_grid_coords(position)
grid_key = f"{grid_coords[0]}_{grid_coords[1]}_{grid_coords[2]}"
if grid_key not in self.grid:
self.grid[grid_key] = []
self.grid[grid_key].append(member)
def remove_member(self, member: Dict, position: Vec3):
"""
从网格中移除成员
:param member: 成员对象
:param position: 成员位置
"""
grid_coords = self.get_grid_coords(position)
grid_key = f"{grid_coords[0]}_{grid_coords[1]}_{grid_coords[2]}"
if grid_key in self.grid:
if member in self.grid[grid_key]:
self.grid[grid_key].remove(member)
def update_member(self, member: Dict, old_position: Vec3, new_position: Vec3):
"""
更新成员在网格中的位置
:param member: 成员对象
:param old_position: 原位置
:param new_position: 新位置
"""
# 从旧位置移除
self.remove_member(member, old_position)
# 添加到新位置
self.add_member(member, new_position)
def get_neighbors(self, position: Vec3, radius: float) -> List[Dict]:
"""
获取指定位置周围的邻居
:param position: 中心位置
:param radius: 搜索半径
:return: 邻居列表
"""
neighbors = []
# 计算需要检查的网格格子范围
cells_to_check = int(radius / self.cell_size) + 1
center_coords = self.get_grid_coords(position)
# 检查中心格子及其邻居
for offset in self.neighbor_offsets:
grid_x = center_coords[0] + offset[0]
grid_y = center_coords[1] + offset[1]
grid_z = center_coords[2] + offset[2]
# 检查格子是否在有效范围内
if (0 <= grid_x < self.grid_width and
0 <= grid_y < self.grid_height and
0 <= grid_z < self.grid_depth):
grid_key = f"{grid_x}_{grid_y}_{grid_z}"
if grid_key in self.grid:
# 检查格子中的每个成员是否在半径范围内
for member in self.grid[grid_key]:
distance = (position - member['position']).length()
if distance <= radius and member not in neighbors:
neighbors.append(member)
return neighbors
def clear(self):
"""
清空网格
"""
self.grid.clear()
class PhysicsSimulator:
"""
物理模拟器,提供更真实的物理行为
"""
def __init__(self):
self.gravity = Vec3(0, 0, -9.81) # 重力加速度
self.air_resistance = 0.01 # 空气阻力系数
self.fluid_density = 1.225 # 空气密度 (kg/m^3)
def apply_gravity(self, member: Dict, dt: float):
"""
应用重力
"""
if 'mass' in member:
mass = member['mass']
else:
mass = 1.0 # 默认质量
gravity_force = self.gravity * mass
member['velocity'] += gravity_force * dt / mass
def apply_air_resistance(self, member: Dict, dt: float):
"""
应用空气阻力
"""
velocity = member['velocity']
speed = velocity.length()
if speed > 0:
# 空气阻力与速度平方成正比
drag_force_magnitude = 0.5 * self.fluid_density * speed * speed * 0.01 # 简化的阻力系数
drag_force = -velocity.normalized() * drag_force_magnitude
# 应用阻力
member['velocity'] += drag_force * dt
def apply_buoyancy(self, member: Dict, fluid_level: float, dt: float):
"""
应用浮力(适用于模拟水中生物)
"""
if member['position'].z < fluid_level:
# 计算浮力(简化模型)
buoyancy_force = Vec3(0, 0, 0.5) # 向上的浮力
member['velocity'] += buoyancy_force * dt
def update_physics(self, member: Dict, dt: float = 1.0/60.0):
"""
更新物理状态
"""
# 应用重力
self.apply_gravity(member, dt)
# 应用空气阻力
self.apply_air_resistance(member, dt)
class AdvancedBoidsAlgorithm:
"""
高级Boids算法实现类
包含空间分区优化、物理模拟和高级行为规则
"""
def __init__(self, config):
self.config = config
# 初始化空间分区系统
bounds = {
'min_x': -100, 'max_x': 100,
'min_y': -100, 'max_y': 100,
'min_z': -10, 'max_z': 100
}
cell_size = self.config.get('perception_radius', 10.0) # 使用感知半径作为网格大小
self.spatial_partition = SpatialPartition(bounds, cell_size)
# 初始化物理模拟器
self.physics_simulator = PhysicsSimulator()
# 用于记录性能统计
self.stats = {
'neighbor_lookups': 0,
'total_neighbors_found': 0,
'algorithm_calls': 0
}
def calculate_cohesion(self, member: Dict, neighbors: List[Dict]) -> Vec3:
"""
计算聚集力
个体向邻居群集中心移动
"""
if not neighbors:
return Vec3(0, 0, 0)
# 计算邻居的中心位置
center = Vec3(0, 0, 0)
for neighbor in neighbors:
center += neighbor['position']
center /= len(neighbors)
# 返回指向中心的向量
return (center - member['position']).normalized()
def calculate_separation(self, member: Dict, neighbors: List[Dict]) -> Vec3:
"""
计算分离力
个体避免与邻居过于靠近
"""
if not neighbors:
return Vec3(0, 0, 0)
separation = Vec3(0, 0, 0)
for neighbor in neighbors:
diff = member['position'] - neighbor['position']
distance = diff.length()
if distance > 0: # 避免除零
# 距离越近,分离力越大
separation += diff.normalized() / distance
if len(neighbors) > 0:
separation /= len(neighbors)
return separation.normalized() if separation.length() > 0 else separation
def calculate_alignment(self, member: Dict, neighbors: List[Dict]) -> Vec3:
"""
计算对齐力
个体与邻居的移动方向保持一致
"""
if not neighbors:
return Vec3(0, 0, 0)
# 计算邻居的平均速度方向
avg_velocity = Vec3(0, 0, 0)
for neighbor in neighbors:
avg_velocity += neighbor['velocity']
avg_velocity /= len(neighbors)
return avg_velocity.normalized()
def calculate_cohesion_with_weighted_neighbors(self, member: Dict, neighbors: List[Dict]) -> Vec3:
"""
计算加权聚集力
距离越近的邻居对聚集力的影响越大
"""
if not neighbors:
return Vec3(0, 0, 0)
weighted_center = Vec3(0, 0, 0)
total_weight = 0.0
for neighbor in neighbors:
distance = (member['position'] - neighbor['position']).length()
if distance > 0:
# 距离越近,权重越大
weight = 1.0 / distance
weighted_center += neighbor['position'] * weight
total_weight += weight
if total_weight > 0:
weighted_center /= total_weight
return (weighted_center - member['position']).normalized()
else:
return Vec3(0, 0, 0)
def calculate_separation_with_personal_space(self, member: Dict, neighbors: List[Dict]) -> Vec3:
"""
计算分离力,考虑个人空间
"""
if not neighbors:
return Vec3(0, 0, 0)
separation = Vec3(0, 0, 0)
personal_space = self.config.get("personal_space", 2.0) # 个人空间半径
for neighbor in neighbors:
diff = member['position'] - neighbor['position']
distance = diff.length()
if distance > 0 and distance < personal_space:
# 仅在邻居进入个人空间时才产生分离力
separation += diff.normalized() * (personal_space - distance) / personal_space
return separation.normalized() if separation.length() > 0 else separation
def calculate_obstacle_avoidance(self, member: Dict, obstacles: List[Dict]) -> Vec3:
"""
计算避障力
避开场景中的障碍物
"""
if not obstacles:
return Vec3(0, 0, 0)
avoidance = Vec3(0, 0, 0)
obstacle_radius = self.config.get("obstacle_radius", 5.0)
for obstacle in obstacles:
diff = member['position'] - obstacle['position']
distance = diff.length()
if distance < obstacle_radius + obstacle.get('radius', 1.0):
# 距离越近,避障力越大
force_magnitude = max(0, (obstacle_radius + obstacle.get('radius', 1.0)) - distance)
avoidance += diff.normalized() * force_magnitude
return avoidance.normalized() if avoidance.length() > 0 else avoidance
def calculate_seek(self, member: Dict, target: Vec3) -> Vec3:
"""
计算寻求力
向目标位置移动
"""
desired = target - member['position']
distance = desired.length()
if distance > 0:
desired.normalize()
# 如果启用了达到行为,当接近目标时减速
if self.config.get("arrival_enabled", False):
slow_radius = self.config.get("arrival_slow_radius", 5.0)
if distance < slow_radius:
desired *= self.config.get("max_speed", 5.0) * (distance / slow_radius)
else:
desired *= self.config.get("max_speed", 5.0)
else:
desired *= self.config.get("max_speed", 5.0)
# 计算转向力
steer = desired - member['velocity']
# 限制转向力
max_force = self.config.get("max_force", 0.5)
if steer.length() > max_force:
steer.normalize()
steer *= max_force
return steer
return Vec3(0, 0, 0)
def calculate_flee(self, member: Dict, target: Vec3) -> Vec3:
"""
计算逃离力
远离目标位置
"""
desired = member['position'] - target
distance = desired.length()
if distance > 0:
desired.normalize()
desired *= self.config.get("max_speed", 5.0)
# 计算转向力
steer = desired - member['velocity']
# 限制转向力
max_force = self.config.get("max_force", 0.5)
if steer.length() > max_force:
steer.normalize()
steer *= max_force
return steer
return Vec3(0, 0, 0)
def calculate_wander(self, member: Dict) -> Vec3:
"""
计算游走力
产生自然的随机移动
"""
# 获取游走参数
wander_radius = self.config.get("wander_radius", 5.0)
wander_distance = self.config.get("wander_distance", 10.0)
wander_jitter = self.config.get("wander_jitter", 1.0)
# 如果成员没有游走目标,创建一个
if 'wander_target' not in member:
member['wander_target'] = Vec3(
random.uniform(-1, 1),
random.uniform(-1, 1),
random.uniform(-1, 1)
).normalized() * wander_radius
# 添加随机扰动
jitter = Vec3(
random.uniform(-1, 1),
random.uniform(-1, 1),
random.uniform(-1, 1)
).normalized() * wander_jitter
member['wander_target'] += jitter
member['wander_target'].normalize()
member['wander_target'] *= wander_radius
# 计算游走目标在前方的距离
target = member['wander_target'] + Vec3(0, 0, 0) # 相对于成员前方的点
target *= wander_distance
# 转换到世界坐标系
# 这里简化处理,实际应用中需要考虑成员的朝向
desired = target - member['position']
if desired.length() > 0:
desired.normalize()
desired *= self.config.get("max_speed", 5.0)
# 计算转向力
steer = desired - member['velocity']
# 限制转向力
max_force = self.config.get("max_force", 0.5)
if steer.length() > max_force:
steer.normalize()
steer *= max_force
return steer
return Vec3(0, 0, 0)
def calculate_boundaries(self, member: Dict, bounds: Dict[str, float]) -> Vec3:
"""
计算边界力
将成员保持在指定区域内
"""
if not bounds:
return Vec3(0, 0, 0)
# 获取边界参数
min_x, max_x = bounds['min_x'], bounds['max_x']
min_y, max_y = bounds['min_y'], bounds['max_y']
min_z, max_z = bounds['min_z'], bounds['max_z']
# 边界缓冲距离
buffer = self.config.get("boundary_buffer", 5.0)
# 计算边界力
bound_force = Vec3(0, 0, 0)
if member['position'].x < min_x + buffer:
bound_force.x = 1.0
elif member['position'].x > max_x - buffer:
bound_force.x = -1.0
if member['position'].y < min_y + buffer:
bound_force.y = 1.0
elif member['position'].y > max_y - buffer:
bound_force.y = -1.0
if member['position'].z < min_z + buffer:
bound_force.z = 1.0
elif member['position'].z > max_z - buffer:
bound_force.z = -1.0
# 应用边界权重
bound_force *= self.config.get("boundary_weight", 2.0)
return bound_force
def calculate_follow_path(self, member: Dict, path: List[Vec3]) -> Vec3:
"""
计算路径跟随力
沿着预定义路径移动
"""
if not path or len(path) < 2:
return Vec3(0, 0, 0)
# 获取路径参数
path_radius = self.config.get("path_radius", 3.0)
# 找到最近的路径点
nearest_index = 0
nearest_distance = float('inf')
for i, point in enumerate(path):
distance = (member['position'] - point).length()
if distance < nearest_distance:
nearest_distance = distance
nearest_index = i
# 预测位置
prediction = member['velocity'].normalized()
prediction *= self.config.get("prediction_distance", 10.0)
predict_pos = member['position'] + prediction
# 找到预测位置最近的路径点
predict_nearest_index = 0
predict_nearest_distance = float('inf')
for i, point in enumerate(path):
distance = (predict_pos - point).length()
if distance < predict_nearest_distance:
predict_nearest_distance = distance
predict_nearest_index = i
# 如果偏离路径太远,直接寻求最近的路径点
if nearest_distance > path_radius:
target = path[nearest_index]
return self.calculate_seek(member, target)
# 否则寻求预测位置前方的路径点
if predict_nearest_index < len(path) - 1:
target = path[predict_nearest_index + 1]
return self.calculate_seek(member, target)
return Vec3(0, 0, 0)
def calculate_flock_formation(self, member: Dict, neighbors: List[Dict], formation_type: str) -> Vec3:
"""
计算队形保持力
维持特定的群体队形
"""
if not neighbors:
return Vec3(0, 0, 0)
# 根据队形类型计算期望位置
target_position = Vec3(0, 0, 0)
if formation_type == "V":
# V字形队形
# 简化实现,实际应用中需要更复杂的计算
if 'index' in member:
index = member['index']
angle = math.radians(index * 15) # 每个成员间隔15度
distance = 5.0 # 固定间距
target_position = Vec3(
math.cos(angle) * distance,
math.sin(angle) * distance,
0
)
elif formation_type == "line":
# 直线队形
if 'index' in member:
index = member['index']
spacing = 3.0
target_position = Vec3(0, index * spacing, 0)
elif formation_type == "circle":
# 圆形队形
if 'index' in member:
index = member['index']
total = len(neighbors) + 1
angle = 2 * math.pi * index / total
radius = 10.0
target_position = Vec3(
math.cos(angle) * radius,
math.sin(angle) * radius,
0
)
elif formation_type == "sphere":
# 球形队形
if 'index' in member:
index = member['index']
total = len(neighbors) + 1
# 使用斐波那契螺旋法在球面上分布点
phi = math.pi * (3 - math.sqrt(5)) # 金角
y = 1 - (index / float(total - 1)) * 2 # y: 1 to -1
radius = math.sqrt(1 - y * y) # 半径在y处
theta = phi * index # 金角旋转
x = math.cos(theta) * radius
z = math.sin(theta) * radius
target_position = Vec3(x * 10, y * 10, z * 10) # 缩放到合适大小
# 计算到达期望位置的力
if target_position.length() > 0:
desired = target_position - member['position']
if desired.length() > 0:
desired.normalize()
desired *= self.config.get("max_speed", 5.0)
# 计算转向力
steer = desired - member['velocity']
# 应用队形权重
steer *= self.config.get("formation_weight", 1.0)
# 限制转向力
max_force = self.config.get("max_force", 0.5)
if steer.length() > max_force:
steer.normalize()
steer *= max_force
return steer
return Vec3(0, 0, 0)
def calculate_avoid_predator(self, member: Dict, predators: List[Dict]) -> Vec3:
"""
计算躲避捕食者力
远离捕食者
"""
if not predators:
return Vec3(0, 0, 0)
avoidance = Vec3(0, 0, 0)
predator_radius = self.config.get("predator_radius", 15.0)
for predator in predators:
diff = member['position'] - predator['position']
distance = diff.length()
if distance < predator_radius:
# 距离越近,躲避力越大
force = diff.normalized() / (distance / predator_radius)
avoidance += force
# 应用捕食者躲避权重
avoidance *= self.config.get("predator_avoid_weight", 3.0)
return avoidance.normalized() if avoidance.length() > 0 else avoidance
def calculate_cohesion_with_predator_awareness(self, member: Dict, neighbors: List[Dict], predators: List[Dict]) -> Vec3:
"""
计算受捕食者影响的聚集力
在有捕食者时,成员会更紧密地聚集
"""
if not neighbors:
return Vec3(0, 0, 0)
# 检查是否有捕食者在附近
has_nearby_predator = False
for predator in predators:
if (member['position'] - predator['position']).length() < self.config.get("predator_radius", 15.0):
has_nearby_predator = True
break
# 如果有捕食者在附近,增加聚集权重
target_factor = 1.0
if has_nearby_predator:
target_factor = 2.0
# 计算邻居的中心位置
center = Vec3(0, 0, 0)
for neighbor in neighbors:
center += neighbor['position']
center /= len(neighbors)
# 返回指向中心的向量
return (center - member['position']).normalized() * target_factor
def calculate_formation_with_obstacle_avoidance(self, member: Dict, neighbors: List[Dict], formation_type: str, obstacles: List[Dict]) -> Vec3:
"""
计算考虑障碍物的队形保持力
在保持队形的同时避开障碍物
"""
# 首先计算队形力
formation_force = self.calculate_flock_formation(member, neighbors, formation_type)
# 然后计算避障力
obstacle_force = self.calculate_obstacle_avoidance(member, obstacles)
# 结合两种力量
combined_force = formation_force + obstacle_force * 1.5 # 避障力权重稍大
return combined_force.normalized() * formation_force.length()
def update_member(self, member: Dict, neighbors: List[Dict], obstacles: List[Dict] = [],
target: Vec3 = None, bounds: Dict[str, float] = None,
path: List[Vec3] = None, predators: List[Dict] = [],
interaction_forces: Dict = None, environment_force: Vec3 = None):
"""
更新群体成员状态
计算所有作用力并更新成员的位置和速度
"""
# 记录统计信息
self.stats['algorithm_calls'] += 1
self.stats['total_neighbors_found'] += len(neighbors)
# 计算基本Boids力
cohesion_force = self.calculate_cohesion_with_weighted_neighbors(member, neighbors)
separation_force = self.calculate_separation_with_personal_space(member, neighbors)
alignment_force = self.calculate_alignment(member, neighbors)
# 计算扩展力
obstacle_force = self.calculate_obstacle_avoidance(member, obstacles)
bound_force = self.calculate_boundaries(member, bounds)
predator_force = self.calculate_avoid_predator(member, predators)
# 计算目标导向力
seek_force = Vec3(0, 0, 0)
if target:
seek_force = self.calculate_seek(member, target)
# 计算路径跟随力
path_force = Vec3(0, 0, 0)
if path:
path_force = self.calculate_follow_path(member, path)
# 计算游走力
wander_force = Vec3(0, 0, 0)
if self.config.get("wander_enabled", False):
wander_force = self.calculate_wander(member)
# 计算受捕食者影响的聚集力
predator_aware_cohesion = self.calculate_cohesion_with_predator_awareness(member, neighbors, predators)
# 计算群体间交互力
interaction_force = Vec3(0, 0, 0)
if interaction_forces:
# 合作力
if 'cooperative_force' in interaction_forces:
interaction_force += interaction_forces['cooperative_force']
# 竞争力
if 'competition_force' in interaction_forces:
interaction_force += interaction_forces['competition_force']
# 捕食力
if 'predation_force' in interaction_forces:
interaction_force += interaction_forces['predation_force']
# 逃避力
if 'escape_force' in interaction_forces:
interaction_force += interaction_forces['escape_force']
# 应用权重
cohesion_force *= self.config.get("cohesion_weight", 1.0)
separation_force *= self.config.get("separation_weight", 1.5)
alignment_force *= self.config.get("alignment_weight", 1.0)
obstacle_force *= self.config.get("obstacle_weight", 2.0)
seek_force *= self.config.get("seek_weight", 1.0)
bound_force *= self.config.get("boundary_weight", 2.0)
path_force *= self.config.get("path_weight", 1.0)
wander_force *= self.config.get("wander_weight", 0.5)
predator_force *= self.config.get("predator_avoid_weight", 3.0)
predator_aware_cohesion *= self.config.get("cohesion_weight", 1.0) * 0.5 # 给予一半权重
interaction_force *= self.config.get("interaction_weight", 1.0) # 交互力权重
# 计算总加速度
acceleration = (
cohesion_force +
separation_force +
alignment_force +
obstacle_force +
seek_force +
bound_force +
path_force +
wander_force +
predator_force +
predator_aware_cohesion +
interaction_force
)
# 添加环境力
if environment_force:
acceleration += environment_force * 0.1 # 环境力权重
# 应用加速度限制
max_acceleration = self.config.get("max_acceleration", 1.0)
if acceleration.length() > max_acceleration:
acceleration.normalize()
acceleration *= max_acceleration
# 应用物理模拟
dt = 1.0 / 60.0 # 时间步长
self.physics_simulator.update_physics(member, dt)
# 更新速度和位置
member['velocity'] += acceleration
# 限制最大速度
max_speed = self.config.get("max_speed", 5.0)
if member['velocity'].length() > max_speed:
member['velocity'].normalize()
member['velocity'] *= max_speed
member['position'] += member['velocity']
# 更新成员的朝向(使其面向移动方向)
if member['velocity'].length() > 0 and 'node' in member:
# 简化实现,实际应用中可能需要更复杂的朝向计算
member['node'].lookAt(member['position'] + member['velocity'])
def update_swarm_parallel(self, members: List[Dict], all_members: List[Dict],
obstacles: List[Dict] = [], target: Vec3 = None,
bounds: Dict[str, float] = None, path: List[Vec3] = None,
predators: List[Dict] = [], interaction_forces: Dict = None,
environment_forces: Dict = None):
"""
并行更新群体成员状态
使用多线程或多进程来并行计算每个成员的更新
"""
if not self.config.get("spatial_partitioning", False):
# 如果没有使用空间分区,则无法高效并行计算邻居
# 回退到串行计算
for member in members:
neighbors = self._find_neighbors_serial(member, all_members)
member_interactions = interaction_forces.get(id(member), {}) if interaction_forces else {}
member_environment_force = environment_forces.get(id(member), Vec3(0, 0, 0)) if environment_forces else Vec3(0, 0, 0)
self.update_member(member, neighbors, obstacles, target, bounds, path, predators, member_interactions, member_environment_force)
return
# 如果启用了空间分区,则可以使用并行计算
use_threading = True # 对于I/O密集型任务使用线程更好
# 如果需要处理大量成员,可以使用多线程
if use_threading:
# 使用ThreadPoolExecutor进行并行计算
num_threads = min(len(members), cpu_count())
# 创建线程池
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交任务
futures = []
for member in members:
# 获取邻居(这一步可能也需要优化)
neighbors = self.spatial_partition.get_neighbors(member['position'],
self.config.get("perception_radius", 10.0))
member_interactions = interaction_forces.get(id(member), {}) if interaction_forces else {}
member_environment_force = environment_forces.get(id(member), Vec3(0, 0, 0)) if environment_forces else Vec3(0, 0, 0)
future = executor.submit(self._update_member_task, member, neighbors, obstacles,
target, bounds, path, predators, member_interactions, member_environment_force)
futures.append(future)
# 等待所有任务完成
for future in concurrent.futures.as_completed(futures):
# 获取结果(如果需要的话)
result = future.result()
else:
# 使用多进程进行并行计算
num_processes = min(len(members), cpu_count())
# 准备参数
tasks = []
for member in members:
neighbors = self.spatial_partition.get_neighbors(member['position'],
self.config.get("perception_radius", 10.0))
member_interactions = interaction_forces.get(id(member), {}) if interaction_forces else {}
member_environment_force = environment_forces.get(id(member), Vec3(0, 0, 0)) if environment_forces else Vec3(0, 0, 0)
task_args = (member, neighbors, obstacles, target, bounds, path, predators, member_interactions, member_environment_force)
tasks.append(task_args)
# 使用多进程池
with Pool(processes=num_processes) as pool:
pool.starmap(self._update_member_task, tasks)
def _update_member_task(self, member: Dict, neighbors: List[Dict], obstacles: List[Dict],
target: Vec3, bounds: Dict[str, float], path: List[Vec3],
predators: List[Dict], interaction_forces: Dict = None,
environment_force: Vec3 = None):
"""
用于并行任务的成员更新方法
"""
self.update_member(member, neighbors, obstacles, target, bounds, path, predators, interaction_forces, environment_force)
def _find_neighbors_serial(self, member: Dict, all_members: List[Dict]) -> List[Dict]:
"""
串行查找邻居(当不使用空间分区时)
"""
neighbors = []
perception_radius = self.config.get("perception_radius", 10.0)
for other in all_members:
if other != member: # 排除自己
distance = (member['position'] - other['position']).length()
if distance < perception_radius:
neighbors.append(other)
return neighbors
def get_stats(self) -> Dict[str, int]:
"""
获取性能统计信息
"""
return self.stats.copy()
def clear_stats(self):
"""
清空性能统计信息
"""
self.stats = {
'neighbor_lookups': 0,
'total_neighbors_found': 0,
'algorithm_calls': 0
}