This commit is contained in:
Jorijn van der Graaf 2026-01-30 07:19:59 +01:00
commit 8275e01b6c
3 changed files with 62 additions and 26 deletions

View file

@ -29,7 +29,7 @@ import :VulkanBuffer;
import :Types;
export namespace Crafter {
template <typename Raygen, typename ClosestHit, typename Miss, typename... Shaders>
template <typename Raygen, typename ClosestHit, typename Miss, typename ShadowClosestHit, typename ShadowMiss, typename... Shaders>
class PipelineRTVulkan {
public:
inline static VkPipeline pipeline;
@ -50,7 +50,7 @@ export namespace Crafter {
VulkanDevice::CheckVkResult(vkCreatePipelineLayout(VulkanDevice::device, &pipelineLayoutInfo, nullptr, &pipelineLayout));
std::array<VkPipelineShaderStageCreateInfo, 3> shaderStages;
std::array<VkPipelineShaderStageCreateInfo, 5> shaderStages;
shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStages[0].stage = Raygen::_stage;
@ -76,8 +76,23 @@ export namespace Crafter {
shaderStages[2].pSpecializationInfo = nullptr;
shaderStages[2].pNext = nullptr;
shaderStages[3].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStages[3].stage = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
shaderStages[3].module = ShadowClosestHit::shader;
shaderStages[3].pName = ShadowClosestHit::_entrypoint.value;
shaderStages[3].flags = 0;
shaderStages[3].pSpecializationInfo = nullptr;
shaderStages[3].pNext = nullptr;
std::array<VkRayTracingShaderGroupCreateInfoKHR, 3> groups {{
shaderStages[4].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStages[4].stage = ShadowMiss::_stage;
shaderStages[4].module = ShadowMiss::shader;
shaderStages[4].pName = ShadowMiss::_entrypoint.value;
shaderStages[4].flags = 0;
shaderStages[4].pSpecializationInfo = nullptr;
shaderStages[4].pNext = nullptr;
std::array<VkRayTracingShaderGroupCreateInfoKHR, 5> groups {{
{
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
@ -101,6 +116,22 @@ export namespace Crafter {
.closestHitShader = 2,
.anyHitShader = VK_SHADER_UNUSED_KHR,
.intersectionShader = VK_SHADER_UNUSED_KHR
},
{
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
.generalShader = VK_SHADER_UNUSED_KHR,
.closestHitShader = 3,
.anyHitShader = VK_SHADER_UNUSED_KHR,
.intersectionShader = VK_SHADER_UNUSED_KHR
},
{
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
.generalShader = 4,
.closestHitShader = VK_SHADER_UNUSED_KHR,
.anyHitShader = VK_SHADER_UNUSED_KHR,
.intersectionShader = VK_SHADER_UNUSED_KHR
}
}};
@ -110,7 +141,7 @@ export namespace Crafter {
.pStages = shaderStages.data(),
.groupCount = static_cast<std::uint32_t>(groups.size()),
.pGroups = groups.data(),
.maxPipelineRayRecursionDepth = 1,
.maxPipelineRayRecursionDepth = 2,
.layout = pipelineLayout
};
@ -125,35 +156,33 @@ export namespace Crafter {
shaderHandles.resize(dataSize);
VulkanDevice::CheckVkResult(VulkanDevice::vkGetRayTracingShaderGroupHandlesKHR(VulkanDevice::device, pipeline, 0, groupCount, dataSize, shaderHandles.data()));
std::uint32_t raygenSize = AlignUp(handleSize, handleAlignment);
std::uint32_t missSize = AlignUp(handleSize, handleAlignment);
std::uint32_t hitSize = AlignUp(handleSize, handleAlignment);
std::uint32_t callableSize = 0;
std::uint32_t sbtStride = AlignUp(handleSize, handleAlignment);
std::uint32_t raygenOffset = 0;
std::uint32_t missOffset = AlignUp(raygenOffset + sbtStride, baseAlignment);
std::uint32_t hitOffset = AlignUp(missOffset + sbtStride * 2, baseAlignment);
std::uint32_t raygenOffset = 0;
std::uint32_t missOffset = AlignUp(raygenSize, baseAlignment);
std::uint32_t hitOffset = AlignUp(missOffset + missSize, baseAlignment);
std::uint32_t callableOffset = AlignUp(hitOffset + hitSize, baseAlignment);
std::size_t bufferSize = callableOffset + callableSize;
std::uint32_t hitGroupCount = 2;
std::size_t bufferSize = hitOffset + sbtStride * hitGroupCount;
sbtBuffer.Create(VK_BUFFER_USAGE_2_SHADER_BINDING_TABLE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, bufferSize);
// Ray generation shader (group 0)
std::memcpy(sbtBuffer.value + raygenOffset, shaderHandles.data() + 0 * handleSize, handleSize);
raygenRegion.deviceAddress = sbtBuffer.address + raygenOffset;
raygenRegion.stride = raygenSize;
raygenRegion.size = raygenSize;
raygenRegion.stride = sbtStride;
raygenRegion.size = sbtStride;
std::memcpy(sbtBuffer.value + missOffset, shaderHandles.data() + 1 * handleSize, handleSize);
std::memcpy(sbtBuffer.value + missOffset + 1 * sbtStride, shaderHandles.data() + 4 * handleSize, handleSize);
missRegion.deviceAddress = sbtBuffer.address + missOffset;
missRegion.stride = missSize;
missRegion.size = missSize;
missRegion.stride = sbtStride;
missRegion.size = sbtStride * 2;
std::memcpy(sbtBuffer.value + hitOffset, shaderHandles.data() + 2 * handleSize, handleSize);
std::memcpy(sbtBuffer.value + hitOffset + 0 * sbtStride, shaderHandles.data() + 2 * handleSize, handleSize);
std::memcpy(sbtBuffer.value + hitOffset + 1 * sbtStride, shaderHandles.data() + 3 * handleSize, handleSize);
hitRegion.deviceAddress = sbtBuffer.address + hitOffset;
hitRegion.stride = hitSize;
hitRegion.size = hitSize;
hitRegion.stride = sbtStride;
hitRegion.size = sbtStride * hitGroupCount;
callableRegion.deviceAddress = 0;
callableRegion.stride = 0;