Crafter.Graphics/implementations/Crafter.Graphics-PipelineRTWebGPU.cpp
catbot 1e749818ef fix(webgpu): reshape wavefront TRACE/SHADE to 2-D to survive >4.19M rays
A 1-D indirect dispatch of ceil(W*H/64) workgroups for the wavefront
TRACE/SHADE stages overflows maxComputeWorkgroupsPerDimension (65535 on
Dawn/Firefox) once the surface exceeds ~4.19M rays (~2560x1640). Per the
WebGPU spec such a dispatch is silently dropped — no validation error —
so at 4K the world is never traced and the accumulator stays black while
non-RT passes survive.

_wfPrep now spreads the workgroups across a 2-D grid (x clamped to 65535,
y = ceil(wg/65535)), and the wfTrace/wfShade entry points rebuild the
linear ray index from (global_invocation_id, num_workgroups). The existing
`i >= _wfCurCount()` guard absorbs the grid overshoot. GENERATE/RESOLVE
already use a 2-D tile dispatch and are unchanged.

Verified in Firefox/WebGPU with RTStress at a 3449x1739 surface (5.99M
rays, 93716 workgroups — well over the 65535 cap): renders the full cube
grid where master shows a black screen.

Resolves #11

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-01 11:09:15 +00:00

257 lines
12 KiB
C++

