You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/Feature/Common/Assets/Shaders/PostProcessing/FastDepthAwareBlurCommon.azsli

263 lines
11 KiB
Plaintext

/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#pragma once
#include <Atom/Features/SrgSemantics.azsli>
#include <Atom/RPI/Math.azsli>
// --- Algorithm Overview ---
//
// This is a highly optimized blur that uses LDS (groupshared memory) and gather operations to reduce texture
// fetches and blur calculations. By using gather ops to read 4 values simultaneously, each thread can output
// 3 pixels instead of 1 (3x not 4x because although we are gathering 4 values per thread, some values need
// to be LDS padding, where the padding about is equal to the blur radius).
//
// A normal guassian blur with a radius of 6 pixels would calculate the output pixel as such
// (X represents target output pixels and R represents reads from LDS within the 6 pixel blur radius)
//
// Blur loop direction: -->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->
// LDS access thread 1: R R R R R R X R R R R R R
// LDS access thread 2: R R R R R R X R R R R R R
// LDS access thread 3: R R R R R R X R R R R R R
//
// As you can see, each row is 13 reads from LDS, for a total of 39 LDS reads to output 3 pixels
//
// Now here is how things look with this algorithm:
//
// Blur direction loop 1: -->-->-->-->-->-->-->-->-->-->-->
// LDS access thread 1: R R R R R R X X X
//
// Blur direction loop 2: <--<--<--<--<--<--<--<--<--<--<--
// LDS access thread 1: X X X R R R R R R
//
// Instead of weighting each sample directly in relation to the output pixel, we accumulate values from on pixel
// to the next, using both a constant falloff (values of around 2/3 work well) and a falloff based on the depth
// slope (or depth gradient). We continuously acumulate values in one direction until we have the values for our
// three target output pixels, and then we loop in the opposite direction to blur from the other side. As you can
// see, to output 3 pixels we now only have to perform 9 + 9 = 18 LDS reads, instead of 39 in the example above.
//
// The depth falloff works by measuring the depth difference/slope/gradient between pixels. This slope is analogous
// to the surface normal. A constant slope value across pixels reflects a flat surface, and thus we want to blur more.
// A changing slope value across pixels reflects geometry changes/edges, and in such cases we want to blur less.
//
// Heres an example of how that works. The line represents the depth, the numbers the along the line the depth slope:
//
// Depth
//
// 9 0 0 0 0 (E)
// 8 (A) ---------------- (B) /\
// 7 / \ +1 / \
// 6 / -1 \ / \ -1
// 5 / +1 \ / \
// 4 / \ / \
// 3 / \ +1 / \
// 2 / -1 \ / \ -1
// 1 / +1 --------------
// (C) 0 0 0 (D)
//
//
// In this example, our geometry is six flat surfaces with five edges labeled A, B, C, D and E. The slope values are
// written along the geometry line. To detect edges, we simply subtract neighboring slope values. If the difference
// zero, then the geometry is a flat surface. If the difference is non-zero, then it's an edge. The greater the absolute
// value of the difference is, the sharper the edge. For the first flat surface, the slope is +1. Any two neighboring
// slope values will both be +1, and so the difference is 1 - 1 = 0, which confirms the flat surface. At point (A),
// the neighboring slope values are +1 and 0. The difference is 1 - 0 = 1, and so this means we have an edge. At point
// (E), the neighboring slope values are +1 and -1. The difference is 1 - (-1) = 2, which indicates a sharper edge than (A)
//
// By factoring in the slope difference into the falloff as we accumulate values across pixels, we create a depth aware blur
// that smooths values across flat surfaces but is very good at preserving edges and geometric detail.
//
// --- SRG ---
ShaderResourceGroup PassSrg : SRG_PerPass
{
Texture2D<float> m_linearDepth;
Texture2D<float> m_blurSource;
RWTexture2D<float> m_blurTarget;
// Must match the struct in FastDepthAwareBlurPasses.h
struct FastDepthAwareBlurPassConstants
{
// The texture dimensions of blur output
uint2 m_outputSize;
// The size of a pixel relative to screenspace UV
// Calculated by taking the inverse of the texture dimensions
float2 m_pixelSize;
// The size of half a pixel relative to screenspace UV
float2 m_halfPixelSize;
// How much a value is reduced from pixel to pixel on a perfectly flat surface
float m_constFalloff;
// Threshold used to reduce computed depth difference during blur and thus the depth falloff
// Can be thought of as a bias that blurs curved surfaces more like flat surfaces
// but generally not needed and can be set to 0.0f
float m_depthFalloffThreshold;
// How much the difference in depth slopes between pixels affects the blur falloff.
// The higher this value, the sharper edges will appear
float m_depthFalloffStrength;
uint3 PADDING;
};
FastDepthAwareBlurPassConstants m_constants;
Sampler PointSampler
{
MinFilter = Point;
MagFilter = Point;
MipFilter = Point;
AddressU = Clamp;
AddressV = Clamp;
AddressW = Clamp;
};
}
float2 GetOutputSize() { return PassSrg::m_constants.m_outputSize; }
float2 GetPixelSize() { return PassSrg::m_constants.m_pixelSize; }
float2 GetHalfPixelSize() { return PassSrg::m_constants.m_halfPixelSize; }
float GetConstFalloff() { return PassSrg::m_constants.m_constFalloff; }
float GetDepthFalloffThreshold() { return PassSrg::m_constants.m_depthFalloffThreshold; }
float GetDepthFalloffStrength() { return PassSrg::m_constants.m_depthFalloffStrength; }
// --- LDS FUNCTIONS ---
float ReadDepthFromLDS(int2 ldsPosition)
{
return depthLDS[ GetLdsIndex(ldsPosition) ];
}
float ReadValueFromLDS(int2 ldsPosition)
{
return valueLDS[ GetLdsIndex(ldsPosition) ];
}
void WriteDepthGatherToLDS(int2 ldsPosition, float4 depthGather)
{
// Write the gathered depth values to LDS
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.w;
++ldsPosition.x;
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.z;
++ldsPosition.y;
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.y;
--ldsPosition.x;
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.x;
}
void WriteValueGatherToLDS(int2 ldsPosition, float4 valueGather)
{
// Write the gathered depth values to LDS
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.w;
++ldsPosition.x;
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.z;
++ldsPosition.y;
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.y;
--ldsPosition.x;
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.x;
}
void GatherDepthAndValueToLDS(int2 ldsPosition, float2 ldsOffsetUV)
{
float2 gatherUV = mad(float2(ldsPosition), GetPixelSize(), ldsOffsetUV);
// Gather and write depth to LDS
float4 depthGather = PassSrg::m_linearDepth.Gather(PassSrg::PointSampler, gatherUV);
WriteDepthGatherToLDS(ldsPosition, depthGather);
// Gather and write source values to LDS
float4 valueGather = PassSrg::m_blurSource.Gather(PassSrg::PointSampler, gatherUV);
WriteValueGatherToLDS(ldsPosition, valueGather);
// Sync after LDS
GroupMemoryBarrierWithGroupSync();
}
// --- BLUR FUNCTIONS ---
// Calculates the falloff (0-1) for the current sample given the previous depth slope and current depth slope
float CalculateDepthFalloff(float previousDepthSlope, float currentDepthSlope)
{
float difference = abs(previousDepthSlope - currentDepthSlope) - GetDepthFalloffThreshold();
// Falloff = saturate(1.0f - difference * depthFalloffStrength):
return saturate( mad(-difference, GetDepthFalloffStrength(), 1.0f) );
}
// The kernel for the blur loop. Computes the current depth slope and compares it to the previous to calculate falloff
void BlurStep(inout int2 ldsPosition, int2 direction,
inout float previousValue, inout float currentValue,
inout float previousDepth, inout float currentDepth,
inout float previousDepthSlope, inout float currentDepthSlope)
{
// Current becomes previous
previousValue = currentValue;
previousDepth = currentDepth;
previousDepthSlope = currentDepthSlope;
// Get current
ldsPosition += direction;
currentDepth = ReadDepthFromLDS(ldsPosition);
currentValue = ReadValueFromLDS(ldsPosition);
currentDepthSlope = currentDepth - previousDepth;
// Calculate and apply falloff
float falloff = CalculateDepthFalloff(previousDepthSlope, currentDepthSlope) * GetConstFalloff();
currentValue = lerp(currentValue, previousValue, falloff);
}
// Executes one blur loop in param 'direction' starting at param 'ldsPosition'
void BlurDirection(int2 ldsPosition, int2 direction, inout float outputs[OUTPUT_PIXEL_PER_THREAD])
{
float previousDepth = ReadDepthFromLDS(ldsPosition);
float previousValue = ReadValueFromLDS(ldsPosition);
ldsPosition += direction;
float currentDepth = ReadDepthFromLDS(ldsPosition);
float currentValue = ReadValueFromLDS(ldsPosition);
// Depth slope represents how much the depth changes from one pixel to the next (kind of like computing a normal from depth)
// By comparing two consecutive depth slopes we get pretty accurate edge detection
float previousDepthSlope = currentDepth - previousDepth;
float currentDepthSlope = previousDepthSlope;
// We don't have a previous depth slope for the first sample so just apply the constant falloff
currentValue = lerp(previousValue, currentValue, GetConstFalloff());
// Blur the pixels before the ones we write to
[unroll]
for(int i = 0; i < LOOP_COUNT; ++i)
{
BlurStep(ldsPosition, direction,
previousValue, currentValue,
previousDepth, currentDepth,
previousDepthSlope, currentDepthSlope);
}
// Blur the pixels we write to
[unroll]
for(int i = 0; i < OUTPUT_PIXEL_PER_THREAD; ++i)
{
BlurStep(ldsPosition, direction,
previousValue, currentValue,
previousDepth, currentDepth,
previousDepthSlope, currentDepthSlope);
// Accumulate the blur value into the output
outputs[i] += currentValue * 0.5f;
}
}