You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
o3de/Gems/Atom/Feature/Common/Assets/Shaders/PostProcessing/Taa.azsl

272 lines
10 KiB
Plaintext

/*
* All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
* its licensors.
*
* For complete copyright and license terms please see the LICENSE at the root of this
* distribution (the "License"). All use of this software is governed by the License,
* or, if provided, by the license below or the license accompanying this file. Do not
* remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
*/
#include <scenesrg.srgi>
#include <Atom/Features/PostProcessing/PostProcessUtil.azsli>
#define TILE_DIM_X 16
#define TILE_DIM_Y 16
ShaderResourceGroup PassSrg : SRG_PerPass
{
Texture2D<float4> m_inputColor;
Texture2D<float4> m_inputDepth;
Texture2D<float2> m_motionVectors;
Texture2D<float4> m_lastFrameAccumulation;
RWTexture2D<float4> m_outputColor;
Sampler LinearSampler
{
MinFilter = Linear;
MagFilter = Linear;
MipFilter = Linear;
AddressU = Clamp;
AddressV = Clamp;
AddressW = Clamp;
};
// Current frame's default contribution to the history.
float m_currentFrameContribution;
// Increase this value for weaker clamping, decrease for stronger clamping, default 1.0.
float m_clampGamma;
// Default 0.5, used for flicker reduction. Any sample further than this many standard deviations outside the neighborhood
// will have its weight decreased. The further outside the max deviation, the more its weight is reduced.
float m_maxDeviationBeforeDampening;
struct Constants
{
uint2 m_inputColorSize;
float2 m_inputColorRcpSize;
// 3x3 filter weights
// 8 2 6
// 3 0 1
// 7 4 5
float4 m_weights1; // 0 1 2 3
float4 m_weights2; // 4 5 6 7
float4 m_weights3; // 8 x x x
};
Constants m_constantData;
}
static const int2 offsets[9] =
{
// Center
int2(0, 0),
// Cross
int2( 1, 0),
int2( 0,-1),
int2(-1, 0),
int2( 0, 1),
// Diagonals
int2( 1,-1),
int2( 1, 1),
int2(-1,-1),
int2(-1, 1),
};
float3 RgbToYCoCg(float3 rgb)
{
const float3x3 conversionMatrix =
{
0.25, 0.50, 0.25,
0.50, 0.00, -0.50,
-0.25, 0.50, -0.25
};
return mul(conversionMatrix, rgb);
}
float3 YCoCgToRgb(float3 yCoCg)
{
const float3x3 conversionMatrix =
{
1.0, 1.0, -1.0,
1.0, 0.0, 1.0,
1.0, -1.0, -1.0
};
return mul(conversionMatrix, yCoCg);
}
// Sample a texture with a 5 tap Catmull-Rom. Consider ripping this out and putting in a more general location.
// This function samples a 4x4 neighborhood around the uv. By taking advantage of bilinear filtering this can be
// done with only 9 taps on the edges between pixels. The cost is further reduced by dropping the 4 diagonal
// samples as their influence is negligible.
float4 SampleCatmullRom5Tap(Texture2D<float4> texture, SamplerState linearSampler, float2 uv, float2 textureSize, float2 rcpTextureSize, float sharpness)
{
// Think of sample locations in the 4x4 neighborhood as having a top left coordinate of 0,0 and
// a bottom right coordinate of 3,3.
// Find the position in texture space then round it to get the center of the 1,1 pixel (tc1)
float2 texelPos = uv * textureSize;
float2 tc1= floor(texelPos - 0.5) + 0.5;
// Offset from center position to texel
float2 f = texelPos - tc1;
// Compute Catmull-Rom weights based on the offset and sharpness
float c = sharpness;
float2 w0 = f * (-c + f * (2.0 * c - c * f));
float2 w1 = 1.0 + f * f * (c -3.0 + (2.0 - c) * f);
float2 w2 = f * (c + f * ((3.0 - 2.0 * c) - (2.0 - c) * f));
float2 w3 = f * f * (c * f - c);
float2 w12 = w1 + w2;
// Compute uv coordinates for sampling the texture
float2 tc0 = (tc1 - 1.0f) * rcpTextureSize;
float2 tc3 = (tc1 + 2.0f) * rcpTextureSize;
float2 tc12 = (tc1 + w2 / w12) * rcpTextureSize;
// Compute sample weights
float sw0 = w12.x * w0.y;
float sw1 = w0.x * w12.y;
float sw2 = w12.x * w12.y;
float sw3 = w3.x * w12.y;
float sw4 = w12.x * w3.y;
// total weight of samples to normalize result.
float totalWeight = sw0 + sw1 + sw2 + sw3 + sw4;
float4 result = 0.0f;
result += texture.SampleLevel(linearSampler, float2(tc12.x, tc0.y), 0.0) * sw0;
result += texture.SampleLevel(linearSampler, float2( tc0.x, tc12.y), 0.0) * sw1;
result += texture.SampleLevel(linearSampler, float2(tc12.x, tc12.y), 0.0) * sw2;
result += texture.SampleLevel(linearSampler, float2( tc3.x, tc12.y), 0.0) * sw3;
result += texture.SampleLevel(linearSampler, float2(tc12.x, tc3.y), 0.0) * sw4;
return result / totalWeight;
}
[numthreads(TILE_DIM_X, TILE_DIM_Y, 1)]
void MainCS(
uint3 dispatchThreadID : SV_DispatchThreadID,
uint3 groupID : SV_GroupID,
uint groupIndex : SV_GroupIndex)
{
uint2 pixelCoord = dispatchThreadID.xy;
const float filterWeights[9] =
{
PassSrg::m_constantData.m_weights1.x,
PassSrg::m_constantData.m_weights1.y,
PassSrg::m_constantData.m_weights1.z,
PassSrg::m_constantData.m_weights1.w,
PassSrg::m_constantData.m_weights2.x,
PassSrg::m_constantData.m_weights2.y,
PassSrg::m_constantData.m_weights2.z,
PassSrg::m_constantData.m_weights2.w,
PassSrg::m_constantData.m_weights3.x,
};
float3 sum = 0.0;
float3 sumOfSquares = 0.0;
float nearestDepth = 1.0;
uint2 nearestDepthPixelCoord;
float3 thisFrameColor = float3(0.0, 0.0, 0.0);
// Sample the neighborhood to filter the current pixel, gather statistics about
// its neighbors, and find the closest neighbor to choose a motion vector.
[unroll] for (int i = 0; i < 9; ++i)
{
uint2 neighborhoodPixelCoord = pixelCoord + offsets[i];
float3 neighborhoodColor = PassSrg::m_inputColor[neighborhoodPixelCoord].rgb;
// Convert to YCoCg space for better clipping.
neighborhoodColor = RgbToYCoCg(neighborhoodColor);
sum += neighborhoodColor;
sumOfSquares += neighborhoodColor * neighborhoodColor;
thisFrameColor += neighborhoodColor * filterWeights[i];
// Find the coordinate of the nearest depth
float neighborhoodDepth = PassSrg::m_inputDepth[neighborhoodPixelCoord].r;
if (neighborhoodDepth < nearestDepth)
{
nearestDepth = neighborhoodDepth;
nearestDepthPixelCoord = neighborhoodPixelCoord;
}
}
// Variance clipping, see http://developer.download.nvidia.com/gameworks/events/GDC2016/msalvi_temporal_supersampling.pdf
float3 mean = sum / 9.0;
float3 standardDeviation = max(0.0, sqrt(sumOfSquares / 9.0 - mean * mean));
standardDeviation *= PassSrg::m_clampGamma;
// Grab the motion vector from the closest pixel in the 3x3 neighborhood. This is done so that motion vectors correctly
// track edges. For instance, if a pixel lies on the edge of a moving object, where the color is a blend of the
// forground and background, it's possible for the pixel center to hit the (not moving) background. However, the correct
// history for this pixel will be the location this edge was the previous frame. By choosing the motion of the nearest
// pixel in the neighborhood that edge will be correctly tracked.
// Motion vectors store the direction of movement, so to look up where things were in the previous frame, it's negated.
float2 previousPositionOffset = -PassSrg::m_motionVectors[nearestDepthPixelCoord];
// Get the uv coordinate for the previous frame.
float2 rcpSize = PassSrg::m_constantData.m_inputColorRcpSize;
float2 uvCoord = (pixelCoord + 0.5f) * rcpSize;
float2 uvOld = uvCoord + previousPositionOffset;
float2 previousPositionOffsetInPixels = float2(PassSrg::m_constantData.m_inputColorSize) * previousPositionOffset;
// Sample the last frame using a 5-tap Catmull-Rom
float3 lastFrameColor = SampleCatmullRom5Tap(PassSrg::m_lastFrameAccumulation, PassSrg::LinearSampler, uvOld, PassSrg::m_constantData.m_inputColorSize, PassSrg::m_constantData.m_inputColorRcpSize, 0.5).rgb;
lastFrameColor = RgbToYCoCg(lastFrameColor);
// Last frame color relative to mean
float3 centerColorOffset = lastFrameColor - mean;
float3 colorOffsetStandardDeviationRatio = abs(standardDeviation / centerColorOffset);
// Clamp the color by the aabb of the standardDeviation. Can never be greater than 1, so will always be inside or on the bounds of the aabb.
float clampedColorLength = min(min(min(1, colorOffsetStandardDeviationRatio.x), colorOffsetStandardDeviationRatio.y), colorOffsetStandardDeviationRatio.z);
// Calculate the true clamped color by offsetting it back from the mean.
float3 lastFrameClampedColor = mean + centerColorOffset * clampedColorLength;
// Anti-flickering - Reduce current frame weight the more it deviates from the history based on the standard deviation of the neighborhood.
// Start reducing weight at differences greater than m_maxDeviationBeforeDampening standard deviations in luminance.
float standardDeviationWeight = standardDeviation.r * PassSrg::m_maxDeviationBeforeDampening;
float3 sdFromLastFrame = standardDeviationWeight / abs(lastFrameClampedColor.r - thisFrameColor.r);
float currentFrameWeight = PassSrg::m_currentFrameContribution;
currentFrameWeight *= saturate(sdFromLastFrame * sdFromLastFrame);
// Back to Rgb space
thisFrameColor = YCoCgToRgb(thisFrameColor);
lastFrameClampedColor = YCoCgToRgb(lastFrameClampedColor);
// Out of bounds protection.
if (any(uvOld > 1.0) || any(uvOld < 0.0))
{
currentFrameWeight = 1.0f;
}
// Blend should be in perceptual space, so tonemap first
float luminance = GetLuminance(thisFrameColor);
thisFrameColor = thisFrameColor / (1 + luminance);
lastFrameClampedColor = lastFrameClampedColor / (1 + luminance);
// Blend color with history
float3 color = lerp(lastFrameClampedColor, thisFrameColor, currentFrameWeight);
// Un-tonemap color
color = color * (1.0 + luminance);
// NaN protection (without this NaNs could get in the history buffer and quickly consume the frame)
color = max(0.0, color);
PassSrg::m_outputColor[pixelCoord].rgb = color;
}