You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/Feature/Common/Assets/Shaders/PostProcessing/FastDepthAwareBlurHor.azsl

311 lines
11 KiB
Plaintext

/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#include <Atom/Features/SrgSemantics.azsli>
#include <Atom/RPI/Math.azsli>
// --- Algorithm Overview ---
// For an overview of this algorithm, please see FastDepthAwareBlurCommon.azsli
// Each thread will output 3 pixels (this is a performance optimization,
// we're essentially calculating the blur for 3 pixels at the same time)
#define OUTPUT_PIXEL_PER_THREAD 3
// How many pixels wide the blur will be
#define BLUR_RADIUS 8
// Number of threads in a group
// Testing showed 64 x 2 groups had best performance on nvidia hardware (not yet tested on AMD)
#define THREADS_X 64 // Must be at least 16 or shader will produce errors
#define THREADS_Y 2 // Must be at least 2 or shader will produce errors
// How many output pixels are computed per group
#define OUTPUT_SIZE_X (THREADS_X * OUTPUT_PIXEL_PER_THREAD)
#define OUTPUT_SIZE_Y THREADS_Y
// Number of pixels to blur before the 3 output pixels (-2 because we fetch two samples before the loop)
#define LOOP_COUNT (BLUR_RADIUS - 2)
// LDS offset for the start of the first loop going in the positive direction
#define LOOP_1_START 0
// LDS offset for the start of the second loop going in the negative direction
#define LOOP_2_START (BLUR_RADIUS * 2 + OUTPUT_PIXEL_PER_THREAD - 1)
#define LDS_WIDTH (THREADS_X * 4)
#define LDS_HEIGHT OUTPUT_SIZE_Y
// LDS_SIZE = LDS_WIDTH * LDS_HEIGHT = 64 * 8 = 512
#define LDS_SIZE 512
// Before any blur loops are executed, we accumulate all depth and value samples in LDS
groupshared float depthLDS[LDS_SIZE];
groupshared float valueLDS[LDS_SIZE];
// Calculate flat LDS index from a 2d position
int GetLdsIndex(int2 ldsPosition)
{
return mad(ldsPosition.y, LDS_WIDTH, ldsPosition.x);
}
// --- Common file start ---
// #include <Atom/Features/PostProcessing/FastDepthAwareBlurCommon.azsli>
// This include fails with the asset processor when generating the .shader for this file
// Everything below this is copy pasted from FastDepthAwareBlurCommon.azsli up until the
// "Common file end" marker below
// --- SRG ---
ShaderResourceGroup PassSrg : SRG_PerPass
{
Texture2D<float> m_linearDepth;
Texture2D<float> m_blurSource;
RWTexture2D<float> m_blurTarget;
// Must match the struct in FastDepthAwareBlurPasses.cpp
struct FastDepthAwareBlurPassConstants
{
// The texture dimensions of blur output
uint2 m_outputSize;
// The size of a pixel relative to screenspace UV
// Calculated by taking the inverse of the texture dimensions
float2 m_pixelSize;
// The size of half a pixel relative to screenspace UV
float2 m_halfPixelSize;
// How much a value is reduced from pixel to pixel on a perfectly flat surface
float m_constFalloff;
// Threshold used to reduce computed depth difference during blur and thus the depth falloff
// Can be thought of as a bias that blurs curved surfaces more like flat surfaces
// but generally not needed and can be set to 0.0f
float m_depthFalloffThreshold;
// How much the difference in depth slopes between pixels affects the blur falloff.
// The higher this value, the sharper edges will appear
float m_depthFalloffStrength;
uint3 PADDING;
};
FastDepthAwareBlurPassConstants m_constants;
Sampler PointSampler
{
MinFilter = Point;
MagFilter = Point;
MipFilter = Point;
AddressU = Clamp;
AddressV = Clamp;
AddressW = Clamp;
};
}
float2 GetOutputSize() { return PassSrg::m_constants.m_outputSize; }
float2 GetPixelSize() { return PassSrg::m_constants.m_pixelSize; }
float2 GetHalfPixelSize() { return PassSrg::m_constants.m_halfPixelSize; }
float GetConstFalloff() { return PassSrg::m_constants.m_constFalloff; }
float GetDepthFalloffThreshold() { return PassSrg::m_constants.m_depthFalloffThreshold; }
float GetDepthFalloffStrength() { return PassSrg::m_constants.m_depthFalloffStrength; }
// --- LDS FUNCTIONS ---
float ReadDepthFromLDS(int2 ldsPosition)
{
return depthLDS[ GetLdsIndex(ldsPosition) ];
}
float ReadValueFromLDS(int2 ldsPosition)
{
return valueLDS[ GetLdsIndex(ldsPosition) ];
}
void WriteDepthGatherToLDS(int2 ldsPosition, float4 depthGather)
{
// Write the gathered depth values to LDS
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.w;
++ldsPosition.x;
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.z;
++ldsPosition.y;
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.y;
--ldsPosition.x;
depthLDS[ GetLdsIndex(ldsPosition) ] = depthGather.x;
}
void WriteValueGatherToLDS(int2 ldsPosition, float4 valueGather)
{
// Write the gathered depth values to LDS
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.w;
++ldsPosition.x;
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.z;
++ldsPosition.y;
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.y;
--ldsPosition.x;
valueLDS[ GetLdsIndex(ldsPosition) ] = valueGather.x;
}
void GatherDepthAndValueToLDS(int2 ldsPosition, float2 ldsOffsetUV)
{
float2 gatherUV = mad(float2(ldsPosition), GetPixelSize(), ldsOffsetUV);
// Gather and write depth to LDS
float4 depthGather = PassSrg::m_linearDepth.Gather(PassSrg::PointSampler, gatherUV);
WriteDepthGatherToLDS(ldsPosition, depthGather);
// Gather and write source values to LDS
float4 valueGather = PassSrg::m_blurSource.Gather(PassSrg::PointSampler, gatherUV);
WriteValueGatherToLDS(ldsPosition, valueGather);
// Sync after LDS
GroupMemoryBarrierWithGroupSync();
}
// --- BLUR FUNCTIONS ---
// Calculates the falloff (0-1) for the current sample given the previous depth slope and current depth slope
float CalculateDepthFalloff(float previousDepthSlope, float currentDepthSlope)
{
float difference = abs(previousDepthSlope - currentDepthSlope) - GetDepthFalloffThreshold();
// Falloff = saturate(1.0f - difference * depthFalloffStrength):
return saturate( mad(-difference, GetDepthFalloffStrength(), 1.0f) );
}
// The kernel for the blur loop. Computes the current depth slope and compares it to the previous to calculate falloff
void BlurStep(inout int2 ldsPosition, int2 direction,
inout float previousValue, inout float currentValue,
inout float previousDepth, inout float currentDepth,
inout float previousDepthSlope, inout float currentDepthSlope)
{
// Current becomes previous
previousValue = currentValue;
previousDepth = currentDepth;
previousDepthSlope = currentDepthSlope;
// Get current
ldsPosition += direction;
currentDepth = ReadDepthFromLDS(ldsPosition);
currentValue = ReadValueFromLDS(ldsPosition);
currentDepthSlope = currentDepth - previousDepth;
// Calculate and apply falloff
float falloff = CalculateDepthFalloff(previousDepthSlope, currentDepthSlope) * GetConstFalloff();
currentValue = lerp(currentValue, previousValue, falloff);
}
// Executes one blur loop in param 'direction' starting at param 'ldsPosition'
void BlurDirection(int2 ldsPosition, int2 direction, inout float outputs[OUTPUT_PIXEL_PER_THREAD])
{
float previousDepth = ReadDepthFromLDS(ldsPosition);
float previousValue = ReadValueFromLDS(ldsPosition);
ldsPosition += direction;
float currentDepth = ReadDepthFromLDS(ldsPosition);
float currentValue = ReadValueFromLDS(ldsPosition);
// Depth slope represents how much the depth changes from one pixel to the next (kind of like computing a normal from depth)
// By comparing two consecutive depth slopes we get pretty accurate edge detection
float previousDepthSlope = currentDepth - previousDepth;
float currentDepthSlope = previousDepthSlope;
// We don't have a previous depth slope for the first sample so just apply the constant falloff
currentValue = lerp(previousValue, currentValue, GetConstFalloff());
// Blur the pixels before the ones we write to
[unroll]
for(int i = 0; i < LOOP_COUNT; ++i)
{
BlurStep(ldsPosition, direction,
previousValue, currentValue,
previousDepth, currentDepth,
previousDepthSlope, currentDepthSlope);
}
// Blur the pixels we write to
[unroll]
for(int i = 0; i < OUTPUT_PIXEL_PER_THREAD; ++i)
{
BlurStep(ldsPosition, direction,
previousValue, currentValue,
previousDepth, currentDepth,
previousDepthSlope, currentDepthSlope);
// Accumulate the blur value into the output
outputs[i] += currentValue * 0.5f;
}
}
// --- Common file end ---
[numthreads(THREADS_X, THREADS_Y, 1)]
void MainCS(uint3 thread_id : SV_GroupThreadID, uint3 group_id : SV_GroupID, uint3 dispatch_id: SV_DispatchThreadID, uint linear_id : SV_GroupIndex)
{
// LDS covers a screen area that is shifted back by BLUR_RADIUS
float2 ldsOffsetPixel = mad(float2(group_id.xy), float2(OUTPUT_SIZE_X, OUTPUT_SIZE_Y), float2(-BLUR_RADIUS, 0));
// The screen space UV of the above calculated pixel offset
float2 ldsOffsetUV = mad(ldsOffsetPixel, GetPixelSize(), GetHalfPixelSize());
// --- Gather depth and blur source into LDS ---
{
int2 ldsPosition;
ldsPosition.x = mad(THREADS_X, thread_id.y & 1, thread_id.x);
ldsPosition.x = ldsPosition.x << 1;
ldsPosition.y = thread_id.y & 0xFFFFFFFE;
GatherDepthAndValueToLDS(ldsPosition, ldsOffsetUV);
}
// We'll be gathering outputs for three pixels in the blur loops
float outputs[OUTPUT_PIXEL_PER_THREAD] = {0.0f, 0.0f, 0.0f};
// Tracks which position to sample from in LDS
int2 ldsPosition;
// --- Blur left to right ---
{
ldsPosition.x = mad(thread_id.x, OUTPUT_PIXEL_PER_THREAD, LOOP_1_START);
ldsPosition.y = thread_id.y;
const int2 direction = int2(1, 0);
BlurDirection(ldsPosition, direction, outputs);
}
// Inverse order for the three outputs because next loop inverses in the opposite direction
swap(outputs[0], outputs[2]);
// --- Blur right to left ---
{
ldsPosition.x = mad(thread_id.x, OUTPUT_PIXEL_PER_THREAD, LOOP_2_START);
ldsPosition.y = thread_id.y;
const int2 direction = int2(-1, 0);
BlurDirection(ldsPosition, direction, outputs);
}
// Re-inverse the order of the outputs so they're back in the correct order
swap(outputs[0], outputs[2]);
// --- Output three pixels ---
uint2 outputPixel = dispatch_id.xy;
[unroll]
for(uint i = 0; i < OUTPUT_PIXEL_PER_THREAD; ++i)
{
outputPixel.x = mad(dispatch_id.x, OUTPUT_PIXEL_PER_THREAD, i);
PassSrg::m_blurTarget[outputPixel.xy] = outputs[i];
}
}