You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp

366 lines
17 KiB
C++

/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#include <AzCore/Asset/AssetCommon.h>
#include <AzCore/Asset/AssetManagerBus.h>
#include <Atom/RHI/CommandList.h>
#include <Atom/RHI/Factory.h>
#include <Atom/RHI/FrameScheduler.h>
#include <Atom/RHI/DispatchRaysItem.h>
#include <Atom/RHI/RHISystemInterface.h>
#include <Atom/RHI/PipelineState.h>
#include <Atom/RPI.Reflect/Pass/PassTemplate.h>
#include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
#include <Atom/RPI.Public/Base.h>
#include <Atom/RPI.Public/Pass/PassUtils.h>
#include <Atom/RPI.Public/RPIUtils.h>
#include <Atom/RPI.Public/RenderPipeline.h>
#include <Atom/RPI.Public/Scene.h>
#include <Atom/RPI.Public/View.h>
#include <RayTracing/RayTracingPass.h>
#include <RayTracing/RayTracingPassData.h>
#include <RayTracing/RayTracingFeatureProcessor.h>
namespace AZ
{
namespace Render
{
RPI::Ptr<RayTracingPass> RayTracingPass::Create(const RPI::PassDescriptor& descriptor)
{
RPI::Ptr<RayTracingPass> pass = aznew RayTracingPass(descriptor);
return pass;
}
RayTracingPass::RayTracingPass(const RPI::PassDescriptor& descriptor)
: RenderPass(descriptor)
, m_passDescriptor(descriptor)
{
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
if (device->GetFeatures().m_rayTracing == false)
{
// raytracing is not supported on this platform
SetEnabled(false);
return;
}
Init();
}
RayTracingPass::~RayTracingPass()
{
RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
}
void RayTracingPass::Init()
{
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
m_passData = RPI::PassUtils::GetPassData<RayTracingPassData>(m_passDescriptor);
if (m_passData == nullptr)
{
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Invalid RayTracingPassData", GetPathName().GetCStr());
return;
}
// ray generation shader
m_rayGenerationShader = LoadShader(m_passData->m_rayGenerationShaderAssetReference);
if (m_rayGenerationShader == nullptr)
{
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load RayGeneration shader [%s]", GetPathName().GetCStr(), m_passData->m_rayGenerationShaderAssetReference.m_filePath.data());
return;
}
auto shaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
shaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
// closest hit shader
m_closestHitShader = LoadShader(m_passData->m_closestHitShaderAssetReference);
if (m_closestHitShader == nullptr)
{
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load ClosestHit shader [%s]", GetPathName().GetCStr(), m_passData->m_closestHitShaderAssetReference.m_filePath.data());
return;
}
shaderVariant = m_closestHitShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
RHI::PipelineStateDescriptorForRayTracing closestHitShaderDescriptor;
shaderVariant.ConfigurePipelineState(closestHitShaderDescriptor);
// miss shader
m_missShader = LoadShader(m_passData->m_missShaderAssetReference);
if (m_missShader == nullptr)
{
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load Miss shader [%s]", GetPathName().GetCStr(), m_passData->m_missShaderAssetReference.m_filePath.data());
return;
}
shaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
shaderVariant.ConfigurePipelineState(missShaderDescriptor);
// retrieve global pipeline state
m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
// create global srg
const auto& globalSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingGlobalSrgBindingSlot);
AZ_Error("PassSystem", globalSrgLayout != nullptr, "RayTracingPass [%s] Failed to find RayTracingGlobalSrg layout", GetPathName().GetCStr());
m_shaderResourceGroup = RPI::ShaderResourceGroup::Create( m_rayGenerationShader->GetAsset(), m_rayGenerationShader->GetSupervariantIndex(), globalSrgLayout->GetName());
AZ_Assert(m_shaderResourceGroup, "RayTracingPass [%s]: Failed to create RayTracingGlobalSrg", GetPathName().GetCStr());
RPI::PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
// check to see if the shader requires the View, Scene, or RayTracingMaterial Srgs
const auto& viewSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::View);
m_requiresViewSrg = (viewSrgLayout != nullptr);
const auto& sceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Scene);
m_requiresSceneSrg = (sceneSrgLayout != nullptr);
const auto& rayTracingMaterialSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingMaterialSrgBindingSlot);
m_requiresRayTracingMaterialSrg = (rayTracingMaterialSrgLayout != nullptr);
// build the ray tracing pipeline state descriptor
RHI::RayTracingPipelineStateDescriptor descriptor;
descriptor.Build()
->PipelineState(m_globalPipelineState.get())
->MaxPayloadSize(m_passData->m_maxPayloadSize)
->MaxAttributeSize(m_passData->m_maxAttributeSize)
->MaxRecursionDepth(m_passData->m_maxRecursionDepth)
->ShaderLibrary(rayGenerationShaderDescriptor)
->RayGenerationShaderName(AZ::Name(m_passData->m_rayGenerationShaderName.c_str()))
->ShaderLibrary(missShaderDescriptor)
->MissShaderName(AZ::Name(m_passData->m_missShaderName.c_str()))
->ShaderLibrary(closestHitShaderDescriptor)
->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitShaderName.c_str()))
->HitGroup(AZ::Name("HitGroup"))
->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitShaderName.c_str()));
// create the ray tracing pipeline state object
m_rayTracingPipelineState = RHI::Factory::Get().CreateRayTracingPipelineState();
m_rayTracingPipelineState->Init(*device.get(), &descriptor);
// make sure the shader table rebuilds if we're hotreloading
m_rayTracingRevision = 0;
RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_rayGenerationShaderAssetReference.m_assetId);
RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitShaderAssetReference.m_assetId);
RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_missShaderAssetReference.m_assetId);
}
Data::Instance<RPI::Shader> RayTracingPass::LoadShader(const RPI::AssetReference& shaderAssetReference)
{
Data::Asset<RPI::ShaderAsset> shaderAsset;
if (shaderAssetReference.m_assetId.IsValid())
{
shaderAsset = RPI::FindShaderAsset(shaderAssetReference.m_assetId, shaderAssetReference.m_filePath);
}
if (!shaderAsset.GetId().IsValid())
{
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load shader asset [%s]", GetPathName().GetCStr(), shaderAssetReference.m_filePath.data());
return nullptr;
}
return RPI::Shader::FindOrCreate(shaderAsset);
}
void RayTracingPass::FrameBeginInternal(FramePrepareParams params)
{
RPI::Scene* scene = m_pipeline->GetScene();
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
if (!rayTracingFeatureProcessor)
{
return;
}
if (!m_rayTracingShaderTable)
{
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
m_rayTracingShaderTable->Init(*device.get(), rayTracingBufferPools);
}
RPI::RenderPass::FrameBeginInternal(params);
}
void RayTracingPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
{
RPI::Scene* scene = m_pipeline->GetScene();
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
RPI::RenderPass::SetupFrameGraphDependencies(frameGraph);
frameGraph.SetEstimatedItemCount(1);
// TLAS
{
const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer();
if (rayTracingTlasBuffer)
{
AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
{
[[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
}
uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer()->GetDescriptor().m_byteCount);
RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, tlasBufferByteCount);
RHI::BufferScopeAttachmentDescriptor desc;
desc.m_attachmentId = tlasAttachmentId;
desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
}
}
}
void RayTracingPass::CompileResources(const RHI::FrameGraphCompileContext& context)
{
RPI::Scene* scene = m_pipeline->GetScene();
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
if (m_shaderResourceGroup != nullptr)
{
BindPassSrg(context, m_shaderResourceGroup);
m_shaderResourceGroup->Compile();
}
uint32_t rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
if (m_rayTracingRevision != rayTracingRevision)
{
// scene changed, need to rebuild the shader table
m_rayTracingRevision = rayTracingRevision;
AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
if (rayTracingFeatureProcessor->GetSubMeshCount())
{
// build the ray tracing shader table descriptor
RHI::RayTracingShaderTableDescriptor* descriptorBuild = descriptor->Build(AZ::Name("RayTracingShaderTable"), m_rayTracingPipelineState)
->RayGenerationRecord(AZ::Name(m_passData->m_rayGenerationShaderName.c_str()))
->MissRecord(AZ::Name(m_passData->m_missShaderName.c_str()));
// add a hit group for each mesh to the shader table
for (uint32_t i = 0; i < rayTracingFeatureProcessor->GetSubMeshCount(); ++i)
{
descriptorBuild->HitGroupRecord(AZ::Name("HitGroup"));
}
}
m_rayTracingShaderTable->Build(descriptor);
}
}
void RayTracingPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
{
RPI::Scene* scene = m_pipeline->GetScene();
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
if (!rayTracingFeatureProcessor ||
!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer() ||
!rayTracingFeatureProcessor->GetSubMeshCount() ||
!m_rayTracingShaderTable)
{
return;
}
RHI::DispatchRaysItem dispatchRaysItem;
// calculate thread counts if this is a full screen raytracing pass
if (m_passData->m_makeFullscreenPass)
{
RPI::PassAttachment* outputAttachment = nullptr;
if (GetOutputCount() > 0)
{
outputAttachment = GetOutputBinding(0).m_attachment.get();
}
else if (GetInputOutputCount() > 0)
{
outputAttachment = GetInputOutputBinding(0).m_attachment.get();
}
AZ_Assert(outputAttachment != nullptr, "[RayTracingPass '%s']: A fullscreen RayTracing pass must have a valid output or input/output.", GetPathName().GetCStr());
AZ_Assert(outputAttachment->GetAttachmentType() == RHI::AttachmentType::Image, "[RayTracingPass '%s']: The output of a fullscreen RayTracing pass must be an image.", GetPathName().GetCStr());
RHI::Size imageSize = outputAttachment->m_descriptor.m_image.m_size;
dispatchRaysItem.m_width = imageSize.m_width;
dispatchRaysItem.m_height = imageSize.m_height;
dispatchRaysItem.m_depth = imageSize.m_depth;
}
else
{
dispatchRaysItem.m_width = m_passData->m_threadCountX;
dispatchRaysItem.m_height = m_passData->m_threadCountY;
dispatchRaysItem.m_depth = m_passData->m_threadCountZ;
}
// bind RayTracingGlobal, RayTracingScene, and View Srgs
// [GFX TODO][ATOM-15610] Add RenderPass::SetSrgsForRayTracingDispatch
AZStd::vector<RHI::ShaderResourceGroup*> shaderResourceGroups =
{
m_shaderResourceGroup->GetRHIShaderResourceGroup(),
rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()
};
if (m_requiresViewSrg)
{
const AZStd::vector<RPI::ViewPtr>& views = m_pipeline->GetViews(m_passData->m_pipelineViewTag);
if (views.size() > 0)
{
shaderResourceGroups.push_back(views[0]->GetRHIShaderResourceGroup());
}
}
if (m_requiresSceneSrg)
{
shaderResourceGroups.push_back(scene->GetShaderResourceGroup()->GetRHIShaderResourceGroup());
}
if (m_requiresRayTracingMaterialSrg)
{
shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup());
}
dispatchRaysItem.m_shaderResourceGroupCount = aznumeric_cast<uint32_t>(shaderResourceGroups.size());
dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups.data();
dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState.get();
dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable.get();
dispatchRaysItem.m_globalPipelineState = m_globalPipelineState.get();
// submit the DispatchRays item
context.GetCommandList()->Submit(dispatchRaysItem);
}
void RayTracingPass::OnShaderReinitialized([[maybe_unused]] const RPI::Shader& shader)
{
Init();
}
void RayTracingPass::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset<RPI::ShaderAsset>& shaderAsset)
{
Init();
}
void RayTracingPass::OnShaderVariantReinitialized(const RPI::ShaderVariant&)
{
Init();
}
} // namespace Render
} // namespace AZ