You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/Feature/Common/Assets/Shaders/PostProcessing/DepthUpsample.azsl

257 lines
9.9 KiB
Plaintext

/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#include <Atom/Features/SrgSemantics.azsli>
// --- Algorithm Overview ---
//
// This shader will upsample an input image using two input depth textues.
// For simplicity, we call the source image 'sourceHalfRes', the low resolution depth 'depthHalfRes',
// and the higher resolution depth 'depthFullRes' (which has the same resolution as the output image).
// In order to reduce texture operations, each thread writes 2x2 pixels in the target output image.
// This allows each thread to re-use results from texture gather operations between output pixels.
//
// To illustrate, consider the following texture (each number denotes a pixel)
//
// 00 10 02 03 04 05 06 07 08 09
//
// 10 11 12 13 14 15 16 17 18 19
//
// 20 21 22 23 24 25 26 27 28 29
//
// 30 31 32 33 34 35 36 37 38 39
//
// The downsampled version of this texture would have pixels at the following H* locations:
//
// 00 10 02 03 04 05 06 07 08 09
// H0 H1 H2 H3 H4
// 10 11 12 13 14 15 16 17 18 19
//
// 20 21 22 23 24 25 26 27 28 29
// H5 H6 H7 H8 H9
// 30 31 32 33 34 35 36 37 38 39
//
// To calculate the upsampled output pixel (11), we need four half-res depth values (H0, H1, H5, H6),
// four half-res source values (H0, H1, H5, H6) and a full-res depth value (11).
// Note that these same half-res depth and source values are also used to calculate output pixels (12, 21, 22)
// Also note that pixels (H0, H1, H5, H6) can be fetched with a single gather, as can (11, 12, 21, 22)
//
// Thus, we can use a single thread to calculated and output upsampled pixels (11, 12, 21, 22)
// For this, the thread would only need to perform three gathers (assuming source is a single chanel texture)
// Gather 1: half-res depth (H0, H1, H5, H6)
// Gather 2: half-res source (H0, H1, H5, H6)
// Gather 3: full-res source (11, 12, 21, 22)
//
// Thus, we dispatch threads at the following T* locations
//
// T-00 T-01 T-02 T-03 T-04 T-05
// 00 10 02 03 04 05 06 07 08 09
// H0 H1 H2 H3 H4
// 10 11 12 13 14 15 16 17 18 19
// T-10 T-11 T-12 T-13 T-14 T-15
// 20 21 22 23 24 25 26 27 28 29
// H5 H6 H7 H8 H9
// 30 31 32 33 34 35 36 37 38 39
// T-20 T-21 T-22 T-23 T-24 T-25
//
// Continuing our example, here the thread T-11 would calculate the full res output pixels (11, 12, 21, 22)
// Two things to note about the thread dispatch:
//
// 1) the width and height of the thread group are equal to the width and height of the half-res textures + 1.
// The +1 is to have enough threads to output to row 3* and column *9.
// Note that if the full-res texture has uneven dimensions this +1 is not necessary:
//
// T-00 T-01 T-02 T-03 T-04
// 00 10 02 03 04 05 06 07 08
// H0 H1 H2 H3 H4
// 10 11 12 13 14 15 16 17 18
// T-10 T-11 T-12 T-13 T-14
// 20 21 22 23 24 25 26 27 28
// H5 H6 H7 H8 H9
//
// 2) While the thread dispatch has similar dimensions to the half-res textures, the positions are shifted by (-1, -1),
// i.e. they are shifted up and to the left by the width of a full-res pixel. This is so the threads can properly use
// texture gather instructions on both the downsampled depth/source and the full-res depth (see above example for T-11)
//
#define THREADS 16
ShaderResourceGroup PassSrg : SRG_PerPass
{
Texture2D<float> m_depthFullRes;
Texture2D<float> m_depthHalfRes;
Texture2D<float> m_sourceHalfRes;
RWTexture2D<float> m_outputFullRes;
// Must match the struct in DepthDownsamplePasses.cpp
struct UpsampleConstants
{
// The size of a pixel in the input image relative to screenspace UV
// Calculated by taking the inverse of the texture dimensions
float2 m_inputPixelSize;
// The size of a pixel in the output image relative to screenspace UV
// Calculated by taking the inverse of the texture dimensions
float2 m_outputPixelSize;
};
UpsampleConstants m_constants;
Sampler PointSampler
{
MinFilter = Point;
MagFilter = Point;
MipFilter = Point;
AddressU = Clamp;
AddressV = Clamp;
AddressW = Clamp;
};
}
float GetDepthFactor(float depth1, float depth2)
{
const float epsilon = 0.00001f;
float distance = abs(depth1 - depth2) + epsilon;
float distanceSq = distance * distance;
return 1.0f / distanceSq;
}
[numthreads(THREADS, THREADS, 1)]
void MainCS(uint3 dispatch_id: SV_DispatchThreadID)
{
float2 position = dispatch_id.xy;
// Gather half res depth and source values
float2 halfResGatherUV = position * PassSrg::m_constants.m_inputPixelSize;
float4 halfDepths = PassSrg::m_depthHalfRes.Gather(PassSrg::PointSampler, halfResGatherUV);
float4 sourceValues = PassSrg::m_sourceHalfRes.Gather(PassSrg::PointSampler, halfResGatherUV);
// Gather full res depth
float2 fullResGatherUV = position * 2.0f * PassSrg::m_constants.m_outputPixelSize;
float4 fullDepths = PassSrg::m_depthFullRes.Gather(PassSrg::PointSampler, fullResGatherUV);
// Gather operation retrieves values with the following layout:
//
// W Z
// X Y
float4 outputValues = (float4)0.0f;
// Calculate output W
{
float weight = 0.0f;
float totalWeight = 0.0f;
// 0.75 and 0.25 here is how far this full-res pixel is from the half-res pixels we are sampling
// Consider to half-res pixels and two full-res pixels in between. The full-res pixel on the right
// is 3x closer to the half-res pixel on the right than the half res pixel on the left, thus the
// weights become 3/4 and 1/4 or 0.75 and 0.25
weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.w, halfDepths.w);
outputValues.w += sourceValues.w * weight;
totalWeight += weight;
weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.w, halfDepths.z);
outputValues.w += sourceValues.z * weight;
totalWeight += weight;
weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.w, halfDepths.x);
outputValues.w += sourceValues.x * weight;
totalWeight += weight;
weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.w, halfDepths.y);
outputValues.w += sourceValues.y * weight;
totalWeight += weight;
outputValues.w /= totalWeight;
}
// Calculate output Z
{
float weight = 0.0f;
float totalWeight = 0.0f;
weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.z, halfDepths.z);
outputValues.z += sourceValues.z * weight;
totalWeight += weight;
weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.z, halfDepths.w);
outputValues.z += sourceValues.w * weight;
totalWeight += weight;
weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.z, halfDepths.y);
outputValues.z += sourceValues.y * weight;
totalWeight += weight;
weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.z, halfDepths.x);
outputValues.z += sourceValues.x * weight;
totalWeight += weight;
outputValues.z /= totalWeight;
}
// Calculate output Y
{
float weight = 0.0f;
float totalWeight = 0.0f;
weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.y, halfDepths.y);
outputValues.y += sourceValues.y * weight;
totalWeight += weight;
weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.y, halfDepths.x);
outputValues.y += sourceValues.x * weight;
totalWeight += weight;
weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.y, halfDepths.z);
outputValues.y += sourceValues.z * weight;
totalWeight += weight;
weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.y, halfDepths.w);
outputValues.y += sourceValues.w * weight;
totalWeight += weight;
outputValues.y /= totalWeight;
}
// Calculate output X
{
float weight = 0.0f;
float totalWeight = 0.0f;
weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.x, halfDepths.x);
outputValues.x += sourceValues.x * weight;
totalWeight += weight;
weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.x, halfDepths.y);
outputValues.x += sourceValues.y * weight;
totalWeight += weight;
weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.x, halfDepths.w);
outputValues.x += sourceValues.w * weight;
totalWeight += weight;
weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.x, halfDepths.z);
outputValues.x += sourceValues.z * weight;
totalWeight += weight;
outputValues.x /= totalWeight;
}
// To understand the -1, read the last paragraph of the Algorithm Overview section at the start
uint2 outputPixel = mad(dispatch_id.xy, 2, -1);
PassSrg::m_outputFullRes[outputPixel] = outputValues.w;
++outputPixel.x;
PassSrg::m_outputFullRes[outputPixel] = outputValues.z;
++outputPixel.y;
PassSrg::m_outputFullRes[outputPixel] = outputValues.y;
--outputPixel.x;
PassSrg::m_outputFullRes[outputPixel] = outputValues.x;
}