/* Crafter®.Graphics Copyright (C) 2026 Catcrafts® catcrafts.net */ // Megakernel WGSL assembly. The library prelude lives JS-side // (additional/dom-webgpu.js, rtWgslPrelude) — we don't have access to // it from C++ — so this file emits only the *user-controlled* portions // (concatenated SBT sources + the generated switch statements) and the // stable entry-point glue. The JS side wraps these with the prelude // before handing to device.createShaderModule. // // Wire format passed across the JS boundary is a single WGSL string // containing the substitution markers `// @CRAFTER_RT_USER_SOURCES`, // `// @CRAFTER_RT_CLOSESTHIT_CASES`, `// @CRAFTER_RT_ANYHIT_CASES`, // `// @CRAFTER_RT_MISS_CASES`, `// @CRAFTER_RT_RAYGEN_BODY` already // expanded; the JS side just concatenates prelude + this string. module; module Crafter.Graphics:PipelineRTWebGPU_impl; import :PipelineRTWebGPU; import :ShaderBindingTableWebGPU; import :WebGPUComputeShader; import :RT; import :WebGPU; import std; using namespace Crafter; namespace { constexpr std::string_view kPlaceholderClosestHit = "fn _crafter_default_closesthit(ray: RayDesc, hit: HitInfo, payload: ptr) {}"; constexpr std::string_view kPlaceholderAnyHit = "fn _crafter_default_anyhit(ray: RayDesc, hit: HitInfo, payload: ptr) -> u32 { return RT_ANYHIT_ACCEPT; }"; constexpr std::string_view kPlaceholderMiss = "fn _crafter_default_miss(ray: RayDesc, payload: ptr) {}"; constexpr std::string_view kPlaceholderIntersection = "fn _crafter_default_intersection(ray: RayDesc, aabbMin: vec3, aabbMax: vec3, primitiveId: u32) -> IntersectionResult { var r: IntersectionResult; r.hit = false; return r; }"; void AppendCase(std::string& out, std::uint32_t hitGroupIndex, std::string_view entryFn, std::string_view args) { out += " case "; out += std::to_string(hitGroupIndex); out += "u: { "; out += entryFn; out += "("; out += args; out += "); }\n"; } // anyhit has a return type — case body forwards the result. void AppendAnyHitCase(std::string& out, std::uint32_t hitGroupIndex, std::string_view entryFn) { out += " case "; out += std::to_string(hitGroupIndex); out += "u: { return "; out += entryFn; out += "(ray, hit, payload); }\n"; } // intersection has a return type — forwards the AABB args + the result. void AppendIntersectionCase(std::string& out, std::uint32_t hitGroupIndex, std::string_view entryFn) { out += " case "; out += std::to_string(hitGroupIndex); out += "u: { return "; out += entryFn; out += "(ray, aabbMin, aabbMax, primitiveId); }\n"; } } void PipelineRTWebGPU::Init(WebGPUCommandEncoderRef /*cmd*/, std::span raygenGroups, std::span missGroups, std::span hitGroups, const ShaderBindingTableWebGPU& sbt, std::span bindings) { userBindings.assign(bindings.begin(), bindings.end()); std::string wgsl; wgsl.reserve(8 * 1024); // ── Section 1: user closesthit / anyhit / miss source files ──────── // // Raygens come later (after `traceRay` is declared) so we partition // shaders by stage. Concatenating *all* non-raygen sources here lets // them declare shared helpers, `struct Payload`, etc., in any order. wgsl += "// ── user closesthit / anyhit / miss / resolve sources ─────\n"; for (const auto& shader : sbt.shaders) { if (shader.stage == WebGPURTStage::Raygen) continue; wgsl += shader.source; wgsl += "\n"; } // ── Payload-typed wavefront storage binding ──────────────────────── // // Emitted *after* the user sources so it can name the user's `Payload` // type. Holds one Payload per in-flight ray slot across both ping/pong // ray buffers (capacity = 2·W·H). SHADE loads ray.payloadSlot here; // emit helpers (rtEmitPrimaryRay / rtEmitRay) store into it. wgsl += "\n@group(1) @binding(15) var " "wfPayload : array;\n"; // ── Section 2: mega-switch dispatchers ───────────────────────────── // // runClosestHit, runAnyHit, runMiss each dispatch on the per-hit / // per-ray index registered against the appropriate group span. // Indices match the user's expectations from VkRayTracingShaderGroup // ordering: closest-hit group N (N from 0..hitGroups.size()-1) is // selected by hitGroupIndex == N. wgsl += "\nfn runClosestHit(hg: u32, ray: RayDesc, hit: HitInfo, payload: ptr) {\n"; wgsl += " switch hg {\n"; bool anyClosestHit = false; for (std::uint32_t i = 0; i < hitGroups.size(); ++i) { const auto& g = hitGroups[i]; if (g.closestHitShader == kRTShaderUnused) continue; if (g.closestHitShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.closestHitShader].entryFn; AppendCase(wgsl, i, fn, "ray, hit, payload"); anyClosestHit = true; } if (!anyClosestHit) wgsl += " // (no closest-hit shaders registered)\n"; wgsl += " default: { }\n"; wgsl += " }\n"; wgsl += "}\n\n"; wgsl += "fn runAnyHit(hg: u32, ray: RayDesc, hit: HitInfo, payload: ptr) -> u32 {\n"; wgsl += " switch hg {\n"; bool anyAnyhit = false; for (std::uint32_t i = 0; i < hitGroups.size(); ++i) { const auto& g = hitGroups[i]; if (g.anyHitShader == kRTShaderUnused) continue; if (g.anyHitShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.anyHitShader].entryFn; AppendAnyHitCase(wgsl, i, fn); anyAnyhit = true; } if (!anyAnyhit) wgsl += " // (no any-hit shaders registered)\n"; wgsl += " default: { return RT_ANYHIT_ACCEPT; }\n"; wgsl += " }\n"; wgsl += "}\n\n"; wgsl += "fn runMiss(idx: u32, ray: RayDesc, payload: ptr) {\n"; wgsl += " switch idx {\n"; bool anyMiss = false; for (std::uint32_t i = 0; i < missGroups.size(); ++i) { const auto& g = missGroups[i]; if (g.generalShader == kRTShaderUnused) continue; if (g.generalShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.generalShader].entryFn; AppendCase(wgsl, i, fn, "ray, payload"); anyMiss = true; } if (!anyMiss) wgsl += " // (no miss shaders registered)\n"; wgsl += " default: { }\n"; wgsl += " }\n"; wgsl += "}\n"; // runIntersection — per-AABB procedural intersection dispatch. For a // ProceduralHitGroup the intersection shader determines the hit; for // triangle groups (or groups with no intersection shader) the default // reports no hit, so the BLAS leaf falls back to the triangle path. wgsl += "\nfn runIntersection(hg: u32, ray: RayDesc, aabbMin: vec3, aabbMax: vec3, primitiveId: u32) -> IntersectionResult {\n"; wgsl += " switch hg {\n"; bool anyIntersection = false; for (std::uint32_t i = 0; i < hitGroups.size(); ++i) { const auto& g = hitGroups[i]; if (g.intersectionShader == kRTShaderUnused) continue; if (g.intersectionShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.intersectionShader].entryFn; AppendIntersectionCase(wgsl, i, fn); anyIntersection = true; } if (!anyIntersection) wgsl += " // (no intersection shaders registered)\n"; wgsl += " default: { }\n"; wgsl += " }\n"; wgsl += " var none: IntersectionResult;\n"; wgsl += " none.hit = false;\n"; wgsl += " return none;\n"; wgsl += "}\n"; // Trace-time capability flags. The library traversal (injected at the // marker below) gates its any-hit / intersection callbacks on these // consts, so a triangle-only opaque scene dead-strips all user code out // of TRACE and keeps its zero-user-code register footprint. When either // is set the JS side also gives the TRACE pipeline the user bind-group // layout (so any-hit / intersection shaders can sample @group(3+) // resources) — it scans for the exact `@CRAFTER_RT_TRACE_USER` marker. wgsl += "\nconst RT_HAS_ANYHIT: bool = "; wgsl += (anyAnyhit ? "true" : "false"); wgsl += ";\n"; wgsl += "const RT_HAS_INTERSECTION: bool = "; wgsl += (anyIntersection ? "true" : "false"); wgsl += ";\n"; if (anyAnyhit || anyIntersection) { wgsl += "// @CRAFTER_RT_TRACE_USER = true\n"; } // runResolve — RESOLVE-stage tonemap hook. The first registered // Resolve shader wins; with none, identity passthrough (alpha forced // to 1) so the wavefront output matches a megakernel that wrote raw // colors. std::string resolveEntryFn; for (const auto& shader : sbt.shaders) { if (shader.stage == WebGPURTStage::Resolve) { resolveEntryFn = shader.entryFn; break; } } wgsl += "\nfn runResolve(coord: vec2, hdr: vec4) -> vec4 {\n"; if (!resolveEntryFn.empty()) { wgsl += " return "; wgsl += resolveEntryFn; wgsl += "(coord, hdr);\n"; } else { wgsl += " return vec4(hdr.rgb, 1.0);\n"; } wgsl += "}\n"; // Marker — JS-side prelude/post-amble searches for this token to know // where the library helpers (traverseBlas/traverseTlas/traceRay) get // injected, followed by raygen sources and the @compute entry point. wgsl += "\n// @CRAFTER_RT_LIBRARY_HELPERS_HERE\n"; // ── Section 3: user raygen source files ──────────────────────────── // // Comes after the library injects traceRay, so raygens can call it. wgsl += "\n// ── user raygen sources ───────────────────────────────────\n"; std::uint32_t raygenEntryIndex = kRTShaderUnused; std::string raygenEntryFn; for (const auto& shader : sbt.shaders) { if (shader.stage != WebGPURTStage::Raygen) continue; wgsl += shader.source; wgsl += "\n"; // Pick the first raygen group's general shader as the entry. Mirrors // Vulkan's pRayGenShaderBindingTable[0] → first invoked raygen. if (raygenEntryFn.empty()) raygenEntryFn = shader.entryFn; } if (!raygenGroups.empty() && raygenGroups[0].generalShader != kRTShaderUnused && raygenGroups[0].generalShader < sbt.shaders.size()) { raygenEntryIndex = raygenGroups[0].generalShader; raygenEntryFn = sbt.shaders[raygenEntryIndex].entryFn; } if (raygenEntryFn.empty()) { std::println("PipelineRTWebGPU::Init: no raygen shader registered"); pipelineHandle = 0; return; } // ── Section 4: wavefront @compute entry points ───────────────────── // // Five kernels share this one module; createComputePipeline selects // each by entryPoint name. GENERATE/RESOLVE are 8x8 screen tiles; // TRACE/SHADE are 64-wide 1-D over the compacted ray list (dispatched // indirectly from PREP); PREP is a single thread. The library helper // bodies (_rtwTraverseTlas, rtEmit*, rtAccumulate, _wfCurCount, …) are // injected JS-side at the marker above. // GENERATE — one thread per pixel; clears the pixel's accumulator and // runs the user raygen, which calls rtEmitPrimaryRay. wgsl += "\n@compute @workgroup_size(8, 8, 1)\n"; wgsl += "fn wfGenerate(@builtin(global_invocation_id) gid: vec3) {\n"; wgsl += " if (gid.x >= wfParams.surfaceW || gid.y >= wfParams.surfaceH) { return; }\n"; wgsl += " let pixel = gid.y * wfParams.surfaceW + gid.x;\n"; wgsl += " wfAccum[pixel] = vec4(0.0, 0.0, 0.0, 0.0);\n"; wgsl += " _wfPixel = pixel;\n"; wgsl += " "; wgsl += raygenEntryFn; wgsl += "(gid);\n"; wgsl += "}\n"; // PREP — single thread; reads the live ray count and publishes the // indirect dispatch args for the upcoming TRACE/SHADE, then zeroes the // next buffer's emit counter so SHADE starts compacting from 0. wgsl += "\n@compute @workgroup_size(1)\n"; wgsl += "fn wfPrep() { _wfPrep(); }\n"; // TRACE — zero user code: pure traversal + intersection. One thread // per live ray; writes a HitResult into wfHits[i]. wgsl += "\n@compute @workgroup_size(64)\n"; wgsl += "fn wfTrace(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { _wfTrace(gid.y * nwg.x * 64u + gid.x); }\n"; // SHADE — one thread per live ray; loads the ray + its hit + payload, // dispatches to runMiss / runClosestHit, which may rtAccumulate and // rtEmitRay continuation/shadow rays into the next buffer. wgsl += "\n@compute @workgroup_size(64)\n"; wgsl += "fn wfShade(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { _wfShade(gid.y * nwg.x * 64u + gid.x); }\n"; // RESOLVE — one thread per pixel; runs the user resolve (or identity) // over the linear accumulator and stores to the output image. wgsl += "\n@compute @workgroup_size(8, 8, 1)\n"; wgsl += "fn wfResolve(@builtin(global_invocation_id) gid: vec3) {\n"; wgsl += " if (gid.x >= wfParams.surfaceW || gid.y >= wfParams.surfaceH) { return; }\n"; wgsl += " let pixel = gid.y * wfParams.surfaceW + gid.x;\n"; wgsl += " let outc = runResolve(gid.xy, wfAccum[pixel]);\n"; wgsl += " textureStore(outImage, vec2(i32(gid.x), i32(gid.y)), outc);\n"; wgsl += "}\n"; pipelineHandle = WebGPU::wgpuLoadRTPipeline( wgsl.data(), static_cast(wgsl.size()), userBindings.empty() ? nullptr : userBindings.data(), static_cast(userBindings.size())); }