Skip to content

ffx_variable_shading.h compute shader #4

@pemgithub

Description

@pemgithub

Can you explain how the compute shader and related threadgroups work?

From C++ code:

FFX_VariableShading_GetDispatchInfo(data, AdditionalShadingRates(), w, h);
pCmdLst->Dispatch(w, h, 1);
// coarse tiles are potentially 2x2, so each thread computes 2x2 pixels
if (cb->tileSize == 8)
{
    //each threadgroup computes 4 VRS tiles
    numThreadGroupsX = FFX_VariableShading_DivideRoundingUp(vrsImageWidth, 2);
    numThreadGroupsY = FFX_VariableShading_DivideRoundingUp(vrsImageHeight, 2);
}

From VRSImageGenCS.hlsl (static const uint FFX_VariableShading_ThreadCount1D = 8;)

[numthreads(FFX_VariableShading_ThreadCount1D, FFX_VariableShading_ThreadCount1D, 1)]
void mainCS(
    uint3 Gid  : SV_GroupID,
    uint3 Gtid : SV_GroupThreadID,
    uint  Gidx : SV_GroupIndex)
{
    FFX_VariableShading_GenerateVrsImage(Gid, Gtid, Gidx);
}

From ffx_variable_shading.h:

// sample source texture (using motion vectors)
while (index < FFX_VariableShading_SampleCount)
{
    int2 index2D = 2 * int2(index % FFX_VariableShading_SampleCount1D, index / FFX_VariableShading_SampleCount1D);
    float4 lum = 0;
    lum.x = FFX_VariableShading_GetLuminance(baseOffset + index2D + int2(0, 0));
    lum.y = FFX_VariableShading_GetLuminance(baseOffset + index2D + int2(1, 0));
    lum.z = FFX_VariableShading_GetLuminance(baseOffset + index2D + int2(0, 1));
    lum.w = FFX_VariableShading_GetLuminance(baseOffset + index2D + int2(1, 1));
    ...
    index += FFX_VariableShading_ThreadCount;

For example, suppose we have an image that's 1080x3200 with 8x8 tiles, so the VRS image is 135 x 400. Suppose our example supports up to 2x2 coarse pixel size, so numThreadGroupsX is 68, numThreadGroupsY is 200. Why is it 68x200? 4 VRS tiles (2x2 tiles) per thread. Each threadgroup is 8x8 threads and computes 4 VRS tiles (2x2 VRS tiles is 16x16 pixels) (each thread gets luminance for a 2x2 pixels x 8x8 threads = 16x16 pixels). I think that sort of makes sense...

However, what is while (index < FFX_VariableShading_SampleCount) doing? while (index < 100) { do stuff then index += 64 }. What is this while loop doing? Is it okay to comment it out?

static const uint FFX_VariableShading_ThreadCount1D = 8;
static const uint FFX_VariableShading_SampleCount1D = FFX_VariableShading_ThreadCount1D + 2; // 10
static const uint FFX_VariableShading_SampleCount = FFX_VariableShading_SampleCount1D * FFX_VariableShading_SampleCount1D; // 100
static const uint FFX_VariableShading_ThreadCount = FFX_VariableShading_ThreadCount1D * FFX_VariableShading_ThreadCount1D; // 64

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions