blob: 722dc75f1be9aac5d6c6b1ed04df97878a05a9e9 [file] [log] [blame] [edit]
struct Params {
filterDim : u32,
blockDim : u32,
};
@group(0) @binding(0) var samp : sampler;
@group(0) @binding(1) var<uniform> params : Params;
@group(1) @binding(1) var inputTex : texture_2d<f32>;
@group(1) @binding(2) var outputTex : texture_storage_2d<rgba8unorm, write>;
struct Flip {
value : u32,
};
@group(1) @binding(3) var<uniform> flip : Flip;
// This shader blurs the input texture in one direction, depending on whether
// |flip.value| is 0 or 1.
// It does so by running (256 / 4) threads per workgroup to load 256
// texels into 4 rows of shared memory. Each thread loads a
// 4 x 4 block of texels to take advantage of the texture sampling
// hardware.
// Then, each thread computes the blur result by averaging the adjacent texel values
// in shared memory.
// Because we're operating on a subset of the texture, we cannot compute all of the
// results since not all of the neighbors are available in shared memory.
// Specifically, with 256 x 256 tiles, we can only compute and write out
// square blocks of size 256 - (filterSize - 1). We compute the number of blocks
// needed in Javascript and dispatch that amount.
var<workgroup> tile : array<array<vec3<f32>, 256>, 4>;
@stage(compute) @workgroup_size(64, 1, 1)
fn main(
@builtin(workgroup_id) WorkGroupID : vec3<u32>,
@builtin(local_invocation_id) LocalInvocationID : vec3<u32>
) {
let filterOffset : u32 = (params.filterDim - 1u) / 2u;
let dims : vec2<i32> = textureDimensions(inputTex, 0);
let baseIndex = vec2<i32>(
WorkGroupID.xy * vec2<u32>(params.blockDim, 4u) +
LocalInvocationID.xy * vec2<u32>(4u, 1u)
) - vec2<i32>(i32(filterOffset), 0);
for (var r : u32 = 0u; r < 4u; r = r + 1u) {
for (var c : u32 = 0u; c < 4u; c = c + 1u) {
var loadIndex = baseIndex + vec2<i32>(i32(c), i32(r));
if (flip.value != 0u) {
loadIndex = loadIndex.yx;
}
tile[r][4u * LocalInvocationID.x + c] =
textureSampleLevel(inputTex, samp,
(vec2<f32>(loadIndex) + vec2<f32>(0.25, 0.25)) / vec2<f32>(dims), 0.0).rgb;
}
}
workgroupBarrier();
for (var r : u32 = 0u; r < 4u; r = r + 1u) {
for (var c : u32 = 0u; c < 4u; c = c + 1u) {
var writeIndex = baseIndex + vec2<i32>(i32(c), i32(r));
if (flip.value != 0u) {
writeIndex = writeIndex.yx;
}
let center : u32 = 4u * LocalInvocationID.x + c;
if (center >= filterOffset &&
center < 256u - filterOffset &&
all(writeIndex < dims)) {
var acc : vec3<f32> = vec3<f32>(0.0, 0.0, 0.0);
for (var f : u32 = 0u; f < params.filterDim; f = f + 1u) {
var i : u32 = center + f - filterOffset;
acc = acc + (1.0 / f32(params.filterDim)) * tile[r][i];
}
textureStore(outputTex, writeIndex, vec4<f32>(acc, 1.0));
}
}
}
}