You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp

116 lines
5.3 KiB
C++

/*
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
* its licensors.
*
* For complete copyright and license terms please see the LICENSE at the root of this
* distribution (the "License"). All use of this software is governed by the License,
* or, if provided, by the license below or the license accompanying this file. Do not
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
*/
#include <Atom/RPI.Public/Shader/ShaderVariant.h>
#include <Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h>
#include <Atom/RPI.Public/Shader/ShaderReloadDebugTracker.h>
#include <Atom/RHI/DrawListTagRegistry.h>
#include <Atom/RHI/RHISystemInterface.h>
#include <Atom/RHI.Reflect/ShaderStageFunction.h>
namespace AZ
{
namespace RPI
{
bool ShaderVariant::Init(
const Data::Asset<ShaderAsset>& shaderAsset,
const Data::Asset<ShaderVariantAsset>& shaderVariantAsset)
{
Data::AssetBus::MultiHandler::BusDisconnect();
Data::AssetBus::MultiHandler::BusConnect(shaderAsset.GetId());
Data::AssetBus::MultiHandler::BusConnect(shaderVariantAsset.GetId());
m_shaderAsset = shaderAsset;
m_pipelineStateType = shaderAsset->GetPipelineStateType();
m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor();
m_shaderVariantAsset = shaderVariantAsset;
return true;
}
ShaderVariant::~ShaderVariant()
{
Data::AssetBus::MultiHandler::BusDisconnect();
}
void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
{
descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
switch (descriptor.GetType())
{
case RHI::PipelineStateType::Draw:
{
AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Draw, "ShaderVariant is not intended for the raster pipeline.");
RHI::PipelineStateDescriptorForDraw& descriptorForDraw = static_cast<RHI::PipelineStateDescriptorForDraw&>(descriptor);
descriptorForDraw.m_vertexFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex);
descriptorForDraw.m_tessellationFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Tessellation);
descriptorForDraw.m_fragmentFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Fragment);
descriptorForDraw.m_renderStates = m_shaderVariantAsset->GetRenderStates();
break;
}
case RHI::PipelineStateType::Dispatch:
{
AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Dispatch, "ShaderVariant is not intended for the compute pipeline.");
RHI::PipelineStateDescriptorForDispatch& descriptorForDispatch = static_cast<RHI::PipelineStateDescriptorForDispatch&>(descriptor);
descriptorForDispatch.m_computeFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Compute);
break;
}
case RHI::PipelineStateType::RayTracing:
{
AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::RayTracing, "ShaderVariant is not intended for the ray tracing pipeline.");
RHI::PipelineStateDescriptorForRayTracing& descriptorForRayTracing = static_cast<RHI::PipelineStateDescriptorForRayTracing&>(descriptor);
descriptorForRayTracing.m_rayTracingFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::RayTracing);
break;
}
default:
AZ_Assert(false, "Unexpected PipelineStateType");
break;
}
}
const ShaderInputContract& ShaderVariant::GetInputContract() const
{
return m_shaderVariantAsset->GetInputContract();
}
const ShaderOutputContract& ShaderVariant::GetOutputContract() const
{
return m_shaderVariantAsset->GetOutputContract();
}
void ShaderVariant::OnAssetReloaded(Data::Asset<Data::AssetData> asset)
{
ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->ShaderVariant::OnAssetReloaded %s", this, asset.GetHint().c_str());
if (asset.GetAs<ShaderVariantAsset>())
{
Data::Asset<ShaderVariantAsset> shaderVariantAsset = { asset.GetAs<ShaderVariantAsset>(), AZ::Data::AssetLoadBehavior::PreLoad };
Init(m_shaderAsset, shaderVariantAsset);
ShaderReloadNotificationBus::Event(m_shaderAsset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, *this);
}
if (asset.GetAs<ShaderAsset>())
{
Data::Asset<ShaderAsset> shaderAsset = { asset.GetAs<ShaderAsset>(), AZ::Data::AssetLoadBehavior::PreLoad };
Init(shaderAsset, m_shaderVariantAsset);
ShaderReloadNotificationBus::Event(m_shaderAsset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, *this);
}
}
} // namespace RPI
} // namespace AZ