/*
Crafter®.Graphics
Copyright (C) 2026 Catcrafts®
catcrafts.net
*/
// Megakernel WGSL assembly. The library prelude lives JS-side
// (additional/dom-webgpu.js, rtWgslPrelude) — we don't have access to
// it from C++ — so this file emits only the *user-controlled* portions
// (concatenated SBT sources + the generated switch statements) and the
// stable entry-point glue. The JS side wraps these with the prelude
// before handing to device.createShaderModule.
//
// Wire format passed across the JS boundary is a single WGSL string
// containing the substitution markers `// @CRAFTER_RT_USER_SOURCES`,
// `// @CRAFTER_RT_CLOSESTHIT_CASES`, `// @CRAFTER_RT_ANYHIT_CASES`,
// `// @CRAFTER_RT_MISS_CASES`, `// @CRAFTER_RT_RAYGEN_BODY` already
// expanded; the JS side just concatenates prelude + this string.
module;
module Crafter.Graphics:PipelineRTWebGPU_impl;
import :PipelineRTWebGPU;
import :ShaderBindingTableWebGPU;
import :WebGPUComputeShader;
import :RT;
import :WebGPU;
import std;
using namespace Crafter;
namespace {
constexpr std::string_view kPlaceholderClosestHit =
"fn _crafter_default_closesthit(ray: RayDesc, hit: HitInfo, payload: ptr<function, Payload>) {}";
constexpr std::string_view kPlaceholderAnyHit =
"fn _crafter_default_anyhit(ray: RayDesc, hit: HitInfo, payload: ptr<function, Payload>) -> u32 { return RT_ANYHIT_ACCEPT; }";
constexpr std::string_view kPlaceholderMiss =
"fn _crafter_default_miss(ray: RayDesc, payload: ptr<function, Payload>) {}";
void AppendCase(std::string& out,
std::uint32_t hitGroupIndex,
std::string_view entryFn,
std::string_view args) {
out += " case ";
out += std::to_string(hitGroupIndex);
out += "u: { ";
out += entryFn;
out += "(";
out += args;
out += "); }\n";
}
// anyhit has a return type — case body forwards the result.
void AppendAnyHitCase(std::string& out,
std::uint32_t hitGroupIndex,
std::string_view entryFn) {
out += " case ";
out += std::to_string(hitGroupIndex);
out += "u: { return ";
out += entryFn;
out += "(ray, hit, payload); }\n";
}
}
void PipelineRTWebGPU::Init(WebGPUCommandEncoderRef /*cmd*/,
std::span<const RTShaderGroup> raygenGroups,
std::span<const RTShaderGroup> missGroups,
std::span<const RTShaderGroup> hitGroups,
const ShaderBindingTableWebGPU& sbt,
std::span<const UICustomBinding> bindings) {
userBindings.assign(bindings.begin(), bindings.end());
std::string wgsl;
wgsl.reserve(8 * 1024);
// ── Section 1: user closesthit / anyhit / miss source files ────────
//
// Raygens come later (after `traceRay` is declared) so we partition
// shaders by stage. Concatenating *all* non-raygen sources here lets
// them declare shared helpers, `struct Payload`, etc., in any order.
wgsl += "// ── user closesthit / anyhit / miss / resolve sources ─────\n";
for (const auto& shader : sbt.shaders) {
if (shader.stage == WebGPURTStage::Raygen) continue;
wgsl += shader.source;
wgsl += "\n";
}
// ── Payload-typed wavefront storage binding ────────────────────────
//
// Emitted *after* the user sources so it can name the user's `Payload`
// type. Holds one Payload per in-flight ray slot across both ping/pong
// ray buffers (capacity = 2·W·H). SHADE loads ray.payloadSlot here;
// emit helpers (rtEmitPrimaryRay / rtEmitRay) store into it.
wgsl += "\n@group(1) @binding(15) var<storage, read_write> "
"wfPayload : array<Payload>;\n";
// ── Section 2: mega-switch dispatchers ─────────────────────────────
//
// runClosestHit, runAnyHit, runMiss each dispatch on the per-hit /
// per-ray index registered against the appropriate group span.
// Indices match the user's expectations from VkRayTracingShaderGroup
// ordering: closest-hit group N (N from 0..hitGroups.size()-1) is
// selected by hitGroupIndex == N.
wgsl += "\nfn runClosestHit(hg: u32, ray: RayDesc, hit: HitInfo, payload: ptr<function, Payload>) {\n";
wgsl += " switch hg {\n";
bool anyClosestHit = false;
for (std::uint32_t i = 0; i < hitGroups.size(); ++i) {
const auto& g = hitGroups[i];
if (g.closestHitShader == kRTShaderUnused) continue;
if (g.closestHitShader >= sbt.shaders.size()) continue;
const auto& fn = sbt.shaders[g.closestHitShader].entryFn;
AppendCase(wgsl, i, fn, "ray, hit, payload");
anyClosestHit = true;
}
if (!anyClosestHit) wgsl += " // (no closest-hit shaders registered)\n";
wgsl += " default: { }\n";
wgsl += " }\n";
wgsl += "}\n\n";
wgsl += "fn runAnyHit(hg: u32, ray: RayDesc, hit: HitInfo, payload: ptr<function, Payload>) -> u32 {\n";
wgsl += " switch hg {\n";
bool anyAnyhit = false;
for (std::uint32_t i = 0; i < hitGroups.size(); ++i) {
const auto& g = hitGroups[i];
if (g.anyHitShader == kRTShaderUnused) continue;
if (g.anyHitShader >= sbt.shaders.size()) continue;
const auto& fn = sbt.shaders[g.anyHitShader].entryFn;
AppendAnyHitCase(wgsl, i, fn);
anyAnyhit = true;
}
if (!anyAnyhit) wgsl += " // (no any-hit shaders registered)\n";
wgsl += " default: { return RT_ANYHIT_ACCEPT; }\n";
wgsl += " }\n";
wgsl += "}\n\n";
wgsl += "fn runMiss(idx: u32, ray: RayDesc, payload: ptr<function, Payload>) {\n";
wgsl += " switch idx {\n";
bool anyMiss = false;
for (std::uint32_t i = 0; i < missGroups.size(); ++i) {
const auto& g = missGroups[i];
if (g.generalShader == kRTShaderUnused) continue;
if (g.generalShader >= sbt.shaders.size()) continue;
const auto& fn = sbt.shaders[g.generalShader].entryFn;
AppendCase(wgsl, i, fn, "ray, payload");
anyMiss = true;
}
if (!anyMiss) wgsl += " // (no miss shaders registered)\n";
wgsl += " default: { }\n";
wgsl += " }\n";
wgsl += "}\n";
// runResolve — RESOLVE-stage tonemap hook. The first registered
// Resolve shader wins; with none, identity passthrough (alpha forced
// to 1) so the wavefront output matches a megakernel that wrote raw
// colors.
std::string resolveEntryFn;
for (const auto& shader : sbt.shaders) {
if (shader.stage == WebGPURTStage::Resolve) { resolveEntryFn = shader.entryFn; break; }
}
wgsl += "\nfn runResolve(coord: vec2<u32>, hdr: vec4<f32>) -> vec4<f32> {\n";
if (!resolveEntryFn.empty()) {
wgsl += " return ";
wgsl += resolveEntryFn;
wgsl += "(coord, hdr);\n";
} else {
wgsl += " return vec4<f32>(hdr.rgb, 1.0);\n";
}
wgsl += "}\n";
// Marker — JS-side prelude/post-amble searches for this token to know
// where the library helpers (traverseBlas/traverseTlas/traceRay) get
// injected, followed by raygen sources and the @compute entry point.
wgsl += "\n// @CRAFTER_RT_LIBRARY_HELPERS_HERE\n";
// ── Section 3: user raygen source files ────────────────────────────
//
// Comes after the library injects traceRay, so raygens can call it.
wgsl += "\n// ── user raygen sources ───────────────────────────────────\n";
std::uint32_t raygenEntryIndex = kRTShaderUnused;
std::string raygenEntryFn;
for (const auto& shader : sbt.shaders) {
if (shader.stage != WebGPURTStage::Raygen) continue;
wgsl += shader.source;
wgsl += "\n";
// Pick the first raygen group's general shader as the entry. Mirrors
// Vulkan's pRayGenShaderBindingTable[0] → first invoked raygen.
if (raygenEntryFn.empty()) raygenEntryFn = shader.entryFn;
}
if (!raygenGroups.empty()
&& raygenGroups[0].generalShader != kRTShaderUnused
&& raygenGroups[0].generalShader < sbt.shaders.size()) {
raygenEntryIndex = raygenGroups[0].generalShader;
raygenEntryFn = sbt.shaders[raygenEntryIndex].entryFn;
}
if (raygenEntryFn.empty()) {
std::println("PipelineRTWebGPU::Init: no raygen shader registered");
pipelineHandle = 0;
return;
}
// ── Section 4: wavefront @compute entry points ─────────────────────
//
// Five kernels share this one module; createComputePipeline selects
// each by entryPoint name. GENERATE/RESOLVE are 8x8 screen tiles;
// TRACE/SHADE are 64-wide 1-D over the compacted ray list (dispatched
// indirectly from PREP); PREP is a single thread. The library helper
// bodies (_rtwTraverseTlas, rtEmit*, rtAccumulate, _wfCurCount, …) are
// injected JS-side at the marker above.
// GENERATE — one thread per pixel; clears the pixel's accumulator and
// runs the user raygen, which calls rtEmitPrimaryRay.
wgsl += "\n@compute @workgroup_size(8, 8, 1)\n";
wgsl += "fn wfGenerate(@builtin(global_invocation_id) gid: vec3<u32>) {\n";
wgsl += " if (gid.x >= wfParams.surfaceW || gid.y >= wfParams.surfaceH) { return; }\n";
wgsl += " let pixel = gid.y * wfParams.surfaceW + gid.x;\n";
wgsl += " wfAccum[pixel] = vec4<f32>(0.0, 0.0, 0.0, 0.0);\n";
wgsl += " _wfPixel = pixel;\n";
wgsl += " ";
wgsl += raygenEntryFn;
wgsl += "(gid);\n";
wgsl += "}\n";
// PREP — single thread; reads the live ray count and publishes the
// indirect dispatch args for the upcoming TRACE/SHADE, then zeroes the
// next buffer's emit counter so SHADE starts compacting from 0.
wgsl += "\n@compute @workgroup_size(1)\n";
wgsl += "fn wfPrep() { _wfPrep(); }\n";
// TRACE — zero user code: pure traversal + intersection. One thread
// per live ray; writes a HitResult into wfHits[i].
wgsl += "\n@compute @workgroup_size(64)\n";
wgsl += "fn wfTrace(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { _wfTrace(gid.y * nwg.x * 64u + gid.x); }\n";
// SHADE — one thread per live ray; loads the ray + its hit + payload,
// dispatches to runMiss / runClosestHit, which may rtAccumulate and
// rtEmitRay continuation/shadow rays into the next buffer.
wgsl += "\n@compute @workgroup_size(64)\n";
wgsl += "fn wfShade(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>) { _wfShade(gid.y * nwg.x * 64u + gid.x); }\n";
// RESOLVE — one thread per pixel; runs the user resolve (or identity)
// over the linear accumulator and stores to the output image.
wgsl += "\n@compute @workgroup_size(8, 8, 1)\n";
wgsl += "fn wfResolve(@builtin(global_invocation_id) gid: vec3<u32>) {\n";
wgsl += " if (gid.x >= wfParams.surfaceW || gid.y >= wfParams.surfaceH) { return; }\n";
wgsl += " let pixel = gid.y * wfParams.surfaceW + gid.x;\n";
wgsl += " let outc = runResolve(gid.xy, wfAccum[pixel]);\n";
wgsl += " textureStore(outImage, vec2<i32>(i32(gid.x), i32(gid.y)), outc);\n";
wgsl += "}\n";
pipelineHandle = WebGPU::wgpuLoadRTPipeline(
wgsl.data(),
static_cast<std::int32_t>(wgsl.size()),
userBindings.empty() ? nullptr : userBindings.data(),
static_cast<std::int32_t>(userBindings.size()));
}