Merge pull request #946 from aws-lumberyard-dev/Atom/dmcdiar/ATOM-15555
[ATOM-15555] Data-Driven RayTracingPassmain
commit
b20f3bc72f
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
|
||||
* its licensors.
|
||||
*
|
||||
* For complete copyright and license terms please see the LICENSE at the root of this
|
||||
* distribution (the "License"). All use of this software is governed by the License,
|
||||
* or, if provided, by the license below or the license accompanying this file. Do not
|
||||
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
*
|
||||
*/
|
||||
|
||||
#include <Atom/Features/SrgSemantics.azsli>
|
||||
|
||||
ShaderResourceGroup RayTracingMaterialSrg : SRG_RayTracingMaterial
|
||||
{
|
||||
Sampler LinearSampler
|
||||
{
|
||||
AddressU = Wrap;
|
||||
AddressV = Wrap;
|
||||
MinFilter = Linear;
|
||||
MagFilter = Linear;
|
||||
MipFilter = Linear;
|
||||
MaxAnisotropy = 16;
|
||||
};
|
||||
|
||||
// material info structured buffer
|
||||
struct MaterialInfo
|
||||
{
|
||||
float4 m_baseColor;
|
||||
float m_metallicFactor;
|
||||
float m_roughnessFactor;
|
||||
uint m_textureFlags;
|
||||
uint m_textureStartIndex;
|
||||
};
|
||||
|
||||
// hit shaders can retrieve the MaterialInfo for a mesh hit using: RayTracingMaterialSrg::m_materialInfo[InstanceIndex()]
|
||||
StructuredBuffer<MaterialInfo> m_materialInfo;
|
||||
|
||||
// texture flag bits indicating if optional textures are present
|
||||
#define TEXTURE_FLAG_BASECOLOR 1
|
||||
#define TEXTURE_FLAG_NORMAL 2
|
||||
#define TEXTURE_FLAG_METALLIC 4
|
||||
#define TEXTURE_FLAG_ROUGHNESS 8
|
||||
|
||||
// unbounded array of Material textures
|
||||
Texture2D m_materialTextures[];
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
|
||||
* its licensors.
|
||||
*
|
||||
* For complete copyright and license terms please see the LICENSE at the root of this
|
||||
* distribution (the "License"). All use of this software is governed by the License,
|
||||
* or, if provided, by the license below or the license accompanying this file. Do not
|
||||
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
*
|
||||
*/
|
||||
|
||||
struct TextureData
|
||||
{
|
||||
float4 m_baseColor;
|
||||
float3 m_normal;
|
||||
float m_metallic;
|
||||
float m_roughness;
|
||||
};
|
||||
|
||||
TextureData GetHitTextureData(RayTracingMaterialSrg::MaterialInfo materialInfo, float2 uv)
|
||||
{
|
||||
TextureData textureData = (TextureData)0;
|
||||
|
||||
uint textureIndex = materialInfo.m_textureStartIndex;
|
||||
|
||||
// base color
|
||||
if (materialInfo.m_textureFlags & TEXTURE_FLAG_BASECOLOR)
|
||||
{
|
||||
textureData.m_baseColor = RayTracingMaterialSrg::m_materialTextures[textureIndex++].SampleLevel(RayTracingMaterialSrg::LinearSampler, uv, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
textureData.m_baseColor = materialInfo.m_baseColor;
|
||||
}
|
||||
|
||||
// normal
|
||||
if (materialInfo.m_textureFlags & TEXTURE_FLAG_NORMAL)
|
||||
{
|
||||
textureData.m_normal = RayTracingMaterialSrg::m_materialTextures[textureIndex++].SampleLevel(RayTracingMaterialSrg::LinearSampler, uv, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
textureData.m_normal = float3(0.0f, 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
// metallic
|
||||
if (materialInfo.m_textureFlags & TEXTURE_FLAG_METALLIC)
|
||||
{
|
||||
textureData.m_metallic = RayTracingMaterialSrg::m_materialTextures[textureIndex++].SampleLevel(RayTracingMaterialSrg::LinearSampler, uv, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
textureData.m_metallic = materialInfo.m_metallicFactor;
|
||||
}
|
||||
|
||||
// roughness
|
||||
if (materialInfo.m_textureFlags & TEXTURE_FLAG_ROUGHNESS)
|
||||
{
|
||||
textureData.m_roughness = RayTracingMaterialSrg::m_materialTextures[textureIndex++].SampleLevel(RayTracingMaterialSrg::LinearSampler, uv, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
textureData.m_roughness = materialInfo.m_roughnessFactor;
|
||||
}
|
||||
|
||||
return textureData;
|
||||
}
|
||||
|
||||
@ -0,0 +1,126 @@
|
||||
/*
|
||||
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
|
||||
* its licensors.
|
||||
*
|
||||
* For complete copyright and license terms please see the LICENSE at the root of this
|
||||
* distribution (the "License"). All use of this software is governed by the License,
|
||||
* or, if provided, by the license below or the license accompanying this file. Do not
|
||||
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
*
|
||||
*/
|
||||
|
||||
// returns the normalized camera view ray into the scene for this raytracing dispatch thread
|
||||
float3 GetViewRayDirection(float4x4 viewProjectionInverseMatrix)
|
||||
{
|
||||
float2 pixel = ((float2)DispatchRaysIndex().xy + float2(0.5f, 0.5f)) / (float2)DispatchRaysDimensions();
|
||||
float2 ndc = pixel * float2(2.0f, -2.0f) + float2(-1.0f, 1.0f);
|
||||
return normalize(mul(viewProjectionInverseMatrix, float4(ndc, 0.0f, 1.0f)).xyz);
|
||||
}
|
||||
|
||||
// returns the vertex indices for the primitive hit by the ray
|
||||
// Note: usable only in a raytracing Hit shader
|
||||
uint3 GetHitIndices(RayTracingSceneSrg::MeshInfo meshInfo)
|
||||
{
|
||||
// compute the array index of the index buffer for this mesh in the m_meshBuffers unbounded array
|
||||
uint meshIndexBufferArrayIndex = meshInfo.m_bufferStartIndex + MESH_INDEX_BUFFER_OFFSET;
|
||||
|
||||
// compute the offset into the index buffer for this primitve of the mesh
|
||||
uint offsetBytes = meshInfo.m_indexOffset + (PrimitiveIndex() * 12);
|
||||
|
||||
// load the indices for this primitive from the index buffer
|
||||
return RayTracingSceneSrg::m_meshBuffers[meshIndexBufferArrayIndex].Load3(offsetBytes);
|
||||
}
|
||||
|
||||
// returns the interpolated vertex data for the primitive hit by the ray
|
||||
// Note: usable only in a raytracing hit shader
|
||||
struct VertexData
|
||||
{
|
||||
float3 m_position;
|
||||
float3 m_normal;
|
||||
float3 m_tangent;
|
||||
float3 m_bitangent;
|
||||
float2 m_uv;
|
||||
};
|
||||
|
||||
VertexData GetHitInterpolatedVertexData(RayTracingSceneSrg::MeshInfo meshInfo, float2 builtInBarycentrics)
|
||||
{
|
||||
// retrieve the poly indices
|
||||
uint3 indices = GetHitIndices(meshInfo);
|
||||
|
||||
// compute barycentrics
|
||||
float3 barycentrics = float3((1.0f - builtInBarycentrics.x - builtInBarycentrics.y), builtInBarycentrics.x, builtInBarycentrics.y);
|
||||
|
||||
// compute the vertex data using barycentric interpolation
|
||||
VertexData vertexData = (VertexData)0;
|
||||
for (uint i = 0; i < 3; ++i)
|
||||
{
|
||||
// position
|
||||
{
|
||||
// array index of the position buffer for this mesh in the m_meshBuffers unbounded array
|
||||
uint meshVertexPositionArrayIndex = meshInfo.m_bufferStartIndex + MESH_POSITION_BUFFER_OFFSET;
|
||||
|
||||
// offset into the position buffer for this vertex
|
||||
uint positionOffset = meshInfo.m_positionOffset + (indices[i] * 12);
|
||||
|
||||
// load the position data
|
||||
vertexData.m_position += asfloat(RayTracingSceneSrg::m_meshBuffers[meshVertexPositionArrayIndex].Load3(positionOffset)) * barycentrics[i];
|
||||
}
|
||||
|
||||
// normal
|
||||
{
|
||||
// array index of the normal buffer for this mesh in the m_meshBuffers unbounded array
|
||||
uint meshVertexNormalArrayIndex = meshInfo.m_bufferStartIndex + MESH_NORMAL_BUFFER_OFFSET;
|
||||
|
||||
// offset into the normal buffer for this vertex
|
||||
uint normalOffset = meshInfo.m_normalOffset + (indices[i] * 12);
|
||||
|
||||
// load the normal data
|
||||
vertexData.m_normal += asfloat(RayTracingSceneSrg::m_meshBuffers[meshVertexNormalArrayIndex].Load3(normalOffset)) * barycentrics[i];
|
||||
}
|
||||
|
||||
// tangent
|
||||
{
|
||||
// array index of the tangent buffer for this mesh in the m_meshBuffers unbounded array
|
||||
uint meshVertexTangentArrayIndex = meshInfo.m_bufferStartIndex + MESH_TANGENT_BUFFER_OFFSET;
|
||||
|
||||
// offset into the tangent buffer for this vertex
|
||||
uint tangentOffset = meshInfo.m_tangentOffset + (indices[i] * 12);
|
||||
|
||||
// load the tangent data
|
||||
vertexData.m_tangent += asfloat(RayTracingSceneSrg::m_meshBuffers[meshVertexTangentArrayIndex].Load3(tangentOffset)) * barycentrics[i];
|
||||
}
|
||||
|
||||
// bitangent
|
||||
{
|
||||
// array index of the bitangent buffer for this mesh in the m_meshBuffers unbounded array
|
||||
uint meshVertexBitangentArrayIndex = meshInfo.m_bufferStartIndex + MESH_BITANGENT_BUFFER_OFFSET;
|
||||
|
||||
// offset into the bitangent buffer for this vertex
|
||||
uint bitangentOffset = meshInfo.m_bitangentOffset + (indices[i] * 12);
|
||||
|
||||
// load the bitangent data
|
||||
vertexData.m_bitangent += asfloat(RayTracingSceneSrg::m_meshBuffers[meshVertexBitangentArrayIndex].Load3(bitangentOffset)) * barycentrics[i];
|
||||
}
|
||||
|
||||
// optional streams begin after MESH_BITANGENT_BUFFER_OFFSET
|
||||
uint optionalBufferOffset = MESH_BITANGENT_BUFFER_OFFSET + 1;
|
||||
|
||||
// UV
|
||||
if (meshInfo.m_bufferFlags & MESH_BUFFER_FLAG_UV)
|
||||
{
|
||||
// array index of the UV buffer for this mesh in the m_meshBuffers unbounded array
|
||||
uint meshVertexUVArrayIndex = meshInfo.m_bufferStartIndex + optionalBufferOffset++;
|
||||
|
||||
// offset into the UV buffer for this vertex
|
||||
uint uvOffset = meshInfo.m_uvOffset + (indices[i] * 8);
|
||||
|
||||
// load the UV data
|
||||
vertexData.m_uv += asfloat(RayTracingSceneSrg::m_meshBuffers[meshVertexUVArrayIndex].Load2(uvOffset)) * barycentrics[i];
|
||||
}
|
||||
}
|
||||
|
||||
vertexData.m_normal = normalize(vertexData.m_normal);
|
||||
|
||||
return vertexData;
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,362 @@
|
||||
/*
|
||||
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
|
||||
* its licensors.
|
||||
*
|
||||
* For complete copyright and license terms please see the LICENSE at the root of this
|
||||
* distribution (the "License"). All use of this software is governed by the License,
|
||||
* or, if provided, by the license below or the license accompanying this file. Do not
|
||||
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
*
|
||||
*/
|
||||
|
||||
#include <AzCore/Asset/AssetCommon.h>
|
||||
#include <AzCore/Asset/AssetManagerBus.h>
|
||||
#include <Atom/RHI/CommandList.h>
|
||||
#include <Atom/RHI/Factory.h>
|
||||
#include <Atom/RHI/FrameScheduler.h>
|
||||
#include <Atom/RHI/DispatchRaysItem.h>
|
||||
#include <Atom/RHI/RHISystemInterface.h>
|
||||
#include <Atom/RHI/PipelineState.h>
|
||||
#include <Atom/RPI.Reflect/Pass/PassTemplate.h>
|
||||
#include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
|
||||
#include <Atom/RPI.Public/Base.h>
|
||||
#include <Atom/RPI.Public/Pass/PassUtils.h>
|
||||
#include <Atom/RPI.Public/RPIUtils.h>
|
||||
#include <Atom/RPI.Public/RenderPipeline.h>
|
||||
#include <Atom/RPI.Public/Scene.h>
|
||||
#include <Atom/RPI.Public/View.h>
|
||||
#include <RayTracing/RayTracingPass.h>
|
||||
#include <RayTracing/RayTracingPassData.h>
|
||||
#include <RayTracing/RayTracingFeatureProcessor.h>
|
||||
|
||||
namespace AZ
|
||||
{
|
||||
namespace Render
|
||||
{
|
||||
RPI::Ptr<RayTracingPass> RayTracingPass::Create(const RPI::PassDescriptor& descriptor)
|
||||
{
|
||||
RPI::Ptr<RayTracingPass> pass = aznew RayTracingPass(descriptor);
|
||||
return pass;
|
||||
}
|
||||
|
||||
RayTracingPass::RayTracingPass(const RPI::PassDescriptor& descriptor)
|
||||
: RenderPass(descriptor)
|
||||
, m_passDescriptor(descriptor)
|
||||
{
|
||||
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
|
||||
if (device->GetFeatures().m_rayTracing == false)
|
||||
{
|
||||
// raytracing is not supported on this platform
|
||||
SetEnabled(false);
|
||||
return;
|
||||
}
|
||||
|
||||
Init();
|
||||
}
|
||||
|
||||
RayTracingPass::~RayTracingPass()
|
||||
{
|
||||
RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
|
||||
}
|
||||
|
||||
void RayTracingPass::Init()
|
||||
{
|
||||
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
|
||||
|
||||
m_passData = RPI::PassUtils::GetPassData<RayTracingPassData>(m_passDescriptor);
|
||||
if (m_passData == nullptr)
|
||||
{
|
||||
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Invalid RayTracingPassData", GetPathName().GetCStr());
|
||||
return;
|
||||
}
|
||||
|
||||
// ray generation shader
|
||||
m_rayGenerationShader = LoadShader(m_passData->m_rayGenerationShaderAssetReference);
|
||||
if (m_rayGenerationShader == nullptr)
|
||||
{
|
||||
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load RayGeneration shader [%s]", GetPathName().GetCStr(), m_passData->m_rayGenerationShaderAssetReference.m_filePath.data());
|
||||
return;
|
||||
}
|
||||
|
||||
auto shaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
|
||||
RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
|
||||
shaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
|
||||
|
||||
// closest hit shader
|
||||
m_closestHitShader = LoadShader(m_passData->m_closestHitShaderAssetReference);
|
||||
if (m_closestHitShader == nullptr)
|
||||
{
|
||||
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load ClosestHit shader [%s]", GetPathName().GetCStr(), m_passData->m_closestHitShaderAssetReference.m_filePath.data());
|
||||
return;
|
||||
}
|
||||
|
||||
shaderVariant = m_closestHitShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
|
||||
RHI::PipelineStateDescriptorForRayTracing closestHitShaderDescriptor;
|
||||
shaderVariant.ConfigurePipelineState(closestHitShaderDescriptor);
|
||||
|
||||
// miss shader
|
||||
m_missShader = LoadShader(m_passData->m_missShaderAssetReference);
|
||||
if (m_missShader == nullptr)
|
||||
{
|
||||
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load Miss shader [%s]", GetPathName().GetCStr(), m_passData->m_missShaderAssetReference.m_filePath.data());
|
||||
return;
|
||||
}
|
||||
|
||||
shaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
|
||||
RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
|
||||
shaderVariant.ConfigurePipelineState(missShaderDescriptor);
|
||||
|
||||
// retrieve global pipeline state
|
||||
m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
|
||||
AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
|
||||
|
||||
// create global srg
|
||||
Data::Asset<RPI::ShaderResourceGroupAsset> globalSrgAsset = m_rayGenerationShader->FindShaderResourceGroupAsset(RayTracingGlobalSrgBindingSlot);
|
||||
AZ_Error("PassSystem", globalSrgAsset.GetId().IsValid(), "RayTracingPass [%s] Failed to find RayTracingGlobalSrg asset", GetPathName().GetCStr());
|
||||
AZ_Error("PassSystem", globalSrgAsset.IsReady(), "RayTracingPass [%s] asset is not loaded for shader", GetPathName().GetCStr());
|
||||
|
||||
m_shaderResourceGroup = RPI::ShaderResourceGroup::Create(globalSrgAsset);
|
||||
AZ_Assert(m_shaderResourceGroup, "RayTracingPass [%s]: Failed to create RayTracingGlobalSrg", GetPathName().GetCStr());
|
||||
RPI::PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
|
||||
|
||||
// check to see if the shader requires the View and RayTracingMaterial Srgs
|
||||
Data::Asset<RPI::ShaderResourceGroupAsset> viewSrgAsset = m_rayGenerationShader->FindShaderResourceGroupAsset(RPI::SrgBindingSlot::View);
|
||||
m_requiresViewSrg = viewSrgAsset.GetId().IsValid();
|
||||
|
||||
Data::Asset<RPI::ShaderResourceGroupAsset> rayTracingMaterialSrgAsset = m_rayGenerationShader->FindShaderResourceGroupAsset(RayTracingMaterialSrgBindingSlot);
|
||||
m_requiresRayTracingMaterialSrg = rayTracingMaterialSrgAsset.GetId().IsValid();
|
||||
|
||||
// build the ray tracing pipeline state descriptor
|
||||
RHI::RayTracingPipelineStateDescriptor descriptor;
|
||||
descriptor.Build()
|
||||
->PipelineState(m_globalPipelineState.get())
|
||||
->MaxPayloadSize(m_passData->m_maxPayloadSize)
|
||||
->MaxAttributeSize(m_passData->m_maxAttributeSize)
|
||||
->MaxRecursionDepth(m_passData->m_maxRecursionDepth)
|
||||
->ShaderLibrary(rayGenerationShaderDescriptor)
|
||||
->RayGenerationShaderName(AZ::Name(m_passData->m_rayGenerationShaderName.c_str()))
|
||||
->ShaderLibrary(missShaderDescriptor)
|
||||
->MissShaderName(AZ::Name(m_passData->m_missShaderName.c_str()))
|
||||
->ShaderLibrary(closestHitShaderDescriptor)
|
||||
->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitShaderName.c_str()))
|
||||
->HitGroup(AZ::Name("HitGroup"))
|
||||
->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitShaderName.c_str()));
|
||||
|
||||
// create the ray tracing pipeline state object
|
||||
m_rayTracingPipelineState = RHI::Factory::Get().CreateRayTracingPipelineState();
|
||||
m_rayTracingPipelineState->Init(*device.get(), &descriptor);
|
||||
|
||||
// make sure the shader table rebuilds if we're hotreloading
|
||||
m_rayTracingRevision = 0;
|
||||
|
||||
RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
|
||||
RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_rayGenerationShaderAssetReference.m_assetId);
|
||||
RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitShaderAssetReference.m_assetId);
|
||||
RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_missShaderAssetReference.m_assetId);
|
||||
}
|
||||
|
||||
Data::Instance<RPI::Shader> RayTracingPass::LoadShader(const RPI::AssetReference& shaderAssetReference)
|
||||
{
|
||||
Data::Asset<RPI::ShaderAsset> shaderAsset;
|
||||
if (shaderAssetReference.m_assetId.IsValid())
|
||||
{
|
||||
shaderAsset = RPI::FindShaderAsset(shaderAssetReference.m_assetId, shaderAssetReference.m_filePath);
|
||||
}
|
||||
|
||||
if (!shaderAsset.GetId().IsValid())
|
||||
{
|
||||
AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load shader asset [%s]", GetPathName().GetCStr(), shaderAssetReference.m_filePath.data());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return RPI::Shader::FindOrCreate(shaderAsset);
|
||||
}
|
||||
|
||||
void RayTracingPass::FrameBeginInternal(FramePrepareParams params)
|
||||
{
|
||||
RPI::Scene* scene = m_pipeline->GetScene();
|
||||
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
|
||||
if (!rayTracingFeatureProcessor)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_rayTracingShaderTable)
|
||||
{
|
||||
RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
|
||||
RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
|
||||
|
||||
m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
|
||||
m_rayTracingShaderTable->Init(*device.get(), rayTracingBufferPools);
|
||||
}
|
||||
|
||||
RPI::RenderPass::FrameBeginInternal(params);
|
||||
}
|
||||
|
||||
void RayTracingPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
|
||||
{
|
||||
RPI::Scene* scene = m_pipeline->GetScene();
|
||||
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
|
||||
AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
|
||||
|
||||
RPI::RenderPass::SetupFrameGraphDependencies(frameGraph);
|
||||
frameGraph.SetEstimatedItemCount(1);
|
||||
|
||||
// TLAS
|
||||
{
|
||||
const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer();
|
||||
if (rayTracingTlasBuffer)
|
||||
{
|
||||
AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
|
||||
if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
|
||||
{
|
||||
[[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
|
||||
AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
|
||||
}
|
||||
|
||||
uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer()->GetDescriptor().m_byteCount);
|
||||
RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, tlasBufferByteCount);
|
||||
|
||||
RHI::BufferScopeAttachmentDescriptor desc;
|
||||
desc.m_attachmentId = tlasAttachmentId;
|
||||
desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
|
||||
desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
|
||||
|
||||
frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RayTracingPass::CompileResources(const RHI::FrameGraphCompileContext& context)
|
||||
{
|
||||
RPI::Scene* scene = m_pipeline->GetScene();
|
||||
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
|
||||
AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
|
||||
|
||||
if (m_shaderResourceGroup != nullptr)
|
||||
{
|
||||
BindPassSrg(context, m_shaderResourceGroup);
|
||||
m_shaderResourceGroup->Compile();
|
||||
}
|
||||
|
||||
uint32_t rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
|
||||
if (m_rayTracingRevision != rayTracingRevision)
|
||||
{
|
||||
// scene changed, need to rebuild the shader table
|
||||
m_rayTracingRevision = rayTracingRevision;
|
||||
|
||||
AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
|
||||
|
||||
if (rayTracingFeatureProcessor->GetSubMeshCount())
|
||||
{
|
||||
// build the ray tracing shader table descriptor
|
||||
RHI::RayTracingShaderTableDescriptor* descriptorBuild = descriptor->Build(AZ::Name("RayTracingShaderTable"), m_rayTracingPipelineState)
|
||||
->RayGenerationRecord(AZ::Name(m_passData->m_rayGenerationShaderName.c_str()))
|
||||
->MissRecord(AZ::Name(m_passData->m_missShaderName.c_str()));
|
||||
|
||||
// add a hit group for each mesh to the shader table
|
||||
for (uint32_t i = 0; i < rayTracingFeatureProcessor->GetSubMeshCount(); ++i)
|
||||
{
|
||||
descriptorBuild->HitGroupRecord(AZ::Name("HitGroup"));
|
||||
}
|
||||
}
|
||||
|
||||
m_rayTracingShaderTable->Build(descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
void RayTracingPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
|
||||
{
|
||||
RPI::Scene* scene = m_pipeline->GetScene();
|
||||
RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
|
||||
AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
|
||||
|
||||
if (!rayTracingFeatureProcessor ||
|
||||
!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer() ||
|
||||
!rayTracingFeatureProcessor->GetSubMeshCount() ||
|
||||
!m_rayTracingShaderTable)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
RHI::DispatchRaysItem dispatchRaysItem;
|
||||
|
||||
// calculate thread counts if this is a full screen raytracing pass
|
||||
if (m_passData->m_makeFullscreenPass)
|
||||
{
|
||||
RPI::PassAttachment* outputAttachment = nullptr;
|
||||
|
||||
if (GetOutputCount() > 0)
|
||||
{
|
||||
outputAttachment = GetOutputBinding(0).m_attachment.get();
|
||||
}
|
||||
else if (GetInputOutputCount() > 0)
|
||||
{
|
||||
outputAttachment = GetInputOutputBinding(0).m_attachment.get();
|
||||
}
|
||||
|
||||
AZ_Assert(outputAttachment != nullptr, "[RayTracingPass '%s']: A fullscreen RayTracing pass must have a valid output or input/output.", GetPathName().GetCStr());
|
||||
AZ_Assert(outputAttachment->GetAttachmentType() == RHI::AttachmentType::Image, "[RayTracingPass '%s']: The output of a fullscreen RayTracing pass must be an image.", GetPathName().GetCStr());
|
||||
|
||||
RHI::Size imageSize = outputAttachment->m_descriptor.m_image.m_size;
|
||||
|
||||
dispatchRaysItem.m_width = imageSize.m_width;
|
||||
dispatchRaysItem.m_height = imageSize.m_height;
|
||||
dispatchRaysItem.m_depth = imageSize.m_depth;
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatchRaysItem.m_width = m_passData->m_threadCountX;
|
||||
dispatchRaysItem.m_height = m_passData->m_threadCountY;
|
||||
dispatchRaysItem.m_depth = m_passData->m_threadCountZ;
|
||||
}
|
||||
|
||||
// bind RayTracingGlobal, RayTracingScene, and View Srgs
|
||||
// [GFX TODO][ATOM-15610] Add RenderPass::SetSrgsForRayTracingDispatch
|
||||
AZStd::vector<RHI::ShaderResourceGroup*> shaderResourceGroups =
|
||||
{
|
||||
m_shaderResourceGroup->GetRHIShaderResourceGroup(),
|
||||
rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()
|
||||
};
|
||||
|
||||
if (m_requiresViewSrg)
|
||||
{
|
||||
const AZStd::vector<RPI::ViewPtr>& views = m_pipeline->GetViews(m_passData->m_pipelineViewTag);
|
||||
if (views.size() > 0)
|
||||
{
|
||||
shaderResourceGroups.push_back(views[0]->GetRHIShaderResourceGroup());
|
||||
}
|
||||
}
|
||||
|
||||
if (m_requiresRayTracingMaterialSrg)
|
||||
{
|
||||
shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup());
|
||||
}
|
||||
|
||||
dispatchRaysItem.m_shaderResourceGroupCount = aznumeric_cast<uint32_t>(shaderResourceGroups.size());
|
||||
dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups.data();
|
||||
dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState.get();
|
||||
dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable.get();
|
||||
dispatchRaysItem.m_globalPipelineState = m_globalPipelineState.get();
|
||||
|
||||
// submit the DispatchRays item
|
||||
context.GetCommandList()->Submit(dispatchRaysItem);
|
||||
}
|
||||
|
||||
void RayTracingPass::OnShaderReinitialized([[maybe_unused]] const RPI::Shader& shader)
|
||||
{
|
||||
Init();
|
||||
}
|
||||
|
||||
void RayTracingPass::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset<RPI::ShaderAsset>& shaderAsset)
|
||||
{
|
||||
Init();
|
||||
}
|
||||
|
||||
void RayTracingPass::OnShaderVariantReinitialized([[maybe_unused]] const RPI::Shader& shader, [[maybe_unused]] const RPI::ShaderVariantId& shaderVariantId, [[maybe_unused]] RPI::ShaderVariantStableId shaderVariantStableId)
|
||||
{
|
||||
Init();
|
||||
}
|
||||
} // namespace RPI
|
||||
} // namespace AZ
|
||||
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
|
||||
* its licensors.
|
||||
*
|
||||
* For complete copyright and license terms please see the LICENSE at the root of this
|
||||
* distribution (the "License"). All use of this software is governed by the License,
|
||||
* or, if provided, by the license below or the license accompanying this file. Do not
|
||||
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
*
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <AzCore/Memory/SystemAllocator.h>
|
||||
#include <Atom/RHI/RayTracingPipelineState.h>
|
||||
#include <Atom/RHI/RayTracingShaderTable.h>
|
||||
#include <Atom/RPI.Public/Pass/RenderPass.h>
|
||||
#include <Atom/RPI.Public/Shader/Shader.h>
|
||||
#include <Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h>
|
||||
|
||||
namespace AZ
|
||||
{
|
||||
namespace Render
|
||||
{
|
||||
struct RayTracingPassData;
|
||||
|
||||
//! This pass executes a raytracing shader as specified in the PassData.
|
||||
class RayTracingPass
|
||||
: public RPI::RenderPass
|
||||
, private RPI::ShaderReloadNotificationBus::MultiHandler
|
||||
{
|
||||
AZ_RPI_PASS(RayTracingPass);
|
||||
|
||||
public:
|
||||
AZ_RTTI(RayTracingPass, "{7A68A36E-956A-4258-93FE-38686042C4D9}", RPI::RenderPass);
|
||||
AZ_CLASS_ALLOCATOR(RayTracingPass, SystemAllocator, 0);
|
||||
virtual ~RayTracingPass();
|
||||
|
||||
//! Creates a RayTracingPass
|
||||
static RPI::Ptr<RayTracingPass> Create(const RPI::PassDescriptor& descriptor);
|
||||
|
||||
protected:
|
||||
RayTracingPass(const RPI::PassDescriptor& descriptor);
|
||||
|
||||
// Pass overrides
|
||||
void FrameBeginInternal(FramePrepareParams params) override;
|
||||
|
||||
// Scope producer functions
|
||||
void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override;
|
||||
void CompileResources(const RHI::FrameGraphCompileContext& context) override;
|
||||
void BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context) override;
|
||||
|
||||
// ShaderReloadNotificationBus::Handler overrides
|
||||
void OnShaderReinitialized(const RPI::Shader& shader) override;
|
||||
void OnShaderAssetReinitialized(const Data::Asset<RPI::ShaderAsset>& shaderAsset) override;
|
||||
void OnShaderVariantReinitialized(const RPI::Shader& shader, const RPI::ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId shaderVariantStableId) override;
|
||||
|
||||
// load the raytracing shaders and setup pipeline states
|
||||
void Init();
|
||||
|
||||
// helper for loading a shader from a shader asset reference
|
||||
Data::Instance<RPI::Shader> LoadShader(const RPI::AssetReference& shaderAssetReference);
|
||||
|
||||
// pass data
|
||||
RPI::PassDescriptor m_passDescriptor;
|
||||
const RayTracingPassData* m_passData = nullptr;
|
||||
|
||||
// revision number of the ray tracing TLAS when the shader table was built
|
||||
uint32_t m_rayTracingRevision = 0;
|
||||
|
||||
// raytracing shaders, pipeline states, and shader table
|
||||
Data::Instance<RPI::Shader> m_rayGenerationShader;
|
||||
Data::Instance<RPI::Shader> m_missShader;
|
||||
Data::Instance<RPI::Shader> m_closestHitShader;
|
||||
RHI::Ptr<RHI::RayTracingPipelineState> m_rayTracingPipelineState;
|
||||
RHI::ConstPtr<RHI::PipelineState> m_globalPipelineState;
|
||||
RHI::Ptr<RHI::RayTracingShaderTable> m_rayTracingShaderTable;
|
||||
bool m_requiresViewSrg = false;
|
||||
bool m_requiresRayTracingMaterialSrg = false;
|
||||
};
|
||||
} // namespace RPI
|
||||
} // namespace AZ
|
||||
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
|
||||
* its licensors.
|
||||
*
|
||||
* For complete copyright and license terms please see the LICENSE at the root of this
|
||||
* distribution (the "License"). All use of this software is governed by the License,
|
||||
* or, if provided, by the license below or the license accompanying this file. Do not
|
||||
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
*
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <Atom/RPI.Reflect/Asset/AssetReference.h>
|
||||
#include <Atom/RPI.Reflect/Pass/RenderPassData.h>
|
||||
|
||||
namespace AZ
|
||||
{
|
||||
namespace Render
|
||||
{
|
||||
//! Custom data for the RayTracingPass, specified in the PassRequest.
|
||||
struct RayTracingPassData
|
||||
: public RPI::RenderPassData
|
||||
{
|
||||
AZ_RTTI(RayTracingPassData, "{26C2E2FD-D30A-4142-82A3-0167BC94B3EE}", RPI::RenderPassData);
|
||||
AZ_CLASS_ALLOCATOR(RayTracingPassData, SystemAllocator, 0);
|
||||
|
||||
RayTracingPassData() = default;
|
||||
virtual ~RayTracingPassData() = default;
|
||||
|
||||
static void Reflect(ReflectContext* context)
|
||||
{
|
||||
if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
|
||||
{
|
||||
serializeContext->Class<RayTracingPassData, RenderPassData>()
|
||||
->Version(1)
|
||||
->Field("RayGenerationShaderAsset", &RayTracingPassData::m_rayGenerationShaderAssetReference)
|
||||
->Field("RayGenerationShaderName", &RayTracingPassData::m_rayGenerationShaderName)
|
||||
->Field("ClosestHitShaderAsset", &RayTracingPassData::m_closestHitShaderAssetReference)
|
||||
->Field("ClosestHitShaderName", &RayTracingPassData::m_closestHitShaderName)
|
||||
->Field("MissShaderAsset", &RayTracingPassData::m_missShaderAssetReference)
|
||||
->Field("MissShaderName", &RayTracingPassData::m_missShaderName)
|
||||
->Field("MaxPayloadSize", &RayTracingPassData::m_maxPayloadSize)
|
||||
->Field("MaxAttributeSize", &RayTracingPassData::m_maxAttributeSize)
|
||||
->Field("MaxRecursionDepth", &RayTracingPassData::m_maxRecursionDepth)
|
||||
->Field("Thread Count X", &RayTracingPassData::m_threadCountX)
|
||||
->Field("Thread Count Y", &RayTracingPassData::m_threadCountY)
|
||||
->Field("Thread Count Z", &RayTracingPassData::m_threadCountZ)
|
||||
->Field("Make Fullscreen Pass", &RayTracingPassData::m_makeFullscreenPass)
|
||||
;
|
||||
}
|
||||
}
|
||||
|
||||
RPI::AssetReference m_rayGenerationShaderAssetReference;
|
||||
AZStd::string m_rayGenerationShaderName;
|
||||
RPI::AssetReference m_closestHitShaderAssetReference;
|
||||
AZStd::string m_closestHitShaderName;
|
||||
RPI::AssetReference m_missShaderAssetReference;
|
||||
AZStd::string m_missShaderName;
|
||||
|
||||
uint32_t m_maxPayloadSize = 64;
|
||||
uint32_t m_maxAttributeSize = 32;
|
||||
uint32_t m_maxRecursionDepth = 1;
|
||||
|
||||
uint32_t m_threadCountX = 1;
|
||||
uint32_t m_threadCountY = 1;
|
||||
uint32_t m_threadCountZ = 1;
|
||||
|
||||
bool m_makeFullscreenPass = false;
|
||||
};
|
||||
} // namespace RPI
|
||||
} // namespace AZ
|
||||
|
||||
Loading…
Reference in New Issue