You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
430 lines
16 KiB
C++
430 lines
16 KiB
C++
/*
|
|
* Copyright (c) Contributors to the Open 3D Engine Project.
|
|
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0 OR MIT
|
|
*
|
|
*/
|
|
|
|
#include "RHITestFixture.h"
|
|
#include <Tests/Factory.h>
|
|
#include <Tests/Device.h>
|
|
#include <Atom/RHI/ScopeProducer.h>
|
|
#include <Atom/RHI/FrameScheduler.h>
|
|
#include <AzCore/Math/Random.h>
|
|
|
|
namespace UnitTest
|
|
{
|
|
using namespace AZ;
|
|
|
|
struct ImportedImage
|
|
{
|
|
RHI::AttachmentId m_id;
|
|
RHI::Ptr<RHI::Image> m_image;
|
|
};
|
|
|
|
struct ImportedBuffer
|
|
{
|
|
RHI::AttachmentId m_id;
|
|
RHI::Ptr<RHI::Buffer> m_buffer;
|
|
};
|
|
|
|
struct TransientImage
|
|
{
|
|
RHI::AttachmentId m_id;
|
|
RHI::ImageDescriptor m_descriptor;
|
|
};
|
|
|
|
struct TransientBuffer
|
|
{
|
|
RHI::AttachmentId m_id;
|
|
RHI::BufferDescriptor m_descriptor;
|
|
};
|
|
|
|
class ScopeProducer
|
|
: public RHI::ScopeProducer
|
|
{
|
|
public:
|
|
AZ_CLASS_ALLOCATOR(ScopeProducer, SystemAllocator, 0);
|
|
|
|
ScopeProducer(const RHI::ScopeId& scopeId)
|
|
: RHI::ScopeProducer(scopeId)
|
|
{}
|
|
|
|
void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override
|
|
{
|
|
RHI::FrameGraphAttachmentInterface attachmentDatabase = frameGraph.GetAttachmentDatabase();
|
|
|
|
for (ImportedImage& image : m_imageImports)
|
|
{
|
|
ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(image.m_id));
|
|
attachmentDatabase.ImportImage(image.m_id, image.m_image);
|
|
ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(image.m_id));
|
|
}
|
|
|
|
for (ImportedBuffer& buffer : m_bufferImports)
|
|
{
|
|
ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
|
|
attachmentDatabase.ImportBuffer(buffer.m_id, buffer.m_buffer);
|
|
ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
|
|
}
|
|
|
|
for (const TransientImage& image : m_transientImages)
|
|
{
|
|
ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(image.m_id));
|
|
attachmentDatabase.CreateTransientImage(RHI::TransientImageDescriptor{image.m_id, image.m_descriptor});
|
|
ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(image.m_id));
|
|
}
|
|
|
|
for (const TransientBuffer& buffer : m_transientBuffers)
|
|
{
|
|
ASSERT_FALSE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
|
|
attachmentDatabase.CreateTransientBuffer(RHI::TransientBufferDescriptor{buffer.m_id, buffer.m_descriptor});
|
|
ASSERT_TRUE(attachmentDatabase.IsAttachmentValid(buffer.m_id));
|
|
}
|
|
|
|
for (const ImageUsage& usage : m_imageUsages)
|
|
{
|
|
frameGraph.UseShaderAttachment(usage.m_descriptor, usage.m_access);
|
|
}
|
|
|
|
for (const BufferUsage& usage : m_bufferUsages)
|
|
{
|
|
frameGraph.UseShaderAttachment(usage.m_descriptor, usage.m_access);
|
|
}
|
|
}
|
|
|
|
void CompileResources(const RHI::FrameGraphCompileContext& context) override
|
|
{
|
|
ASSERT_TRUE(context.GetScopeId() == GetScopeId());
|
|
|
|
for (const ImageUsage& usage : m_imageUsages)
|
|
{
|
|
ASSERT_TRUE(context.GetImageView(usage.m_descriptor.m_attachmentId) != nullptr);
|
|
}
|
|
|
|
for (const BufferUsage& usage : m_bufferUsages)
|
|
{
|
|
ASSERT_TRUE(context.GetBufferView(usage.m_descriptor.m_attachmentId) != nullptr);
|
|
}
|
|
}
|
|
|
|
void BuildCommandList(const RHI::FrameGraphExecuteContext& context) override
|
|
{
|
|
ASSERT_TRUE(context.GetScopeId() == GetScopeId());
|
|
ASSERT_TRUE(context.GetCommandListIndex() == 0);
|
|
ASSERT_TRUE(context.GetCommandListCount() == 1);
|
|
}
|
|
|
|
AZStd::vector<ImportedImage> m_imageImports;
|
|
AZStd::vector<ImportedBuffer> m_bufferImports;
|
|
AZStd::vector<TransientImage> m_transientImages;
|
|
AZStd::vector<TransientBuffer> m_transientBuffers;
|
|
|
|
struct ImageUsage
|
|
{
|
|
RHI::ImageScopeAttachmentDescriptor m_descriptor;
|
|
RHI::ScopeAttachmentAccess m_access;
|
|
};
|
|
|
|
struct BufferUsage
|
|
{
|
|
RHI::BufferScopeAttachmentDescriptor m_descriptor;
|
|
RHI::ScopeAttachmentAccess m_access;
|
|
};
|
|
|
|
AZStd::vector<ImageUsage> m_imageUsages;
|
|
AZStd::vector<BufferUsage> m_bufferUsages;
|
|
};
|
|
|
|
class FrameSchedulerTests
|
|
: public RHITestFixture
|
|
{
|
|
public:
|
|
FrameSchedulerTests()
|
|
: RHITestFixture()
|
|
{
|
|
}
|
|
|
|
void SetUp() override
|
|
{
|
|
UnitTest::RHITestFixture::SetUp();
|
|
|
|
m_rootFactory.reset(aznew Factory());
|
|
|
|
RHI::Ptr<RHI::Device> device = MakeTestDevice();
|
|
|
|
m_device = device;
|
|
m_state.reset(new State);
|
|
|
|
{
|
|
m_state->m_bufferPool = RHI::Factory::Get().CreateBufferPool();
|
|
|
|
RHI::BufferPoolDescriptor desc;
|
|
desc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
|
|
m_state->m_bufferPool->Init(*device, desc);
|
|
}
|
|
|
|
for (uint32_t i = 0; i < ImportedBufferCount; ++i)
|
|
{
|
|
RHI::Ptr<RHI::Buffer> buffer;
|
|
buffer = RHI::Factory::Get().CreateBuffer();
|
|
|
|
RHI::BufferDescriptor desc;
|
|
desc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
|
|
desc.m_byteCount = BufferSize;
|
|
|
|
RHI::BufferInitRequest request;
|
|
request.m_descriptor = desc;
|
|
request.m_buffer = buffer.get();
|
|
m_state->m_bufferPool->InitBuffer(request);
|
|
|
|
m_state->m_bufferAttachments[i].m_id = RHI::AttachmentId{AZStd::string::format("B%d", i)};
|
|
m_state->m_bufferAttachments[i].m_buffer = AZStd::move(buffer);
|
|
}
|
|
|
|
{
|
|
m_state->m_imagePool = RHI::Factory::Get().CreateImagePool();
|
|
|
|
RHI::ImagePoolDescriptor desc;
|
|
desc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
|
|
m_state->m_imagePool->Init(*device, desc);
|
|
}
|
|
|
|
for (uint32_t i = 0; i < ImportedImageCount; ++i)
|
|
{
|
|
RHI::Ptr<RHI::Image> image;
|
|
image = RHI::Factory::Get().CreateImage();
|
|
|
|
RHI::ImageDescriptor desc = RHI::ImageDescriptor::Create2D(
|
|
RHI::ImageBindFlags::ShaderReadWrite,
|
|
ImageSize,
|
|
ImageSize,
|
|
RHI::Format::R8G8B8A8_UNORM);
|
|
|
|
RHI::ImageInitRequest request;
|
|
request.m_descriptor = desc;
|
|
request.m_image = image.get();
|
|
m_state->m_imagePool->InitImage(request);
|
|
|
|
m_state->m_imageAttachments[i].m_id = RHI::AttachmentId{AZStd::string::format("I%d", i)};
|
|
m_state->m_imageAttachments[i].m_image = AZStd::move(image);
|
|
}
|
|
|
|
for (uint32_t i = 0; i < ScopeCount; ++i)
|
|
{
|
|
m_state->m_producers.emplace_back(aznew ScopeProducer(RHI::ScopeId{AZStd::string::format("S%d", i)}));
|
|
}
|
|
}
|
|
|
|
void TearDown() override
|
|
{
|
|
m_state.reset();
|
|
m_device = nullptr;
|
|
m_rootFactory.reset();
|
|
RHITestFixture::TearDown();
|
|
}
|
|
|
|
void Test()
|
|
{
|
|
RHI::FrameScheduler frameScheduler;
|
|
|
|
RHI::FrameSchedulerDescriptor descriptor;
|
|
descriptor.m_transientAttachmentPoolDescriptor.m_bufferBudgetInBytes = 80 * 1024 * 1024;
|
|
frameScheduler.Init(*m_device, descriptor);
|
|
|
|
RHI::ImageScopeAttachmentDescriptor imageBindingDescs[2];
|
|
imageBindingDescs[0].m_imageViewDescriptor = RHI::ImageViewDescriptor();
|
|
imageBindingDescs[0].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
|
|
imageBindingDescs[0].m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 0.0, 0.0, 0.0);
|
|
imageBindingDescs[1] = imageBindingDescs[0];
|
|
imageBindingDescs[1].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
|
|
|
|
RHI::BufferScopeAttachmentDescriptor bufferBindingDescs[2];
|
|
bufferBindingDescs[0].m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, BufferSize);
|
|
bufferBindingDescs[0].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
|
|
bufferBindingDescs[0].m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 0.0, 0.0, 0.0);
|
|
bufferBindingDescs[1] = bufferBindingDescs[0];
|
|
bufferBindingDescs[1].m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
|
|
|
|
AZ::SimpleLcgRandom random;
|
|
|
|
struct Interval
|
|
{
|
|
uint32_t m_begin;
|
|
uint32_t m_end;
|
|
};
|
|
|
|
Interval bufferScopeIntervals[BufferCount];
|
|
for (uint32_t i = 0; i < BufferCount; ++i)
|
|
{
|
|
uint32_t b = random.GetRandom() % ScopeCount;
|
|
uint32_t e = random.GetRandom() % ScopeCount;
|
|
if (b > e)
|
|
{
|
|
AZStd::swap(b, e);
|
|
}
|
|
|
|
bufferScopeIntervals[i].m_begin = b;
|
|
bufferScopeIntervals[i].m_end = e;
|
|
}
|
|
|
|
Interval imageScopeIntervals[ImageCount];
|
|
for (uint32_t i = 0; i < ImageCount; ++i)
|
|
{
|
|
uint32_t b = random.GetRandom() % ScopeCount;
|
|
uint32_t e = random.GetRandom() % ScopeCount;
|
|
if (b > e)
|
|
{
|
|
AZStd::swap(b, e);
|
|
}
|
|
|
|
imageScopeIntervals[i].m_begin = b;
|
|
imageScopeIntervals[i].m_end = e;
|
|
}
|
|
|
|
for (uint32_t scopeIdx = 0; scopeIdx < ScopeCount; ++scopeIdx)
|
|
{
|
|
ScopeProducer& producer = *m_state->m_producers[scopeIdx];
|
|
|
|
//
|
|
// IMPORTS
|
|
//
|
|
|
|
for (uint32_t i = 0; i < ImportedBufferCount; ++i)
|
|
{
|
|
if (scopeIdx == bufferScopeIntervals[i].m_begin)
|
|
{
|
|
producer.m_bufferImports.push_back(m_state->m_bufferAttachments[i]);
|
|
bufferBindingDescs[0].m_attachmentId = m_state->m_bufferAttachments[i].m_id;
|
|
producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
|
|
}
|
|
else if (scopeIdx == bufferScopeIntervals[i].m_end)
|
|
{
|
|
bufferBindingDescs[1].m_attachmentId = m_state->m_bufferAttachments[i].m_id;
|
|
producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
|
|
}
|
|
}
|
|
|
|
for (uint32_t i = 0; i < ImportedImageCount; ++i)
|
|
{
|
|
if (scopeIdx == imageScopeIntervals[i].m_begin)
|
|
{
|
|
producer.m_imageImports.push_back(m_state->m_imageAttachments[i]);
|
|
imageBindingDescs[0].m_attachmentId = m_state->m_imageAttachments[i].m_id;
|
|
producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
|
|
}
|
|
else if (scopeIdx == imageScopeIntervals[i].m_end)
|
|
{
|
|
imageBindingDescs[1].m_attachmentId = m_state->m_imageAttachments[i].m_id;
|
|
producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
|
|
}
|
|
}
|
|
|
|
//
|
|
// TRANSIENTS
|
|
//
|
|
|
|
for (uint32_t i = 0; i < TransientBufferCount; ++i)
|
|
{
|
|
const uint32_t adjustedIndex = i + ImportedBufferCount;
|
|
|
|
TransientBuffer transientBuffer =
|
|
{
|
|
RHI::AttachmentId{AZStd::string::format("B%d", adjustedIndex)},
|
|
RHI::BufferDescriptor(RHI::BufferBindFlags::ShaderReadWrite, BufferSize)
|
|
};
|
|
|
|
bufferBindingDescs[0].m_attachmentId = transientBuffer.m_id;
|
|
bufferBindingDescs[1].m_attachmentId = transientBuffer.m_id;
|
|
|
|
if (scopeIdx == bufferScopeIntervals[adjustedIndex].m_begin)
|
|
{
|
|
producer.m_transientBuffers.push_back(transientBuffer);
|
|
producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
|
|
}
|
|
else if (scopeIdx == bufferScopeIntervals[adjustedIndex].m_end)
|
|
{
|
|
producer.m_bufferUsages.push_back(ScopeProducer::BufferUsage{ bufferBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
|
|
}
|
|
}
|
|
|
|
for (uint32_t i = 0; i < TransientImageCount; ++i)
|
|
{
|
|
const uint32_t adjustedIndex = i + ImportedImageCount;
|
|
|
|
TransientImage transientImage =
|
|
{
|
|
RHI::AttachmentId{AZStd::string::format("I%d", adjustedIndex)},
|
|
RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, ImageSize, ImageSize, RHI::Format::R8G8B8A8_UNORM)
|
|
};
|
|
|
|
imageBindingDescs[0].m_attachmentId = transientImage.m_id;
|
|
imageBindingDescs[1].m_attachmentId = transientImage.m_id;
|
|
|
|
if (scopeIdx == imageScopeIntervals[adjustedIndex].m_begin)
|
|
{
|
|
producer.m_transientImages.push_back(transientImage);
|
|
producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[0], RHI::ScopeAttachmentAccess::ReadWrite });
|
|
}
|
|
else if (scopeIdx == imageScopeIntervals[adjustedIndex].m_end)
|
|
{
|
|
producer.m_imageUsages.push_back(ScopeProducer::ImageUsage{ imageBindingDescs[1], RHI::ScopeAttachmentAccess::Read });
|
|
}
|
|
}
|
|
}
|
|
|
|
for (uint32_t frameIdx = 0; frameIdx < FrameIterationCount; ++frameIdx)
|
|
{
|
|
frameScheduler.BeginFrame();
|
|
|
|
for (AZStd::unique_ptr<ScopeProducer>& producer : m_state->m_producers)
|
|
{
|
|
frameScheduler.ImportScopeProducer(*producer);
|
|
}
|
|
|
|
RHI::FrameSchedulerCompileRequest compileRequest;
|
|
compileRequest.m_jobPolicy = RHI::JobPolicy::Serial;
|
|
frameScheduler.Compile(compileRequest);
|
|
|
|
frameScheduler.Execute(RHI::JobPolicy::Serial);
|
|
|
|
frameScheduler.EndFrame();
|
|
}
|
|
|
|
frameScheduler.Shutdown();
|
|
}
|
|
|
|
private:
|
|
static const uint32_t FrameIterationCount = 128;
|
|
static const uint32_t ImportedImageCount = 16;
|
|
static const uint32_t ImportedBufferCount = 16;
|
|
static const uint32_t TransientBufferCount = 16;
|
|
static const uint32_t TransientImageCount = 16;
|
|
static const uint32_t BufferCount = ImportedBufferCount + TransientBufferCount;
|
|
static const uint32_t ImageCount = ImportedImageCount + TransientImageCount;
|
|
static const uint32_t BufferSize = 64;
|
|
static const uint32_t ImageSize = 16;
|
|
static const uint32_t ScopeCount = 16;
|
|
|
|
AZStd::unique_ptr<Factory> m_rootFactory;
|
|
RHI::Ptr<RHI::Device> m_device;
|
|
|
|
struct State
|
|
{
|
|
RHI::Ptr<RHI::BufferPool> m_bufferPool;
|
|
RHI::Ptr<RHI::ImagePool> m_imagePool;
|
|
ImportedImage m_imageAttachments[ImportedImageCount];
|
|
ImportedBuffer m_bufferAttachments[ImportedBufferCount];
|
|
AZStd::vector<AZStd::unique_ptr<ScopeProducer>> m_producers;
|
|
};
|
|
|
|
AZStd::unique_ptr<State> m_state;
|
|
};
|
|
|
|
TEST_F(FrameSchedulerTests, Test)
|
|
{
|
|
Test();
|
|
}
|
|
}
|