diff --git a/additional/dom-webgpu.js b/additional/dom-webgpu.js index a15142f..e5fbbd3 100644 --- a/additional/dom-webgpu.js +++ b/additional/dom-webgpu.js @@ -1357,6 +1357,15 @@ struct BVHNode { // per-vertex stride lives in the user's WGSL — the library doesn't store // it because the layout is example-defined (Sponza uses 8 u32 / vertex // for VertexNormalTangentUVPacked). +// +// geomType selects the BLAS primitive kind: 0 = triangles (the default; +// the vertices/indices streams are positions + a triangle index buffer), +// 1 = AABBs (VK_GEOMETRY_TYPE_AABBS_KHR — the vertices stream holds 2 vec3 +// per primitive [min, max] and there are no indices; a registered +// intersection shader determines the hit). opaque is the geometry's opaque +// bit (1 = opaque, no any-hit; 0 = any-hit may run, subject to the ray / +// instance force flags). triangleCount doubles as the AABB primitive count +// for geomType == 1. struct MeshRecord { rootAabbMin: vec3, vertexOffset: u32, @@ -1366,6 +1375,21 @@ struct MeshRecord { primRemapOffset: u32, triangleCount: u32, attribsOffset: u32, + geomType: u32, + opaque: u32, + _padMr0: u32, + _padMr1: u32, +}; + +// Result of an intersection shader (procedural / AABB geometry) and the +// TRACE-stage candidate inspection. hit gates the rest; t is the +// object-space ray parameter; attribs are forwarded to closest-hit / +// any-hit (HitInfo.attribs); hitKind is user-defined (e.g. front/back). +struct IntersectionResult { + hit: bool, + t: f32, + attribs: vec2, + hitKind: u32, }; // Per-instance TLAS record built by the TLAS-build compute pass. @@ -1469,6 +1493,28 @@ fn _rtFetchTri(meshRec: MeshRecord, triIndex: u32) -> array, 3> { ); } +// Fetch one procedural AABB (geomType == 1). The vertices heap stores +// 2 vec3 per primitive [min, max]; vertexOffset is the per-mesh base in +// vec3 units (matching the triangle path's vertexOffset units). +fn _rtFetchAabb(meshRec: MeshRecord, primIndex: u32) -> array, 2> { + let base = (meshRec.vertexOffset + primIndex * 2u) * 3u; + return array, 2>( + vec3(vertices[base + 0u], vertices[base + 1u], vertices[base + 2u]), + vec3(vertices[base + 3u], vertices[base + 4u], vertices[base + 5u]), + ); +} + +// DXR/VK opacity resolution. Ray FORCE flags win, then instance FORCE +// flags, then the geometry's own opaque bit. Non-opaque candidates run +// the any-hit shader during traversal. +fn _rtResolveOpaque(rayFlags: u32, instFlags: u32, geomOpaque: bool) -> bool { + if ((rayFlags & RT_FLAG_OPAQUE) != 0u) { return true; } + if ((rayFlags & RT_FLAG_NO_OPAQUE) != 0u) { return false; } + if ((instFlags & RT_INSTANCE_FORCE_OPAQUE) != 0u) { return true; } + if ((instFlags & RT_INSTANCE_FORCE_NO_OPAQUE) != 0u) { return false; } + return geomOpaque; +} + fn _rtAabb(ro: vec3, invRd: vec3, mn: vec3, mx: vec3, tMax: f32) -> bool { // Reject degenerate (mn > mx) boxes outright. The min(t0,t1)/ // max(t0,t1) trick below silently re-orients an inverted box @@ -1681,9 +1727,41 @@ fn rtEmitRay(origin: vec3, tMin: f32, dir: vec3, tMax: f32, wfPayload[r.payloadSlot] = payload; } -// Opaque-only BLAS descent (no anyhit — TRACE runs zero user code). -fn _rtwTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord, - instanceId: u32, hitGroupBase: u32, +// Inspect one candidate hit: run the any-hit shader for non-opaque +// geometry, commit on accept. Returns 0 = ignored (keep searching, do not +// shrink bestT), 1 = committed (continue traversal with the tighter bestT), +// 2 = end search (committed + terminate). rayWorld is what any-hit sees +// as the ray (world space, matching closest-hit); cand carries the +// object-space hit. When RT_HAS_ANYHIT is false the any-hit call is +// const-folded away and this reduces to a plain commit — so opaque, +// triangle-only scenes keep TRACE's zero-user-code footprint. +fn _rtwCommitCandidate(rayWorld: RayDesc, flags: u32, instFlags: u32, + geomOpaque: bool, hitGroupBase: u32, cand: HitInfo, + payload: ptr, + bestHit: ptr, + bestT: ptr) -> u32 { + let opaque = _rtResolveOpaque(flags, instFlags, geomOpaque); + var verdict = RT_ANYHIT_ACCEPT; + if (RT_HAS_ANYHIT && !opaque) { + verdict = runAnyHit(hitGroupBase, rayWorld, cand, payload); + } + if (verdict == RT_ANYHIT_IGNORE) { return 0u; } + *bestHit = cand; + *bestT = cand.t; + if (verdict == RT_ANYHIT_END_SEARCH) { return 2u; } + if ((flags & RT_FLAG_TERMINATE_ON_FIRST_HIT) != 0u) { return 2u; } + return 1u; +} + +// BLAS descent. Handles both triangle and AABB (procedural) geometry and +// runs the any-hit shader for non-opaque candidates. payload is threaded +// in so any-hit / intersection can read & mutate it (_wfTrace writes it +// back). For opaque triangle-only scenes the user-callback branches +// const-fold out (RT_HAS_ANYHIT / RT_HAS_INTERSECTION both false), leaving +// the original pure-traversal loop. +fn _rtwTraverseBlas(rayObj: RayDesc, rayWorld: RayDesc, flags: u32, instFlags: u32, + meshRec: MeshRecord, instanceId: u32, hitGroupBase: u32, + payload: ptr, bestHit: ptr, bestT: ptr) -> bool { let invD = vec3(1.0) / rayObj.direction; @@ -1698,27 +1776,57 @@ fn _rtwTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord, sp = sp - 1u; nodeRel = stack[sp]; continue; } if (node.primCount > 0u) { - for (var i: u32 = 0u; i < node.primCount; i = i + 1u) { - let triIndex = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i]; - let verts = _rtFetchTri(meshRec, triIndex); - let tr = _rtTri(rayObj.origin, rayObj.direction, - verts[0], verts[1], verts[2], rayObj.tMin, *bestT); - if (!tr.hit) { continue; } - let geomNormal = cross(verts[1] - verts[0], verts[2] - verts[0]); - let facing = dot(geomNormal, rayObj.direction); - if ((flags & RT_FLAG_CULL_BACK_FACING_TRIANGLES) != 0u && facing > 0.0) { continue; } - if ((flags & RT_FLAG_CULL_FRONT_FACING_TRIANGLES) != 0u && facing < 0.0) { continue; } - var candidate: HitInfo; - candidate.t = tr.t; - candidate.instanceId = instanceId; - candidate.primitiveId = triIndex; - candidate.hitGroupIndex = hitGroupBase; - candidate.attribs = vec2(tr.u, tr.v); - candidate.objectRayOrigin = rayObj.origin; - candidate.objectRayDirection = rayObj.direction; - *bestHit = candidate; - *bestT = tr.t; - if ((flags & RT_FLAG_TERMINATE_ON_FIRST_HIT) != 0u) { return true; } + if (RT_HAS_INTERSECTION && meshRec.geomType == 1u) { + // ── AABB / procedural geometry (VK_GEOMETRY_TYPE_AABBS) ── + if ((flags & RT_FLAG_SKIP_AABBS) == 0u) { + for (var i: u32 = 0u; i < node.primCount; i = i + 1u) { + let primId = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i]; + let box = _rtFetchAabb(meshRec, primId); + if (!_rtAabb(rayObj.origin, invD, box[0], box[1], *bestT)) { continue; } + var iray: RayDesc; + iray.origin = rayObj.origin; iray.tMin = rayObj.tMin; + iray.direction = rayObj.direction; iray.tMax = *bestT; + let ir = runIntersection(hitGroupBase, iray, box[0], box[1], primId); + if (!ir.hit) { continue; } + if (ir.t < rayObj.tMin || ir.t > *bestT) { continue; } + var cand: HitInfo; + cand.t = ir.t; + cand.instanceId = instanceId; + cand.primitiveId = primId; + cand.hitGroupIndex = hitGroupBase; + cand.attribs = ir.attribs; + cand.objectRayOrigin = rayObj.origin; + cand.objectRayDirection = rayObj.direction; + let r = _rtwCommitCandidate(rayWorld, flags, instFlags, + meshRec.opaque != 0u, hitGroupBase, + cand, payload, bestHit, bestT); + if (r == 2u) { return true; } + } + } + } else if ((flags & RT_FLAG_SKIP_TRIANGLES) == 0u) { + for (var i: u32 = 0u; i < node.primCount; i = i + 1u) { + let triIndex = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i]; + let verts = _rtFetchTri(meshRec, triIndex); + let tr = _rtTri(rayObj.origin, rayObj.direction, + verts[0], verts[1], verts[2], rayObj.tMin, *bestT); + if (!tr.hit) { continue; } + let geomNormal = cross(verts[1] - verts[0], verts[2] - verts[0]); + let facing = dot(geomNormal, rayObj.direction); + if ((flags & RT_FLAG_CULL_BACK_FACING_TRIANGLES) != 0u && facing > 0.0) { continue; } + if ((flags & RT_FLAG_CULL_FRONT_FACING_TRIANGLES) != 0u && facing < 0.0) { continue; } + var candidate: HitInfo; + candidate.t = tr.t; + candidate.instanceId = instanceId; + candidate.primitiveId = triIndex; + candidate.hitGroupIndex = hitGroupBase; + candidate.attribs = vec2(tr.u, tr.v); + candidate.objectRayOrigin = rayObj.origin; + candidate.objectRayDirection = rayObj.direction; + let r = _rtwCommitCandidate(rayWorld, flags, instFlags, + meshRec.opaque != 0u, hitGroupBase, + candidate, payload, bestHit, bestT); + if (r == 2u) { return true; } + } } if (sp == 0u) { break; } sp = sp - 1u; nodeRel = stack[sp]; continue; @@ -1749,6 +1857,7 @@ fn _rtwTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord, fn _rtwTraverseTlas(rayWorld: RayDesc, flags: u32, cullMask: u32, sbtRecordOffset: u32, + payload: ptr, bestHit: ptr, bestT: ptr) -> bool { let invD = vec3(1.0) / rayWorld.direction; @@ -1793,7 +1902,7 @@ fn _rtwTraverseTlas(rayWorld: RayDesc, flags: u32, cullMask: u32, let hitGroupBase = sbtRecordOffset + hitGroupOffset; let meshRec = meshRecords[inst.blasMeshIdx]; let pre = *bestT; - let endSearch = _rtwTraverseBlas(rayObj, effective, meshRec, i, hitGroupBase, bestHit, bestT); + let endSearch = _rtwTraverseBlas(rayObj, rayWorld, effective, iflags, meshRec, i, hitGroupBase, payload, bestHit, bestT); if ((*bestT) < pre || endSearch) { (*bestHit).objectToWorldR0 = inst.objectToWorldR0; (*bestHit).objectToWorldR1 = inst.objectToWorldR1; @@ -1861,7 +1970,19 @@ fn _wfTrace(i: u32) { var bestHit: HitInfo; bestHit.t = ray.tMax; var bestT = ray.tMax; - _rtwTraverseTlas(rd, ray.flags, ray.cullMask & 0xFFu, ray.sbtRecordOffset, &bestHit, &bestT); + // Any-hit / intersection shaders run inside traversal and may read & + // mutate the payload, so load it here and write it back below. For an + // opaque triangle-only scene both consts are false and the payload + // touch const-folds away — TRACE keeps reading zero user state. + var payload: Payload; + if (RT_HAS_ANYHIT || RT_HAS_INTERSECTION) { + _wfPixel = ray.pixel; + payload = wfPayload[ray.payloadSlot]; + } + _rtwTraverseTlas(rd, ray.flags, ray.cullMask & 0xFFu, ray.sbtRecordOffset, &payload, &bestHit, &bestT); + if (RT_HAS_ANYHIT || RT_HAS_INTERSECTION) { + wfPayload[ray.payloadSlot] = payload; + } var hr: HitResult; if (bestT < ray.tMax) { hr.t = bestHit.t; @@ -1966,7 +2087,10 @@ fn _rqTraverseBlas(rayObj: RayDesc, flags: u32, meshRec: MeshRecord, sp = sp - 1u; nodeRel = stack[sp]; continue; } if (node.primCount > 0u) { - for (var i: u32 = 0u; i < node.primCount; i = i + 1u) { + // rayQuery is triangle-only; AABB (procedural) BLAS need an + // intersection shader the rayQuery path doesn't run, so skip + // their leaves rather than misread the AABB stream as triangles. + for (var i: u32 = 0u; i < node.primCount * select(0u, 1u, meshRec.geomType == 0u); i = i + 1u) { let triIndex = primRemap[meshRec.primRemapOffset + node.firstChildOrPrim + i]; let verts = _rtFetchTri(meshRec, triIndex); let tr = _rtTri(rayObj.origin, rayObj.direction, @@ -2572,7 +2696,7 @@ function rtInit() { rtState.attribsHeap = makeRtHeap(); rtState.meshRecordsCapacity = 16; rtState.meshRecordsBuffer = device.createBuffer({ - size: rtState.meshRecordsCapacity * 48, + size: rtState.meshRecordsCapacity * 64, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC, }); rtState.rtHeader = device.createBuffer({ @@ -2635,12 +2759,12 @@ function rtMeshRecordsEnsure(meshCount) { let cap = rtState.meshRecordsCapacity; while (cap < meshCount) cap *= 2; const ng = device.createBuffer({ - size: cap * 48, + size: cap * 64, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC, }); const enc = device.createCommandEncoder(); enc.copyBufferToBuffer(rtState.meshRecordsBuffer, 0, ng, 0, - rtState.meshRecordsCapacity * 48); + rtState.meshRecordsCapacity * 64); queue.submit([enc.finish()]); rtState.meshRecordsBuffer.destroy(); rtState.meshRecordsBuffer = ng; @@ -2653,9 +2777,11 @@ env.wgpuRegisterMeshBLAS = (minX, minY, minZ, maxX, maxY, maxZ, indicesPtr, indexCount, bvhNodesPtr, bvhNodeCount, primRemapPtr, primRemapCount, - attribsPtr, attribsByteCount) => { + attribsPtr, attribsByteCount, + geomType, opaqueFlag, primCount) => { if (!rtState.vertHeap) rtInit(); - console.log(`[crafter-wgpu] mesh BLAS: bbox=(${minX.toFixed(1)}..${maxX.toFixed(1)}, ${minY.toFixed(1)}..${maxY.toFixed(1)}, ${minZ.toFixed(1)}..${maxZ.toFixed(1)}), ${vertexCount} verts, ${indexCount/3} tris, attribs=${attribsByteCount}B`); + const kind = (geomType === 1) ? `${primCount} aabbs` : `${indexCount/3} tris`; + console.log(`[crafter-wgpu] mesh BLAS: bbox=(${minX.toFixed(1)}..${maxX.toFixed(1)}, ${minY.toFixed(1)}..${maxY.toFixed(1)}, ${minZ.toFixed(1)}..${maxZ.toFixed(1)}), ${vertexCount} verts, ${kind}, opaque=${opaqueFlag}, attribs=${attribsByteCount}B`); const vBytes = vertexCount * 12; const iBytes = indexCount * 4; @@ -2701,8 +2827,8 @@ env.wgpuRegisterMeshBLAS = (minX, minY, minZ, maxX, maxY, maxZ, const handle = rtState.nextMeshHandle++; rtMeshRecordsEnsure(handle + 1); - // Build the MeshRecord (48 bytes) and write it. - const rec = new ArrayBuffer(48); + // Build the MeshRecord (64 bytes) and write it. + const rec = new ArrayBuffer(64); const f32 = new Float32Array(rec); const u32 = new Uint32Array(rec); f32[0] = minX; f32[1] = minY; f32[2] = minZ; @@ -2711,9 +2837,16 @@ env.wgpuRegisterMeshBLAS = (minX, minY, minZ, maxX, maxY, maxZ, u32[7] = iOff; u32[8] = nOff; u32[9] = rOff; - u32[10] = (vertexCount > 0) ? (indexCount / 3) : 0; + // triangleCount field doubles as the primitive count (= AABB count for + // geomType 1). Triangle meshes derive it from the index buffer. + u32[10] = (geomType === 1) ? (primCount >>> 0) + : ((vertexCount > 0) ? (indexCount / 3) : 0); u32[11] = aOff; - queue.writeBuffer(rtState.meshRecordsBuffer, handle * 48, rec); + u32[12] = (geomType === 1) ? 1 : 0; // geomType + u32[13] = opaqueFlag ? 1 : 0; // opaque bit + u32[14] = 0; // _padMr0 + u32[15] = 0; // _padMr1 + queue.writeBuffer(rtState.meshRecordsBuffer, handle * 64, rec); return handle; }; @@ -2924,6 +3057,12 @@ env.wgpuLoadRTPipeline = (wgslPtr, wgslLen, bindingsPtr, bindingsCount) => { + beforeHelpers + "\n" + rtWgslPureHelpers + "\n" + rtWgslWavefrontHelpers + "\n" + afterHelpers; + // When the pipeline registers any-hit / intersection shaders, those run + // inside TRACE and may sample the user's @group(3+) resources — so TRACE + // needs the full user pipeline layout (and its bind groups set at + // dispatch). PipelineRTWebGPU emits this exact marker when so. + const traceHasUser = fullWgsl.includes("@CRAFTER_RT_TRACE_USER = true"); + // Parse user bindings (same wire format as wgpuLoadCustomShader). For // the wavefront RT pipeline, group 0 = WfParams, group 1 = data heaps, // group 2 = indirect args — so user bindings must start at group 3. @@ -3003,11 +3142,14 @@ env.wgpuLoadRTPipeline = (wgslPtr, wgslLen, bindingsPtr, bindingsCount) => { const entry = { genPipe: mk(userLayout, "wfGenerate"), prepPipe: mk(prepLayout, "wfPrep"), - tracePipe: mk(traceLayout, "wfTrace"), + // TRACE gets the user layout only when any-hit / intersection + // shaders run there; otherwise it keeps the minimal params+data + // layout and zero user code (the common opaque-triangle path). + tracePipe: mk(traceHasUser ? userLayout : traceLayout, "wfTrace"), shadePipe: mk(userLayout, "wfShade"), resolvePipe: mk(userLayout, "wfResolve"), paramsBgl, dataBgl, indirectBgl, emptyBgl, userBgls, - byGroup, sortedGroups, + byGroup, sortedGroups, traceHasUser, }; const handle = newHandle(); rtPipelines.set(handle, entry); @@ -3187,6 +3329,9 @@ env.wgpuDispatchRT = (pipelineHandle, pushPtr, pushBytes, p.setPipeline(pipe.tracePipe); p.setBindGroup(0, paramsBg, [slotOff(traceSlot)]); p.setBindGroup(1, dataBg); + // Any-hit / intersection shaders run in TRACE and may read the + // user @group(3+) resources — bind them when present. + if (pipe.traceHasUser) setUser(p); p.dispatchWorkgroupsIndirect(wf.indirect, 0); p.end(); }