diff --git a/examples/VulkanTriangle/README.md b/examples/VulkanTriangle/README.md index 5ff2ab5..45f9fd7 100644 --- a/examples/VulkanTriangle/README.md +++ b/examples/VulkanTriangle/README.md @@ -44,12 +44,16 @@ bug (full investigation in #7, summarised below). proprietary driver only, `VulkanShader` rewrites the compiled SPIR-V at module-load time so that every `OpLoad` of an `accelerationStructureEXT` out of the heap becomes a load of the TLAS *device address* (from a -synthesized push-constant block) followed by +push-constant block) followed by `OpConvertUToAccelerationStructureKHR` — which reads no descriptor and so never touches the faulting path. `RTPass` feeds the active frame's TLAS -address in as push data. `raygen.glsl` and the example code are unchanged; -acceleration structures still bind into the heap normally. On every other -driver the workaround is inert. It's gated on +address in as push data. SPIR-V allows only one push-constant block per +entry point, so when a shader already declares one the TLAS address is +appended to *that* block (rather than adding a second, which would fail +validation — issue #18); shaders without a push constant get a freshly +synthesized single-member block. `raygen.glsl` and the example code are +unchanged; acceleration structures still bind into the heap normally. On +every other driver the workaround is inert. It's gated on `Device::workaroundDescriptorHeapAS` and confined to one fenced block in `interfaces/Crafter.Graphics-ShaderVulkan.cppm` so it can be deleted wholesale once a fixed NVIDIA driver ships. diff --git a/interfaces/Crafter.Graphics-Device.cppm b/interfaces/Crafter.Graphics-Device.cppm index 679a40c..ed331ff 100644 --- a/interfaces/Crafter.Graphics-Device.cppm +++ b/interfaces/Crafter.Graphics-Device.cppm @@ -178,6 +178,12 @@ export namespace Crafter { // path and RTPass pushes the active TLAS address as push data. Delete // this flag and everything keyed on it once a fixed driver ships. inline static bool workaroundDescriptorHeapAS = false; + // Byte offset of the TLAS-address member inside the patched raygen's + // push-constant block — 0 for a freshly synthesized block, or the end + // of the user's own block when the address is appended to it (the + // shader can't have two push-constant blocks). VulkanShader sets this + // at module load; RTPass feeds it to vkCmdPushDataEXT. + inline static std::uint32_t workaroundTlasPushOffset = 0; static void CheckVkResult(VkResult result); static std::uint32_t GetMemoryType(std::uint32_t typeBits, VkMemoryPropertyFlags properties); diff --git a/interfaces/Crafter.Graphics-RTPass.cppm b/interfaces/Crafter.Graphics-RTPass.cppm index 65907d6..d436387 100644 --- a/interfaces/Crafter.Graphics-RTPass.cppm +++ b/interfaces/Crafter.Graphics-RTPass.cppm @@ -46,7 +46,10 @@ export namespace Crafter { VkDeviceAddress tlasAddr = RenderingElement3D::tlases[frameIdx].address; VkPushDataInfoEXT pushInfo { .sType = VK_STRUCTURE_TYPE_PUSH_DATA_INFO_EXT, - .offset = 0, + // Where the rewritten raygen reads the TLAS address: 0 when + // VulkanShader synthesized a fresh block, or the offset of + // the member it appended to the shader's existing block. + .offset = Device::workaroundTlasPushOffset, .data = { .address = &tlasAddr, .size = sizeof(tlasAddr) }, }; Device::vkCmdPushDataEXT(cmd, &pushInfo); diff --git a/interfaces/Crafter.Graphics-ShaderVulkan.cppm b/interfaces/Crafter.Graphics-ShaderVulkan.cppm index 7bff83f..49a6699 100644 --- a/interfaces/Crafter.Graphics-ShaderVulkan.cppm +++ b/interfaces/Crafter.Graphics-ShaderVulkan.cppm @@ -42,22 +42,36 @@ import :Types; // // glslang has no GLSL spelling for that conversion, so we rewrite the compiled // SPIR-V at module-load time: every `OpLoad %accelStruct ` becomes a -// load of the TLAS device address from a synthesized push-constant block -// followed by OpConvertUToAccelerationStructureKHR. RTPass pushes the active -// frame's TLAS address into that push constant. Shaders that never touch an -// acceleration structure (no OpTypeAccelerationStructureKHR) are left untouched. -namespace WorkaroundNvidiaAS { +// load of the TLAS device address from a push-constant block followed by +// OpConvertUToAccelerationStructureKHR. RTPass pushes the active frame's TLAS +// address into that push constant. Shaders that never touch an acceleration +// structure (no OpTypeAccelerationStructureKHR) are left untouched. +// +// SPIR-V allows at most one push-constant variable per entry point, so we never +// add a second one: if the shader already declares a push-constant block we +// append a ulong member (the TLAS address) to the *existing* block and read +// from there; only shaders with no push constant of their own get a freshly +// synthesized single-member block. Its byte offset is the offset of that +// member (published via Crafter::Device::workaroundTlasPushOffset) which RTPass feeds to +// vkCmdPushDataEXT so the address lands where the rewritten load reads it. +// +// Exported so tests/PushConstantRewrite can drive Patch() over real compiled +// SPIR-V and check the result with spirv-val; nothing in the engine calls it +// from outside this file. Goes away with the rest of the workaround. +export namespace WorkaroundNvidiaAS { // SPIR-V numeric opcodes / enums used below. enum : std::uint32_t { OpEntryPoint = 15, OpCapability = 17, - OpTypeInt = 21, OpTypeStruct = 30, OpTypePointer = 32, + OpTypeInt = 21, OpTypeFloat = 22, OpTypeVector = 23, OpTypeMatrix = 24, + OpTypeArray = 28, OpTypeStruct = 30, OpTypePointer = 32, OpConstant = 43, OpVariable = 59, OpLoad = 61, OpAccessChain = 65, OpDecorate = 71, OpMemberDecorate = 72, OpConvertUToAccelerationStructureKHR = 4447, OpTypeAccelerationStructureKHR = 5341, CapabilityInt64 = 11, StorageClassPushConstant = 9, - DecorationBlock = 2, DecorationOffset = 35, + DecorationBlock = 2, DecorationMatrixStride = 7, + DecorationArrayStride = 6, DecorationOffset = 35, }; inline bool IsAnnotation(std::uint32_t op) { @@ -69,6 +83,10 @@ namespace WorkaroundNvidiaAS { using Instr = std::vector; + inline std::uint32_t AlignUp(std::uint32_t v, std::uint32_t a) { + return (v + a - 1u) & ~(a - 1u); + } + inline void Patch(std::vector& words) { if (words.size() < 5) return; // not a SPIR-V module we understand. @@ -82,23 +100,61 @@ namespace WorkaroundNvidiaAS { i += len; } - // ── Scan for the AS type, reusable int/long types+constants, and the - // section boundaries we need to insert into. + // ── Scan for the AS type, reusable int/long types+constants, any + // existing push-constant block, the type/decoration/constant tables + // needed to size that block, and the section boundaries to insert into. std::uint32_t asTypeId = 0, ulongTypeId = 0, uintTypeId = 0, uintZeroId = 0; + std::uint32_t existingPcVarId = 0, existingPcStructId = 0, existingPtrUlongId = 0; std::size_t lastCapIdx = 0, lastAnnotIdx = 0, firstFuncIdx = instrs.size(); std::size_t entryIdx = instrs.size(); + std::map typeInstr; // type-result-id → defining instr + std::map constU32; // OpConstant id → 32-bit value + std::map uintConstByValue; // uint value → OpConstant id + std::map arrayStride; // array type id → ArrayStride + std::map memberOffset; // (struct<<32|idx) → Offset + std::map memberMatStride; // (struct<<32|idx) → MatrixStride + std::map ptrPointee; // pointer type id → pointee type id for (std::size_t k = 0; k < instrs.size(); ++k) { - std::uint32_t op = instrs[k][0] & 0xFFFFu; + const Instr& in = instrs[k]; + std::uint32_t op = in[0] & 0xFFFFu; switch (op) { - case OpTypeAccelerationStructureKHR: asTypeId = instrs[k][1]; break; + case OpTypeAccelerationStructureKHR: asTypeId = in[1]; typeInstr[in[1]] = ∈ break; case OpTypeInt: - if (instrs[k][2] == 64 && instrs[k][3] == 0) ulongTypeId = instrs[k][1]; - else if (instrs[k][2] == 32 && instrs[k][3] == 0) uintTypeId = instrs[k][1]; + if (in[2] == 64 && in[3] == 0) ulongTypeId = in[1]; + else if (in[2] == 32 && in[3] == 0) uintTypeId = in[1]; + typeInstr[in[1]] = ∈ + break; + case OpTypeFloat: case OpTypeVector: case OpTypeMatrix: + case OpTypeArray: case OpTypeStruct: + typeInstr[in[1]] = ∈ + break; + case OpTypePointer: + typeInstr[in[1]] = ∈ ptrPointee[in[1]] = in[3]; + if (in[2] == StorageClassPushConstant && in[3] == ulongTypeId) + existingPtrUlongId = in[1]; break; case OpConstant: - if (uintTypeId && instrs[k][1] == uintTypeId && instrs[k][3] == 0) - uintZeroId = instrs[k][2]; + if (in.size() >= 4) constU32[in[2]] = in[3]; + if (uintTypeId && in[1] == uintTypeId && in.size() >= 4) { + uintConstByValue.emplace(in[3], in[2]); + if (in[3] == 0) uintZeroId = in[2]; + } break; + case OpVariable: + if (in[3] == StorageClassPushConstant) { + existingPcVarId = in[2]; + existingPcStructId = ptrPointee.count(in[1]) ? ptrPointee[in[1]] : 0; + } + break; + case OpDecorate: + if (in.size() >= 4 && in[2] == DecorationArrayStride) arrayStride[in[1]] = in[3]; + break; + case OpMemberDecorate: { + std::uint64_t key = (static_cast(in[1]) << 32) | in[2]; + if (in.size() >= 5 && in[3] == DecorationOffset) memberOffset[key] = in[4]; + if (in.size() >= 5 && in[3] == DecorationMatrixStride) memberMatStride[key] = in[4]; + break; + } case OpCapability: lastCapIdx = k; break; case OpEntryPoint: if (entryIdx == instrs.size()) entryIdx = k; break; default: break; @@ -116,73 +172,153 @@ namespace WorkaroundNvidiaAS { return in; }; - // ── Synthesize the types/constants/push-constant we need, reusing any - // the module already defines (SPIR-V forbids duplicate type defs). - std::vector typeDefs; - if (uintTypeId == 0) { - uintTypeId = newId(); - typeDefs.push_back(mk({OpTypeInt, uintTypeId, 32, 0})); - } - if (uintZeroId == 0) { - uintZeroId = newId(); - typeDefs.push_back(mk({OpConstant, uintTypeId, uintZeroId, 0})); - } - if (ulongTypeId == 0) { - ulongTypeId = newId(); - typeDefs.push_back(mk({OpTypeInt, ulongTypeId, 64, 0})); - } - std::uint32_t pcStructId = newId(); - std::uint32_t ptrPushStructId = newId(); - std::uint32_t ptrPushUlongId = newId(); - std::uint32_t pcVarId = newId(); - typeDefs.push_back(mk({OpTypeStruct, pcStructId, ulongTypeId})); - typeDefs.push_back(mk({OpTypePointer, ptrPushStructId, StorageClassPushConstant, pcStructId})); - typeDefs.push_back(mk({OpTypePointer, ptrPushUlongId, StorageClassPushConstant, ulongTypeId})); - typeDefs.push_back(mk({OpVariable, ptrPushStructId, pcVarId, StorageClassPushConstant})); - - std::vector decorations = { - mk({OpMemberDecorate, pcStructId, 0, DecorationOffset, 0}), - mk({OpDecorate, pcStructId, DecorationBlock}), + // Byte footprint of a type, honouring the explicit Array/Matrix strides + // glslang emits so the result is correct under both scalar and std140 + // block layout. Used only to find where an existing push block ends. + std::function footprint = + [&](std::uint32_t tid) -> std::uint32_t { + auto it = typeInstr.find(tid); + if (it == typeInstr.end()) return 0; + const Instr& t = *it->second; + switch (t[0] & 0xFFFFu) { + case OpTypeInt: case OpTypeFloat: return t[2] / 8u; + case OpTypeVector: return t[3] * footprint(t[2]); + case OpTypeMatrix: return t[3] * footprint(t[2]); // cols × column-vec + case OpTypeArray: { + std::uint32_t len = constU32.count(t[3]) ? constU32[t[3]] : 0; + std::uint32_t stride = arrayStride.count(tid) ? arrayStride[tid] + : footprint(t[2]); + return len * stride; + } + case OpTypeStruct: { + std::uint32_t end = 0; + for (std::size_t m = 2; m < t.size(); ++m) { + std::uint32_t idx = static_cast(m - 2); + std::uint64_t key = (static_cast(t[1]) << 32) | idx; + std::uint32_t off = memberOffset.count(key) ? memberOffset[key] : 0; + std::uint32_t sz; + auto mt = typeInstr.find(t[m]); + if (mt != typeInstr.end() && (mt->second->at(0) & 0xFFFFu) == OpTypeMatrix + && memberMatStride.count(key)) + sz = memberMatStride[key] * (*mt->second)[3]; + else + sz = footprint(t[m]); + end = std::max(end, off + sz); + } + return end; + } + case OpTypePointer: return 8; + default: return 0; + } }; - // ── Rewrite each `OpLoad %asType ` into address-load + convert. + bool merge = existingPcVarId != 0 && existingPcStructId != 0 + && typeInstr.count(existingPcStructId) + && (typeInstr[existingPcStructId]->at(0) & 0xFFFFu) == OpTypeStruct; + + // ── Synthesize/ensure the int/long types and constants we need, reusing + // any the module already defines (SPIR-V forbids duplicate type defs). + std::vector typeDefs; + if (uintTypeId == 0) { uintTypeId = newId(); typeDefs.push_back(mk({OpTypeInt, uintTypeId, 32, 0})); } + if (ulongTypeId == 0) { ulongTypeId = newId(); typeDefs.push_back(mk({OpTypeInt, ulongTypeId, 64, 0})); } + + std::uint32_t pcVarId, ptrPushUlongId, memberIdxConstId, memberIdx; + std::vector decorations; + + if (merge) { + // Append a ulong member to the user's existing block; read from it. + pcVarId = existingPcVarId; + const Instr* structInstr = typeInstr[existingPcStructId]; + memberIdx = static_cast(structInstr->size() - 2); + Crafter::Device::workaroundTlasPushOffset = AlignUp(footprint(existingPcStructId), 8); + + ptrPushUlongId = existingPtrUlongId; + if (ptrPushUlongId == 0) { + ptrPushUlongId = newId(); + typeDefs.push_back(mk({OpTypePointer, ptrPushUlongId, StorageClassPushConstant, ulongTypeId})); + } + // Member index constant for the access chain — reuse an existing + // uint constant of the right value, else mint one (must be an + // integer constant, so only uint-typed ones qualify for reuse). + auto found = uintConstByValue.find(memberIdx); + if (found != uintConstByValue.end()) { + memberIdxConstId = found->second; + } else { + memberIdxConstId = newId(); + typeDefs.push_back(mk({OpConstant, uintTypeId, memberIdxConstId, memberIdx})); + } + decorations.push_back(mk({OpMemberDecorate, existingPcStructId, memberIdx, DecorationOffset, Crafter::Device::workaroundTlasPushOffset})); + } else { + // No user push constant — synthesize a fresh single-member block. + if (uintZeroId == 0) { uintZeroId = newId(); typeDefs.push_back(mk({OpConstant, uintTypeId, uintZeroId, 0})); } + std::uint32_t pcStructId = newId(); + std::uint32_t ptrPushStructId = newId(); + ptrPushUlongId = newId(); + pcVarId = newId(); + typeDefs.push_back(mk({OpTypeStruct, pcStructId, ulongTypeId})); + typeDefs.push_back(mk({OpTypePointer, ptrPushStructId, StorageClassPushConstant, pcStructId})); + typeDefs.push_back(mk({OpTypePointer, ptrPushUlongId, StorageClassPushConstant, ulongTypeId})); + typeDefs.push_back(mk({OpVariable, ptrPushStructId, pcVarId, StorageClassPushConstant})); + decorations.push_back(mk({OpMemberDecorate, pcStructId, 0, DecorationOffset, 0})); + decorations.push_back(mk({OpDecorate, pcStructId, DecorationBlock})); + memberIdxConstId = uintZeroId; + Crafter::Device::workaroundTlasPushOffset = 0; + } + + // ── Rewrite each `OpLoad %asType ` into address-load + convert, and + // (when merging) append the ulong member to the existing struct type. std::vector rebuilt; rebuilt.reserve(instrs.size() + 8); - for (const Instr& in : instrs) { + for (Instr in : instrs) { std::uint32_t op = in[0] & 0xFFFFu; if (op == OpLoad && in[1] == asTypeId) { std::uint32_t resultId = in[2]; std::uint32_t chainId = newId(); std::uint32_t addrId = newId(); - rebuilt.push_back(mk({OpAccessChain, ptrPushUlongId, chainId, pcVarId, uintZeroId})); + rebuilt.push_back(mk({OpAccessChain, ptrPushUlongId, chainId, pcVarId, memberIdxConstId})); rebuilt.push_back(mk({OpLoad, ulongTypeId, addrId, chainId})); rebuilt.push_back(mk({OpConvertUToAccelerationStructureKHR, asTypeId, resultId, addrId})); } else { - rebuilt.push_back(in); + if (merge && op == OpTypeStruct && in[1] == existingPcStructId) { + in.push_back(ulongTypeId); + in[0] = static_cast(in.size() << 16) | OpTypeStruct; + } + rebuilt.push_back(std::move(in)); } } instrs.swap(rebuilt); // Recompute structural anchors (the rewrite above shifted indices). lastCapIdx = 0; lastAnnotIdx = 0; firstFuncIdx = instrs.size(); entryIdx = instrs.size(); + std::size_t structIdx = instrs.size(); for (std::size_t k = 0; k < instrs.size(); ++k) { std::uint32_t op = instrs[k][0] & 0xFFFFu; if (op == OpCapability) lastCapIdx = k; if (op == OpEntryPoint && entryIdx == instrs.size()) entryIdx = k; if (IsAnnotation(op)) lastAnnotIdx = k; if (op == 54 && firstFuncIdx == instrs.size()) firstFuncIdx = k; + if (merge && op == OpTypeStruct && instrs[k][1] == existingPcStructId) structIdx = k; } - // Append the push-constant variable to the entry point's interface - // list (required for SPIR-V ≥ 1.4 — both raygen modules are 1.4). - if (entryIdx != instrs.size() && words[1] >= 0x00010400u) { + // The newly-defined types (notably ulong) must precede every use. When + // merging, the user's struct — now carrying the appended ulong member — + // already sits in the type section, so the defs go in just before it; + // for a fresh block the whole bundle can go at the end of the type + // section (right before the first function). + std::size_t typeDefsIdx = (merge && structIdx != instrs.size()) ? structIdx : firstFuncIdx; + + // A freshly synthesized push-constant variable must join the entry + // point's interface list (required for SPIR-V ≥ 1.4 — raygen is 1.4). + // A merged-into variable is already used, so it is already listed. + if (!merge && entryIdx != instrs.size() && words[1] >= 0x00010400u) { instrs[entryIdx].push_back(pcVarId); instrs[entryIdx][0] = static_cast(instrs[entryIdx].size() << 16) | OpEntryPoint; } - // Insert highest-index-first so earlier anchors stay valid. - instrs.insert(instrs.begin() + firstFuncIdx, typeDefs.begin(), typeDefs.end()); + // Insert highest-index-first so earlier anchors stay valid (typeDefsIdx + // ≥ lastAnnotIdx+1 ≥ lastCapIdx+1 in both the merge and synthesize cases). + instrs.insert(instrs.begin() + typeDefsIdx, typeDefs.begin(), typeDefs.end()); instrs.insert(instrs.begin() + lastAnnotIdx + 1, decorations.begin(), decorations.end()); instrs.insert(instrs.begin() + lastCapIdx + 1, mk({OpCapability, CapabilityInt64})); diff --git a/project.cpp b/project.cpp index c5d39a1..62df28d 100644 --- a/project.cpp +++ b/project.cpp @@ -205,6 +205,38 @@ extern "C" Configuration CrafterBuildProject(std::span a cfg.shaders.emplace_back(fs::path("shaders/ui-images.comp.glsl"), std::string("main"), ShaderType::Compute); cfg.shaders.emplace_back(fs::path("shaders/ui-text.comp.glsl"), std::string("main"), ShaderType::Compute); cfg.buildFiles.emplace_back(fs::path("shaders/ui-shared.glsl")); + + // Regression test for issue #18: drive the NVIDIA descriptor-heap + // AS-read workaround's SPIR-V rewrite over real compiled shaders and + // check the result with spirv-val (one push-constant block, correct + // TLAS offset). The test executable recompiles the whole module plus + // tests/PushConstantRewrite/main.cpp; Configuration isn't copyable + // (it owns the parsed module list), so the shared build settings are + // mirrored field by field. glslang and spirv-val are invoked at + // runtime, so the test declares them as required tools. Remove with + // the rest of the workaround. + Test pcTest; + Configuration& tc = pcTest.config; + tc.path = cfg.path; + tc.name = "PushConstantRewrite"; + tc.outputName = "PushConstantRewrite"; + tc.type = ConfigurationType::Executable; + tc.target = cfg.target; + tc.march = cfg.march; + tc.mtune = cfg.mtune; + tc.debug = cfg.debug; + tc.sysroot = cfg.sysroot; + tc.dependencies = cfg.dependencies; + tc.externalDependencies = cfg.externalDependencies; + tc.compileFlags = cfg.compileFlags; + tc.linkFlags = cfg.linkFlags; + tc.defines = cfg.defines; + tc.cFiles = cfg.cFiles; + std::vector testImpls(impls.begin(), impls.end()); + testImpls.emplace_back("tests/PushConstantRewrite/main"); + tc.GetInterfacesAndImplementations(ifaces, testImpls); + pcTest.requires_ = { "tool:glslang", "tool:spirv-val" }; + cfg.tests.push_back(std::move(pcTest)); } return cfg; diff --git a/tests/PushConstantRewrite/main.cpp b/tests/PushConstantRewrite/main.cpp new file mode 100644 index 0000000..d55ef9f --- /dev/null +++ b/tests/PushConstantRewrite/main.cpp @@ -0,0 +1,228 @@ +/* +Crafter®.Graphics +Copyright (C) 2026 Catcrafts® +catcrafts.net + +This library is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License version 3.0 as published by the Free Software Foundation; + +This library is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public +License along with this library; if not, write to the Free Software +Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +// Regression test for issue #18: the NVIDIA descriptor-heap AS-read workaround +// (WorkaroundNvidiaAS::Patch) used to bolt a brand-new push-constant block onto +// every patched ray-tracing shader. SPIR-V allows at most one push-constant +// block statically used per entry point, so any shader that already declared +// one ended up with two and failed spirv-val: +// +// Entry point id '4' uses more than one PushConstant interface. +// +// This test compiles representative ray-generation shaders with glslang, runs +// them through the real Patch(), and asserts with spirv-val that the result is +// valid and contains exactly one push-constant variable — both for shaders +// that already have a push constant (merge path) and for those that don't +// (synthesize path). It also checks the published TLAS push-constant offset. +// +// Delete this test together with the rest of the workaround once a fixed NVIDIA +// driver ships. + +#include "vulkan/vulkan.h" +#include + +import Crafter.Graphics; +import std; +using namespace Crafter; + +namespace { + +namespace fs = std::filesystem; + +int RunCommand(const std::string& cmd) { + int status = std::system(cmd.c_str()); + if (status == -1) return -1; + // Mirror WEXITSTATUS without pulling in : glibc encodes the + // exit code in bits 8..15 of the wait status when the low byte is zero. + return (status & 0x7f) == 0 ? ((status >> 8) & 0xff) : 128 + (status & 0x7f); +} + +std::vector ReadSpirv(const fs::path& p) { + std::ifstream f(p, std::ios::binary | std::ios::ate); + if (!f) return {}; + std::streamsize size = f.tellg(); + f.seekg(0); + std::vector words(static_cast(size) / sizeof(std::uint32_t)); + f.read(reinterpret_cast(words.data()), size); + return words; +} + +void WriteSpirv(const fs::path& p, const std::vector& words) { + std::ofstream f(p, std::ios::binary); + f.write(reinterpret_cast(words.data()), + static_cast(words.size() * sizeof(std::uint32_t))); +} + +// Count OpVariable instructions in the PushConstant storage class (SC == 9). +int CountPushConstantVariables(const std::vector& words) { + constexpr std::uint32_t OpVariable = 59; + constexpr std::uint32_t StorageClassPushConstant = 9; + int count = 0; + for (std::size_t i = 5; i < words.size();) { + std::uint32_t len = words[i] >> 16; + if (len == 0 || i + len > words.size()) break; + if ((words[i] & 0xFFFFu) == OpVariable && len >= 4 && words[i + 3] == StorageClassPushConstant) + ++count; + i += len; + } + return count; +} + +struct Case { + std::string_view name; + std::string_view glsl; + bool readsAccelStruct; // whether Patch should rewrite anything + bool hasExistingPushConst; // whether the source already declares a push block + std::uint32_t expectedOffset; // expected Device::workaroundTlasPushOffset (only checked when readsAccelStruct) +}; + +// Shared raygen scaffolding: a heap AS + heap image, traced and stored to. +constexpr std::string_view kHeader = + "#version 460\n" + "#extension GL_EXT_ray_tracing : enable\n" + "#extension GL_EXT_shader_image_load_formatted : enable\n" + "#extension GL_EXT_descriptor_heap : enable\n" + "#extension GL_EXT_nonuniform_qualifier : enable\n" + "layout(descriptor_heap) uniform accelerationStructureEXT topLevelAS[];\n" + "layout(descriptor_heap) uniform writeonly image2D image[];\n" + "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"; + +const std::array kCases = {{ + // No push constant at all → Patch synthesizes a fresh single-member block at offset 0. + { "no-push-constant", std::string_view{ + "" + }, true, false, 0 }, + + // Existing block {mat4 @0, vec3 @64, uint @76}; ends at 80, already 8-aligned. + { "merge-mat4-vec3-uint", std::string_view{ + "layout(push_constant) uniform PC { mat4 m; vec3 l; uint f; } pc;\n" + }, true, true, 80 }, + + // Existing block {uint @0}; ends at 4, TLAS rounds up to the next 8. + { "merge-uint", std::string_view{ + "layout(push_constant) uniform PC { uint f; } pc;\n" + }, true, true, 8 }, + + // Existing block {vec4 v[2] @0 (32 bytes), uint @32}; ends at 36, rounds to 40. + { "merge-array", std::string_view{ + "layout(push_constant) uniform PC { vec4 v[2]; uint f; } pc;\n" + }, true, true, 40 }, + + // Push constant but NO acceleration-structure read → Patch is a no-op; the + // single user block must survive untouched and still validate. + { "push-constant-no-as", std::string_view{ + "layout(push_constant) uniform PC { vec4 tint; } pc;\n" + }, false, true, 0 }, +}}; + +std::string BuildSource(const Case& c) { + std::string s(kHeader); + s += c.glsl; + s += "void main() {\n"; + s += " uvec2 pixel = gl_LaunchIDEXT.xy;\n"; + s += " vec3 origin = vec3(0.0, 0.0, -300.0);\n"; + s += " vec3 dir = normalize(vec3(0.0, 0.0, 1.0));\n"; + if (c.readsAccelStruct) + s += " traceRayEXT(topLevelAS[0], gl_RayFlagsNoneEXT, 0xff, 0,0,0, origin, 0.001, dir, 10000.0, 0);\n"; + // Reference the push constant so glslang keeps the block in the module. + std::string_view g = c.glsl; + std::string extra = "vec4(hitValue, 1.0)"; + if (g.find("mat4 m;") != std::string_view::npos) + extra = "pc.m * vec4(hitValue, 1.0) + vec4(pc.l, float(pc.f))"; + else if (g.find("uint f; } pc;") != std::string_view::npos && g.find("vec4 v[2]") != std::string_view::npos) + extra = "vec4(hitValue, 1.0) + pc.v[0] + pc.v[1] + vec4(float(pc.f))"; + else if (g.find("uint f; } pc;") != std::string_view::npos) + extra = "vec4(hitValue, float(pc.f))"; + else if (g.find("vec4 tint;") != std::string_view::npos) + extra = "vec4(hitValue, 1.0) + pc.tint"; + s += " imageStore(image[0], ivec2(pixel), " + extra + ");\n"; + s += "}\n"; + return s; +} + +} // namespace + +int main() { + const fs::path dir = fs::temp_directory_path() / "crafter-pcrewrite-test"; + std::error_code ec; + fs::create_directories(dir, ec); + + int failures = 0; + for (const Case& c : kCases) { + const fs::path glslPath = dir / (std::string(c.name) + ".rgen.glsl"); + const fs::path spvPath = dir / (std::string(c.name) + ".spv"); + const fs::path patched = dir / (std::string(c.name) + ".patched.spv"); + + { std::ofstream f(glslPath); f << BuildSource(c); } + + std::string compile = "glslang --target-env vulkan1.4 -V -S rgen \"" + + glslPath.string() + "\" -o \"" + spvPath.string() + "\" > /dev/null"; + if (RunCommand(compile) != 0) { + std::println(std::cerr, "[{}] glslang failed to compile the source shader", c.name); + ++failures; + continue; + } + + std::vector words = ReadSpirv(spvPath); + if (words.size() < 5) { + std::println(std::cerr, "[{}] could not read compiled SPIR-V", c.name); + ++failures; + continue; + } + + Device::workaroundTlasPushOffset = 0xDEADBEEFu; // poison so we know Patch set it + WorkaroundNvidiaAS::Patch(words); + WriteSpirv(patched, words); + + // 1. The patched module must pass spirv-val under the engine's flags. + std::string validate = "spirv-val \"" + patched.string() + + "\" --relax-block-layout --scalar-block-layout --target-env vulkan1.4"; + if (RunCommand(validate) != 0) { + std::println(std::cerr, "[{}] spirv-val rejected the patched module", c.name); + ++failures; + continue; + } + + // 2. Exactly one push-constant variable — the whole point of issue #18. + int pcVars = CountPushConstantVariables(words); + if (pcVars != 1) { + std::println(std::cerr, "[{}] expected exactly 1 push-constant variable, found {}", c.name, pcVars); + ++failures; + continue; + } + + // 3. The TLAS offset Patch published must match the expected layout end. + if (c.readsAccelStruct && Device::workaroundTlasPushOffset != c.expectedOffset) { + std::println(std::cerr, "[{}] expected TLAS push offset {}, got {}", + c.name, c.expectedOffset, Device::workaroundTlasPushOffset); + ++failures; + continue; + } + + std::println(std::cout, "[{}] ok (push-constant vars: {}, tlas offset: {})", + c.name, pcVars, c.readsAccelStruct ? Device::workaroundTlasPushOffset : 0u); + } + + if (failures != 0) { + std::println(std::cerr, "{} case(s) failed", failures); + return 1; + } + std::println(std::cout, "all push-constant rewrite cases passed"); + return 0; +}