templated pipeline
This commit is contained in:
parent
8275e01b6c
commit
74832c6824
8 changed files with 229 additions and 170 deletions
|
|
@ -26,10 +26,19 @@ export module Crafter.Graphics:PipelineRTVulkan;
|
|||
import std;
|
||||
import :VulkanDevice;
|
||||
import :VulkanBuffer;
|
||||
import :ShaderBindingTableVulkan;
|
||||
import :Types;
|
||||
|
||||
export namespace Crafter {
|
||||
template <typename Raygen, typename ClosestHit, typename Miss, typename ShadowClosestHit, typename ShadowMiss, typename... Shaders>
|
||||
template <std::uint32_t GeneralShader, std::uint32_t ClosestHitShader, std::uint32_t AnyHitShader, std::uint32_t IntersectionShader>
|
||||
struct ShaderGroup {
|
||||
static constexpr std::uint32_t generalShader = GeneralShader;
|
||||
static constexpr std::uint32_t closestHitShader = ClosestHitShader;
|
||||
static constexpr std::uint32_t anyHitShader = AnyHitShader;
|
||||
static constexpr std::uint32_t intersectionShader = IntersectionShader;
|
||||
};
|
||||
|
||||
template <typename Shaders, typename ShaderGroups>
|
||||
class PipelineRTVulkan {
|
||||
public:
|
||||
inline static VkPipeline pipeline;
|
||||
|
|
@ -41,7 +50,7 @@ export namespace Crafter {
|
|||
inline static VkStridedDeviceAddressRegionKHR hitRegion;
|
||||
inline static VkStridedDeviceAddressRegionKHR callableRegion;
|
||||
|
||||
static void Init(std::span<VkDescriptorSetLayout> setLayouts) {
|
||||
static void Init(VkCommandBuffer cmd, std::span<VkDescriptorSetLayout> setLayouts) {
|
||||
VkPipelineLayoutCreateInfo pipelineLayoutInfo {
|
||||
.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
|
||||
.setLayoutCount = static_cast<std::uint32_t>(setLayouts.size()),
|
||||
|
|
@ -50,144 +59,146 @@ export namespace Crafter {
|
|||
|
||||
VulkanDevice::CheckVkResult(vkCreatePipelineLayout(VulkanDevice::device, &pipelineLayoutInfo, nullptr, &pipelineLayout));
|
||||
|
||||
std::array<VkPipelineShaderStageCreateInfo, 5> shaderStages;
|
||||
constexpr auto groupIndexSeq = std::make_index_sequence<std::tuple_size_v<ShaderGroups>>{};
|
||||
|
||||
shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[0].stage = Raygen::_stage;
|
||||
shaderStages[0].module = Raygen::shader;
|
||||
shaderStages[0].pName = Raygen::_entrypoint.value;
|
||||
shaderStages[0].flags = 0;
|
||||
shaderStages[0].pSpecializationInfo = nullptr;
|
||||
shaderStages[0].pNext = nullptr;
|
||||
|
||||
shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[1].stage = Miss::_stage;
|
||||
shaderStages[1].module = Miss::shader;
|
||||
shaderStages[1].pName = Miss::_entrypoint.value;
|
||||
shaderStages[1].flags = 0;
|
||||
shaderStages[1].pSpecializationInfo = nullptr;
|
||||
shaderStages[1].pNext = nullptr;
|
||||
|
||||
shaderStages[2].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[2].stage = ClosestHit::_stage;
|
||||
shaderStages[2].module = ClosestHit::shader;
|
||||
shaderStages[2].pName = ClosestHit::_entrypoint.value;
|
||||
shaderStages[2].flags = 0;
|
||||
shaderStages[2].pSpecializationInfo = nullptr;
|
||||
shaderStages[2].pNext = nullptr;
|
||||
|
||||
shaderStages[3].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[3].stage = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
|
||||
shaderStages[3].module = ShadowClosestHit::shader;
|
||||
shaderStages[3].pName = ShadowClosestHit::_entrypoint.value;
|
||||
shaderStages[3].flags = 0;
|
||||
shaderStages[3].pSpecializationInfo = nullptr;
|
||||
shaderStages[3].pNext = nullptr;
|
||||
|
||||
shaderStages[4].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[4].stage = ShadowMiss::_stage;
|
||||
shaderStages[4].module = ShadowMiss::shader;
|
||||
shaderStages[4].pName = ShadowMiss::_entrypoint.value;
|
||||
shaderStages[4].flags = 0;
|
||||
shaderStages[4].pSpecializationInfo = nullptr;
|
||||
shaderStages[4].pNext = nullptr;
|
||||
|
||||
std::array<VkRayTracingShaderGroupCreateInfoKHR, 5> groups {{
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
||||
.generalShader = 0,
|
||||
.closestHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
},
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
||||
.generalShader = 1,
|
||||
.closestHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
},
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
|
||||
.generalShader = VK_SHADER_UNUSED_KHR,
|
||||
.closestHitShader = 2,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
},
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
|
||||
.generalShader = VK_SHADER_UNUSED_KHR,
|
||||
.closestHitShader = 3,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
},
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
||||
.generalShader = 4,
|
||||
.closestHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
}
|
||||
}};
|
||||
constexpr std::array<VkRayTracingShaderGroupCreateInfoKHR, std::tuple_size_v<ShaderGroups>> groups = GetShaderGroups(groupIndexSeq);
|
||||
|
||||
VkRayTracingPipelineCreateInfoKHR rtPipelineInfo{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR,
|
||||
.stageCount = static_cast<std::uint32_t>(shaderStages.size()),
|
||||
.pStages = shaderStages.data(),
|
||||
.stageCount = static_cast<std::uint32_t>(ShaderBindingTableVulkan<Shaders>::shaderStages.size()),
|
||||
.pStages = ShaderBindingTableVulkan<Shaders>::shaderStages.data(),
|
||||
.groupCount = static_cast<std::uint32_t>(groups.size()),
|
||||
.pGroups = groups.data(),
|
||||
.maxPipelineRayRecursionDepth = 2,
|
||||
.maxPipelineRayRecursionDepth = 1,
|
||||
.layout = pipelineLayout
|
||||
};
|
||||
|
||||
VulkanDevice::CheckVkResult(VulkanDevice::vkCreateRayTracingPipelinesKHR(VulkanDevice::device, {}, {}, 1, &rtPipelineInfo, nullptr, &pipeline));
|
||||
|
||||
std::uint32_t handleSize = VulkanDevice::rayTracingProperties.shaderGroupHandleSize;
|
||||
std::uint32_t handleAlignment = VulkanDevice::rayTracingProperties.shaderGroupHandleAlignment;
|
||||
std::uint32_t baseAlignment = VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment;
|
||||
std::uint32_t groupCount = rtPipelineInfo.groupCount;
|
||||
|
||||
std::size_t dataSize = handleSize * groupCount;
|
||||
std::size_t dataSize = VulkanDevice::rayTracingProperties.shaderGroupHandleSize * rtPipelineInfo.groupCount;
|
||||
shaderHandles.resize(dataSize);
|
||||
VulkanDevice::CheckVkResult(VulkanDevice::vkGetRayTracingShaderGroupHandlesKHR(VulkanDevice::device, pipeline, 0, groupCount, dataSize, shaderHandles.data()));
|
||||
VulkanDevice::CheckVkResult(VulkanDevice::vkGetRayTracingShaderGroupHandlesKHR(VulkanDevice::device, pipeline, 0, rtPipelineInfo.groupCount, dataSize, shaderHandles.data()));
|
||||
|
||||
std::uint32_t sbtStride = AlignUp(handleSize, handleAlignment);
|
||||
std::uint32_t raygenOffset = 0;
|
||||
std::uint32_t missOffset = AlignUp(raygenOffset + sbtStride, baseAlignment);
|
||||
std::uint32_t hitOffset = AlignUp(missOffset + sbtStride * 2, baseAlignment);
|
||||
std::uint32_t sbtStride = AlignUp(VulkanDevice::rayTracingProperties.shaderGroupHandleSize, VulkanDevice::rayTracingProperties.shaderGroupHandleAlignment);
|
||||
|
||||
std::uint32_t hitGroupCount = 2;
|
||||
std::size_t bufferSize = hitOffset + sbtStride * hitGroupCount;
|
||||
raygenRegion.stride = sbtStride;
|
||||
raygenRegion.deviceAddress = 0;
|
||||
raygenRegion.size = GetGroupCount<VK_SHADER_STAGE_RAYGEN_BIT_KHR>(groupIndexSeq) * sbtStride;
|
||||
|
||||
missRegion.stride = sbtStride;
|
||||
missRegion.deviceAddress = AlignUp(raygenRegion.size, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment);
|
||||
missRegion.size = GetGroupCount<VK_SHADER_STAGE_MISS_BIT_KHR>(groupIndexSeq) * sbtStride;
|
||||
|
||||
hitRegion.stride = sbtStride;
|
||||
hitRegion.deviceAddress = AlignUp(missRegion.deviceAddress + missRegion.size, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment);
|
||||
hitRegion.size = GetGroupCount<VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR>(groupIndexSeq) * sbtStride;
|
||||
|
||||
std::size_t bufferSize = hitRegion.deviceAddress + hitRegion.size;
|
||||
sbtBuffer.Create(VK_BUFFER_USAGE_2_SHADER_BINDING_TABLE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, bufferSize);
|
||||
|
||||
// Ray generation shader (group 0)
|
||||
std::memcpy(sbtBuffer.value + raygenOffset, shaderHandles.data() + 0 * handleSize, handleSize);
|
||||
raygenRegion.deviceAddress = sbtBuffer.address + raygenOffset;
|
||||
raygenRegion.stride = sbtStride;
|
||||
raygenRegion.size = sbtStride;
|
||||
AddShaderGroupsToBuffer(sbtStride, groupIndexSeq);
|
||||
sbtBuffer.FlushDevice(cmd, VK_ACCESS_MEMORY_READ_BIT, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR);
|
||||
|
||||
std::memcpy(sbtBuffer.value + missOffset, shaderHandles.data() + 1 * handleSize, handleSize);
|
||||
std::memcpy(sbtBuffer.value + missOffset + 1 * sbtStride, shaderHandles.data() + 4 * handleSize, handleSize);
|
||||
missRegion.deviceAddress = sbtBuffer.address + missOffset;
|
||||
missRegion.stride = sbtStride;
|
||||
missRegion.size = sbtStride * 2;
|
||||
|
||||
std::memcpy(sbtBuffer.value + hitOffset + 0 * sbtStride, shaderHandles.data() + 2 * handleSize, handleSize);
|
||||
std::memcpy(sbtBuffer.value + hitOffset + 1 * sbtStride, shaderHandles.data() + 3 * handleSize, handleSize);
|
||||
hitRegion.deviceAddress = sbtBuffer.address + hitOffset;
|
||||
hitRegion.stride = sbtStride;
|
||||
hitRegion.size = sbtStride * hitGroupCount;
|
||||
raygenRegion.deviceAddress += sbtBuffer.address;
|
||||
missRegion.deviceAddress += sbtBuffer.address;
|
||||
hitRegion.deviceAddress += sbtBuffer.address;
|
||||
|
||||
callableRegion.deviceAddress = 0;
|
||||
callableRegion.stride = 0;
|
||||
callableRegion.size = 0;
|
||||
}
|
||||
private:
|
||||
template<std::size_t index>
|
||||
consteval static void AddShaderGroup(std::array<VkRayTracingShaderGroupCreateInfoKHR, std::tuple_size_v<ShaderGroups>>& groups) {
|
||||
using groupTemplate = std::tuple_element_t<index, ShaderGroups>;
|
||||
VkRayTracingShaderGroupTypeKHR groupType;
|
||||
if constexpr(groupTemplate::generalShader != VK_SHADER_UNUSED_KHR) {
|
||||
groupType = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
|
||||
} else if constexpr(groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR) {
|
||||
groupType = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
|
||||
} else {
|
||||
static_assert(
|
||||
groupTemplate::generalShader != VK_SHADER_UNUSED_KHR ||
|
||||
groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR,
|
||||
"Shader group must define either a general or closest-hit shader"
|
||||
);
|
||||
}
|
||||
groups[index] = {
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = groupType,
|
||||
.generalShader = groupTemplate::generalShader,
|
||||
.closestHitShader = groupTemplate::closestHitShader,
|
||||
.anyHitShader = groupTemplate::anyHitShader,
|
||||
.intersectionShader = groupTemplate::intersectionShader
|
||||
};
|
||||
}
|
||||
template<std::size_t... Is>
|
||||
consteval static std::array<VkRayTracingShaderGroupCreateInfoKHR, std::tuple_size_v<ShaderGroups>> GetShaderGroups(std::index_sequence<Is...>) {
|
||||
std::array<VkRayTracingShaderGroupCreateInfoKHR, std::tuple_size_v<ShaderGroups>> groups{};
|
||||
(AddShaderGroup<Is>(groups), ...);
|
||||
return groups;
|
||||
}
|
||||
|
||||
template<std::size_t index, VkShaderStageFlagBits stage>
|
||||
consteval static void GetGroupCountImpl(std::uint32_t& count) {
|
||||
using groupTemplate = std::tuple_element_t<index, ShaderGroups>;
|
||||
if constexpr(groupTemplate::generalShader != VK_SHADER_UNUSED_KHR) {
|
||||
using shaderTemplate = std::tuple_element_t<groupTemplate::generalShader, Shaders>;
|
||||
if constexpr(shaderTemplate::_stage == stage) {
|
||||
count++;
|
||||
}
|
||||
} else if constexpr(groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR) {
|
||||
using shaderTemplate = std::tuple_element_t<groupTemplate::closestHitShader, Shaders>;
|
||||
if constexpr(shaderTemplate::_stage == stage) {
|
||||
count++;
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
groupTemplate::generalShader != VK_SHADER_UNUSED_KHR ||
|
||||
groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR,
|
||||
"Shader group must define either a general or closest-hit shader"
|
||||
);
|
||||
}
|
||||
}
|
||||
template<VkShaderStageFlagBits stage, std::size_t... Is>
|
||||
consteval static std::uint32_t GetGroupCount(std::index_sequence<Is...>) {
|
||||
std::uint32_t count = 0;
|
||||
(GetGroupCountImpl<Is, stage>(count), ...);
|
||||
return count;
|
||||
}
|
||||
|
||||
template<std::size_t index, VkShaderStageFlagBits stage>
|
||||
static void AddShaderGroupToBuffer(std::uint32_t sbtStride, std::uint32_t& offset) {
|
||||
using groupTemplate = std::tuple_element_t<index, ShaderGroups>;
|
||||
if constexpr(groupTemplate::generalShader != VK_SHADER_UNUSED_KHR) {
|
||||
using shaderTemplate = std::tuple_element_t<groupTemplate::generalShader, Shaders>;
|
||||
if constexpr(shaderTemplate::_stage == stage) {
|
||||
std::memcpy(sbtBuffer.value + offset, shaderHandles.data() + index * VulkanDevice::rayTracingProperties.shaderGroupHandleSize, VulkanDevice::rayTracingProperties.shaderGroupHandleSize);
|
||||
offset += sbtStride;
|
||||
}
|
||||
} else if constexpr(groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR) {
|
||||
using shaderTemplate = std::tuple_element_t<groupTemplate::closestHitShader, Shaders>;
|
||||
if constexpr(shaderTemplate::_stage == stage) {
|
||||
std::memcpy(sbtBuffer.value + offset, shaderHandles.data() + index * VulkanDevice::rayTracingProperties.shaderGroupHandleSize, VulkanDevice::rayTracingProperties.shaderGroupHandleSize);
|
||||
offset += sbtStride;
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
groupTemplate::generalShader != VK_SHADER_UNUSED_KHR ||
|
||||
groupTemplate::closestHitShader != VK_SHADER_UNUSED_KHR,
|
||||
"Shader group must define either a general or closest-hit shader"
|
||||
);
|
||||
}
|
||||
}
|
||||
template<std::size_t... Is>
|
||||
static void AddShaderGroupsToBuffer(std::uint32_t sbtStride, std::index_sequence<Is...>) {
|
||||
std::uint32_t offset = 0;
|
||||
(AddShaderGroupToBuffer<Is, VK_SHADER_STAGE_RAYGEN_BIT_KHR>(sbtStride, offset), ...);
|
||||
offset = AlignUp(offset, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment);
|
||||
(AddShaderGroupToBuffer<Is, VK_SHADER_STAGE_MISS_BIT_KHR>(sbtStride, offset), ...);
|
||||
offset = AlignUp(offset, VulkanDevice::rayTracingProperties.shaderGroupBaseAlignment);
|
||||
(AddShaderGroupToBuffer<Is, VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR>(sbtStride, offset), ...);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue