/* * Copyright (c) Contributors to the Open 3D Engine Project. * For complete copyright and license terms please see the LICENSE at the root of this distribution. * * SPDX-License-Identifier: Apache-2.0 OR MIT * */ #include // --- Algorithm Overview --- // // This shader will upsample an input image using two input depth textues. // For simplicity, we call the source image 'sourceHalfRes', the low resolution depth 'depthHalfRes', // and the higher resolution depth 'depthFullRes' (which has the same resolution as the output image). // In order to reduce texture operations, each thread writes 2x2 pixels in the target output image. // This allows each thread to re-use results from texture gather operations between output pixels. // // To illustrate, consider the following texture (each number denotes a pixel) // // 00 10 02 03 04 05 06 07 08 09 // // 10 11 12 13 14 15 16 17 18 19 // // 20 21 22 23 24 25 26 27 28 29 // // 30 31 32 33 34 35 36 37 38 39 // // The downsampled version of this texture would have pixels at the following H* locations: // // 00 10 02 03 04 05 06 07 08 09 // H0 H1 H2 H3 H4 // 10 11 12 13 14 15 16 17 18 19 // // 20 21 22 23 24 25 26 27 28 29 // H5 H6 H7 H8 H9 // 30 31 32 33 34 35 36 37 38 39 // // To calculate the upsampled output pixel (11), we need four half-res depth values (H0, H1, H5, H6), // four half-res source values (H0, H1, H5, H6) and a full-res depth value (11). // Note that these same half-res depth and source values are also used to calculate output pixels (12, 21, 22) // Also note that pixels (H0, H1, H5, H6) can be fetched with a single gather, as can (11, 12, 21, 22) // // Thus, we can use a single thread to calculated and output upsampled pixels (11, 12, 21, 22) // For this, the thread would only need to perform three gathers (assuming source is a single chanel texture) // Gather 1: half-res depth (H0, H1, H5, H6) // Gather 2: half-res source (H0, H1, H5, H6) // Gather 3: full-res source (11, 12, 21, 22) // // Thus, we dispatch threads at the following T* locations // // T-00 T-01 T-02 T-03 T-04 T-05 // 00 10 02 03 04 05 06 07 08 09 // H0 H1 H2 H3 H4 // 10 11 12 13 14 15 16 17 18 19 // T-10 T-11 T-12 T-13 T-14 T-15 // 20 21 22 23 24 25 26 27 28 29 // H5 H6 H7 H8 H9 // 30 31 32 33 34 35 36 37 38 39 // T-20 T-21 T-22 T-23 T-24 T-25 // // Continuing our example, here the thread T-11 would calculate the full res output pixels (11, 12, 21, 22) // Two things to note about the thread dispatch: // // 1) the width and height of the thread group are equal to the width and height of the half-res textures + 1. // The +1 is to have enough threads to output to row 3* and column *9. // Note that if the full-res texture has uneven dimensions this +1 is not necessary: // // T-00 T-01 T-02 T-03 T-04 // 00 10 02 03 04 05 06 07 08 // H0 H1 H2 H3 H4 // 10 11 12 13 14 15 16 17 18 // T-10 T-11 T-12 T-13 T-14 // 20 21 22 23 24 25 26 27 28 // H5 H6 H7 H8 H9 // // 2) While the thread dispatch has similar dimensions to the half-res textures, the positions are shifted by (-1, -1), // i.e. they are shifted up and to the left by the width of a full-res pixel. This is so the threads can properly use // texture gather instructions on both the downsampled depth/source and the full-res depth (see above example for T-11) // #define THREADS 16 ShaderResourceGroup PassSrg : SRG_PerPass { Texture2D m_depthFullRes; Texture2D m_depthHalfRes; Texture2D m_sourceHalfRes; RWTexture2D m_outputFullRes; // Must match the struct in DepthDownsamplePasses.cpp struct UpsampleConstants { // The size of a pixel in the input image relative to screenspace UV // Calculated by taking the inverse of the texture dimensions float2 m_inputPixelSize; // The size of a pixel in the output image relative to screenspace UV // Calculated by taking the inverse of the texture dimensions float2 m_outputPixelSize; }; UpsampleConstants m_constants; Sampler PointSampler { MinFilter = Point; MagFilter = Point; MipFilter = Point; AddressU = Clamp; AddressV = Clamp; AddressW = Clamp; }; } float GetDepthFactor(float depth1, float depth2) { const float epsilon = 0.00001f; float distance = abs(depth1 - depth2) + epsilon; float distanceSq = distance * distance; return 1.0f / distanceSq; } [numthreads(THREADS, THREADS, 1)] void MainCS(uint3 dispatch_id: SV_DispatchThreadID) { float2 position = dispatch_id.xy; // Gather half res depth and source values float2 halfResGatherUV = position * PassSrg::m_constants.m_inputPixelSize; float4 halfDepths = PassSrg::m_depthHalfRes.Gather(PassSrg::PointSampler, halfResGatherUV); float4 sourceValues = PassSrg::m_sourceHalfRes.Gather(PassSrg::PointSampler, halfResGatherUV); // Gather full res depth float2 fullResGatherUV = position * 2.0f * PassSrg::m_constants.m_outputPixelSize; float4 fullDepths = PassSrg::m_depthFullRes.Gather(PassSrg::PointSampler, fullResGatherUV); // Gather operation retrieves values with the following layout: // // W Z // X Y float4 outputValues = (float4)0.0f; // Calculate output W { float weight = 0.0f; float totalWeight = 0.0f; // 0.75 and 0.25 here is how far this full-res pixel is from the half-res pixels we are sampling // Consider to half-res pixels and two full-res pixels in between. The full-res pixel on the right // is 3x closer to the half-res pixel on the right than the half res pixel on the left, thus the // weights become 3/4 and 1/4 or 0.75 and 0.25 weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.w, halfDepths.w); outputValues.w += sourceValues.w * weight; totalWeight += weight; weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.w, halfDepths.z); outputValues.w += sourceValues.z * weight; totalWeight += weight; weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.w, halfDepths.x); outputValues.w += sourceValues.x * weight; totalWeight += weight; weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.w, halfDepths.y); outputValues.w += sourceValues.y * weight; totalWeight += weight; outputValues.w /= totalWeight; } // Calculate output Z { float weight = 0.0f; float totalWeight = 0.0f; weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.z, halfDepths.z); outputValues.z += sourceValues.z * weight; totalWeight += weight; weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.z, halfDepths.w); outputValues.z += sourceValues.w * weight; totalWeight += weight; weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.z, halfDepths.y); outputValues.z += sourceValues.y * weight; totalWeight += weight; weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.z, halfDepths.x); outputValues.z += sourceValues.x * weight; totalWeight += weight; outputValues.z /= totalWeight; } // Calculate output Y { float weight = 0.0f; float totalWeight = 0.0f; weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.y, halfDepths.y); outputValues.y += sourceValues.y * weight; totalWeight += weight; weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.y, halfDepths.x); outputValues.y += sourceValues.x * weight; totalWeight += weight; weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.y, halfDepths.z); outputValues.y += sourceValues.z * weight; totalWeight += weight; weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.y, halfDepths.w); outputValues.y += sourceValues.w * weight; totalWeight += weight; outputValues.y /= totalWeight; } // Calculate output X { float weight = 0.0f; float totalWeight = 0.0f; weight = (0.75f * 0.75f) * GetDepthFactor(fullDepths.x, halfDepths.x); outputValues.x += sourceValues.x * weight; totalWeight += weight; weight = (0.25f * 0.75f) * GetDepthFactor(fullDepths.x, halfDepths.y); outputValues.x += sourceValues.y * weight; totalWeight += weight; weight = (0.75f * 0.25f) * GetDepthFactor(fullDepths.x, halfDepths.w); outputValues.x += sourceValues.w * weight; totalWeight += weight; weight = (0.25f * 0.25f) * GetDepthFactor(fullDepths.x, halfDepths.z); outputValues.x += sourceValues.z * weight; totalWeight += weight; outputValues.x /= totalWeight; } // To understand the -1, read the last paragraph of the Algorithm Overview section at the start uint2 outputPixel = mad(dispatch_id.xy, 2, -1); PassSrg::m_outputFullRes[outputPixel] = outputValues.w; ++outputPixel.x; PassSrg::m_outputFullRes[outputPixel] = outputValues.z; ++outputPixel.y; PassSrg::m_outputFullRes[outputPixel] = outputValues.y; --outputPixel.x; PassSrg::m_outputFullRes[outputPixel] = outputValues.x; }