WebGPU RT: dynamic TLAS sweep-tree depth (next_pow2 instances)

The LBVH bitonic sort still runs over the full 16384 (sentinels sink to
the tail), but the sweep tree is now built and traced at depth
log2(next_pow2(nReal)) instead of a fixed 14. Add nPadded to LbvhPC; leaf
init + bottom-up refit use it; the host passes the same next_pow2 to the
trace via WfParams.tlasNPadded. Renders correctly at 512 instances
(depth 9). The fragile sort phases are untouched.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
catbot 2026-05-31 20:28:12 +00:00
commit afc0292fab

View file

@ -2281,7 +2281,11 @@ struct BvhNode {
// runtime resize-on-grow caused subtle BVH corruption (driver-level // runtime resize-on-grow caused subtle BVH corruption (driver-level
// memory recycling, suspected) and was the root cause of mid-game // memory recycling, suspected) and was the root cause of mid-game
// geometry flicker when projectiles entered the TLAS. // geometry flicker when projectiles entered the TLAS.
struct LbvhPC { nReal: u32, _pad0: u32, _pad1: u32, _pad2: u32 }; // nPadded = next_pow2(max(nReal,1)), supplied by the host. The bitonic
// sort still runs over the full N_PADDED (sentinels sink to the tail), but
// the sweep tree is built (and traced) at depth log2(nPadded) so descent
// tracks the real instance count instead of a fixed 14.
struct LbvhPC { nReal: u32, nPadded: u32, _pad1: u32, _pad2: u32 };
@group(0) @binding(5) var<uniform> lbvhPc : LbvhPC; @group(0) @binding(5) var<uniform> lbvhPc : LbvhPC;
const N_PADDED: u32 = 16384u; const N_PADDED: u32 = 16384u;
@ -2436,9 +2440,14 @@ fn lbvhBuildMain(@builtin(local_invocation_id) lid: vec3<u32>) {
storageBarrier(); storageBarrier();
// ── Phase 4: initialize BVH leaf AABBs ─────────────────────────────── // ── Phase 4: initialize BVH leaf AABBs ───────────────────────────────
for (var k: u32 = 0u; k < K_PER; k = k + 1u) { // Only the first nPadded sorted slots become leaves of the (smaller)
// sweep tree; reals occupy [0,nReal), the rest sink as sentinels.
let nPadded = max(lbvhPc.nPadded, 1u);
let leafPerThread = (nPadded + THREADS - 1u) / THREADS;
for (var k: u32 = 0u; k < leafPerThread; k = k + 1u) {
let i = k * THREADS + tid; let i = k * THREADS + tid;
let leafIdx = N_PADDED - 1u + i; if (i < nPadded) {
let leafIdx = nPadded - 1u + i;
let leafKey = sortA[i]; let leafKey = sortA[i];
if (leafKey == 0xFFFFFFFFu) { if (leafKey == 0xFFFFFFFFu) {
outBvh[leafIdx].aabbMin = vec3<f32>( 1e30); outBvh[leafIdx].aabbMin = vec3<f32>( 1e30);
@ -2449,16 +2458,18 @@ fn lbvhBuildMain(@builtin(local_invocation_id) lid: vec3<u32>) {
outBvh[leafIdx].aabbMax = e.aabbMax; outBvh[leafIdx].aabbMax = e.aabbMax;
} }
} }
}
workgroupBarrier(); workgroupBarrier();
storageBarrier(); storageBarrier();
// ── Phase 5: bottom-up sweep-tree refit, LEVELS iterations ────────── // ── Phase 5: bottom-up sweep-tree refit, log2(nPadded) levels ───────
// Deepest internal level has N_PADDED/2 nodes; perThread = ceil of // Deepest internal level has nPadded/2 nodes. The loop bound is uniform
// levelCount / THREADS is uniform per step, so workgroupBarrier // across the workgroup (depends only on nPadded), so the barriers stay
// stays in uniform control flow. // in uniform control flow.
var levelCount: u32 = N_PADDED / 2u; var levelCount: u32 = nPadded / 2u;
var levelStart: u32 = N_PADDED / 2u - 1u; var levelStart: u32 = nPadded / 2u - 1u;
for (var step: u32 = 0u; step < LEVELS; step = step + 1u) { loop {
if (levelCount == 0u) { break; }
let perThread = (levelCount + THREADS - 1u) / THREADS; let perThread = (levelCount + THREADS - 1u) / THREADS;
for (var k: u32 = 0u; k < perThread; k = k + 1u) { for (var k: u32 = 0u; k < perThread; k = k + 1u) {
let nodeOff = k * THREADS + tid; let nodeOff = k * THREADS + tid;
@ -2723,11 +2734,12 @@ env.wgpuBuildTLAS = (instanceBufHandle, instanceCount, tlasOutBufHandle,
{ binding: 4, resource: { buffer: morton } }, { binding: 4, resource: { buffer: morton } },
], ],
}); });
// Write the real instance count to the LBVH count uniform so the // Write the real instance count + the dynamic padded leaf count
// shader can iterate exactly the right number of entries even // (next_pow2) to the LBVH uniform. The sort still runs over the full
// though the storage buffers stay sized for N_PADDED. // N_PADDED, but the sweep tree is built at depth log2(nPadded).
const countBuf = new Uint32Array(4); const countBuf = new Uint32Array(4);
countBuf[0] = instanceCount; countBuf[0] = instanceCount;
countBuf[1] = wfNextPow2(instanceCount);
queue.writeBuffer(rtState.lbvhCountBuf, 0, countBuf); queue.writeBuffer(rtState.lbvhCountBuf, 0, countBuf);
const lbvhBg = device.createBindGroup({ const lbvhBg = device.createBindGroup({
@ -2798,7 +2810,15 @@ const WF_PAYLOAD_BYTES = 64;
// Dynamic-offset uniform ring: one WfParams slot per wavefront pass. 128 // Dynamic-offset uniform ring: one WfParams slot per wavefront pass. 128
// slots covers maxDepth up to ~42 (1 + 3·maxDepth + 1 passes). // slots covers maxDepth up to ~42 (1 + 3·maxDepth + 1 passes).
const WF_PARAM_SLOTS = 128; const WF_PARAM_SLOTS = 128;
const WF_FIXED_TLAS_NPADDED = 16384; // matches lbvhBuildWgsl N_PADDED const WF_TLAS_MAX_NPADDED = 16384; // LBVH sort capacity (N_PADDED)
// Smallest power of two >= max(n,1), clamped to the LBVH capacity. The
// TLAS sweep tree is built and traced at this depth so descent tracks the
// real instance count instead of a fixed 16384-leaf (depth-14) tree.
function wfNextPow2(n) {
let p = 1;
while (p < n && p < WF_TLAS_MAX_NPADDED) p <<= 1;
return p;
}
function ensureWavefrontBuffers(W, H) { function ensureWavefrontBuffers(W, H) {
const cap = W * H; const cap = W * H;
@ -3041,11 +3061,14 @@ env.wgpuDispatchRT = (pipelineHandle, pushPtr, pushBytes,
// 1+3*d .. +2 PREP / TRACE / SHADE for bounce d // 1+3*d .. +2 PREP / TRACE / SHADE for bounce d
// 1+3*depth RESOLVE // 1+3*depth RESOLVE
const passCount = 2 + 3 * depth; const passCount = 2 + 3 * depth;
// TLAS descent depth = log2(tlasNPadded); must match the value the
// build used (both derive next_pow2 from the same instance count).
const tlasNPadded = wfNextPow2(instanceCount);
const ring = new Uint32Array(WF_PARAM_SLOTS * 64); // 256 B = 64 u32 per slot const ring = new Uint32Array(WF_PARAM_SLOTS * 64); // 256 B = 64 u32 per slot
const writeSlot = (slot, curIsA, bounce) => { const writeSlot = (slot, curIsA, bounce) => {
const o = slot * 64; const o = slot * 64;
ring[o + 0] = W; ring[o + 1] = H; ring[o + 2] = cap; ring[o + 3] = curIsA; ring[o + 0] = W; ring[o + 1] = H; ring[o + 2] = cap; ring[o + 3] = curIsA;
ring[o + 4] = bounce; ring[o + 5] = depth; ring[o + 6] = WF_FIXED_TLAS_NPADDED; ring[o + 7] = 0; ring[o + 4] = bounce; ring[o + 5] = depth; ring[o + 6] = tlasNPadded; ring[o + 7] = 0;
}; };
writeSlot(0, 1, 0); // GENERATE writeSlot(0, 1, 0); // GENERATE
for (let d = 0; d < depth; d++) { for (let d = 0; d < depth; d++) {