Create helper function for getting threads per (#4480)

* Create helper function for getting threads per
group from a compute shader

Added GetComputeShaderNumThreads() functions to RPIUtils.
By default the function returns 1, 1, 1 in case of errors.

Updated existing code that was looking for 'numthreads' attribute data
with the new GetComputeShaderNumThreads() API.

Signed-off-by: garrieta <garrieta@amazon.com>
monroegm-disable-blank-issue-2
galibzon 4 years ago committed by GitHub
parent bb8971a3ad
commit 643bd84739
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,27 +54,10 @@ namespace AZ
m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
// retrieve the number of threads per thread group from the shader
const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
if (!outcome.IsSuccess())
{
const RHI::ShaderStageAttributeArguments& args = *numThreads;
bool validArgs = args.size() == 3;
if (validArgs)
{
validArgs &= args[0].type() == azrtti_typeid<int>();
validArgs &= args[1].type() == azrtti_typeid<int>();
validArgs &= args[2].type() == azrtti_typeid<int>();
}
if (!validArgs)
{
AZ_Error("PassSystem", false, "[DiffuseProbeGridBlendDistancePass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
return;
}
m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
AZ_Error("PassSystem", false, "[DiffuseProbeGridBlendDistancePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
}
}

@ -54,27 +54,10 @@ namespace AZ
m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
// retrieve the number of threads per thread group from the shader
const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
if (!outcome.IsSuccess())
{
const RHI::ShaderStageAttributeArguments& args = *numThreads;
bool validArgs = args.size() == 3;
if (validArgs)
{
validArgs &= args[0].type() == azrtti_typeid<int>();
validArgs &= args[1].type() == azrtti_typeid<int>();
validArgs &= args[2].type() == azrtti_typeid<int>();
}
if (!validArgs)
{
AZ_Error("PassSystem", false, "[DiffuseProbeBlendIrradiancePass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
return;
}
m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
AZ_Error("PassSystem", false, "[DiffuseProbeBlendIrradiancePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
}
}

@ -67,27 +67,10 @@ namespace AZ
srgLayout = shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
// retrieve the number of threads per thread group from the shader
const auto numThreads = shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(shader->GetAsset(), dispatchArgs);
if (!outcome.IsSuccess())
{
const RHI::ShaderStageAttributeArguments& args = *numThreads;
bool validArgs = args.size() == 3;
if (validArgs)
{
validArgs &= args[0].type() == azrtti_typeid<int>();
validArgs &= args[1].type() == azrtti_typeid<int>();
validArgs &= args[2].type() == azrtti_typeid<int>();
}
if (!validArgs)
{
AZ_Error("PassSystem", false, "[DiffuseProbeGridBorderUpdatePass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
return;
}
dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
AZ_Error("PassSystem", false, "[DiffuseProbeGridBorderUpdatePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
}
}

@ -58,27 +58,10 @@ namespace AZ
m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
// retrieve the number of threads per thread group from the shader
const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
if (!outcome.IsSuccess())
{
const RHI::ShaderStageAttributeArguments& args = *numThreads;
bool validArgs = args.size() == 3;
if (validArgs)
{
validArgs &= args[0].type() == azrtti_typeid<int>();
validArgs &= args[1].type() == azrtti_typeid<int>();
validArgs &= args[2].type() == azrtti_typeid<int>();
}
if (!validArgs)
{
AZ_Error("PassSystem", false, "[DiffuseProbeClassificationPass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
return;
}
m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
AZ_Error("PassSystem", false, "[DiffuseProbeClassificationPass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
}
}

@ -58,27 +58,10 @@ namespace AZ
m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
// retrieve the number of threads per thread group from the shader
const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
if (!outcome.IsSuccess())
{
const RHI::ShaderStageAttributeArguments& args = *numThreads;
bool validArgs = args.size() == 3;
if (validArgs)
{
validArgs &= args[0].type() == azrtti_typeid<int>();
validArgs &= args[1].type() == azrtti_typeid<int>();
validArgs &= args[2].type() == azrtti_typeid<int>();
}
if (!validArgs)
{
AZ_Error("PassSystem", false, "[DiffuseProbeRelocationPass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
return;
}
m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
AZ_Error("PassSystem", false, "[DiffuseProbeRelocationPass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
}
}

@ -13,6 +13,7 @@
#include <Atom/RPI.Public/Shader/Shader.h>
#include <Atom/RPI.Public/Model/ModelLod.h>
#include <Atom/RPI.Public/Buffer/Buffer.h>
#include <Atom/RPI.Public/RPIUtils.h>
#include <Atom/RHI/Factory.h>
#include <Atom/RHI/BufferView.h>
@ -79,15 +80,11 @@ namespace AZ
m_dispatchItem.m_pipelineState = m_morphTargetShader->AcquirePipelineState(pipelineStateDescriptor);
// Get the threads-per-group values from the compute shader [numthreads(x,y,z)]
const auto& numThreads = m_morphTargetShader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, AZ::Name{ "numthreads" });
auto& arguments = m_dispatchItem.m_arguments.m_direct;
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_morphTargetShader->GetAsset(), arguments);
if (!outcome.IsSuccess())
{
const auto& args = *numThreads;
// Check that the arguments are valid integers, and fall back to 1,1,1 if there is an error
arguments.m_threadsPerGroupX = static_cast<uint16_t>(args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : 1);
arguments.m_threadsPerGroupY = static_cast<uint16_t>(args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : 1);
arguments.m_threadsPerGroupZ = static_cast<uint16_t>(args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : 1);
AZ_Error("MorphTargetDispatchItem", false, outcome.GetError().c_str());
}
arguments.m_totalNumberOfThreadsX = m_morphTargetMetaData.m_vertexCount;

@ -14,6 +14,7 @@
#include <Atom/RPI.Public/Shader/Shader.h>
#include <Atom/RPI.Public/Model/ModelLod.h>
#include <Atom/RPI.Public/Buffer/Buffer.h>
#include <Atom/RPI.Public/RPIUtils.h>
#include <Atom/RHI/Factory.h>
#include <Atom/RHI/BufferView.h>
@ -199,17 +200,14 @@ namespace AZ
m_instanceSrg->Compile();
m_dispatchItem.m_uniqueShaderResourceGroup = m_instanceSrg->GetRHIShaderResourceGroup();
m_dispatchItem.m_pipelineState = m_skinningShader->AcquirePipelineState(pipelineStateDescriptor);
const auto& numThreads = m_skinningShader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, AZ::Name{ "numthreads" });
auto& arguments = m_dispatchItem.m_arguments.m_direct;
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_skinningShader->GetAsset(), arguments);
if (!outcome.IsSuccess())
{
const auto& args = *numThreads;
arguments.m_threadsPerGroupX = static_cast<uint16_t>(args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : 1);
arguments.m_threadsPerGroupY = static_cast<uint16_t>(args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : 1);
arguments.m_threadsPerGroupZ = static_cast<uint16_t>(args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : 1);
AZ_Error("SkinnedMeshInputBuffers", false, outcome.GetError().c_str());
}
arguments.m_totalNumberOfThreadsX = xThreads;
arguments.m_totalNumberOfThreadsY = yThreads;
arguments.m_totalNumberOfThreadsZ = 1;

@ -11,6 +11,7 @@
#include <AtomCore/Instance/Instance.h>
#include <Atom/RHI/DispatchItem.h>
#include <Atom/RPI.Public/Base.h>
#include <Atom/RPI.Public/Image/StreamingImage.h>
#include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
@ -40,6 +41,23 @@ namespace AZ
//! Loads a streaming image asset for the given file path
Data::Instance<RPI::StreamingImage> LoadStreamingTexture(AZStd::string_view path);
//! Looks for a three arguments attribute named @attributeName in the given shader asset.
//! Assigns the value to each non-null output variables.
//! @param shaderAsset
//! @param attributeName
//! @param numThreadsX Can be NULL. If not NULL it takes the value of the 1st argument of the attribute. Becomes 1 on error.
//! @param numThreadsY Can be NULL. If not NULL it takes the value of the 2nd argument of the attribute. Becomes 1 on error.
//! @param numThreadsZ Can be NULL. If not NULL it takes the value of the 3rd argument of the attribute. Becomes 1 on error.
//! @returns An Outcome instance with error message in case of error.
AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ);
//! Same as above, but assumes the name of the attribute to be 'numthreads'.
AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ);
//! Same as above. Provided as a convenience when all arguments of the 'numthreads' attributes should be assigned to RHI::DispatchDirect::m_threadsPerGroup* variables.
AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, RHI::DispatchDirect& dispatchDirect);
} // namespace RPI
} // namespace AZ

