EG/ssbo_component/ssbo_controller.py
2026-02-25 11:49:31 +08:00

551 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from panda3d.core import (
GeomVertexFormat, GeomVertexWriter, GeomVertexReader, GeomVertexRewriter,
InternalName, Vec3, Vec4, LMatrix4f, ShaderBuffer, GeomEnums,
BoundingSphere, NodePath, GeomNode, Texture, SamplerState,
Point3, BoundingBox, Quat
)
import struct
import time
class ObjectController:
"""
物体控制器 (No Custom Shader Mode)
====================================
Uses RP's default rendering (no rp.set_effect) for maximum FPS.
Vertex colors baked for picking. Movement modifies vertex data directly.
Stores original vertex positions per object for rotation/translation.
"""
def __init__(self):
self.name_to_ids = {}
self.id_to_name = {}
self.key_to_node = {}
self.node_list = []
self.display_names = {}
self.global_transforms = [] # Original transforms (for center/position)
self.id_to_chunk = {} # global_id -> (chunk_key, local_idx)
self.chunks = {} # chunk_key -> dict with 'node' key
# Vertex index: local_id -> list of (geom_node_np, geom_idx, [row_indices])
self.vertex_index = {}
# Original vertex positions: local_id -> list of (Vec3,) matching row order
self.original_positions = {}
# Current position offsets: local_id -> Vec3 delta
self.position_offsets = {}
self.local_to_global_id = {}
self.local_transform_state = {}
self.local_transform_base_positions = {}
self.virtual_tree = None
self.virtual_tree_meta = None
self.model = None
self.chunk_node = None # Single chunk node
def bake_ids_and_collect(self, model):
"""
Bake IDs into vertex colors, flatten, then build vertex index.
NO transform reset — vertices keep world-space positions.
NO SSBO — uses RP default rendering.
"""
t0 = time.time()
geom_nodes = list(model.find_all_matches("**/+GeomNode"))
print(f"[控制器] 找到 {len(geom_nodes)} 个 GeomNode")
self.name_to_ids = {}
self.id_to_name = {}
self.key_to_node = {}
self.node_list = []
self.display_names = {}
self.global_transforms = []
self.id_to_chunk = {}
self.chunks = {}
self.vertex_index = {}
self.original_positions = {}
self.position_offsets = {}
self.local_to_global_id = {}
self.local_transform_state = {}
self.local_transform_base_positions = {}
self.virtual_tree = None
self.virtual_tree_meta = None
global_id_counter = 0
chunk_key = model.get_name() or "default"
# No chunk wrapper — flatten directly on model (same as load_jyc_flatten.py)
self.chunk_node = model
self.chunks[chunk_key] = {'node': model, 'base_id': 0}
# Flatten hierarchy
for np in geom_nodes:
np.wrt_reparent_to(model)
local_idx = 0
for np in geom_nodes:
gnode = np.node()
if gnode.get_num_parents() > 1:
parent = np.get_parent()
if not parent.is_empty():
new_np = np.copy_to(parent)
np.detach_node()
np = new_np
gnode = np.node()
unique_key = str(np)
display_name = np.get_name() or f"Object_{global_id_counter}"
if unique_key not in self.name_to_ids:
self.name_to_ids[unique_key] = []
self.key_to_node[unique_key] = np
self.node_list.append(unique_key)
self.display_names[unique_key] = display_name
# Save original transform
mat_double = np.get_mat()
original_transform = LMatrix4f(mat_double)
for i in range(gnode.get_num_geoms()):
geom = gnode.modify_geom(i)
vdata = geom.modify_vertex_data()
if not vdata.has_column("color"):
new_format = vdata.get_format().get_union_format(GeomVertexFormat.get_v3c4())
vdata.set_format(new_format)
# Encode Local ID in R/G
low = local_idx % 256
high = local_idx // 256
r = low / 255.0
g = high / 255.0
writer = GeomVertexWriter(vdata, InternalName.make("color"))
for row in range(vdata.get_num_rows()):
writer.set_row(row)
writer.set_data4f(r, g, 0.0, 1.0)
self.global_transforms.append(original_transform)
self.id_to_chunk[global_id_counter] = (chunk_key, local_idx)
self.name_to_ids[unique_key].append(global_id_counter)
self.id_to_name[global_id_counter] = unique_key
self.local_to_global_id[local_idx] = global_id_counter
self.position_offsets[local_idx] = Vec3(0, 0, 0)
global_id_counter += 1
local_idx += 1
# DO NOT reset transform — keep world-space positions
# Flatten directly on model — NO set_final, allows per-geom frustum culling
model.flatten_strong()
t1 = time.time()
print(f"[控制器] Flatten took {(t1-t0)*1000:.0f}ms")
# Build vertex index AFTER flatten
self._build_vertex_index(model)
self._init_local_transform_state()
self.build_virtual_hierarchy()
t2 = time.time()
print(f"[控制器] Vertex index built in {(t2-t1)*1000:.0f}ms, "
f"{len(self.vertex_index)} unique IDs indexed")
self.model = model
self.node_list.sort()
return global_id_counter
def build_virtual_hierarchy(self):
"""Build a readonly virtual tree from node_list path keys."""
root = {
"name": "",
"path": "",
"children": {},
"leaf_key": None,
"display_name": "",
}
max_depth = 0
leaf_count = 0
for key in self.node_list:
if not key:
continue
parts = [p for p in str(key).split("/") if p]
if not parts:
continue
max_depth = max(max_depth, len(parts))
cursor = root
path_acc = ""
for i, part in enumerate(parts):
path_acc = f"{path_acc}/{part}" if path_acc else part
child = cursor["children"].get(part)
if child is None:
child = {
"name": part,
"path": path_acc,
"children": {},
"leaf_key": None,
"display_name": part,
}
cursor["children"][part] = child
cursor = child
if i == len(parts) - 1:
cursor["leaf_key"] = key
cursor["display_name"] = self.display_names.get(key, part)
leaf_count += 1
self.virtual_tree = root
self.virtual_tree_meta = {"max_depth": max_depth, "leaf_count": leaf_count}
return root
def get_virtual_hierarchy(self):
"""Return cached virtual tree; build on demand."""
if self.virtual_tree is None:
return self.build_virtual_hierarchy()
return self.virtual_tree
def _build_vertex_index(self, chunk_root):
"""
After flatten, batch-read all vertex data with numpy to build:
local_id -> [(geom_node_np, geom_idx, row_indices_array)]
Also stores original vertex positions per object (as numpy arrays).
"""
import numpy as np
for gn_np in chunk_root.find_all_matches("**/+GeomNode"):
gnode = gn_np.node()
for gi in range(gnode.get_num_geoms()):
geom = gnode.get_geom(gi)
vdata = geom.get_vertex_data()
num_rows = vdata.get_num_rows()
if num_rows == 0:
continue
# Find vertex and color column info
fmt = vdata.get_format()
# Get position column
pos_col = fmt.get_column(InternalName.get_vertex())
if pos_col is None:
continue
pos_array_idx = fmt.get_array_with(InternalName.get_vertex())
pos_start = pos_col.get_start()
# Get color column
color_col = fmt.get_column(InternalName.make("color"))
if color_col is None:
continue
color_array_idx = fmt.get_array_with(InternalName.make("color"))
color_start = color_col.get_start()
# Read raw position array
pos_array_format = fmt.get_array(pos_array_idx)
pos_stride = pos_array_format.get_stride()
pos_handle = vdata.get_array(pos_array_idx).get_handle()
pos_raw = bytes(pos_handle.get_data())
pos_buf = np.frombuffer(pos_raw, dtype=np.uint8).reshape(num_rows, pos_stride)
# Extract xyz positions (3 floats starting at pos_start)
positions = np.ndarray((num_rows, 3), dtype=np.float32,
buffer=pos_buf[:, pos_start:pos_start+12].tobytes())
# Read raw color array
color_array_format = fmt.get_array(color_array_idx)
color_stride = color_array_format.get_stride()
if color_array_idx == pos_array_idx:
color_buf = pos_buf
else:
color_handle = vdata.get_array(color_array_idx).get_handle()
color_raw = bytes(color_handle.get_data())
color_buf = np.frombuffer(color_raw, dtype=np.uint8).reshape(num_rows, color_stride)
# Decode color format to get ID
# Color can be stored as float32 RGBA or unorm8 RGBA
num_components = color_col.get_num_components()
component_bytes = color_col.get_component_bytes()
if component_bytes == 4: # float32 per component
color_data = np.ndarray((num_rows, num_components), dtype=np.float32,
buffer=color_buf[:, color_start:color_start+num_components*4].tobytes())
r_vals = (color_data[:, 0] * 255.0 + 0.5).astype(np.int32)
g_vals = (color_data[:, 1] * 255.0 + 0.5).astype(np.int32)
elif component_bytes == 1: # uint8 per component
color_bytes = color_buf[:, color_start:color_start+num_components].copy()
r_vals = color_bytes[:, 0].astype(np.int32)
g_vals = color_bytes[:, 1].astype(np.int32)
else:
# Fallback: skip this geom
continue
local_ids = r_vals + (g_vals << 8)
# Group rows by local_id using argsort (O(N log N) instead of O(N×K))
sort_idx = np.argsort(local_ids)
sorted_ids = local_ids[sort_idx]
sorted_positions = positions[sort_idx]
# Find group boundaries
boundaries = np.where(np.diff(sorted_ids) != 0)[0] + 1
# Split into groups
id_groups = np.split(sort_idx, boundaries)
pos_groups = np.split(sorted_positions, boundaries)
group_ids = sorted_ids[np.concatenate([[0], boundaries])]
for k in range(len(group_ids)):
uid = int(group_ids[k])
rows = id_groups[k]
pos = pos_groups[k]
if uid not in self.vertex_index:
self.vertex_index[uid] = []
self.original_positions[uid] = []
self.vertex_index[uid].append((gn_np, gi, rows))
self.original_positions[uid].append(pos.copy())
def _init_local_transform_state(self):
"""Initialize transform state for each local_idx after vertex index is ready."""
self.local_transform_state = {}
self.local_transform_base_positions = {}
for local_idx in self.vertex_index.keys():
self.local_transform_base_positions[local_idx] = self.original_positions.get(local_idx, [])
self.local_transform_state[local_idx] = {
"offset": Vec3(0, 0, 0),
"quat": Quat.identQuat(),
"scale": Vec3(1, 1, 1),
"pivot": self.get_local_pivot(local_idx),
}
def get_local_indices_from_global_ids(self, global_ids):
"""Map global ids to unique local indices."""
local_indices = []
if not global_ids:
return local_indices
seen = set()
for global_id in global_ids:
mapping = self.id_to_chunk.get(global_id)
if not mapping:
continue
_, local_idx = mapping
if local_idx in seen:
continue
if local_idx not in self.vertex_index:
continue
seen.add(local_idx)
local_indices.append(local_idx)
return local_indices
def get_local_pivot(self, local_idx):
"""Get pivot for one local object (world-space center)."""
global_id = self.local_to_global_id.get(local_idx)
if global_id is None:
return Vec3(0, 0, 0)
return self.get_object_center(global_id)
def get_selection_center(self, local_indices):
"""Get center point for a multi-object selection."""
if not local_indices:
return Vec3(0, 0, 0)
acc = Vec3(0, 0, 0)
valid = 0
for local_idx in local_indices:
state = self.local_transform_state.get(local_idx)
if not state:
continue
acc += state.get("pivot", Vec3(0, 0, 0)) + state.get("offset", Vec3(0, 0, 0))
valid += 1
if valid == 0:
return Vec3(0, 0, 0)
return acc / float(valid)
def begin_transform_session(self, local_indices):
"""Create immutable baseline snapshot for one gizmo drag session."""
if not local_indices:
return {"locals": {}}
locals_snapshot = {}
for local_idx in local_indices:
base_state = self.local_transform_state.get(local_idx)
if not base_state:
continue
entries = self.vertex_index.get(local_idx, [])
base_positions = self.local_transform_base_positions.get(local_idx, [])
locals_snapshot[local_idx] = {
"offset": Vec3(base_state["offset"]),
"quat": Quat(base_state["quat"]),
"scale": Vec3(base_state["scale"]),
"pivot": Vec3(base_state["pivot"]),
"entries": entries,
"base_positions": base_positions,
}
return {"locals": locals_snapshot}
def apply_transform_session(self, snapshot, delta_pos, delta_quat, delta_scale):
"""Apply transform delta to all local indices in snapshot and rewrite vertices."""
import numpy as np
if not snapshot or "locals" not in snapshot:
return
if delta_pos is None:
delta_pos = Vec3(0, 0, 0)
if delta_quat is None:
delta_quat = Quat.identQuat()
if delta_scale is None:
delta_scale = Vec3(1, 1, 1)
dscale = np.array([delta_scale.x, delta_scale.y, delta_scale.z], dtype=np.float32)
dpos = np.array([delta_pos.x, delta_pos.y, delta_pos.z], dtype=np.float32)
for local_idx, local_data in snapshot["locals"].items():
base_offset = local_data["offset"]
base_quat = local_data["quat"]
base_scale = local_data["scale"]
pivot = local_data["pivot"]
final_offset = Vec3(base_offset) + delta_pos
final_quat = Quat(delta_quat * base_quat)
final_scale = Vec3(
base_scale.x * delta_scale.x,
base_scale.y * delta_scale.y,
base_scale.z * delta_scale.z,
)
rot_mat = self._quat_to_np_mat3(final_quat)
self.local_transform_state[local_idx]["offset"] = final_offset
self.local_transform_state[local_idx]["quat"] = final_quat
self.local_transform_state[local_idx]["scale"] = final_scale
self.position_offsets[local_idx] = final_offset
pivot_np = np.array([pivot.x, pivot.y, pivot.z], dtype=np.float32)
base_s = np.array([base_scale.x, base_scale.y, base_scale.z], dtype=np.float32)
total_scale = base_s * dscale
total_offset = np.array([base_offset.x, base_offset.y, base_offset.z], dtype=np.float32) + dpos
entries = local_data["entries"]
base_positions = local_data["base_positions"]
for i, (gn_np, gi, rows) in enumerate(entries):
if i >= len(base_positions):
continue
orig_pos = base_positions[i]
if orig_pos is None or len(orig_pos) == 0:
continue
centered = orig_pos - pivot_np
scaled = centered * total_scale
rotated = scaled @ rot_mat.T
new_pos = rotated + pivot_np + total_offset
gnode = gn_np.node()
geom = gnode.modify_geom(gi)
vdata = geom.modify_vertex_data()
writer = GeomVertexWriter(vdata, "vertex")
for j in range(len(rows)):
writer.set_row(int(rows[j]))
writer.set_data3f(float(new_pos[j, 0]), float(new_pos[j, 1]), float(new_pos[j, 2]))
def _quat_to_np_mat3(self, quat):
"""Convert Panda3D Quat to 3x3 numpy rotation matrix."""
import numpy as np
q = Quat(quat)
q.normalize()
w = float(q.getR())
x = float(q.getI())
y = float(q.getJ())
z = float(q.getK())
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
wx = w * x
wy = w * y
wz = w * z
return np.array([
[1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)],
[2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)],
[2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)],
], dtype=np.float32)
def create_ssbo(self):
"""No SSBO needed — using RP default rendering."""
return None
def move_object(self, global_id, delta):
"""
Move an object by modifying vertex positions directly.
delta: Vec3 translation to apply.
Uses numpy for batch vertex updates.
"""
import numpy as np
if global_id not in self.id_to_chunk:
return
_, local_idx = self.id_to_chunk[global_id]
if local_idx not in self.vertex_index:
return
# Accumulate offset
self.position_offsets[local_idx] = self.position_offsets.get(local_idx, Vec3(0)) + delta
offset = self.position_offsets[local_idx]
offset_arr = np.array([offset.x, offset.y, offset.z], dtype=np.float32)
# Update each (geom_node, geom_idx, rows) group
entries = self.vertex_index[local_idx]
originals = self.original_positions[local_idx]
for i, (gn_np, gi, rows) in enumerate(entries):
orig_pos = originals[i] # numpy array (N, 3)
new_pos = orig_pos + offset_arr # vectorized add
gnode = gn_np.node()
geom = gnode.modify_geom(gi)
vdata = geom.modify_vertex_data()
writer = GeomVertexWriter(vdata, "vertex")
for j in range(len(rows)):
writer.set_row(int(rows[j]))
writer.set_data3f(float(new_pos[j, 0]), float(new_pos[j, 1]), float(new_pos[j, 2]))
def get_world_pos(self, global_id):
"""Get current world position of an object."""
if global_id not in self.id_to_chunk:
return Vec3(0, 0, 0)
_, local_idx = self.id_to_chunk[global_id]
original_mat = self.global_transforms[global_id]
original_pos = original_mat.get_row3(3)
offset = self.position_offsets.get(local_idx, Vec3(0))
return Vec3(original_pos) + offset
def get_object_center(self, global_id):
"""Get the original center position of an object (for rotation pivot)."""
if global_id >= len(self.global_transforms):
return Vec3(0, 0, 0)
mat = self.global_transforms[global_id]
return Vec3(mat.get_row3(3))
def get_transform(self, global_id):
"""Get original transform."""
if global_id >= len(self.global_transforms):
return LMatrix4f.ident_mat()
return self.global_transforms[global_id]
@property
def transforms(self):
return self.global_transforms