fixes
This commit is contained in:
parent
770313ca40
commit
8275e01b6c
3 changed files with 62 additions and 26 deletions
|
|
@ -29,7 +29,7 @@ import :VulkanBuffer;
|
|||
import :Types;
|
||||
|
||||
export namespace Crafter {
|
||||
template <typename Raygen, typename ClosestHit, typename Miss, typename... Shaders>
|
||||
template <typename Raygen, typename ClosestHit, typename Miss, typename ShadowClosestHit, typename ShadowMiss, typename... Shaders>
|
||||
class PipelineRTVulkan {
|
||||
public:
|
||||
inline static VkPipeline pipeline;
|
||||
|
|
@ -50,7 +50,7 @@ export namespace Crafter {
|
|||
|
||||
VulkanDevice::CheckVkResult(vkCreatePipelineLayout(VulkanDevice::device, &pipelineLayoutInfo, nullptr, &pipelineLayout));
|
||||
|
||||
std::array<VkPipelineShaderStageCreateInfo, 3> shaderStages;
|
||||
std::array<VkPipelineShaderStageCreateInfo, 5> shaderStages;
|
||||
|
||||
shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[0].stage = Raygen::_stage;
|
||||
|
|
@ -76,8 +76,23 @@ export namespace Crafter {
|
|||
shaderStages[2].pSpecializationInfo = nullptr;
|
||||
shaderStages[2].pNext = nullptr;
|
||||
|
||||
shaderStages[3].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[3].stage = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
|
||||
shaderStages[3].module = ShadowClosestHit::shader;
|
||||
shaderStages[3].pName = ShadowClosestHit::_entrypoint.value;
|
||||
shaderStages[3].flags = 0;
|
||||
shaderStages[3].pSpecializationInfo = nullptr;
|
||||
shaderStages[3].pNext = nullptr;
|
||||
|
||||
std::array<VkRayTracingShaderGroupCreateInfoKHR, 3> groups {{
|
||||
shaderStages[4].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
shaderStages[4].stage = ShadowMiss::_stage;
|
||||
shaderStages[4].module = ShadowMiss::shader;
|
||||
shaderStages[4].pName = ShadowMiss::_entrypoint.value;
|
||||
shaderStages[4].flags = 0;
|
||||
shaderStages[4].pSpecializationInfo = nullptr;
|
||||
shaderStages[4].pNext = nullptr;
|
||||
|
||||
std::array<VkRayTracingShaderGroupCreateInfoKHR, 5> groups {{
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
||||
|
|
@ -101,6 +116,22 @@ export namespace Crafter {
|
|||
.closestHitShader = 2,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
},
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
|
||||
.generalShader = VK_SHADER_UNUSED_KHR,
|
||||
.closestHitShader = 3,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
},
|
||||
{
|
||||
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR,
|
||||
.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
||||
.generalShader = 4,
|
||||
.closestHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.anyHitShader = VK_SHADER_UNUSED_KHR,
|
||||
.intersectionShader = VK_SHADER_UNUSED_KHR
|
||||
}
|
||||
}};
|
||||
|
||||
|
|
@ -110,7 +141,7 @@ export namespace Crafter {
|
|||
.pStages = shaderStages.data(),
|
||||
.groupCount = static_cast<std::uint32_t>(groups.size()),
|
||||
.pGroups = groups.data(),
|
||||
.maxPipelineRayRecursionDepth = 1,
|
||||
.maxPipelineRayRecursionDepth = 2,
|
||||
.layout = pipelineLayout
|
||||
};
|
||||
|
||||
|
|
@ -125,35 +156,33 @@ export namespace Crafter {
|
|||
shaderHandles.resize(dataSize);
|
||||
VulkanDevice::CheckVkResult(VulkanDevice::vkGetRayTracingShaderGroupHandlesKHR(VulkanDevice::device, pipeline, 0, groupCount, dataSize, shaderHandles.data()));
|
||||
|
||||
std::uint32_t raygenSize = AlignUp(handleSize, handleAlignment);
|
||||
std::uint32_t missSize = AlignUp(handleSize, handleAlignment);
|
||||
std::uint32_t hitSize = AlignUp(handleSize, handleAlignment);
|
||||
std::uint32_t callableSize = 0;
|
||||
std::uint32_t sbtStride = AlignUp(handleSize, handleAlignment);
|
||||
std::uint32_t raygenOffset = 0;
|
||||
std::uint32_t missOffset = AlignUp(raygenOffset + sbtStride, baseAlignment);
|
||||
std::uint32_t hitOffset = AlignUp(missOffset + sbtStride * 2, baseAlignment);
|
||||
|
||||
std::uint32_t raygenOffset = 0;
|
||||
std::uint32_t missOffset = AlignUp(raygenSize, baseAlignment);
|
||||
std::uint32_t hitOffset = AlignUp(missOffset + missSize, baseAlignment);
|
||||
std::uint32_t callableOffset = AlignUp(hitOffset + hitSize, baseAlignment);
|
||||
|
||||
std::size_t bufferSize = callableOffset + callableSize;
|
||||
std::uint32_t hitGroupCount = 2;
|
||||
std::size_t bufferSize = hitOffset + sbtStride * hitGroupCount;
|
||||
|
||||
sbtBuffer.Create(VK_BUFFER_USAGE_2_SHADER_BINDING_TABLE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, bufferSize);
|
||||
|
||||
// Ray generation shader (group 0)
|
||||
std::memcpy(sbtBuffer.value + raygenOffset, shaderHandles.data() + 0 * handleSize, handleSize);
|
||||
raygenRegion.deviceAddress = sbtBuffer.address + raygenOffset;
|
||||
raygenRegion.stride = raygenSize;
|
||||
raygenRegion.size = raygenSize;
|
||||
raygenRegion.stride = sbtStride;
|
||||
raygenRegion.size = sbtStride;
|
||||
|
||||
std::memcpy(sbtBuffer.value + missOffset, shaderHandles.data() + 1 * handleSize, handleSize);
|
||||
std::memcpy(sbtBuffer.value + missOffset + 1 * sbtStride, shaderHandles.data() + 4 * handleSize, handleSize);
|
||||
missRegion.deviceAddress = sbtBuffer.address + missOffset;
|
||||
missRegion.stride = missSize;
|
||||
missRegion.size = missSize;
|
||||
missRegion.stride = sbtStride;
|
||||
missRegion.size = sbtStride * 2;
|
||||
|
||||
std::memcpy(sbtBuffer.value + hitOffset, shaderHandles.data() + 2 * handleSize, handleSize);
|
||||
std::memcpy(sbtBuffer.value + hitOffset + 0 * sbtStride, shaderHandles.data() + 2 * handleSize, handleSize);
|
||||
std::memcpy(sbtBuffer.value + hitOffset + 1 * sbtStride, shaderHandles.data() + 3 * handleSize, handleSize);
|
||||
hitRegion.deviceAddress = sbtBuffer.address + hitOffset;
|
||||
hitRegion.stride = hitSize;
|
||||
hitRegion.size = hitSize;
|
||||
hitRegion.stride = sbtStride;
|
||||
hitRegion.size = sbtStride * hitGroupCount;
|
||||
|
||||
callableRegion.deviceAddress = 0;
|
||||
callableRegion.stride = 0;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue