From c158ca178fb780497d446513ca8a4a97f10f32a8 Mon Sep 17 00:00:00 2001 From: Chris Santora Date: Tue, 15 Jun 2021 19:26:11 -0700 Subject: [PATCH] Added new shader reinitialization signaling. This was done while working on "ATOM-15728 Shader Hot Reload Fails in Debug Build", but it turned out these changes did not actually fix the issue (or any other known hot-reload issue). Still, these improvements are appropriate as they correct logical oversights. ShaderVariant was not listening to asset reloads. It needs to know when the ShaderVariantAsset reload happens so it can reinitialize it's members as well as propagate reinitialization messages. I added a member for the ShaderAsset as the class needs this to reinitialize itself. So now the class listens for reloads of both the ShaderVariantAsset and the ShaderAsset. Shader was not listening for ShaderAsset reinitialization events. Updated the API for ShaderReloadNotificationBus's OnShaderVariantReinitialized to include the ShaderVariant which is the most relevant information (the other information wasn't really being used anyway). --- .../LightCullingTilePreparePass.cpp | 4 +- .../CoreLights/LightCullingTilePreparePass.h | 2 +- .../MorphTargets/MorphTargetDispatchItem.cpp | 4 +- .../MorphTargets/MorphTargetDispatchItem.h | 4 +- .../Code/Source/RayTracing/RayTracingPass.cpp | 4 +- .../Code/Source/RayTracing/RayTracingPass.h | 2 +- .../SkinnedMesh/SkinnedMeshComputePass.cpp | 8 +++- .../SkinnedMesh/SkinnedMeshComputePass.h | 2 +- .../Atom/RPI.Public/Material/Material.h | 2 +- .../Atom/RPI.Public/Pass/ComputePass.h | 2 +- .../RPI.Public/Pass/FullscreenTrianglePass.h | 2 +- .../Include/Atom/RPI.Public/PipelineState.h | 2 +- .../Shader/ShaderReloadNotificationBus.h | 9 ++-- .../Atom/RPI.Public/Shader/ShaderVariant.h | 15 ++++++- .../Source/RPI.Public/Material/Material.cpp | 4 +- .../Source/RPI.Public/Pass/ComputePass.cpp | 3 +- .../Pass/FullscreenTrianglePass.cpp | 2 +- .../Code/Source/RPI.Public/PipelineState.cpp | 7 +-- .../Code/Source/RPI.Public/Shader/Shader.cpp | 30 ++++++++----- .../RPI.Public/Shader/ShaderVariant.cpp | 45 ++++++++++++++++--- 20 files changed, 103 insertions(+), 50 deletions(-) diff --git a/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp b/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp index 4f2a4f0346..7742036171 100644 --- a/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp +++ b/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp @@ -192,9 +192,7 @@ namespace AZ OnShaderReloaded(); } - void LightCullingTilePreparePass::OnShaderVariantReinitialized( - const AZ::RPI::Shader&, const AZ::RPI::ShaderVariantId&, - AZ::RPI::ShaderVariantStableId) + void LightCullingTilePreparePass::OnShaderVariantReinitialized(const AZ::RPI::ShaderVariant&) { OnShaderReloaded(); } diff --git a/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h b/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h index 0febb66e3d..674f5d9914 100644 --- a/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h +++ b/Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h @@ -56,7 +56,7 @@ namespace AZ // ShaderReloadNotificationBus overrides... void OnShaderReinitialized(const AZ::RPI::Shader& shader) override; void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const AZ::RPI::Shader& shader, const AZ::RPI::ShaderVariantId& shaderVariantId, AZ::RPI::ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const AZ::RPI::ShaderVariant& shaderVariant) override; // Scope producer functions... void CompileResources(const RHI::FrameGraphCompileContext& context) override; diff --git a/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp b/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp index 7b3aedd64e..55a6d5d87a 100644 --- a/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp +++ b/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp @@ -199,7 +199,7 @@ namespace AZ } } - void MorphTargetDispatchItem::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset& shaderAsset) + void MorphTargetDispatchItem::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset& shaderAsset) { if (!Init()) { @@ -207,7 +207,7 @@ namespace AZ } } - void MorphTargetDispatchItem::OnShaderVariantReinitialized([[maybe_unused]] const RPI::Shader& shader, [[maybe_unused]] const RPI::ShaderVariantId& shaderVariantId, [[maybe_unused]] RPI::ShaderVariantStableId shaderVariantStableId) + void MorphTargetDispatchItem::OnShaderVariantReinitialized(const RPI::ShaderVariant&) { if (!Init()) { diff --git a/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.h b/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.h index ad1fd969a5..e7d78caf5f 100644 --- a/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.h +++ b/Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.h @@ -72,8 +72,8 @@ namespace AZ // ShaderInstanceNotificationBus::Handler overrides void OnShaderReinitialized(const RPI::Shader& shader) override; - void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const RPI::Shader& shader, const RPI::ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId shaderVariantStableId) override; + void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; + void OnShaderVariantReinitialized(const RPI::ShaderVariant& shaderVariant) override; RHI::DispatchItem m_dispatchItem; diff --git a/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp b/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp index 988870cc0e..6e11ba838b 100644 --- a/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp +++ b/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp @@ -354,9 +354,9 @@ namespace AZ Init(); } - void RayTracingPass::OnShaderVariantReinitialized([[maybe_unused]] const RPI::Shader& shader, [[maybe_unused]] const RPI::ShaderVariantId& shaderVariantId, [[maybe_unused]] RPI::ShaderVariantStableId shaderVariantStableId) + void RayTracingPass::OnShaderVariantReinitialized(const RPI::ShaderVariant&) { Init(); } - } // namespace RPI + } // namespace Render } // namespace AZ diff --git a/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.h b/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.h index 6ad082e894..92da052d09 100644 --- a/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.h +++ b/Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.h @@ -53,7 +53,7 @@ namespace AZ // ShaderReloadNotificationBus::Handler overrides void OnShaderReinitialized(const RPI::Shader& shader) override; void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const RPI::Shader& shader, const RPI::ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const RPI::ShaderVariant& shaderVariant) override; // load the raytracing shaders and setup pipeline states void Init(); diff --git a/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.cpp b/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.cpp index a3feddb0b6..b0d08314b1 100644 --- a/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.cpp +++ b/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.cpp @@ -65,9 +65,13 @@ namespace AZ } } - void SkinnedMeshComputePass::OnShaderVariantReinitialized(const RPI::Shader& shader, const RPI::ShaderVariantId&, RPI::ShaderVariantStableId) + void SkinnedMeshComputePass::OnShaderVariantReinitialized(const RPI::ShaderVariant& shaderVariant) { - OnShaderReinitialized(shader); + ComputePass::OnShaderVariantReinitialized(shaderVariant); + if (m_skinnedMeshFeatureProcessor) + { + m_skinnedMeshFeatureProcessor->OnSkinningShaderReinitialized(m_shader); + } } } // namespace Render } // namespace AZ diff --git a/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.h b/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.h index 5f7ff08e47..dbd8704c54 100644 --- a/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.h +++ b/Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshComputePass.h @@ -44,7 +44,7 @@ namespace AZ // ShaderReloadNotificationBus::Handler overrides... void OnShaderReinitialized(const RPI::Shader& shader) override; - void OnShaderVariantReinitialized(const RPI::Shader& shader, const RPI::ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const RPI::ShaderVariant& shaderVariant) override; SkinnedMeshFeatureProcessor* m_skinnedMeshFeatureProcessor = nullptr; }; diff --git a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Material/Material.h b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Material/Material.h index d93f7acc85..fa1bb57166 100644 --- a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Material/Material.h +++ b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Material/Material.h @@ -145,7 +145,7 @@ namespace AZ // ShaderReloadNotificationBus overrides... void OnShaderReinitialized(const Shader& shader) override; void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& shaderVariantId, ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override; /////////////////////////////////////////////////////////////////// template diff --git a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h index a4130deed1..450a9ce834 100644 --- a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h +++ b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h @@ -78,7 +78,7 @@ namespace AZ // ShaderReloadNotificationBus::Handler overrides... void OnShaderReinitialized(const Shader& shader) override; void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& shaderVariantId, ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override; void LoadShader(); PassDescriptor m_passDescriptor; diff --git a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h index 616fc4639a..e168a07f08 100644 --- a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h +++ b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h @@ -78,7 +78,7 @@ namespace AZ // ShaderReloadNotificationBus overrides... void OnShaderReinitialized(const Shader& shader) override; void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& shaderVariantId, ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override; /////////////////////////////////////////////////////////////////// void LoadShader(); diff --git a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h index f0fa39d20f..815fc28cb6 100644 --- a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h +++ b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h @@ -88,7 +88,7 @@ namespace AZ // ShaderReloadNotificationBus overrides... void OnShaderReinitialized(const AZ::RPI::Shader& shader) override; void OnShaderAssetReinitialized(const Data::Asset& shaderAsset) override; - void OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& shaderVariantId, ShaderVariantStableId shaderVariantStableId) override; + void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override; /////////////////////////////////////////////////////////////////// // Update shader variant from m_shader. It's called whenever shader, shader asset or shader variant were changed. diff --git a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h index 58b9809e2b..c63ba8f5b5 100644 --- a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h +++ b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h @@ -22,10 +22,11 @@ namespace AZ { class Shader; class ShaderAsset; + class ShaderVariant; /** - * Connect to this EBus to get notifications whenever a Data::Instance reloads its ShaderAsset. - * The bus address is the AssetId of the ShaderAsset. + * Connect to this EBus to get notifications whenever a shader system class reinitializes itself. + * The bus address is the AssetId of the ShaderAsset, even when the thing being reinitialized is a ShaderVariant or other shader related class. */ class ShaderReloadNotifications : public EBusTraits @@ -35,7 +36,7 @@ namespace AZ ////////////////////////////////////////////////////////////////////////// // EBusTraits overrides static const AZ::EBusAddressPolicy AddressPolicy = AZ::EBusAddressPolicy::ById; - typedef Data::AssetId BusIdType; + typedef Data::AssetId BusIdType; ////////////////////////////////////////////////////////////////////////// virtual ~ShaderReloadNotifications() {} @@ -47,7 +48,7 @@ namespace AZ virtual void OnShaderReinitialized(const Shader& shader) { AZ_UNUSED(shader); } //! Called when a particular shader variant is reinitialized. - virtual void OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& shaderVariantId, ShaderVariantStableId shaderVariantStableId) { AZ_UNUSED(shader); AZ_UNUSED(shaderVariantId); AZ_UNUSED(shaderVariantStableId); } + virtual void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) { AZ_UNUSED(shaderVariant); } }; typedef EBus ShaderReloadNotificationBus; diff --git a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h index d189d26b13..7363ab4d1a 100644 --- a/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h +++ b/Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h @@ -23,10 +23,12 @@ namespace AZ //! the RHI::PipelineStateType of the parent Shader instance. For shaders on the raster //! pipeline, the RHI::DrawFilterTag is also provided. class ShaderVariant final + : public Data::AssetBus::MultiHandler { friend class Shader; public: ShaderVariant() = default; + virtual ~ShaderVariant(); AZ_DEFAULT_COPY_MOVE(ShaderVariant); //! Fills a pipeline state descriptor with settings provided by the ShaderVariant. (Note that @@ -54,12 +56,21 @@ namespace AZ bool IsRootVariant() const { return m_shaderVariantAsset->IsRootVariant(); } ShaderVariantStableId GetStableId() const { return m_shaderVariantAsset->GetStableId(); } + + const Data::Asset& GetShaderAsset() const { return m_shaderAsset; } + const Data::Asset& GetShaderVariantAsset() const { return m_shaderVariantAsset; } private: // Called by Shader. Initializes runtime data from asset data. Returns whether the call succeeded. bool Init( - const ShaderAsset& shaderAsset, - Data::Asset shaderVariantAsset); + const Data::Asset& shaderAsset, + const Data::Asset& shaderVariantAsset); + + // AssetBus overrides... + void OnAssetReloaded(Data::Asset asset) override; + + //! A reference to the shader asset that this is a variant of. + Data::Asset m_shaderAsset; // Cached state from the asset to avoid an indirection. RHI::PipelineStateType m_pipelineStateType = RHI::PipelineStateType::Count; diff --git a/Gems/Atom/RPI/Code/Source/RPI.Public/Material/Material.cpp b/Gems/Atom/RPI/Code/Source/RPI.Public/Material/Material.cpp index d9691ca175..e134b701ac 100644 --- a/Gems/Atom/RPI/Code/Source/RPI.Public/Material/Material.cpp +++ b/Gems/Atom/RPI/Code/Source/RPI.Public/Material/Material.cpp @@ -268,9 +268,9 @@ namespace AZ OnAssetReloaded(m_materialAsset); } - void Material::OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& /*shaderVariantId*/, ShaderVariantStableId shaderVariantStableId) + void Material::OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) { - ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Material::OnShaderVariantReinitialized %s variant %u", this, shader.GetAsset().GetHint().c_str(), shaderVariantStableId.GetIndex()); + ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->Material::OnShaderVariantReinitialized %s", this, shaderVariant.GetShaderVariantAsset().GetHint().c_str()); // Note that it would be better to check the shaderVariantId to see if that variant is relevant to this particular material before reinitializing it. // There could be hundreds or even thousands of variants for a shader, but only one of those variants will be used by any given material. So we could diff --git a/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp b/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp index 5077ccaa51..5902c208c6 100644 --- a/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp +++ b/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp @@ -241,9 +241,8 @@ namespace AZ LoadShader(); } - void ComputePass::OnShaderVariantReinitialized(const Shader& shader, const ShaderVariantId& shaderVariantId, ShaderVariantStableId shaderVariantStableId) + void ComputePass::OnShaderVariantReinitialized(const ShaderVariant&) { - AZ_UNUSED(shader); AZ_UNUSED(shaderVariantId); AZ_UNUSED(shaderVariantStableId); LoadShader(); } diff --git a/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp b/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp index a854867998..aee4fc4f48 100644 --- a/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp +++ b/Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp @@ -57,7 +57,7 @@ namespace AZ LoadShader(); } - void FullscreenTrianglePass::OnShaderVariantReinitialized(const Shader&, const ShaderVariantId&, ShaderVariantStableId) + void FullscreenTrianglePass::OnShaderVariantReinitialized(const ShaderVariant&) { LoadShader(); } diff --git a/Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp b/Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp index 545e71e117..8b13817897 100644 --- a/Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp +++ b/Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp @@ -124,12 +124,9 @@ namespace AZ RefreshShaderVariant(); } - void PipelineStateForDraw::OnShaderVariantReinitialized( - [[maybe_unused]] const Shader& shader, - const ShaderVariantId& shaderVariantId, - [[maybe_unused]] ShaderVariantStableId shaderVariantStableId) + void PipelineStateForDraw::OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) { - if(shaderVariantId == m_shaderVariantId) + if(shaderVariant.GetShaderVariantId() == m_shaderVariantId) { RefreshShaderVariant(); } diff --git a/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp b/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp index 3a79f32b05..6f65bd75c9 100644 --- a/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp +++ b/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp @@ -68,7 +68,7 @@ namespace AZ AZStd::unique_lock lock(m_variantCacheMutex); m_shaderVariants.clear(); } - m_rootVariant.Init(shaderAsset, shaderAsset.GetRootVariant()); + m_rootVariant.Init(Data::Asset{&shaderAsset, AZ::Data::AssetLoadBehavior::PreLoad}, shaderAsset.GetRootVariant()); if (m_pipelineLibraryHandle.IsNull()) { @@ -154,7 +154,14 @@ namespace AZ { AZ_Assert(shaderVariantAsset, "Reloaded ShaderVariantAsset is null"); const ShaderVariantStableId stableId = shaderVariantAsset->GetStableId(); - const ShaderVariantId& shaderVariantId = shaderVariantAsset->GetShaderVariantId(); + + // We make a copy of the updated variant because OnShaderVariantReinitialized must not be called inside + // m_variantCacheMutex or deadlocks may occur. + // Or if there is an error, we leave this object in its default state to indicate there was an error. + // [GFX TODO] We really should have a dedicated message/event for this, but that will be covered by a future task where + // we will merge ShaderReloadNotificationBus messages into one. For now, we just indicate the error by passing an empty ShaderVariant, + // all our call sites don't use this data anyway. + ShaderVariant updatedVariant; if (isError) { @@ -165,7 +172,7 @@ namespace AZ return; } AZStd::unique_lock lock(m_variantCacheMutex); - m_shaderVariants.erase(stableId); + m_shaderVariants.erase(stableId); } else { @@ -178,23 +185,26 @@ namespace AZ { ShaderVariant& shaderVariant = iter->second; - if (!shaderVariant.Init(*m_asset.Get(), shaderVariantAsset)) + if (!shaderVariant.Init(m_asset, shaderVariantAsset)) { AZ_Error("Shader", false, "Failed to init shaderVariant with StableId=%u", shaderVariantAsset->GetStableId()); m_shaderVariants.erase(stableId); } + else + { + updatedVariant = shaderVariant; + } } else { //This is the first time the shader variant asset comes to life. - ShaderVariant newVariant; - newVariant.Init(*m_asset, shaderVariantAsset); - m_shaderVariants.emplace(stableId, newVariant); + updatedVariant.Init(m_asset, shaderVariantAsset); + m_shaderVariants.emplace(stableId, updatedVariant); } } - //Even if there was an error, the interested parties should be notified. - ShaderReloadNotificationBus::Event(m_asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, *this, shaderVariantId, stableId); + // [GFX TODO] It might make more sense to call OnShaderReinitialized here + ShaderReloadNotificationBus::Event(m_asset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, updatedVariant); } /////////////////////////////////////////////////////////////////// @@ -340,7 +350,7 @@ namespace AZ } ShaderVariant newVariant; - newVariant.Init(*m_asset, shaderVariantAsset); + newVariant.Init(m_asset, shaderVariantAsset); m_shaderVariants.emplace(shaderVariantStableId, newVariant); return m_shaderVariants.at(shaderVariantStableId); diff --git a/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp b/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp index 4d50f10c9c..acd9922b68 100644 --- a/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp +++ b/Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp @@ -9,11 +9,13 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * */ + #include +#include +#include #include #include - #include namespace AZ @@ -21,15 +23,26 @@ namespace AZ namespace RPI { bool ShaderVariant::Init( - const ShaderAsset& shaderAsset, - Data::Asset shaderVariantAsset) - { - m_pipelineStateType = shaderAsset.GetPipelineStateType(); - m_pipelineLayoutDescriptor = shaderAsset.GetPipelineLayoutDescriptor(); + const Data::Asset& shaderAsset, + const Data::Asset& shaderVariantAsset) + { + Data::AssetBus::MultiHandler::BusDisconnect(); + Data::AssetBus::MultiHandler::BusConnect(shaderAsset.GetId()); + Data::AssetBus::MultiHandler::BusConnect(shaderVariantAsset.GetId()); + + m_shaderAsset = shaderAsset; + m_pipelineStateType = shaderAsset->GetPipelineStateType(); + m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor(); m_shaderVariantAsset = shaderVariantAsset; + return true; } + ShaderVariant::~ShaderVariant() + { + Data::AssetBus::MultiHandler::BusDisconnect(); + } + void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const { descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor; @@ -78,5 +91,25 @@ namespace AZ { return m_shaderVariantAsset->GetOutputContract(); } + + void ShaderVariant::OnAssetReloaded(Data::Asset asset) + { + ShaderReloadDebugTracker::ScopedSection reloadSection("{%p}->ShaderVariant::OnAssetReloaded %s", this, asset.GetHint().c_str()); + + if (asset.GetAs()) + { + Data::Asset shaderVariantAsset = { asset.GetAs(), AZ::Data::AssetLoadBehavior::PreLoad }; + Init(m_shaderAsset, shaderVariantAsset); + ShaderReloadNotificationBus::Event(m_shaderAsset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, *this); + } + + if (asset.GetAs()) + { + Data::Asset shaderAsset = { asset.GetAs(), AZ::Data::AssetLoadBehavior::PreLoad }; + Init(shaderAsset, m_shaderVariantAsset); + ShaderReloadNotificationBus::Event(m_shaderAsset.GetId(), &ShaderReloadNotificationBus::Events::OnShaderVariantReinitialized, *this); + } + } + } // namespace RPI } // namespace AZ