/* * Copyright (c) Contributors to the Open 3D Engine Project. * For complete copyright and license terms please see the LICENSE at the root of this distribution. * * SPDX-License-Identifier: Apache-2.0 OR MIT * */ #include #include #include #include #include #include #include #include #include #include namespace AZ { namespace Render { RPI::Ptr RayTracingAccelerationStructurePass::Create(const RPI::PassDescriptor& descriptor) { RPI::Ptr rayTracingAccelerationStructurePass = aznew RayTracingAccelerationStructurePass(descriptor); return AZStd::move(rayTracingAccelerationStructurePass); } RayTracingAccelerationStructurePass::RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor) : Pass(descriptor) { // disable this pass if we're on a platform that doesn't support raytracing RHI::Ptr device = RHI::RHISystemInterface::Get()->GetDevice(); if (device->GetFeatures().m_rayTracing == false) { SetEnabled(false); } } void RayTracingAccelerationStructurePass::BuildInternal() { InitScope(RHI::ScopeId(GetPathName())); } void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params) { params.m_frameGraphBuilder->ImportScopeProducer(*this); } void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) { RHI::Ptr device = RHI::RHISystemInterface::Get()->GetDevice(); RPI::Scene* scene = m_pipeline->GetScene(); RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor(); if (rayTracingFeatureProcessor) { if (rayTracingFeatureProcessor->GetRevision() != m_rayTracingRevision) { RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools(); RayTracingFeatureProcessor::MeshMap& rayTracingMeshes = rayTracingFeatureProcessor->GetMeshes(); uint32_t rayTracingSubMeshCount = rayTracingFeatureProcessor->GetSubMeshCount(); // create the TLAS descriptor RHI::RayTracingTlasDescriptor tlasDescriptor; RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build(); uint32_t blasIndex = 0; for (auto& rayTracingMesh : rayTracingMeshes) { for (auto& rayTracingSubMesh : rayTracingMesh.second.m_subMeshes) { tlasDescriptorBuild->Instance() ->InstanceID(blasIndex) ->HitGroupIndex(blasIndex) ->Blas(rayTracingSubMesh.m_blas) ->Transform(rayTracingMesh.second.m_transform) ->NonUniformScale(rayTracingMesh.second.m_nonUniformScale) ; } blasIndex++; } // create the TLAS buffers based on the descriptor RHI::Ptr& rayTracingTlas = rayTracingFeatureProcessor->GetTlas(); rayTracingTlas->CreateBuffers(*device, &tlasDescriptor, rayTracingBufferPools); // import and attach the TLAS buffer const RHI::Ptr& rayTracingTlasBuffer = rayTracingTlas->GetTlasBuffer(); if (rayTracingTlasBuffer && rayTracingSubMeshCount) { AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId(); if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false) { [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer); AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result); } uint32_t tlasBufferByteCount = aznumeric_cast(rayTracingTlasBuffer->GetDescriptor().m_byteCount); RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount); RHI::BufferScopeAttachmentDescriptor desc; desc.m_attachmentId = tlasAttachmentId; desc.m_bufferViewDescriptor = tlasBufferViewDescriptor; desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare; frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write); } } // update and compile the RayTracingSceneSrg and RayTracingMaterialSrg // Note: the timing of this update is very important, it needs to be updated after the TLAS is allocated so it can // be set on the RayTracingSceneSrg for this frame, and the ray tracing mesh data in the RayTracingSceneSrg must // exactly match the TLAS. Any mismatch in this data may result in a TDR. rayTracingFeatureProcessor->UpdateRayTracingSrgs(); } } void RayTracingAccelerationStructurePass::BuildCommandList(const RHI::FrameGraphExecuteContext& context) { RPI::Scene* scene = m_pipeline->GetScene(); RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor(); if (!rayTracingFeatureProcessor) { return; } if (!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer()) { return; } if (rayTracingFeatureProcessor->GetRevision() == m_rayTracingRevision) { // TLAS is up to date return; } // update the stored revision, even if we don't have any meshes to process m_rayTracingRevision = rayTracingFeatureProcessor->GetRevision(); if (!rayTracingFeatureProcessor->GetSubMeshCount()) { // no ray tracing meshes in the scene return; } // build newly added BLAS objects RayTracingFeatureProcessor::BlasInstanceMap& blasInstances = rayTracingFeatureProcessor->GetBlasInstances(); for (auto& blasInstance : blasInstances) { if (blasInstance.second.m_blasBuilt == false) { for (auto& blasInstanceSubMesh : blasInstance.second.m_subMeshes) { context.GetCommandList()->BuildBottomLevelAccelerationStructure(*blasInstanceSubMesh.m_blas); } blasInstance.second.m_blasBuilt = true; } } // build the TLAS object context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas()); } } // namespace RPI } // namespace AZ