/* * All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or * its licensors. * * For complete copyright and license terms please see the LICENSE at the root of this * distribution (the "License"). All use of this software is governed by the License, * or, if provided, by the license below or the license accompanying this file. Do not * remove or modify any license notices. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * */ #include "RHITestFixture.h" #include #include #include #include #include #include #include #include namespace UnitTest { using namespace AZ; struct DrawItemData { DrawItemData(SimpleLcgRandom& random, const RHI::Buffer* bufferEmpty, const RHI::PipelineState* psoEmpty) { m_pipelineState = psoEmpty; // Fill with deterministic random data to compare against. for (auto& streamBufferView : m_streamBufferViews) { streamBufferView = RHI::StreamBufferView{ *bufferEmpty, random.GetRandom(), random.GetRandom(), random.GetRandom() }; } m_tag = RHI::DrawListTag(random.GetRandom() % RHI::Limits::Pipeline::DrawListTagCountMax); m_stencilRef = static_cast(random.GetRandom()); m_sortKey = random.GetRandom(); } AZStd::array m_streamBufferViews; const RHI::PipelineState* m_pipelineState; RHI::DrawListTag m_tag; RHI::DrawItemSortKey m_sortKey; uint8_t m_stencilRef; }; struct DrawPacketData { const size_t DrawItemCountMax = 8; DrawPacketData(SimpleLcgRandom& random) { m_bufferEmpty = RHI::Factory::Get().CreateBuffer(); m_psoEmpty = RHI::Factory::Get().CreatePipelineState(); for (auto& srg : m_srgs) { srg = RHI::Factory::Get().CreateShaderResourceGroup(); } unsigned int* data = reinterpret_cast(m_rootConstants.data()); for (uint32_t i = 0; i < m_rootConstants.size()/sizeof(unsigned int); ++i) { data[i] = random.GetRandom(); } for (size_t i = 0; i < DrawItemCountMax; ++i) { m_drawItemDatas.emplace_back(random, m_bufferEmpty.get(), m_psoEmpty.get()); } m_indexBufferView = RHI::IndexBufferView(*m_bufferEmpty, random.GetRandom(), random.GetRandom(), RHI::IndexFormat::Uint16); } void ValidateDrawItem(const DrawItemData& drawItemData, RHI::DrawItemProperties itemProperties) const { const RHI::DrawItem* drawItem = itemProperties.m_item; EXPECT_EQ(itemProperties.m_sortKey, drawItemData.m_sortKey); EXPECT_EQ(drawItem->m_stencilRef, drawItemData.m_stencilRef); EXPECT_EQ(drawItem->m_pipelineState, drawItemData.m_pipelineState); EXPECT_EQ(static_cast(drawItem->m_streamBufferViewCount), drawItemData.m_streamBufferViews.size()); for (size_t i = 0; i < static_cast(drawItem->m_streamBufferViewCount); ++i) { EXPECT_EQ(drawItemData.m_streamBufferViews[i].GetHash(), drawItem->m_streamBufferViews[i].GetHash()); } EXPECT_EQ(static_cast(drawItem->m_shaderResourceGroupCount), m_srgs.size()); for (size_t i = 0; i < m_srgs.size(); ++i) { EXPECT_EQ(m_srgs[i], drawItem->m_shaderResourceGroups[i]); } EXPECT_EQ(static_cast(drawItem->m_rootConstantSize), m_rootConstants.size()); for (size_t i = 0; i < m_rootConstants.size(); ++i) { EXPECT_EQ(m_rootConstants[i], drawItem->m_rootConstants[i]); } EXPECT_EQ(drawItem->m_indexBufferView->GetHash(), m_indexBufferView.GetHash()); } const RHI::DrawPacket* Build(RHI::DrawPacketBuilder& builder) { builder.Begin(nullptr); for (auto& srgPtr : m_srgs) { builder.AddShaderResourceGroup(srgPtr.get()); } builder.SetRootConstants(m_rootConstants); builder.SetIndexBufferView(m_indexBufferView); RHI::DrawListMask drawListMask; for (size_t i = 0; i < DrawItemCountMax; ++i) { const DrawItemData& drawItemData = m_drawItemDatas[i]; drawListMask[drawItemData.m_tag.GetIndex()] = true; RHI::DrawPacketBuilder::DrawRequest drawRequest; drawRequest.m_listTag = drawItemData.m_tag; drawRequest.m_sortKey = drawItemData.m_sortKey; drawRequest.m_stencilRef = drawItemData.m_stencilRef; drawRequest.m_streamBufferViews = drawItemData.m_streamBufferViews; drawRequest.m_pipelineState = drawItemData.m_pipelineState; builder.AddDrawItem(drawRequest); } const RHI::DrawPacket* drawPacket = builder.End(); EXPECT_NE(drawPacket, nullptr); EXPECT_EQ(drawPacket->GetDrawListMask(), drawListMask); EXPECT_EQ(drawPacket->GetDrawItemCount(), m_drawItemDatas.size()); for (size_t i = 0; i < m_drawItemDatas.size(); ++i) { ValidateDrawItem(m_drawItemDatas[i], drawPacket->GetDrawItem(i)); } return drawPacket; } RHI::Ptr m_bufferEmpty; RHI::ConstPtr m_psoEmpty; AZStd::array, RHI::Limits::Pipeline::ShaderResourceGroupCountMax> m_srgs; AZStd::array m_rootConstants; RHI::IndexBufferView m_indexBufferView; AZStd::vector m_drawItemDatas; }; class DrawPacketTest : public RHITestFixture { public: void SetUp() override { RHITestFixture::SetUp(); m_factory.reset(aznew Factory()); m_drawListTagRegistry = RHI::DrawListTagRegistry::Create(); } void TearDown() override { m_drawListTagRegistry = nullptr; m_factory.reset(); RHITestFixture::TearDown(); } protected: static const uint32_t s_randomSeed = 1234; RHI::Ptr m_drawListTagRegistry; RHI::DrawListContext m_drawListContext; AZStd::unique_ptr m_factory; }; TEST_F(DrawPacketTest, TestDrawListTagRegistryNullCase) { RHI::DrawListTag nullTag = m_drawListTagRegistry->AcquireTag(Name()); EXPECT_TRUE(nullTag.IsNull()); EXPECT_EQ(m_drawListTagRegistry->GetAllocatedTagCount(), 0); m_drawListTagRegistry->ReleaseTag(nullTag); EXPECT_EQ(m_drawListTagRegistry->GetAllocatedTagCount(), 0); } TEST_F(DrawPacketTest, TestDrawListTagRegistrySimple) { const Name forwardName1("Forward"); const Name forwardName2("forward"); RHI::DrawListTag forwardTag1 = m_drawListTagRegistry->AcquireTag(forwardName1); RHI::DrawListTag forwardTag2 = m_drawListTagRegistry->AcquireTag(forwardName2); EXPECT_FALSE(forwardTag1.IsNull()); EXPECT_FALSE(forwardTag2.IsNull()); EXPECT_NE(forwardTag1, forwardTag2); RHI::DrawListTag forwardTag3 = m_drawListTagRegistry->AcquireTag(forwardName1); EXPECT_EQ(forwardTag1, forwardTag3); m_drawListTagRegistry->ReleaseTag(forwardTag1); m_drawListTagRegistry->ReleaseTag(forwardTag2); m_drawListTagRegistry->ReleaseTag(forwardTag3); EXPECT_EQ(m_drawListTagRegistry->GetAllocatedTagCount(), 0); } TEST_F(DrawPacketTest, TestDrawListTagRegistryDeAllocateAssert) { AZ_TEST_START_ASSERTTEST; EXPECT_EQ(m_drawListTagRegistry->GetAllocatedTagCount(), 0); const Name tagName{"Test"}; RHI::DrawListTag tag = m_drawListTagRegistry->AcquireTag(tagName); m_drawListTagRegistry->AcquireTag(tagName); m_drawListTagRegistry->AcquireTag(tagName); m_drawListTagRegistry->ReleaseTag(tag); m_drawListTagRegistry->ReleaseTag(tag); m_drawListTagRegistry->ReleaseTag(tag); // One additional forfeit should assert. m_drawListTagRegistry->ReleaseTag(tag); AZ_TEST_STOP_ASSERTTEST(1); } TEST_F(DrawPacketTest, TestDrawListTagRegistryRandomAllocations) { AZ::SimpleLcgRandom random(s_randomSeed); AZStd::vector acquiredTags; const uint32_t IterationCount = 1000; for (uint32_t iter = 0; iter < IterationCount; ++iter) { Name tagNameUnique = Name(AZStd::string::format("Tag_%d", iter)); // Acquire if (random.GetRandom() % 2) { RHI::DrawListTag tag = m_drawListTagRegistry->AcquireTag(tagNameUnique); if (tag.IsNull()) { EXPECT_EQ(m_drawListTagRegistry->GetAllocatedTagCount(), RHI::Limits::Pipeline::DrawListTagCountMax); } else { EXPECT_LT(m_drawListTagRegistry->GetAllocatedTagCount(), RHI::Limits::Pipeline::DrawListTagCountMax); acquiredTags.emplace_back(tag); } } // Forfeit else if (!acquiredTags.empty()) { size_t tagIndex = random.GetRandom() % static_cast(acquiredTags.size()); RHI::DrawListTag tag = acquiredTags[tagIndex]; size_t allocationCountBefore = m_drawListTagRegistry->GetAllocatedTagCount(); m_drawListTagRegistry->ReleaseTag(tag); size_t allocationCountAfter = m_drawListTagRegistry->GetAllocatedTagCount(); EXPECT_EQ(allocationCountBefore - allocationCountAfter, 1); acquiredTags.erase(acquiredTags.begin() + tagIndex); } EXPECT_EQ(acquiredTags.size(), m_drawListTagRegistry->GetAllocatedTagCount()); } // Erase all references, make sure the registry is empty again. for (RHI::DrawListTag tag : acquiredTags) { m_drawListTagRegistry->ReleaseTag(tag); } acquiredTags.clear(); EXPECT_EQ(m_drawListTagRegistry->GetAllocatedTagCount(), 0); } TEST_F(DrawPacketTest, DrawPacketEmpty) { RHI::DrawPacketBuilder builder; builder.Begin(nullptr); const RHI::DrawPacket* drawPacket = builder.End(); EXPECT_EQ(drawPacket, nullptr); } TEST_F(DrawPacketTest, DrawPacketNullItem) { RHI::DrawPacketBuilder builder; builder.Begin(nullptr); RHI::DrawPacketBuilder::DrawRequest drawRequest; builder.AddDrawItem(drawRequest); const RHI::DrawPacket* drawPacket = builder.End(); EXPECT_EQ(drawPacket, nullptr); } TEST_F(DrawPacketTest, DrawPacketBuild) { AZ::SimpleLcgRandom random(s_randomSeed); DrawPacketData drawPacketData(random); RHI::DrawPacketBuilder builder; const RHI::DrawPacket* drawPacket = drawPacketData.Build(builder); delete drawPacket; } TEST_F(DrawPacketTest, DrawPacketBuildClearBuildNull) { AZ::SimpleLcgRandom random(s_randomSeed); DrawPacketData drawPacketData(random); RHI::DrawPacketBuilder builder; const RHI::DrawPacket* drawPacket = drawPacketData.Build(builder); delete drawPacket; // Try to build a 'null' packet. This should result in a null pointer. builder.Begin(nullptr); drawPacket = builder.End(); EXPECT_EQ(drawPacket, nullptr); } TEST_F(DrawPacketTest, DrawListContextFilter) { AZ::SimpleLcgRandom random(s_randomSeed); DrawPacketData drawPacketData(random); RHI::DrawPacketBuilder builder; const RHI::DrawPacket* drawPacket = drawPacketData.Build(builder); RHI::DrawListContext drawListContext; drawListContext.Init(RHI::DrawListMask{}.set()); drawListContext.AddDrawPacket(drawPacket); for (size_t i = 0; i < drawPacket->GetDrawItemCount(); ++i) { RHI::DrawListTag tag = drawPacket->GetDrawListTag(i); RHI::DrawListView drawList = drawListContext.GetList(tag); EXPECT_TRUE(drawList.empty()); } drawListContext.FinalizeLists(); RHI::DrawListsByTag listsByTag; for (size_t i = 0; i < drawPacket->GetDrawItemCount(); ++i) { RHI::DrawListTag tag = drawPacket->GetDrawListTag(i); listsByTag[tag.GetIndex()].push_back(drawPacket->GetDrawItem(i)); } size_t tagIndex = 0; for (auto& drawList : listsByTag) { SortDrawList(drawList, RHI::DrawListSortType::KeyThenDepth); RHI::DrawListTag tag(tagIndex); RHI::DrawListView drawListView = drawListContext.GetList(tag); EXPECT_EQ(drawListView.size(), drawList.size()); for (size_t i = 0; i < drawList.size(); ++i) { EXPECT_EQ(drawList[i], drawListView[i]); } tagIndex++; } drawListContext.Shutdown(); delete drawPacket; } TEST_F(DrawPacketTest, DrawListContextNullFilter) { AZ::SimpleLcgRandom random(s_randomSeed); DrawPacketData drawPacketData(random); RHI::DrawPacketBuilder builder; const RHI::DrawPacket* drawPacket = drawPacketData.Build(builder); RHI::DrawListContext drawListContext; drawListContext.Init(RHI::DrawListMask{}); // Mask set to not contain any draw lists. drawListContext.AddDrawPacket(drawPacket); drawListContext.FinalizeLists(); for (size_t i = 0; i < drawPacket->GetDrawItemCount(); ++i) { RHI::DrawListTag tag = drawPacket->GetDrawListTag(i); RHI::DrawListView drawList = drawListContext.GetList(tag); EXPECT_TRUE(drawList.empty()); } drawListContext.Shutdown(); delete drawPacket; } } AZ_UNIT_TEST_HOOK(DEFAULT_UNIT_TEST_ENV);