You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant2.cpp

77 lines
3.6 KiB
C++

/*
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
* its licensors.
*
* For complete copyright and license terms please see the LICENSE at the root of this
* distribution (the "License"). All use of this software is governed by the License,
* or, if provided, by the license below or the license accompanying this file. Do not
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
*/
#include <Atom/RPI.Public/Shader/ShaderVariant2.h>
#include <Atom/RHI/DrawListTagRegistry.h>
#include <Atom/RHI/RHISystemInterface.h>
#include <Atom/RHI.Reflect/ShaderStageFunction.h>
namespace AZ
{
namespace RPI
{
bool ShaderVariant2::Init(
const ShaderAsset2& shaderAsset,
Data::Asset<ShaderVariantAsset2> shaderVariantAsset,
SupervariantIndex supervariantIndex)
{
m_pipelineStateType = shaderAsset.GetPipelineStateType();
m_pipelineLayoutDescriptor = shaderAsset.GetPipelineLayoutDescriptor(supervariantIndex);
m_shaderVariantAsset = shaderVariantAsset;
m_renderStates = &shaderAsset.GetRenderStates(supervariantIndex);
return true;
}
void ShaderVariant2::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
{
descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
switch (descriptor.GetType())
{
case RHI::PipelineStateType::Draw:
{
AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Draw, "ShaderVariant2 is not intended for the raster pipeline.");
AZ_Assert(m_renderStates, "Invalid RenderStates");
RHI::PipelineStateDescriptorForDraw& descriptorForDraw = static_cast<RHI::PipelineStateDescriptorForDraw&>(descriptor);
descriptorForDraw.m_vertexFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex);
descriptorForDraw.m_tessellationFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Tessellation);
descriptorForDraw.m_fragmentFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Fragment);
descriptorForDraw.m_renderStates = *m_renderStates;
break;
}
case RHI::PipelineStateType::Dispatch:
{
AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Dispatch, "ShaderVariant2 is not intended for the compute pipeline.");
RHI::PipelineStateDescriptorForDispatch& descriptorForDispatch = static_cast<RHI::PipelineStateDescriptorForDispatch&>(descriptor);
descriptorForDispatch.m_computeFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Compute);
break;
}
case RHI::PipelineStateType::RayTracing:
{
AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::RayTracing, "ShaderVariant2 is not intended for the ray tracing pipeline.");
RHI::PipelineStateDescriptorForRayTracing& descriptorForRayTracing = static_cast<RHI::PipelineStateDescriptorForRayTracing&>(descriptor);
descriptorForRayTracing.m_rayTracingFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::RayTracing);
break;
}
default:
AZ_Assert(false, "Unexpected PipelineStateType");
break;
}
}
} // namespace RPI
} // namespace AZ