/* Crafter®.Graphics Copyright (C) 2026 Catcrafts® catcrafts.net */ // Megakernel WGSL assembly. The library prelude lives JS-side // (additional/dom-webgpu.js, rtWgslPrelude) — we don't have access to // it from C++ — so this file emits only the *user-controlled* portions // (concatenated SBT sources + the generated switch statements) and the // stable entry-point glue. The JS side wraps these with the prelude // before handing to device.createShaderModule. // // Wire format passed across the JS boundary is a single WGSL string // containing the substitution markers `// @CRAFTER_RT_USER_SOURCES`, // `// @CRAFTER_RT_CLOSESTHIT_CASES`, `// @CRAFTER_RT_ANYHIT_CASES`, // `// @CRAFTER_RT_MISS_CASES`, `// @CRAFTER_RT_RAYGEN_BODY` already // expanded; the JS side just concatenates prelude + this string. module; module Crafter.Graphics:PipelineRTWebGPU_impl; import :PipelineRTWebGPU; import :ShaderBindingTableWebGPU; import :RT; import :WebGPU; import std; using namespace Crafter; namespace { constexpr std::string_view kPlaceholderClosestHit = "fn _crafter_default_closesthit(ray: RayDesc, hit: HitInfo, payload: ptr) {}"; constexpr std::string_view kPlaceholderAnyHit = "fn _crafter_default_anyhit(ray: RayDesc, hit: HitInfo, payload: ptr) -> u32 { return RT_ANYHIT_ACCEPT; }"; constexpr std::string_view kPlaceholderMiss = "fn _crafter_default_miss(ray: RayDesc, payload: ptr) {}"; void AppendCase(std::string& out, std::uint32_t hitGroupIndex, std::string_view entryFn, std::string_view args) { out += " case "; out += std::to_string(hitGroupIndex); out += "u: { "; out += entryFn; out += "("; out += args; out += "); }\n"; } // anyhit has a return type — case body forwards the result. void AppendAnyHitCase(std::string& out, std::uint32_t hitGroupIndex, std::string_view entryFn) { out += " case "; out += std::to_string(hitGroupIndex); out += "u: { return "; out += entryFn; out += "(ray, hit, payload); }\n"; } } void PipelineRTWebGPU::Init(WebGPUCommandEncoderRef /*cmd*/, std::span raygenGroups, std::span missGroups, std::span hitGroups, const ShaderBindingTableWebGPU& sbt) { std::string wgsl; wgsl.reserve(8 * 1024); // ── Section 1: user closesthit / anyhit / miss source files ──────── // // Raygens come later (after `traceRay` is declared) so we partition // shaders by stage. Concatenating *all* non-raygen sources here lets // them declare shared helpers, `struct Payload`, etc., in any order. wgsl += "// ── user closesthit / anyhit / miss sources ───────────────\n"; for (const auto& shader : sbt.shaders) { if (shader.stage == WebGPURTStage::Raygen) continue; wgsl += shader.source; wgsl += "\n"; } // ── Section 2: mega-switch dispatchers ───────────────────────────── // // runClosestHit, runAnyHit, runMiss each dispatch on the per-hit / // per-ray index registered against the appropriate group span. // Indices match the user's expectations from VkRayTracingShaderGroup // ordering: closest-hit group N (N from 0..hitGroups.size()-1) is // selected by hitGroupIndex == N. wgsl += "\nfn runClosestHit(hg: u32, ray: RayDesc, hit: HitInfo, payload: ptr) {\n"; wgsl += " switch hg {\n"; bool anyClosestHit = false; for (std::uint32_t i = 0; i < hitGroups.size(); ++i) { const auto& g = hitGroups[i]; if (g.closestHitShader == kRTShaderUnused) continue; if (g.closestHitShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.closestHitShader].entryFn; AppendCase(wgsl, i, fn, "ray, hit, payload"); anyClosestHit = true; } if (!anyClosestHit) wgsl += " // (no closest-hit shaders registered)\n"; wgsl += " default: { }\n"; wgsl += " }\n"; wgsl += "}\n\n"; wgsl += "fn runAnyHit(hg: u32, ray: RayDesc, hit: HitInfo, payload: ptr) -> u32 {\n"; wgsl += " switch hg {\n"; bool anyAnyhit = false; for (std::uint32_t i = 0; i < hitGroups.size(); ++i) { const auto& g = hitGroups[i]; if (g.anyHitShader == kRTShaderUnused) continue; if (g.anyHitShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.anyHitShader].entryFn; AppendAnyHitCase(wgsl, i, fn); anyAnyhit = true; } if (!anyAnyhit) wgsl += " // (no any-hit shaders registered)\n"; wgsl += " default: { return RT_ANYHIT_ACCEPT; }\n"; wgsl += " }\n"; wgsl += "}\n\n"; wgsl += "fn runMiss(idx: u32, ray: RayDesc, payload: ptr) {\n"; wgsl += " switch idx {\n"; bool anyMiss = false; for (std::uint32_t i = 0; i < missGroups.size(); ++i) { const auto& g = missGroups[i]; if (g.generalShader == kRTShaderUnused) continue; if (g.generalShader >= sbt.shaders.size()) continue; const auto& fn = sbt.shaders[g.generalShader].entryFn; AppendCase(wgsl, i, fn, "ray, payload"); anyMiss = true; } if (!anyMiss) wgsl += " // (no miss shaders registered)\n"; wgsl += " default: { }\n"; wgsl += " }\n"; wgsl += "}\n"; // Marker — JS-side prelude/post-amble searches for this token to know // where the library helpers (traverseBlas/traverseTlas/traceRay) get // injected, followed by raygen sources and the @compute entry point. wgsl += "\n// @CRAFTER_RT_LIBRARY_HELPERS_HERE\n"; // ── Section 3: user raygen source files ──────────────────────────── // // Comes after the library injects traceRay, so raygens can call it. wgsl += "\n// ── user raygen sources ───────────────────────────────────\n"; std::uint32_t raygenEntryIndex = kRTShaderUnused; std::string raygenEntryFn; for (const auto& shader : sbt.shaders) { if (shader.stage != WebGPURTStage::Raygen) continue; wgsl += shader.source; wgsl += "\n"; // Pick the first raygen group's general shader as the entry. Mirrors // Vulkan's pRayGenShaderBindingTable[0] → first invoked raygen. if (raygenEntryFn.empty()) raygenEntryFn = shader.entryFn; } if (!raygenGroups.empty() && raygenGroups[0].generalShader != kRTShaderUnused && raygenGroups[0].generalShader < sbt.shaders.size()) { raygenEntryIndex = raygenGroups[0].generalShader; raygenEntryFn = sbt.shaders[raygenEntryIndex].entryFn; } if (raygenEntryFn.empty()) { std::println("PipelineRTWebGPU::Init: no raygen shader registered"); pipelineHandle = 0; return; } // ── Section 4: @compute entry point ──────────────────────────────── // // 8x8 tile workgroup matching the rest of the WebGPU backend. wgsl += "\n@compute @workgroup_size(8, 8, 1)\n"; wgsl += "fn main(@builtin(global_invocation_id) gid: vec3) {\n"; wgsl += " "; wgsl += raygenEntryFn; wgsl += "(gid);\n"; wgsl += "}\n"; pipelineHandle = WebGPU::wgpuLoadRTPipeline( wgsl.data(), static_cast(wgsl.size())); }