feat(webgpu-rt): any-hit + AABB (procedural) geometry support #14

Merged
catbot merged 4 commits from claude/issue-13 into master 2026-06-03 00:10:17 +02:00
Showing only changes of commit 1628e1a58c - Show all commits

feat(webgpu-rt): wire any-hit + AABB intersection into wavefront traversal

The TRACE-stage BLAS descent now threads the payload through, runs the
any-hit shader for non-opaque candidates (DXR/VK opacity resolution:
ray FORCE flags > instance FORCE flags > geometry opaque bit), and
handles AABB leaves via the intersection shader. MeshRecord grows to 64
bytes with geomType + opaque. When any-hit/intersection are present the
TRACE pipeline takes the user bind-group layout so those shaders can
sample @group(3+) resources; otherwise TRACE keeps its zero-user-code
path unchanged. rayQuery stays triangle-only (skips AABB leaves).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
catbot 2026-06-02 22:09:25 +00:00

View file

@ -1357,6 +1357,15 @@ struct BVHNode {
// per-vertex stride lives in the user's WGSL — the library doesn't store
// it because the layout is example-defined (Sponza uses 8 u32 / vertex
// for VertexNormalTangentUVPacked).
//
// geomType selects the BLAS primitive kind: 0 = triangles (the default;
// the vertices/indices streams are positions + a triangle index buffer),
// 1 = AABBs (VK_GEOMETRY_TYPE_AABBS_KHR — the vertices stream holds 2 vec3
// per primitive [min, max] and there are no indices; a registered
// intersection shader determines the hit). opaque is the geometry's opaque
// bit (1 = opaque, no any-hit; 0 = any-hit may run, subject to the ray /
// instance force flags). triangleCount doubles as the AABB primitive count
// for geomType == 1.
struct MeshRecord {
rootAabbMin: vec3<f32>,
vertexOffset: u32,
@ -1366,6 +1375,21 @@ struct MeshRecord {
primRemapOffset: u32,
triangleCount: u32,
attribsOffset: u32,
geomType: u32,
opaque: u32,
_padMr0: u32,
_padMr1: u32,
};
// Result of an intersection shader (procedural / AABB geometry) and the
// TRACE-stage candidate inspection. hit gates the rest; t is the
// object-space ray parameter; attribs are forwarded to closest-hit /
// any-hit (HitInfo.attribs); hitKind is user-defined (e.g. front/back).
struct IntersectionResult {
hit: bool,
t: f32,
attribs: vec2<f32>,
hitKind: u32,
};
// Per-instance TLAS record built by the TLAS-build compute pass.
@ -1469,6 +1493,28 @@ fn _rtFetchTri(meshRec: MeshRecord, triIndex: u32) -> array<vec3<f32>, 3> {
);
}
// Fetch one procedural AABB (geomType == 1). The vertices heap stores
// 2 vec3 per primitive [min, max]; vertexOffset is the per-mesh base in
// vec3 units (matching the triangle path's vertexOffset units).
fn _rtFetchAabb(meshRec: MeshRecord, primIndex: u32) -> array<vec3<f32>, 2> {
let base = (meshRec.vertexOffset + primIndex * 2u) * 3u;
return array<vec3<f32>, 2>(
vec3<f32>(vertices[base + 0u], vertices[base + 1u], vertices[base + 2u]),
vec3<f32>(vertices[base + 3u], vertices[base + 4u], vertices[base + 5u]),
);
}
// DXR/VK opacity resolution. Ray FORCE flags win, then instance FORCE
// flags, then the geometry's own opaque bit. Non-opaque candidates run
// the any-hit shader during traversal.
fn _rtResolveOpaque(rayFlags: u32, instFlags: u32, geomOpaque: bool) -> bool {
if ((rayFlags & RT_FLAG_OPAQUE) != 0u) { return true; }
if ((rayFlags & RT_FLAG_NO_OPAQUE) != 0u) { return false; }
if ((instFlags & RT_INSTANCE_FORCE_OPAQUE) != 0u) { return true; }
if ((instFlags & RT_INSTANCE_FORCE_NO_OPAQUE) != 0u) { return false; }
return geomOpaque;
}
fn _rtAabb(ro: vec3<f32>, invRd: vec3<f32>, mn: vec3<f32>, mx: vec3<f32>, tMax: f32) -> bool {
// Reject degenerate (mn > mx) boxes outright. The min(t0,t1)/
// max(t0,t1) trick below silently re-orients an inverted box
@ -1681,9 +1727,41 @@ fn rtEmitRay(origin: vec3<f32>, tMin: f32, dir: vec3<f32>, tMax: f32,
wfPayload[r.payloadSlot] = payload;
}
// Opaque-only BLAS descent (no anyhit — TRACE runs zero user code).
fn _rtwTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord,
instanceId: u32, hitGroupBase: u32,
// Inspect one candidate hit: run the any-hit shader for non-opaque
// geometry, commit on accept. Returns 0 = ignored (keep searching, do not
// shrink bestT), 1 = committed (continue traversal with the tighter bestT),
// 2 = end search (committed + terminate). rayWorld is what any-hit sees
// as the ray (world space, matching closest-hit); cand carries the
// object-space hit. When RT_HAS_ANYHIT is false the any-hit call is
// const-folded away and this reduces to a plain commit — so opaque,
// triangle-only scenes keep TRACE's zero-user-code footprint.
fn _rtwCommitCandidate(rayWorld: RayDesc, flags: u32, instFlags: u32,
geomOpaque: bool, hitGroupBase: u32, cand: HitInfo,
payload: ptr<function, Payload>,
bestHit: ptr<function, HitInfo>,
bestT: ptr<function, f32>) -> u32 {
let opaque = _rtResolveOpaque(flags, instFlags, geomOpaque);
var verdict = RT_ANYHIT_ACCEPT;
if (RT_HAS_ANYHIT && !opaque) {
verdict = runAnyHit(hitGroupBase, rayWorld, cand, payload);
}
if (verdict == RT_ANYHIT_IGNORE) { return 0u; }
*bestHit = cand;
*bestT = cand.t;
if (verdict == RT_ANYHIT_END_SEARCH) { return 2u; }
if ((flags & RT_FLAG_TERMINATE_ON_FIRST_HIT) != 0u) { return 2u; }
return 1u;
}
// BLAS descent. Handles both triangle and AABB (procedural) geometry and
// runs the any-hit shader for non-opaque candidates. payload is threaded
// in so any-hit / intersection can read & mutate it (_wfTrace writes it
// back). For opaque triangle-only scenes the user-callback branches
// const-fold out (RT_HAS_ANYHIT / RT_HAS_INTERSECTION both false), leaving
// the original pure-traversal loop.
fn _rtwTraverseBlas(rayObj: RayDesc, rayWorld: RayDesc, flags: u32, instFlags: u32,
meshRec: MeshRecord, instanceId: u32, hitGroupBase: u32,
payload: ptr<function, Payload>,
bestHit: ptr<function, HitInfo>,
bestT: ptr<function, f32>) -> bool {
let invD = vec3<f32>(1.0) / rayObj.direction;
@ -1698,27 +1776,57 @@ fn _rtwTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord,
sp = sp - 1u; nodeRel = stack[sp]; continue;
}
if (node.primCount > 0u) {
for (var i: u32 = 0u; i < node.primCount; i = i + 1u) {
let triIndex = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i];
let verts = _rtFetchTri(meshRec, triIndex);
let tr = _rtTri(rayObj.origin, rayObj.direction,
verts[0], verts[1], verts[2], rayObj.tMin, *bestT);
if (!tr.hit) { continue; }
let geomNormal = cross(verts[1] - verts[0], verts[2] - verts[0]);
let facing = dot(geomNormal, rayObj.direction);
if ((flags & RT_FLAG_CULL_BACK_FACING_TRIANGLES) != 0u && facing > 0.0) { continue; }
if ((flags & RT_FLAG_CULL_FRONT_FACING_TRIANGLES) != 0u && facing < 0.0) { continue; }
var candidate: HitInfo;
candidate.t = tr.t;
candidate.instanceId = instanceId;
candidate.primitiveId = triIndex;
candidate.hitGroupIndex = hitGroupBase;
candidate.attribs = vec2<f32>(tr.u, tr.v);
candidate.objectRayOrigin = rayObj.origin;
candidate.objectRayDirection = rayObj.direction;
*bestHit = candidate;
*bestT = tr.t;
if ((flags & RT_FLAG_TERMINATE_ON_FIRST_HIT) != 0u) { return true; }
if (RT_HAS_INTERSECTION && meshRec.geomType == 1u) {
// ── AABB / procedural geometry (VK_GEOMETRY_TYPE_AABBS) ──
if ((flags & RT_FLAG_SKIP_AABBS) == 0u) {
for (var i: u32 = 0u; i < node.primCount; i = i + 1u) {
let primId = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i];
let box = _rtFetchAabb(meshRec, primId);
if (!_rtAabb(rayObj.origin, invD, box[0], box[1], *bestT)) { continue; }
var iray: RayDesc;
iray.origin = rayObj.origin; iray.tMin = rayObj.tMin;
iray.direction = rayObj.direction; iray.tMax = *bestT;
let ir = runIntersection(hitGroupBase, iray, box[0], box[1], primId);
if (!ir.hit) { continue; }
if (ir.t < rayObj.tMin || ir.t > *bestT) { continue; }
var cand: HitInfo;
cand.t = ir.t;
cand.instanceId = instanceId;
cand.primitiveId = primId;
cand.hitGroupIndex = hitGroupBase;
cand.attribs = ir.attribs;
cand.objectRayOrigin = rayObj.origin;
cand.objectRayDirection = rayObj.direction;
let r = _rtwCommitCandidate(rayWorld, flags, instFlags,
meshRec.opaque != 0u, hitGroupBase,
cand, payload, bestHit, bestT);
if (r == 2u) { return true; }
}
}
} else if ((flags & RT_FLAG_SKIP_TRIANGLES) == 0u) {
for (var i: u32 = 0u; i < node.primCount; i = i + 1u) {
let triIndex = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i];
let verts = _rtFetchTri(meshRec, triIndex);
let tr = _rtTri(rayObj.origin, rayObj.direction,
verts[0], verts[1], verts[2], rayObj.tMin, *bestT);
if (!tr.hit) { continue; }
let geomNormal = cross(verts[1] - verts[0], verts[2] - verts[0]);
let facing = dot(geomNormal, rayObj.direction);
if ((flags & RT_FLAG_CULL_BACK_FACING_TRIANGLES) != 0u && facing > 0.0) { continue; }
if ((flags & RT_FLAG_CULL_FRONT_FACING_TRIANGLES) != 0u && facing < 0.0) { continue; }
var candidate: HitInfo;
candidate.t = tr.t;
candidate.instanceId = instanceId;
candidate.primitiveId = triIndex;
candidate.hitGroupIndex = hitGroupBase;
candidate.attribs = vec2<f32>(tr.u, tr.v);
candidate.objectRayOrigin = rayObj.origin;
candidate.objectRayDirection = rayObj.direction;
let r = _rtwCommitCandidate(rayWorld, flags, instFlags,
meshRec.opaque != 0u, hitGroupBase,
candidate, payload, bestHit, bestT);
if (r == 2u) { return true; }
}
}
if (sp == 0u) { break; }
sp = sp - 1u; nodeRel = stack[sp]; continue;
@ -1749,6 +1857,7 @@ fn _rtwTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord,
fn _rtwTraverseTlas(rayWorld: RayDesc, flags: u32, cullMask: u32,
sbtRecordOffset: u32,
payload: ptr<function, Payload>,
bestHit: ptr<function, HitInfo>,
bestT: ptr<function, f32>) -> bool {
let invD = vec3<f32>(1.0) / rayWorld.direction;
@ -1793,7 +1902,7 @@ fn _rtwTraverseTlas(rayWorld: RayDesc, flags: u32, cullMask: u32,
let hitGroupBase = sbtRecordOffset + hitGroupOffset;
let meshRec = meshRecords[inst.blasMeshIdx];
let pre = *bestT;
let endSearch = _rtwTraverseBlas(rayObj, effective, meshRec, i, hitGroupBase, bestHit, bestT);
let endSearch = _rtwTraverseBlas(rayObj, rayWorld, effective, iflags, meshRec, i, hitGroupBase, payload, bestHit, bestT);
if ((*bestT) < pre || endSearch) {
(*bestHit).objectToWorldR0 = inst.objectToWorldR0;
(*bestHit).objectToWorldR1 = inst.objectToWorldR1;
@ -1861,7 +1970,19 @@ fn _wfTrace(i: u32) {
var bestHit: HitInfo;
bestHit.t = ray.tMax;
var bestT = ray.tMax;
_rtwTraverseTlas(rd, ray.flags, ray.cullMask & 0xFFu, ray.sbtRecordOffset, &bestHit, &bestT);
// Any-hit / intersection shaders run inside traversal and may read &
// mutate the payload, so load it here and write it back below. For an
// opaque triangle-only scene both consts are false and the payload
// touch const-folds away — TRACE keeps reading zero user state.
var payload: Payload;
if (RT_HAS_ANYHIT || RT_HAS_INTERSECTION) {
_wfPixel = ray.pixel;
payload = wfPayload[ray.payloadSlot];
}
_rtwTraverseTlas(rd, ray.flags, ray.cullMask & 0xFFu, ray.sbtRecordOffset, &payload, &bestHit, &bestT);
if (RT_HAS_ANYHIT || RT_HAS_INTERSECTION) {
wfPayload[ray.payloadSlot] = payload;
}
var hr: HitResult;
if (bestT < ray.tMax) {
hr.t = bestHit.t;
@ -1966,7 +2087,10 @@ fn _rqTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord,
sp = sp - 1u; nodeRel = stack[sp]; continue;
}
if (node.primCount > 0u) {
for (var i: u32 = 0u; i < node.primCount; i = i + 1u) {
// rayQuery is triangle-only; AABB (procedural) BLAS need an
// intersection shader the rayQuery path doesn't run, so skip
// their leaves rather than misread the AABB stream as triangles.
for (var i: u32 = 0u; i < node.primCount * select(0u, 1u, meshRec.geomType == 0u); i = i + 1u) {
let triIndex = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i];
let verts = _rtFetchTri(meshRec, triIndex);
let tr = _rtTri(rayObj.origin, rayObj.direction,
@ -2572,7 +2696,7 @@ function rtInit() {
rtState.attribsHeap = makeRtHeap();
rtState.meshRecordsCapacity = 16;
rtState.meshRecordsBuffer = device.createBuffer({
size: rtState.meshRecordsCapacity * 48,
size: rtState.meshRecordsCapacity * 64,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
});
rtState.rtHeader = device.createBuffer({
@ -2635,12 +2759,12 @@ function rtMeshRecordsEnsure(meshCount) {
let cap = rtState.meshRecordsCapacity;
while (cap < meshCount) cap *= 2;
const ng = device.createBuffer({
size: cap * 48,
size: cap * 64,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
});
const enc = device.createCommandEncoder();
enc.copyBufferToBuffer(rtState.meshRecordsBuffer, 0, ng, 0,
rtState.meshRecordsCapacity * 48);
rtState.meshRecordsCapacity * 64);
queue.submit([enc.finish()]);
rtState.meshRecordsBuffer.destroy();
rtState.meshRecordsBuffer = ng;
@ -2653,9 +2777,11 @@ env.wgpuRegisterMeshBLAS = (minX, minY, minZ, maxX, maxY, maxZ,
indicesPtr, indexCount,
bvhNodesPtr, bvhNodeCount,
primRemapPtr, primRemapCount,
attribsPtr, attribsByteCount) => {
attribsPtr, attribsByteCount,
geomType, opaqueFlag, primCount) => {
if (!rtState.vertHeap) rtInit();
console.log(`[crafter-wgpu] mesh BLAS: bbox=(${minX.toFixed(1)}..${maxX.toFixed(1)}, ${minY.toFixed(1)}..${maxY.toFixed(1)}, ${minZ.toFixed(1)}..${maxZ.toFixed(1)}), ${vertexCount} verts, ${indexCount/3} tris, attribs=${attribsByteCount}B`);
const kind = (geomType === 1) ? `${primCount} aabbs` : `${indexCount/3} tris`;
console.log(`[crafter-wgpu] mesh BLAS: bbox=(${minX.toFixed(1)}..${maxX.toFixed(1)}, ${minY.toFixed(1)}..${maxY.toFixed(1)}, ${minZ.toFixed(1)}..${maxZ.toFixed(1)}), ${vertexCount} verts, ${kind}, opaque=${opaqueFlag}, attribs=${attribsByteCount}B`);
const vBytes = vertexCount * 12;
const iBytes = indexCount * 4;
@ -2701,8 +2827,8 @@ env.wgpuRegisterMeshBLAS = (minX, minY, minZ, maxX, maxY, maxZ,
const handle = rtState.nextMeshHandle++;
rtMeshRecordsEnsure(handle + 1);
// Build the MeshRecord (48 bytes) and write it.
const rec = new ArrayBuffer(48);
// Build the MeshRecord (64 bytes) and write it.
const rec = new ArrayBuffer(64);
const f32 = new Float32Array(rec);
const u32 = new Uint32Array(rec);
f32[0] = minX; f32[1] = minY; f32[2] = minZ;
@ -2711,9 +2837,16 @@ env.wgpuRegisterMeshBLAS = (minX, minY, minZ, maxX, maxY, maxZ,
u32[7] = iOff;
u32[8] = nOff;
u32[9] = rOff;
u32[10] = (vertexCount > 0) ? (indexCount / 3) : 0;
// triangleCount field doubles as the primitive count (= AABB count for
// geomType 1). Triangle meshes derive it from the index buffer.
u32[10] = (geomType === 1) ? (primCount >>> 0)
: ((vertexCount > 0) ? (indexCount / 3) : 0);
u32[11] = aOff;
queue.writeBuffer(rtState.meshRecordsBuffer, handle * 48, rec);
u32[12] = (geomType === 1) ? 1 : 0; // geomType
u32[13] = opaqueFlag ? 1 : 0; // opaque bit
u32[14] = 0; // _padMr0
u32[15] = 0; // _padMr1
queue.writeBuffer(rtState.meshRecordsBuffer, handle * 64, rec);
return handle;
};
@ -2924,6 +3057,12 @@ env.wgpuLoadRTPipeline = (wgslPtr, wgslLen, bindingsPtr, bindingsCount) => {
+ beforeHelpers + "\n" + rtWgslPureHelpers + "\n"
+ rtWgslWavefrontHelpers + "\n" + afterHelpers;
// When the pipeline registers any-hit / intersection shaders, those run
// inside TRACE and may sample the user's @group(3+) resources — so TRACE
// needs the full user pipeline layout (and its bind groups set at
// dispatch). PipelineRTWebGPU emits this exact marker when so.
const traceHasUser = fullWgsl.includes("@CRAFTER_RT_TRACE_USER = true");
// Parse user bindings (same wire format as wgpuLoadCustomShader). For
// the wavefront RT pipeline, group 0 = WfParams, group 1 = data heaps,
// group 2 = indirect args — so user bindings must start at group 3.
@ -3003,11 +3142,14 @@ env.wgpuLoadRTPipeline = (wgslPtr, wgslLen, bindingsPtr, bindingsCount) => {
const entry = {
genPipe: mk(userLayout, "wfGenerate"),
prepPipe: mk(prepLayout, "wfPrep"),
tracePipe: mk(traceLayout, "wfTrace"),
// TRACE gets the user layout only when any-hit / intersection
// shaders run there; otherwise it keeps the minimal params+data
// layout and zero user code (the common opaque-triangle path).
tracePipe: mk(traceHasUser ? userLayout : traceLayout, "wfTrace"),
shadePipe: mk(userLayout, "wfShade"),
resolvePipe: mk(userLayout, "wfResolve"),
paramsBgl, dataBgl, indirectBgl, emptyBgl, userBgls,
byGroup, sortedGroups,
byGroup, sortedGroups, traceHasUser,
};
const handle = newHandle();
rtPipelines.set(handle, entry);
@ -3187,6 +3329,9 @@ env.wgpuDispatchRT = (pipelineHandle, pushPtr, pushBytes,
p.setPipeline(pipe.tracePipe);
p.setBindGroup(0, paramsBg, [slotOff(traceSlot)]);
p.setBindGroup(1, dataBg);
// Any-hit / intersection shaders run in TRACE and may read the
// user @group(3+) resources — bind them when present.
if (pipe.traceHasUser) setUser(p);
p.dispatchWorkgroupsIndirect(wf.indirect, 0);
p.end();
}