/* * All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or * its licensors. * * For complete copyright and license terms please see the LICENSE at the root of this * distribution (the "License"). All use of this software is governed by the License, * or, if provided, by the license below or the license accompanying this file. Do not * remove or modify any license notices. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace AZ { namespace Render { void RayTracingFeatureProcessor::Reflect(ReflectContext* context) { if (auto* serializeContext = azrtti_cast(context)) { serializeContext ->Class() ->Version(0); } } void RayTracingFeatureProcessor::Activate() { RHI::Ptr device = RHI::RHISystemInterface::Get()->GetDevice(); m_rayTracingEnabled = device->GetFeatures().m_rayTracing; if (!m_rayTracingEnabled) { return; } m_transformServiceFeatureProcessor = GetParentScene()->GetFeatureProcessor(); // initialize the ray tracing buffer pools m_bufferPools = RHI::RayTracingBufferPools::CreateRHIRayTracingBufferPools(); m_bufferPools->Init(device); // create TLAS attachmentId AZStd::string uuidString = AZ::Uuid::CreateRandom().ToString(); m_tlasAttachmentId = RHI::AttachmentId(AZStd::string::format("RayTracingTlasAttachmentId_%s", uuidString.c_str())); // create the TLAS object m_tlas = AZ::RHI::RayTracingTlas::CreateRHIRayTracingTlas(); // load the RayTracingSceneSrg asset Data::Asset rayTracingSceneSrgAsset = RPI::AssetUtils::LoadAssetByProductPath("shaderlib/raytracingscenesrg_raytracingscenesrg.azsrg", RPI::AssetUtils::TraceLevel::Error); AZ_Assert(rayTracingSceneSrgAsset.IsReady(), "Failed to load RayTracingSceneSrg asset"); m_rayTracingSceneSrg = RPI::ShaderResourceGroup::Create(rayTracingSceneSrgAsset); } void RayTracingFeatureProcessor::SetMesh(const ObjectId objectId, const SubMeshVector& subMeshes) { if (!m_rayTracingEnabled) { return; } RHI::Ptr device = RHI::RHISystemInterface::Get()->GetDevice(); uint32_t objectIndex = objectId.GetIndex(); MeshMap::iterator itMesh = m_meshes.find(objectIndex); if (itMesh == m_meshes.end()) { m_meshes.insert(AZStd::make_pair(objectIndex, Mesh{ subMeshes })); } else { // updating an existing entry // decrement the mesh count by the number of meshes in the existing entry in case the number of meshes changed m_subMeshCount -= aznumeric_cast(itMesh->second.m_subMeshes.size()); m_meshes[objectIndex].m_subMeshes = subMeshes; } // create the BLAS buffers for each sub-mesh // Note: the buffer is just reserved here, the BLAS is built in the RayTracingAccelerationStructurePass Mesh& mesh = m_meshes[objectIndex]; for (auto& subMesh : mesh.m_subMeshes) { RHI::RayTracingBlasDescriptor blasDescriptor; blasDescriptor.Build() ->Geometry() ->VertexFormat(subMesh.m_vertexFormat) ->VertexBuffer(subMesh.m_positionVertexBufferView) ->IndexBuffer(subMesh.m_indexBufferView) ; // create the BLAS object subMesh.m_blas = AZ::RHI::RayTracingBlas::CreateRHIRayTracingBlas(); // create the buffers from the descriptor subMesh.m_blas->CreateBuffers(*device, &blasDescriptor, *m_bufferPools); } // set initial transform mesh.m_transform = m_transformServiceFeatureProcessor->GetTransformForId(objectId); m_revision++; m_subMeshCount += aznumeric_cast(subMeshes.size()); m_meshInfoBufferNeedsUpdate = true; } void RayTracingFeatureProcessor::RemoveMesh(const ObjectId objectId) { if (!m_rayTracingEnabled) { return; } MeshMap::iterator itMesh = m_meshes.find(objectId.GetIndex()); if (itMesh != m_meshes.end()) { m_subMeshCount -= aznumeric_cast(itMesh->second.m_subMeshes.size()); m_meshes.erase(itMesh); m_revision++; } m_meshInfoBufferNeedsUpdate = true; } void RayTracingFeatureProcessor::SetMeshTransform(const ObjectId objectId, AZ::Transform transform) { if (!m_rayTracingEnabled) { return; } MeshMap::iterator itMesh = m_meshes.find(objectId.GetIndex()); if (itMesh != m_meshes.end()) { itMesh->second.m_transform = transform; m_revision++; } m_meshInfoBufferNeedsUpdate = true; } void RayTracingFeatureProcessor::UpdateRayTracingSceneSrg() { if (!m_tlas->GetTlasBuffer()) { return; } if (m_rayTracingSceneSrg->IsQueuedForCompile()) { //[GFX TODO][ATOM-14792] AtomSampleViewer: Reset scene and feature processors before switching to sample return; } // update the mesh info buffer with the latest ray tracing enabled meshes UpdateMeshInfoBuffer(); // update the RayTracingSceneSrg const RHI::ShaderResourceGroupLayout* srgLayout = m_rayTracingSceneSrg->GetLayout(); RHI::ShaderInputImageIndex imageIndex; RHI::ShaderInputBufferIndex bufferIndex; RHI::ShaderInputConstantIndex constantIndex; // TLAS uint32_t tlasBufferByteCount = aznumeric_cast(m_tlas->GetTlasBuffer()->GetDescriptor().m_byteCount); RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_scene")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_tlas->GetTlasBuffer()->GetBufferView(bufferViewDescriptor).get()); // directional lights const auto directionalLightFP = GetParentScene()->GetFeatureProcessor(); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_directionalLights")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, directionalLightFP->GetLightBuffer()->GetBufferView()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_directionalLightCount")); m_rayTracingSceneSrg->SetConstant(constantIndex, directionalLightFP->GetLightCount()); // spot lights const auto spotLightFP = GetParentScene()->GetFeatureProcessor(); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_spotLights")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, spotLightFP->GetLightBuffer()->GetBufferView()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_spotLightCount")); m_rayTracingSceneSrg->SetConstant(constantIndex, spotLightFP->GetLightCount()); // point lights const auto pointLightFP = GetParentScene()->GetFeatureProcessor(); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_pointLights")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, pointLightFP->GetLightBuffer()->GetBufferView()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_pointLightCount")); m_rayTracingSceneSrg->SetConstant(constantIndex, pointLightFP->GetLightCount()); // disk lights const auto diskLightFP = GetParentScene()->GetFeatureProcessor(); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_diskLights")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, diskLightFP->GetLightBuffer()->GetBufferView()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_diskLightCount")); m_rayTracingSceneSrg->SetConstant(constantIndex, diskLightFP->GetLightCount()); // capsule lights const auto capsuleLightFP = GetParentScene()->GetFeatureProcessor(); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_capsuleLights")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, capsuleLightFP->GetLightBuffer()->GetBufferView()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_capsuleLightCount")); m_rayTracingSceneSrg->SetConstant(constantIndex, capsuleLightFP->GetLightCount()); // quad lights const auto quadLightFP = GetParentScene()->GetFeatureProcessor(); bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_quadLights")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, quadLightFP->GetLightBuffer()->GetBufferView()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_quadLightCount")); m_rayTracingSceneSrg->SetConstant(constantIndex, quadLightFP->GetLightCount()); // diffuse environment map for sky hits ImageBasedLightFeatureProcessor* imageBasedLightFeatureProcessor = GetParentScene()->GetFeatureProcessor(); if (imageBasedLightFeatureProcessor) { imageIndex = srgLayout->FindShaderInputImageIndex(AZ::Name("m_diffuseEnvMap")); m_rayTracingSceneSrg->SetImage(imageIndex, imageBasedLightFeatureProcessor->GetDiffuseImage()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblOrientation")); m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetOrientation()); constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblExposure")); m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetExposure()); } bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_meshInfo")); m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_meshInfoBuffer->GetBufferView()); if (m_subMeshCount) { AZStd::vector meshBuffers; for (const auto& mesh : m_meshes) { const SubMeshVector& subMeshes = mesh.second.m_subMeshes; for (const auto& subMesh : subMeshes) { // add the index, position, and normal buffers for this sub-mesh to the mesh buffer list, this will // go into the shader as an unbounded array in the Srg meshBuffers.push_back(subMesh.m_indexShaderBufferView.get()); meshBuffers.push_back(subMesh.m_positionShaderBufferView.get()); meshBuffers.push_back(subMesh.m_normalShaderBufferView.get()); } } RHI::ShaderInputBufferUnboundedArrayIndex bufferUnboundedArrayIndex = srgLayout->FindShaderInputBufferUnboundedArrayIndex(AZ::Name("m_meshBuffers")); m_rayTracingSceneSrg->SetBufferViewUnboundedArray(bufferUnboundedArrayIndex, meshBuffers); } m_rayTracingSceneSrg->Compile(); } void RayTracingFeatureProcessor::UpdateMeshInfoBuffer() { if (m_meshInfoBufferNeedsUpdate && (m_subMeshCount > 0)) { TransformServiceFeatureProcessor* transformFeatureProcessor = GetParentScene()->GetFeatureProcessor(); AZStd::vector meshInfos; meshInfos.reserve(m_subMeshCount); uint32_t newMeshByteCount = m_subMeshCount * sizeof(MeshInfo); if (m_meshInfoBuffer == nullptr) { // allocate the MeshInfo structured buffer RPI::CommonBufferDescriptor desc; desc.m_poolType = RPI::CommonBufferPoolType::ReadOnly; desc.m_bufferName = "RayTracingMeshInfo"; desc.m_byteCount = newMeshByteCount; desc.m_elementSize = sizeof(MeshInfo); m_meshInfoBuffer = RPI::BufferSystemInterface::Get()->CreateBufferFromCommonPool(desc); } else if (m_meshInfoBuffer->GetBufferSize() < newMeshByteCount) { // resize for the new sub-mesh count m_meshInfoBuffer->Resize(newMeshByteCount); } for (const auto& mesh : m_meshes) { AZ::Transform meshTransform = transformFeatureProcessor->GetTransformForId(TransformServiceFeatureProcessorInterface::ObjectId(mesh.first)); AZ::Transform noScaleTransform = meshTransform; noScaleTransform.ExtractScale(); AZ::Matrix3x3 rotationMatrix = Matrix3x3::CreateFromTransform(noScaleTransform); rotationMatrix = rotationMatrix.GetInverseFull().GetTranspose(); const RayTracingFeatureProcessor::SubMeshVector& subMeshes = mesh.second.m_subMeshes; for (const auto& subMesh : subMeshes) { MeshInfo meshInfo; meshInfo.m_indexOffset = subMesh.m_indexBufferView.GetByteOffset(); meshInfo.m_positionOffset = subMesh.m_positionVertexBufferView.GetByteOffset(); meshInfo.m_normalOffset = subMesh.m_normalVertexBufferView.GetByteOffset(); subMesh.m_irradianceColor.StoreToFloat4(meshInfo.m_irradianceColor.data()); rotationMatrix.StoreToRowMajorFloat9(meshInfo.m_worldInvTranspose.data()); meshInfos.emplace_back(meshInfo); } } m_meshInfoBuffer->UpdateData(meshInfos.data(), newMeshByteCount); m_meshInfoBufferNeedsUpdate = false; } } } }