@ -107,30 +107,13 @@ namespace AZ
dispatchArgs.m_totalNumberOfThreadsY = passData->m_totalNumberOfThreadsY;
dispatchArgs.m_totalNumberOfThreadsZ = passData->m_totalNumberOfThreadsZ;
const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
if (numThreads)
const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), dispatchArgs);
if (!outcome.IsSuccess())
{
const RHI::ShaderStageAttributeArguments& args = *numThreads;
bool validArgs = args.size() == 3;
if (validArgs)
{
validArgs &= args[0].type() == azrtti_typeid<int>();
validArgs &= args[1].type() == azrtti_typeid<int>();
validArgs &= args[2].type() == azrtti_typeid<int>();
}
if (!validArgs)
{
AZ_Error("PassSystem", false, "[ComputePass '%s']: Shader '%s' contains invalid numthreads arguments.",
GetPathName().GetCStr(),
passData->m_shaderReference.m_filePath.data());
return;
}
dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
AZ_Error("PassSystem", false, "[ComputePass '%s']: Shader '%.*s' contains invalid numthreads arguments:\n%s",
GetPathName().GetCStr(), passData->m_shaderReference.m_filePath.size(), passData->m_shaderReference.m_filePath.data(), outcome.GetError().c_str());
}
m_dispatchItem.m_arguments = dispatchArgs;
m_isFullscreenPass = passData->m_makeFullscreenPass;

@ -143,5 +143,79 @@ namespace AZ
return RPI::StreamingImage::FindOrCreate(streamingImageAsset);
}
//! A helper function for GetComputeShaderNumThreads(), to consolidate error messages, etc.
static bool GetAttributeArgumentByIndex(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, const RHI::ShaderStageAttributeArguments& args, const size_t argIndex, uint16_t* value, AZStd::string& errorMsg)
{
if (value)
{
const auto numArguments = args.size();
if (numArguments > argIndex)
{
if (args[argIndex].type() == azrtti_typeid<int>())
{
*value = aznumeric_caster(AZStd::any_cast<int>(args[argIndex]));
}
else
{
errorMsg = AZStd::string::format("Was expecting argument '%zu' in attribute '%s' to be of type 'int' from shader asset '%s'", argIndex, attributeName.GetCStr(), shaderAsset.GetHint().c_str());
return false;
}
}
else
{
errorMsg = AZStd::string::format("Was expecting at least '%zu' arguments in attribute '%s' from shader asset '%s'", argIndex + 1, attributeName.GetCStr(), shaderAsset.GetHint().c_str());
return false;
}
}
return true;
}
AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ)
{
// Set default 1, 1, 1 now. In case of errors later this is what the caller will get.
if (numThreadsX)
{
*numThreadsX = 1;
}
if (numThreadsY)
{
*numThreadsY = 1;
}
if (numThreadsZ)
{
*numThreadsZ = 1;
}
const auto numThreads = shaderAsset->GetAttribute(RHI::ShaderStage::Compute, attributeName);
if (!numThreads)
{
return AZ::Failure(AZStd::string::format("Couldn't find attribute '%s' in shader asset '%s'", attributeName.GetCStr(), shaderAsset.GetHint().c_str()));
}
const RHI::ShaderStageAttributeArguments& args = *numThreads;
AZStd::string errorMsg;
if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 0, numThreadsX, errorMsg))
{
return AZ::Failure(errorMsg);
}
if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 1, numThreadsY, errorMsg))
{
return AZ::Failure(errorMsg);
}
if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 2, numThreadsZ, errorMsg))
{
return AZ::Failure(errorMsg);
}
return AZ::Success();
}
AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ)
{
return GetComputeShaderNumThreads(shaderAsset, Name{ "numthreads" }, numThreadsX, numThreadsY, numThreadsZ);
}
AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, RHI::DispatchDirect& dispatchDirect)
{
return GetComputeShaderNumThreads(shaderAsset, &dispatchDirect.m_threadsPerGroupX, &dispatchDirect.m_threadsPerGroupY, &dispatchDirect.m_threadsPerGroupZ);
}
}
}

Loading…
Cancel
Save