/* Crafter®.Graphics Copyright (C) 2026 Catcrafts® catcrafts.net This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License version 3.0 as published by the Free Software Foundation; This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ module; #ifdef CRAFTER_GRAPHICS_VULKAN #include #endif export module Crafter.Graphics:PipelineRTVulkan; #ifdef CRAFTER_GRAPHICS_VULKAN import std; import :VulkanDevice; import :VulkanBuffer; import :ShaderBindingTableVulkan; import :Types; export namespace Crafter { template struct ShaderGroup { static constexpr std::uint32_t generalShader = GeneralShader; static constexpr std::uint32_t closestHitShader = ClosestHitShader; static constexpr std::uint32_t anyHitShader = AnyHitShader; static constexpr std::uint32_t intersectionShader = IntersectionShader; }; template class PipelineRTVulkan { public: inline static VkPipeline pipeline; inline static VkPipelineLayout pipelineLayout; inline static std::vector shaderHandles; inline static VulkanBuffer sbtBuffer; inline static VkStridedDeviceAddressRegionKHR raygenRegion; inline static VkStridedDeviceAddressRegionKHR missRegion; inline static VkStridedDeviceAddressRegionKHR hitRegion; inline static VkStridedDeviceAddressRegionKHR callableRegion; static void Init(VkCommandBuffer cmd, std::span setLayouts) { VkPipelineLayoutCreateInfo pipelineLayoutInfo { .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, .setLayoutCount = static_cast(setLayouts.size()), .pSetLayouts = setLayouts.data() }; VulkanDevice::CheckVkResult(vkCreatePipelineLayout(VulkanDevice::device, &pipelineLayoutInfo, nullptr, &pipelineLayout)); constexpr auto groupIndexSeq = std::make_index_sequence>{}; constexpr std::array> groups = GetShaderGroups(groupIndexSeq); VkRayTracingPipelineCreateInfoKHR rtPipelineInfo{ .sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR, .stageCount = static_cast(ShaderBindingTableVulkan::shaderStages.size()), .pStages = ShaderBindingTableVulkan::shaderStages.data(), .groupCount = static_cast(groups.size()), .pGroups = groups.data(), .maxPipelineRayRecursionDepth = 1, .layout = pipelineLayout }; VulkanDevice::CheckVkResult(VulkanDevice::vkCreateRayTracingPipelinesKHR(VulkanDevice::device, {}, {}, 1, &rtPipelineInfo, nullptr, &pipeline)); std::size_t dataSize = VulkanDevice::rayTracingProperties.shaderGroupHandleSize * rtPipelineInfo.groupCount; shaderHandles.resize(dataSize); VulkanDevice::CheckVkResult(VulkanDevice::vkGetRayTracingShaderGroupHandlesKHR(VulkanDevice::device, pipeline, 0, rtPipelineInfo.groupCount, dataSize, shaderHandles.data())); std::uint32_t sbtStride = AlignUp(VulkanDevice::rayTracingProperties.shaderGroupHandleSize, VulkanDevice::rayTracingProperties.shaderGroupHandleAlignment); raygenRegion.stride = sbtStride; raygenRegion.deviceAddress = 0; raygenRegion.size = GetGroupCount(groupIndexSeq) * sbtStride; missRegion.stride = sbtStride; missRegion.deviceAddress = AlignUp(raygenRegion.size, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment); missRegion.size = GetGroupCount(groupIndexSeq) * sbtStride; hitRegion.stride = sbtStride; hitRegion.deviceAddress = AlignUp(missRegion.deviceAddress + missRegion.size, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment); hitRegion.size = GetGroupCount(groupIndexSeq) * sbtStride; std::size_t bufferSize = hitRegion.deviceAddress + hitRegion.size; sbtBuffer.Create(VK_BUFFER_USAGE_2_SHADER_BINDING_TABLE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, bufferSize); AddShaderGroupsToBuffer(sbtStride, groupIndexSeq); sbtBuffer.FlushDevice(cmd, VK_ACCESS_MEMORY_READ_BIT, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR); raygenRegion.deviceAddress += sbtBuffer.address; missRegion.deviceAddress += sbtBuffer.address; hitRegion.deviceAddress += sbtBuffer.address; callableRegion.deviceAddress = 0; callableRegion.stride = 0; callableRegion.size = 0; } private: template consteval static void AddShaderGroup(std::array>& groups) { using groupTemplate = std::tuple_element_t; VkRayTracingShaderGroupTypeKHR groupType; if constexpr(groupTemplate::generalShader != VK_SHADER_UNUSED_KHR) { groupType = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; } else if constexpr(groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR) { groupType = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; } else { static_assert( groupTemplate::generalShader != VK_SHADER_UNUSED_KHR || groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR, "Shader group must define either a general or closest-hit shader" ); } groups[index] = { .sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, .type = groupType, .generalShader = groupTemplate::generalShader, .closestHitShader = groupTemplate::closestHitShader, .anyHitShader = groupTemplate::anyHitShader, .intersectionShader = groupTemplate::intersectionShader }; } template consteval static std::array> GetShaderGroups(std::index_sequence) { std::array> groups{}; (AddShaderGroup(groups), ...); return groups; } template consteval static void GetGroupCountImpl(std::uint32_t& count) { using groupTemplate = std::tuple_element_t; if constexpr(groupTemplate::generalShader != VK_SHADER_UNUSED_KHR) { using shaderTemplate = std::tuple_element_t; if constexpr(shaderTemplate::_stage == stage) { count++; } } else if constexpr(groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR) { using shaderTemplate = std::tuple_element_t; if constexpr(shaderTemplate::_stage == stage) { count++; } } else { static_assert( groupTemplate::generalShader != VK_SHADER_UNUSED_KHR || groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR, "Shader group must define either a general or closest-hit shader" ); } } template consteval static std::uint32_t GetGroupCount(std::index_sequence) { std::uint32_t count = 0; (GetGroupCountImpl(count), ...); return count; } template static void AddShaderGroupToBuffer(std::uint32_t sbtStride, std::uint32_t& offset) { using groupTemplate = std::tuple_element_t; if constexpr(groupTemplate::generalShader != VK_SHADER_UNUSED_KHR) { using shaderTemplate = std::tuple_element_t; if constexpr(shaderTemplate::_stage == stage) { std::memcpy(sbtBuffer.value + offset, shaderHandles.data() + index * VulkanDevice::rayTracingProperties.shaderGroupHandleSize, VulkanDevice::rayTracingProperties.shaderGroupHandleSize); offset += sbtStride; } } else if constexpr(groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR) { using shaderTemplate = std::tuple_element_t; if constexpr(shaderTemplate::_stage == stage) { std::memcpy(sbtBuffer.value + offset, shaderHandles.data() + index * VulkanDevice::rayTracingProperties.shaderGroupHandleSize, VulkanDevice::rayTracingProperties.shaderGroupHandleSize); offset += sbtStride; } } else { static_assert( groupTemplate::generalShader != VK_SHADER_UNUSED_KHR || groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR, "Shader group must define either a general or closest-hit shader" ); } } template static void AddShaderGroupsToBuffer(std::uint32_t sbtStride, std::index_sequence) { std::uint32_t offset = 0; (AddShaderGroupToBuffer(sbtStride, offset), ...); offset = AlignUp(offset, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment); (AddShaderGroupToBuffer(sbtStride, offset), ...); offset = AlignUp(offset, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment); (AddShaderGroupToBuffer(sbtStride, offset), ...); } }; } #endif