fix(vulkan-rt): configurable recursion depth + per-shader TLAS push for compute (#21)
Two gaps in the Vulkan RT path that fault the device on the NVIDIA
proprietary driver with a non-trivial pipeline (simple VulkanTriangle
never hit them):
1. maxPipelineRayRecursionDepth was hardcoded to 1, so any closest-hit
shader that traces a secondary ray (shadow ray — a very common
pattern) recursed past the pipeline limit (UB → device fault).
PipelineRTVulkan::Init now takes a maxRecursionDepth parameter
(default 1, clamped to the device's maxRayRecursionDepth).
2. The NVIDIA descriptor-heap AS-read workaround rewrites every shader
that reads an accelerationStructureEXT from the heap — including
compute shaders — to read the TLAS device address from a push
constant, but only RTPass pushed that address. A compute shader that
ray-queries the TLAS (rayQueryEXT) therefore ran against an unwritten
push slot → garbage AS handle → VK_ERROR_DEVICE_LOST.
WorkaroundNvidiaAS::Patch now returns a per-shader PatchResult
{patched, tlasPushOffset} instead of writing the clobber-prone global
Device::workaroundTlasPushOffset (removed). VulkanShader stores it;
ShaderBindingTableVulkan/PipelineRTVulkan carry it for RTPass, and
ComputeShader tracks its own offset and pushes the caller-supplied
TLAS address in Dispatch (new defaulted tlasAddress parameter),
mirroring RTPass::Record.
The PushConstantRewrite regression test now asserts Patch's returned
patched/offset and adds two ray-querying compute-shader cases, proving
the rewrite is stage-agnostic and the per-shader offset is correct.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
2790bbd576
commit
1c310762a7
8 changed files with 248 additions and 75 deletions
|
|
@ -52,8 +52,12 @@ import :Types;
|
|||
// append a ulong member (the TLAS address) to the *existing* block and read
|
||||
// from there; only shaders with no push constant of their own get a freshly
|
||||
// synthesized single-member block. Its byte offset is the offset of that
|
||||
// member (published via Crafter::Device::workaroundTlasPushOffset) which RTPass feeds to
|
||||
// vkCmdPushDataEXT so the address lands where the rewritten load reads it.
|
||||
// member, returned in PatchResult::tlasPushOffset so the caller (RTPass for the
|
||||
// RT pipeline, ComputeShader::Dispatch for a compute pipeline) can feed it to
|
||||
// vkCmdPushDataEXT — landing the address exactly where the rewritten load reads
|
||||
// it. The offset is per-shader rather than a global: a global is clobbered by
|
||||
// whichever shader was patched last and so cannot serve several shaders whose
|
||||
// push-constant layouts differ.
|
||||
//
|
||||
// Exported so tests/PushConstantRewrite can drive Patch() over real compiled
|
||||
// SPIR-V and check the result with spirv-val; nothing in the engine calls it
|
||||
|
|
@ -87,15 +91,24 @@ export namespace WorkaroundNvidiaAS {
|
|||
return (v + a - 1u) & ~(a - 1u);
|
||||
}
|
||||
|
||||
inline void Patch(std::vector<std::uint32_t>& words) {
|
||||
if (words.size() < 5) return; // not a SPIR-V module we understand.
|
||||
// Outcome of patching one shader module. `patched` is true only when the
|
||||
// shader read an acceleration structure and was rewritten; `tlasPushOffset`
|
||||
// is then the byte offset of the TLAS-address member in the (possibly
|
||||
// pre-existing) push-constant block the caller must write.
|
||||
struct PatchResult {
|
||||
bool patched = false;
|
||||
std::uint32_t tlasPushOffset = 0;
|
||||
};
|
||||
|
||||
inline PatchResult Patch(std::vector<std::uint32_t>& words) {
|
||||
if (words.size() < 5) return {}; // not a SPIR-V module we understand.
|
||||
|
||||
// Split header (5 words) from the instruction stream.
|
||||
std::uint32_t bound = words[3];
|
||||
std::vector<Instr> instrs;
|
||||
for (std::size_t i = 5; i < words.size();) {
|
||||
std::uint32_t len = words[i] >> 16;
|
||||
if (len == 0 || i + len > words.size()) return; // malformed — bail.
|
||||
if (len == 0 || i + len > words.size()) return {}; // malformed — bail.
|
||||
instrs.emplace_back(words.begin() + i, words.begin() + i + len);
|
||||
i += len;
|
||||
}
|
||||
|
|
@ -163,7 +176,10 @@ export namespace WorkaroundNvidiaAS {
|
|||
if (op == 54 /*OpFunction*/ && firstFuncIdx == instrs.size()) firstFuncIdx = k;
|
||||
}
|
||||
|
||||
if (asTypeId == 0) return; // shader never reads an acceleration structure.
|
||||
if (asTypeId == 0) return {}; // shader never reads an acceleration structure.
|
||||
|
||||
// Set on whichever path runs below; returned to the caller.
|
||||
std::uint32_t tlasPushOffset = 0;
|
||||
|
||||
auto newId = [&] { return bound++; };
|
||||
auto mk = [](std::initializer_list<std::uint32_t> ops) {
|
||||
|
|
@ -230,7 +246,7 @@ export namespace WorkaroundNvidiaAS {
|
|||
pcVarId = existingPcVarId;
|
||||
const Instr* structInstr = typeInstr[existingPcStructId];
|
||||
memberIdx = static_cast<std::uint32_t>(structInstr->size() - 2);
|
||||
Crafter::Device::workaroundTlasPushOffset = AlignUp(footprint(existingPcStructId), 8);
|
||||
tlasPushOffset = AlignUp(footprint(existingPcStructId), 8);
|
||||
|
||||
ptrPushUlongId = existingPtrUlongId;
|
||||
if (ptrPushUlongId == 0) {
|
||||
|
|
@ -247,7 +263,7 @@ export namespace WorkaroundNvidiaAS {
|
|||
memberIdxConstId = newId();
|
||||
typeDefs.push_back(mk({OpConstant, uintTypeId, memberIdxConstId, memberIdx}));
|
||||
}
|
||||
decorations.push_back(mk({OpMemberDecorate, existingPcStructId, memberIdx, DecorationOffset, Crafter::Device::workaroundTlasPushOffset}));
|
||||
decorations.push_back(mk({OpMemberDecorate, existingPcStructId, memberIdx, DecorationOffset, tlasPushOffset}));
|
||||
} else {
|
||||
// No user push constant — synthesize a fresh single-member block.
|
||||
if (uintZeroId == 0) { uintZeroId = newId(); typeDefs.push_back(mk({OpConstant, uintTypeId, uintZeroId, 0})); }
|
||||
|
|
@ -262,7 +278,7 @@ export namespace WorkaroundNvidiaAS {
|
|||
decorations.push_back(mk({OpMemberDecorate, pcStructId, 0, DecorationOffset, 0}));
|
||||
decorations.push_back(mk({OpDecorate, pcStructId, DecorationBlock}));
|
||||
memberIdxConstId = uintZeroId;
|
||||
Crafter::Device::workaroundTlasPushOffset = 0;
|
||||
tlasPushOffset = 0;
|
||||
}
|
||||
|
||||
// ── Rewrite each `OpLoad %asType <ptr>` into address-load + convert, and
|
||||
|
|
@ -327,6 +343,8 @@ export namespace WorkaroundNvidiaAS {
|
|||
out[3] = bound;
|
||||
for (const Instr& in : instrs) out.insert(out.end(), in.begin(), in.end());
|
||||
words.swap(out);
|
||||
|
||||
return {true, tlasPushOffset};
|
||||
}
|
||||
}
|
||||
// ─── END NVIDIA descriptor-heap AS-read workaround ────────────────────────
|
||||
|
|
@ -339,6 +357,15 @@ export namespace Crafter {
|
|||
VkShaderStageFlagBits stage;
|
||||
std::string entrypoint;
|
||||
VkShaderModule shader;
|
||||
// NVIDIA descriptor-heap AS-read workaround (issue #15 / #7): set when
|
||||
// this module read an acceleration structure and was rewritten to fetch
|
||||
// the TLAS device address from a push constant. `tlasPushOffset` is the
|
||||
// byte offset of that member, which whoever records the dispatch
|
||||
// (RTPass / ComputeShader) must write with vkCmdPushDataEXT. Per-shader
|
||||
// rather than a global because each shader's push-constant layout — and
|
||||
// therefore the offset — can differ. Both false/0 on every other driver.
|
||||
bool patchedAS = false;
|
||||
std::uint32_t tlasPushOffset = 0;
|
||||
VulkanShader(const std::filesystem::path& path, std::string entrypoint, VkShaderStageFlagBits stage, VkSpecializationInfo* specilizationInfo) : stage(stage), entrypoint(entrypoint), specilizationInfo(specilizationInfo) {
|
||||
std::ifstream file(path, std::ios::binary);
|
||||
if (!file) {
|
||||
|
|
@ -364,7 +391,9 @@ export namespace Crafter {
|
|||
// acceleration structure. Remove with the rest of the workaround
|
||||
// once a fixed NVIDIA driver ships.
|
||||
if (Device::workaroundDescriptorHeapAS) {
|
||||
WorkaroundNvidiaAS::Patch(spirv);
|
||||
WorkaroundNvidiaAS::PatchResult patch = WorkaroundNvidiaAS::Patch(spirv);
|
||||
patchedAS = patch.patched;
|
||||
tlasPushOffset = patch.tlasPushOffset;
|
||||
}
|
||||
|
||||
VkShaderModuleCreateInfo module_info{VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue