fix(vulkan-rt): configurable recursion depth + per-shader TLAS push for compute (#21)

Two gaps in the Vulkan RT path that fault the device on the NVIDIA
proprietary driver with a non-trivial pipeline (simple VulkanTriangle
never hit them):

1. maxPipelineRayRecursionDepth was hardcoded to 1, so any closest-hit
   shader that traces a secondary ray (shadow ray — a very common
   pattern) recursed past the pipeline limit (UB → device fault).
   PipelineRTVulkan::Init now takes a maxRecursionDepth parameter
   (default 1, clamped to the device's maxRayRecursionDepth).

2. The NVIDIA descriptor-heap AS-read workaround rewrites every shader
   that reads an accelerationStructureEXT from the heap — including
   compute shaders — to read the TLAS device address from a push
   constant, but only RTPass pushed that address. A compute shader that
   ray-queries the TLAS (rayQueryEXT) therefore ran against an unwritten
   push slot → garbage AS handle → VK_ERROR_DEVICE_LOST.

   WorkaroundNvidiaAS::Patch now returns a per-shader PatchResult
   {patched, tlasPushOffset} instead of writing the clobber-prone global
   Device::workaroundTlasPushOffset (removed). VulkanShader stores it;
   ShaderBindingTableVulkan/PipelineRTVulkan carry it for RTPass, and
   ComputeShader tracks its own offset and pushes the caller-supplied
   TLAS address in Dispatch (new defaulted tlasAddress parameter),
   mirroring RTPass::Record.

The PushConstantRewrite regression test now asserts Patch's returned
patched/offset and adds two ray-querying compute-shader cases, proving
the rewrite is stage-agnostic and the per-shader offset is correct.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
catbot 2026-06-03 18:35:39 +00:00
commit 1c310762a7
8 changed files with 248 additions and 75 deletions

View file

@ -36,6 +36,16 @@ export namespace Crafter {
public:
VkPipeline pipeline = VK_NULL_HANDLE;
// NVIDIA descriptor-heap AS-read workaround (issue #15 / #7): set by
// Load when this shader ray-queries the TLAS through the descriptor
// heap and was rewritten to read its device address from a push
// constant. `workaroundTlasPushOffset` is the byte offset of that member
// (after the caller's own push payload, or 0 if the shader had none).
// Tracked per-shader — a global is clobbered by whichever shader was
// patched last. Both inert (false/0) on every other driver.
bool workaroundNeedsTlas = false;
std::uint32_t workaroundTlasPushOffset = 0;
ComputeShader() = default;
ComputeShader(const ComputeShader&) = delete;
ComputeShader& operator=(const ComputeShader&) = delete;
@ -50,11 +60,21 @@ export namespace Crafter {
// Bind, push constants (if any), dispatch. Caller computes group counts
// and is responsible for any inter-dispatch barriers (UIRenderer::Dispatch
// wraps this with the standard write-after-write barrier).
//
// tlasAddress is the NVIDIA descriptor-heap AS-read workaround hook
// (issue #15 / #7): a shader that ray-queries the TLAS through the
// descriptor heap is rewritten to read its device address from a push
// constant, so the caller must supply the active frame's TLAS address
// (RenderingElement3D::tlases[frameIdx].address) here. It is pushed at
// the shader's workaroundTlasPushOffset only when the shader was
// rewritten (workaroundNeedsTlas) — ignored otherwise and on every
// other driver, so shaders that don't touch an AS pass nothing.
void Dispatch(VkCommandBuffer cmd,
const void* push, std::uint32_t pushBytes,
std::uint32_t gx,
std::uint32_t gy = 1,
std::uint32_t gz = 1) const;
std::uint32_t gz = 1,
VkDeviceAddress tlasAddress = 0) const;
};
}
#endif // !CRAFTER_GRAPHICS_WINDOW_DOM

View file

@ -178,12 +178,12 @@ export namespace Crafter {
// path and RTPass pushes the active TLAS address as push data. Delete
// this flag and everything keyed on it once a fixed driver ships.
inline static bool workaroundDescriptorHeapAS = false;
// Byte offset of the TLAS-address member inside the patched raygen's
// push-constant block — 0 for a freshly synthesized block, or the end
// of the user's own block when the address is appended to it (the
// shader can't have two push-constant blocks). VulkanShader sets this
// at module load; RTPass feeds it to vkCmdPushDataEXT.
inline static std::uint32_t workaroundTlasPushOffset = 0;
// The byte offset of the TLAS-address member inside a patched shader's
// push-constant block is tracked per-shader (VulkanShader::tlasPushOffset),
// not here: a single global is clobbered by whichever shader was patched
// last and so cannot serve several shaders with differing push layouts
// (e.g. an RT raygen and a ray-querying compute shader). RTPass and
// ComputeShader read the offset off the pipeline they record.
static void CheckVkResult(VkResult result);
static std::uint32_t GetMemoryType(std::uint32_t typeBits, VkMemoryPropertyFlags properties);

View file

@ -39,7 +39,25 @@ export namespace Crafter {
VkStridedDeviceAddressRegionKHR hitRegion;
VkStridedDeviceAddressRegionKHR callableRegion;
void Init(VkCommandBuffer cmd, std::span<VkRayTracingShaderGroupCreateInfoKHR> raygenGroups, std::span<VkRayTracingShaderGroupCreateInfoKHR> missGroups, std::span<VkRayTracingShaderGroupCreateInfoKHR> hitGroups, ShaderBindingTableVulkan& shaderTable) {
// NVIDIA descriptor-heap AS-read workaround (issue #15 / #7): copied
// from the shader table at Init so RTPass can push the active TLAS
// device address into the patched shaders' push constant. Inert on
// every other driver.
bool workaroundNeedsTlas = false;
std::uint32_t workaroundTlasPushOffset = 0;
// maxRecursionDepth: the maximum ray-recursion depth the pipeline must
// support — i.e. the deepest chain of nested traceRayEXT calls. The
// raygen counts as depth 1, so a closest-hit shader that traces a shadow
// ray needs 2. Tracing beyond the value the pipeline was created with is
// undefined behaviour and faults the device, so a consumer with any
// recursion past the raygen must raise this. Defaults to 1 (raygen-only,
// matching the simple examples) and is clamped to the device's
// maxRayRecursionDepth.
void Init(VkCommandBuffer cmd, std::span<VkRayTracingShaderGroupCreateInfoKHR> raygenGroups, std::span<VkRayTracingShaderGroupCreateInfoKHR> missGroups, std::span<VkRayTracingShaderGroupCreateInfoKHR> hitGroups, ShaderBindingTableVulkan& shaderTable, std::uint32_t maxRecursionDepth = 1) {
workaroundNeedsTlas = shaderTable.workaroundNeedsTlas;
workaroundTlasPushOffset = shaderTable.workaroundTlasPushOffset;
std::vector<VkRayTracingShaderGroupCreateInfoKHR> groups;
groups.reserve(raygenGroups.size() + missGroups.size() + hitGroups.size());
@ -60,7 +78,7 @@ export namespace Crafter {
.pStages = shaderTable.shaderStages.data(),
.groupCount = static_cast<std::uint32_t>(groups.size()),
.pGroups = groups.data(),
.maxPipelineRayRecursionDepth = 1,
.maxPipelineRayRecursionDepth = std::min(maxRecursionDepth, Device::rayTracingProperties.maxRayRecursionDepth),
.layout = VK_NULL_HANDLE
};

View file

@ -42,14 +42,16 @@ export namespace Crafter {
// block that VulkanShader synthesizes, so the rewritten raygen can
// reach the acceleration structure by address instead of through
// the faulting heap descriptor. Inert on every other driver.
if (Device::workaroundDescriptorHeapAS) {
if (Device::workaroundDescriptorHeapAS && pipeline->workaroundNeedsTlas) {
VkDeviceAddress tlasAddr = RenderingElement3D::tlases[frameIdx].address;
VkPushDataInfoEXT pushInfo {
.sType = VK_STRUCTURE_TYPE_PUSH_DATA_INFO_EXT,
// Where the rewritten raygen reads the TLAS address: 0 when
// VulkanShader synthesized a fresh block, or the offset of
// the member it appended to the shader's existing block.
.offset = Device::workaroundTlasPushOffset,
// Tracked per-pipeline (copied from the shader table) so a
// later-loaded shader can't clobber it.
.offset = pipeline->workaroundTlasPushOffset,
.data = { .address = &tlasAddr, .size = sizeof(tlasAddr) },
};
Device::vkCmdPushDataEXT(cmd, &pushInfo);

View file

@ -33,10 +33,22 @@ export namespace Crafter {
class ShaderBindingTableVulkan {
public:
std::vector<VkPipelineShaderStageCreateInfo> shaderStages;
// NVIDIA descriptor-heap AS-read workaround (issue #15 / #7): true when
// any stage in this table reads an acceleration structure and was
// rewritten to fetch the TLAS address from a push constant, with the
// byte offset that stage expects it at. PipelineRTVulkan copies these so
// RTPass can push the address without consulting a clobber-prone global.
// Both inert (false/0) on every other driver.
bool workaroundNeedsTlas = false;
std::uint32_t workaroundTlasPushOffset = 0;
void Init(const std::span<const VulkanShader> shaders) {
shaderStages.reserve(shaders.size());
for(const VulkanShader& shader: shaders) {
shaderStages.emplace_back(VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, nullptr, 0, shader.stage, shader.shader, shader.entrypoint.c_str(), shader.specilizationInfo);
if (shader.patchedAS) {
workaroundNeedsTlas = true;
workaroundTlasPushOffset = shader.tlasPushOffset;
}
}
}
};

View file

@ -52,8 +52,12 @@ import :Types;
// append a ulong member (the TLAS address) to the *existing* block and read
// from there; only shaders with no push constant of their own get a freshly
// synthesized single-member block. Its byte offset is the offset of that
// member (published via Crafter::Device::workaroundTlasPushOffset) which RTPass feeds to
// vkCmdPushDataEXT so the address lands where the rewritten load reads it.
// member, returned in PatchResult::tlasPushOffset so the caller (RTPass for the
// RT pipeline, ComputeShader::Dispatch for a compute pipeline) can feed it to
// vkCmdPushDataEXT — landing the address exactly where the rewritten load reads
// it. The offset is per-shader rather than a global: a global is clobbered by
// whichever shader was patched last and so cannot serve several shaders whose
// push-constant layouts differ.
//
// Exported so tests/PushConstantRewrite can drive Patch() over real compiled
// SPIR-V and check the result with spirv-val; nothing in the engine calls it
@ -87,15 +91,24 @@ export namespace WorkaroundNvidiaAS {
return (v + a - 1u) & ~(a - 1u);
}
inline void Patch(std::vector<std::uint32_t>& words) {
if (words.size() < 5) return; // not a SPIR-V module we understand.
// Outcome of patching one shader module. `patched` is true only when the
// shader read an acceleration structure and was rewritten; `tlasPushOffset`
// is then the byte offset of the TLAS-address member in the (possibly
// pre-existing) push-constant block the caller must write.
struct PatchResult {
bool patched = false;
std::uint32_t tlasPushOffset = 0;
};
inline PatchResult Patch(std::vector<std::uint32_t>& words) {
if (words.size() < 5) return {}; // not a SPIR-V module we understand.
// Split header (5 words) from the instruction stream.
std::uint32_t bound = words[3];
std::vector<Instr> instrs;
for (std::size_t i = 5; i < words.size();) {
std::uint32_t len = words[i] >> 16;
if (len == 0 || i + len > words.size()) return; // malformed — bail.
if (len == 0 || i + len > words.size()) return {}; // malformed — bail.
instrs.emplace_back(words.begin() + i, words.begin() + i + len);
i += len;
}
@ -163,7 +176,10 @@ export namespace WorkaroundNvidiaAS {
if (op == 54 /*OpFunction*/ && firstFuncIdx == instrs.size()) firstFuncIdx = k;
}
if (asTypeId == 0) return; // shader never reads an acceleration structure.
if (asTypeId == 0) return {}; // shader never reads an acceleration structure.
// Set on whichever path runs below; returned to the caller.
std::uint32_t tlasPushOffset = 0;
auto newId = [&] { return bound++; };
auto mk = [](std::initializer_list<std::uint32_t> ops) {
@ -230,7 +246,7 @@ export namespace WorkaroundNvidiaAS {
pcVarId = existingPcVarId;
const Instr* structInstr = typeInstr[existingPcStructId];
memberIdx = static_cast<std::uint32_t>(structInstr->size() - 2);
Crafter::Device::workaroundTlasPushOffset = AlignUp(footprint(existingPcStructId), 8);
tlasPushOffset = AlignUp(footprint(existingPcStructId), 8);
ptrPushUlongId = existingPtrUlongId;
if (ptrPushUlongId == 0) {
@ -247,7 +263,7 @@ export namespace WorkaroundNvidiaAS {
memberIdxConstId = newId();
typeDefs.push_back(mk({OpConstant, uintTypeId, memberIdxConstId, memberIdx}));
}
decorations.push_back(mk({OpMemberDecorate, existingPcStructId, memberIdx, DecorationOffset, Crafter::Device::workaroundTlasPushOffset}));
decorations.push_back(mk({OpMemberDecorate, existingPcStructId, memberIdx, DecorationOffset, tlasPushOffset}));
} else {
// No user push constant — synthesize a fresh single-member block.
if (uintZeroId == 0) { uintZeroId = newId(); typeDefs.push_back(mk({OpConstant, uintTypeId, uintZeroId, 0})); }
@ -262,7 +278,7 @@ export namespace WorkaroundNvidiaAS {
decorations.push_back(mk({OpMemberDecorate, pcStructId, 0, DecorationOffset, 0}));
decorations.push_back(mk({OpDecorate, pcStructId, DecorationBlock}));
memberIdxConstId = uintZeroId;
Crafter::Device::workaroundTlasPushOffset = 0;
tlasPushOffset = 0;
}
// ── Rewrite each `OpLoad %asType <ptr>` into address-load + convert, and
@ -327,6 +343,8 @@ export namespace WorkaroundNvidiaAS {
out[3] = bound;
for (const Instr& in : instrs) out.insert(out.end(), in.begin(), in.end());
words.swap(out);
return {true, tlasPushOffset};
}
}
// ─── END NVIDIA descriptor-heap AS-read workaround ────────────────────────
@ -339,6 +357,15 @@ export namespace Crafter {
VkShaderStageFlagBits stage;
std::string entrypoint;
VkShaderModule shader;
// NVIDIA descriptor-heap AS-read workaround (issue #15 / #7): set when
// this module read an acceleration structure and was rewritten to fetch
// the TLAS device address from a push constant. `tlasPushOffset` is the
// byte offset of that member, which whoever records the dispatch
// (RTPass / ComputeShader) must write with vkCmdPushDataEXT. Per-shader
// rather than a global because each shader's push-constant layout — and
// therefore the offset — can differ. Both false/0 on every other driver.
bool patchedAS = false;
std::uint32_t tlasPushOffset = 0;
VulkanShader(const std::filesystem::path& path, std::string entrypoint, VkShaderStageFlagBits stage, VkSpecializationInfo* specilizationInfo) : stage(stage), entrypoint(entrypoint), specilizationInfo(specilizationInfo) {
std::ifstream file(path, std::ios::binary);
if (!file) {
@ -364,7 +391,9 @@ export namespace Crafter {
// acceleration structure. Remove with the rest of the workaround
// once a fixed NVIDIA driver ships.
if (Device::workaroundDescriptorHeapAS) {
WorkaroundNvidiaAS::Patch(spirv);
WorkaroundNvidiaAS::PatchResult patch = WorkaroundNvidiaAS::Patch(spirv);
patchedAS = patch.patched;
tlasPushOffset = patch.tlasPushOffset;
}
VkShaderModuleCreateInfo module_info{VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO};