You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingFeatureProcessor.cpp

334 lines
16 KiB
C++

/*
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
* its licensors.
*
* For complete copyright and license terms please see the LICENSE at the root of this
* distribution (the "License"). All use of this software is governed by the License,
* or, if provided, by the license below or the license accompanying this file. Do not
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
*/
#include <RayTracing/RayTracingFeatureProcessor.h>
#include <AzCore/Debug/EventTrace.h>
#include <Atom/Feature/TransformService/TransformServiceFeatureProcessor.h>
#include <Atom/RHI/CpuProfiler.h>
#include <Atom/RHI/Factory.h>
#include <Atom/RHI/RHISystemInterface.h>
#include <Atom/RPI.Public/Scene.h>
#include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
#include <Atom/RPI.Reflect/Asset/AssetUtils.h>
#include <Atom/Feature/ImageBasedLights/ImageBasedLightFeatureProcessor.h>
#include <CoreLights/DirectionalLightFeatureProcessor.h>
#include <CoreLights/SpotLightFeatureProcessor.h>
#include <CoreLights/PointLightFeatureProcessor.h>
#include <CoreLights/DiskLightFeatureProcessor.h>
#include <CoreLights/CapsuleLightFeatureProcessor.h>
#include <CoreLights/QuadLightFeatureProcessor.h>
namespace AZ
{
namespace Render
{
void RayTracingFeatureProcessor::Reflect(ReflectContext* context)
{
if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
{
serializeContext
->Class<RayTracingFeatureProcessor, FeatureProcessor>()
->Version(0);
}
}
void RayTracingFeatureProcessor::Activate()
{
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
m_rayTracingEnabled = device->GetFeatures().m_rayTracing;
if (!m_rayTracingEnabled)
{
return;
}
m_transformServiceFeatureProcessor = GetParentScene()->GetFeatureProcessor<TransformServiceFeatureProcessor>();
// initialize the ray tracing buffer pools
m_bufferPools = RHI::RayTracingBufferPools::CreateRHIRayTracingBufferPools();
m_bufferPools->Init(device);
// create TLAS attachmentId
AZStd::string uuidString = AZ::Uuid::CreateRandom().ToString<AZStd::string>();
m_tlasAttachmentId = RHI::AttachmentId(AZStd::string::format("RayTracingTlasAttachmentId_%s", uuidString.c_str()));
// create the TLAS object
m_tlas = AZ::RHI::RayTracingTlas::CreateRHIRayTracingTlas();
// load the RayTracingSceneSrg asset
Data::Asset<RPI::ShaderResourceGroupAsset> rayTracingSceneSrgAsset =
RPI::AssetUtils::LoadAssetByProductPath<RPI::ShaderResourceGroupAsset>("shaderlib/raytracingscenesrg_raytracingscenesrg.azsrg", RPI::AssetUtils::TraceLevel::Error);
AZ_Assert(rayTracingSceneSrgAsset.IsReady(), "Failed to load RayTracingSceneSrg asset");
m_rayTracingSceneSrg = RPI::ShaderResourceGroup::Create(rayTracingSceneSrgAsset);
}
void RayTracingFeatureProcessor::SetMesh(const ObjectId objectId, const SubMeshVector& subMeshes)
{
if (!m_rayTracingEnabled)
{
return;
}
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
uint32_t objectIndex = objectId.GetIndex();
MeshMap::iterator itMesh = m_meshes.find(objectIndex);
if (itMesh == m_meshes.end())
{
m_meshes.insert(AZStd::make_pair(objectIndex, Mesh{ subMeshes }));
}
else
{
// updating an existing entry
// decrement the mesh count by the number of meshes in the existing entry in case the number of meshes changed
m_subMeshCount -= aznumeric_cast<uint32_t>(itMesh->second.m_subMeshes.size());
m_meshes[objectIndex].m_subMeshes = subMeshes;
}
// create the BLAS buffers for each sub-mesh
// Note: the buffer is just reserved here, the BLAS is built in the RayTracingAccelerationStructurePass
Mesh& mesh = m_meshes[objectIndex];
for (auto& subMesh : mesh.m_subMeshes)
{
RHI::RayTracingBlasDescriptor blasDescriptor;
blasDescriptor.Build()
->Geometry()
->VertexFormat(subMesh.m_vertexFormat)
->VertexBuffer(subMesh.m_positionVertexBufferView)
->IndexBuffer(subMesh.m_indexBufferView)
;
// create the BLAS object
subMesh.m_blas = AZ::RHI::RayTracingBlas::CreateRHIRayTracingBlas();
// create the buffers from the descriptor
subMesh.m_blas->CreateBuffers(*device, &blasDescriptor, *m_bufferPools);
}
// set initial transform
mesh.m_transform = m_transformServiceFeatureProcessor->GetTransformForId(objectId);
m_revision++;
m_subMeshCount += aznumeric_cast<uint32_t>(subMeshes.size());
m_meshInfoBufferNeedsUpdate = true;
}
void RayTracingFeatureProcessor::RemoveMesh(const ObjectId objectId)
{
if (!m_rayTracingEnabled)
{
return;
}
MeshMap::iterator itMesh = m_meshes.find(objectId.GetIndex());
if (itMesh != m_meshes.end())
{
m_subMeshCount -= aznumeric_cast<uint32_t>(itMesh->second.m_subMeshes.size());
m_meshes.erase(itMesh);
m_revision++;
}
m_meshInfoBufferNeedsUpdate = true;
}
void RayTracingFeatureProcessor::SetMeshTransform(const ObjectId objectId, AZ::Transform transform)
{
if (!m_rayTracingEnabled)
{
return;
}
MeshMap::iterator itMesh = m_meshes.find(objectId.GetIndex());
if (itMesh != m_meshes.end())
{
itMesh->second.m_transform = transform;
m_revision++;
}
m_meshInfoBufferNeedsUpdate = true;
}
void RayTracingFeatureProcessor::UpdateRayTracingSceneSrg()
{
if (!m_tlas->GetTlasBuffer())
{
return;
}
if (m_rayTracingSceneSrg->IsQueuedForCompile())
{
//[GFX TODO][ATOM-14792] AtomSampleViewer: Reset scene and feature processors before switching to sample
return;
}
// update the mesh info buffer with the latest ray tracing enabled meshes
UpdateMeshInfoBuffer();
// update the RayTracingSceneSrg
const RHI::ShaderResourceGroupLayout* srgLayout = m_rayTracingSceneSrg->GetLayout();
RHI::ShaderInputImageIndex imageIndex;
RHI::ShaderInputBufferIndex bufferIndex;
RHI::ShaderInputConstantIndex constantIndex;
// TLAS
uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_tlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_scene"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_tlas->GetTlasBuffer()->GetBufferView(bufferViewDescriptor).get());
// directional lights
const auto directionalLightFP = GetParentScene()->GetFeatureProcessor<DirectionalLightFeatureProcessor>();
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_directionalLights"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, directionalLightFP->GetLightBuffer()->GetBufferView());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_directionalLightCount"));
m_rayTracingSceneSrg->SetConstant(constantIndex, directionalLightFP->GetLightCount());
// spot lights
const auto spotLightFP = GetParentScene()->GetFeatureProcessor<SpotLightFeatureProcessor>();
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_spotLights"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, spotLightFP->GetLightBuffer()->GetBufferView());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_spotLightCount"));
m_rayTracingSceneSrg->SetConstant(constantIndex, spotLightFP->GetLightCount());
// point lights
const auto pointLightFP = GetParentScene()->GetFeatureProcessor<PointLightFeatureProcessor>();
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_pointLights"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, pointLightFP->GetLightBuffer()->GetBufferView());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_pointLightCount"));
m_rayTracingSceneSrg->SetConstant(constantIndex, pointLightFP->GetLightCount());
// disk lights
const auto diskLightFP = GetParentScene()->GetFeatureProcessor<DiskLightFeatureProcessor>();
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_diskLights"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, diskLightFP->GetLightBuffer()->GetBufferView());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_diskLightCount"));
m_rayTracingSceneSrg->SetConstant(constantIndex, diskLightFP->GetLightCount());
// capsule lights
const auto capsuleLightFP = GetParentScene()->GetFeatureProcessor<CapsuleLightFeatureProcessor>();
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_capsuleLights"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, capsuleLightFP->GetLightBuffer()->GetBufferView());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_capsuleLightCount"));
m_rayTracingSceneSrg->SetConstant(constantIndex, capsuleLightFP->GetLightCount());
// quad lights
const auto quadLightFP = GetParentScene()->GetFeatureProcessor<QuadLightFeatureProcessor>();
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_quadLights"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, quadLightFP->GetLightBuffer()->GetBufferView());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_quadLightCount"));
m_rayTracingSceneSrg->SetConstant(constantIndex, quadLightFP->GetLightCount());
// diffuse environment map for sky hits
ImageBasedLightFeatureProcessor* imageBasedLightFeatureProcessor = GetParentScene()->GetFeatureProcessor<ImageBasedLightFeatureProcessor>();
if (imageBasedLightFeatureProcessor)
{
imageIndex = srgLayout->FindShaderInputImageIndex(AZ::Name("m_diffuseEnvMap"));
m_rayTracingSceneSrg->SetImage(imageIndex, imageBasedLightFeatureProcessor->GetDiffuseImage());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblOrientation"));
m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetOrientation());
constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_iblExposure"));
m_rayTracingSceneSrg->SetConstant(constantIndex, imageBasedLightFeatureProcessor->GetExposure());
}
bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_meshInfo"));
m_rayTracingSceneSrg->SetBufferView(bufferIndex, m_meshInfoBuffer->GetBufferView());
if (m_subMeshCount)
{
AZStd::vector<const RHI::BufferView*> meshBuffers;
for (const auto& mesh : m_meshes)
{
const SubMeshVector& subMeshes = mesh.second.m_subMeshes;
for (const auto& subMesh : subMeshes)
{
// add the index, position, and normal buffers for this sub-mesh to the mesh buffer list, this will
// go into the shader as an unbounded array in the Srg
meshBuffers.push_back(subMesh.m_indexShaderBufferView.get());
meshBuffers.push_back(subMesh.m_positionShaderBufferView.get());
meshBuffers.push_back(subMesh.m_normalShaderBufferView.get());
}
}
RHI::ShaderInputBufferUnboundedArrayIndex bufferUnboundedArrayIndex = srgLayout->FindShaderInputBufferUnboundedArrayIndex(AZ::Name("m_meshBuffers"));
m_rayTracingSceneSrg->SetBufferViewUnboundedArray(bufferUnboundedArrayIndex, meshBuffers);
}
m_rayTracingSceneSrg->Compile();
}
void RayTracingFeatureProcessor::UpdateMeshInfoBuffer()
{
if (m_meshInfoBufferNeedsUpdate && (m_subMeshCount > 0))
{
TransformServiceFeatureProcessor* transformFeatureProcessor = GetParentScene()->GetFeatureProcessor<TransformServiceFeatureProcessor>();
AZStd::vector<MeshInfo> meshInfos;
meshInfos.reserve(m_subMeshCount);
uint32_t newMeshByteCount = m_subMeshCount * sizeof(MeshInfo);
if (m_meshInfoBuffer == nullptr)
{
// allocate the MeshInfo structured buffer
RPI::CommonBufferDescriptor desc;
desc.m_poolType = RPI::CommonBufferPoolType::ReadOnly;
desc.m_bufferName = "RayTracingMeshInfo";
desc.m_byteCount = newMeshByteCount;
desc.m_elementSize = sizeof(MeshInfo);
m_meshInfoBuffer = RPI::BufferSystemInterface::Get()->CreateBufferFromCommonPool(desc);
}
else if (m_meshInfoBuffer->GetBufferSize() < newMeshByteCount)
{
// resize for the new sub-mesh count
m_meshInfoBuffer->Resize(newMeshByteCount);
}
for (const auto& mesh : m_meshes)
{
AZ::Transform meshTransform = transformFeatureProcessor->GetTransformForId(TransformServiceFeatureProcessorInterface::ObjectId(mesh.first));
AZ::Transform noScaleTransform = meshTransform;
noScaleTransform.ExtractScale();
AZ::Matrix3x3 rotationMatrix = Matrix3x3::CreateFromTransform(noScaleTransform);
rotationMatrix = rotationMatrix.GetInverseFull().GetTranspose();
const RayTracingFeatureProcessor::SubMeshVector& subMeshes = mesh.second.m_subMeshes;
for (const auto& subMesh : subMeshes)
{
MeshInfo meshInfo;
meshInfo.m_indexOffset = subMesh.m_indexBufferView.GetByteOffset();
meshInfo.m_positionOffset = subMesh.m_positionVertexBufferView.GetByteOffset();
meshInfo.m_normalOffset = subMesh.m_normalVertexBufferView.GetByteOffset();
subMesh.m_irradianceColor.StoreToFloat4(meshInfo.m_irradianceColor.data());
rotationMatrix.StoreToRowMajorFloat9(meshInfo.m_worldInvTranspose.data());
meshInfos.emplace_back(meshInfo);
}
}
m_meshInfoBuffer->UpdateData(meshInfos.data(), newMeshByteCount);
m_meshInfoBufferNeedsUpdate = false;
}
}
}
}