[hlsl] Add array access to uniform loading
This CL adds the ability to load array and array elements to the
decompose uniform access transform.
Bug: 349867642
Change-Id: I816cb9f60f278896c68a8b899ef9466307b415b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/197135
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index f155d44..7185c51 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -951,7 +951,7 @@
uint4 v[4];
};
float4x4 v_1(uint start_byte_offset) {
- float4 v_2 = asfloat(v[((0u + start_byte_offset) / 16u)]);
+ float4 v_2 = asfloat(v[(start_byte_offset / 16u)]);
float4 v_3 = asfloat(v[((16u + start_byte_offset) / 16u)]);
float4 v_4 = asfloat(v[((32u + start_byte_offset) / 16u)]);
return float4x4(v_2, v_3, v_4, asfloat(v[((48u + start_byte_offset) / 16u)]));
@@ -986,7 +986,7 @@
uint4 v[2];
};
float2x3 v_1(uint start_byte_offset) {
- float3 v_2 = asfloat(v[((0u + start_byte_offset) / 16u)].xyz);
+ float3 v_2 = asfloat(v[(start_byte_offset / 16u)].xyz);
return float2x3(v_2, asfloat(v[((16u + start_byte_offset) / 16u)].xyz));
}
@@ -1019,8 +1019,8 @@
uint4 v[2];
};
float3x2 v_1(uint start_byte_offset) {
- uint4 v_2 = v[((0u + start_byte_offset) / 16u)];
- float2 v_3 = asfloat(((((((0u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
+ uint4 v_2 = v[(start_byte_offset / 16u)];
+ float2 v_3 = asfloat((((((start_byte_offset % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
uint4 v_4 = v[((8u + start_byte_offset) / 16u)];
float2 v_5 = asfloat(((((((8u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_4.zw) : (v_4.xy)));
uint4 v_6 = v[((16u + start_byte_offset) / 16u)];
@@ -1056,8 +1056,8 @@
uint4 v[1];
};
float2x2 v_1(uint start_byte_offset) {
- uint4 v_2 = v[((0u + start_byte_offset) / 16u)];
- float2 v_3 = asfloat(((((((0u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
+ uint4 v_2 = v[(start_byte_offset / 16u)];
+ float2 v_3 = asfloat((((((start_byte_offset % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
uint4 v_4 = v[((8u + start_byte_offset) / 16u)];
return float2x2(v_3, asfloat(((((((8u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_4.zw) : (v_4.xy))));
}
@@ -1071,7 +1071,7 @@
)");
}
-TEST_F(HlslWriterTest, DISABLED_AccessUniformArray) {
+TEST_F(HlslWriterTest, AccessUniformArray) {
auto* var = b.Var<uniform, array<vec3<f32>, 5>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -1088,28 +1088,37 @@
cbuffer cbuffer_v : register(b0) {
uint4 v[5];
};
-
-typedef float4 v_load_ret[5];
-v_load_ret v_load(uint offset) {
- float4 arr[5] = (float4[5])0;
+typedef float3 ary_ret[5];
+ary_ret v_1(uint start_byte_offset) {
+ float3 a[5] = (float3[5])0;
{
- for(uint i = 0u; (i < 5u); i = (i + 1u)) {
- const uint scalar_offset = ((offset + (i * 16u))) / 4;
- arr[i] = asfloat(v[scalar_offset / 4]);
+ uint v_2 = 0u;
+ v_2 = 0u;
+ while(true) {
+ uint v_3 = v_2;
+ if ((v_3 >= 5u)) {
+ break;
+ }
+ a[v_3] = asfloat(v[((start_byte_offset + (v_3 * 16u)) / 16u)].xyz);
+ {
+ v_2 = (v_3 + 1u);
+ }
+ continue;
}
}
- return arr;
+ float3 v_4[5] = a;
+ return v_4;
}
void foo() {
- float4 a[5] = v_load(0u);
- float4 b = asfloat(v[3]);
+ float3 a[5] = v_1(0u);
+ float3 b = asfloat(v[3u].xyz);
}
)");
}
-TEST_F(HlslWriterTest, DISABLED_AccessUniformArrayWhichCanHaveSizesOtherThenFive) {
+TEST_F(HlslWriterTest, AccessUniformArrayWhichCanHaveSizesOtherThenFive) {
auto* var = b.Var<uniform, array<vec3<f32>, 42>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -1126,22 +1135,31 @@
cbuffer cbuffer_v : register(b0) {
uint4 v[42];
};
-
-typedef float4 v_load_ret[42];
-v_load_ret v_load(uint offset) {
- float4 arr[42] = (float4[42])0;
+typedef float3 ary_ret[42];
+ary_ret v_1(uint start_byte_offset) {
+ float3 a[42] = (float3[42])0;
{
- for(uint i = 0u; (i < 42u); i = (i + 1u)) {
- const uint scalar_offset = ((offset + (i * 16u))) / 4;
- arr[i] = asfloat(v[scalar_offset / 4]);
+ uint v_2 = 0u;
+ v_2 = 0u;
+ while(true) {
+ uint v_3 = v_2;
+ if ((v_3 >= 42u)) {
+ break;
+ }
+ a[v_3] = asfloat(v[((start_byte_offset + (v_3 * 16u)) / 16u)].xyz);
+ {
+ v_2 = (v_3 + 1u);
+ }
+ continue;
}
}
- return arr;
+ float3 v_4[42] = a;
+ return v_4;
}
void foo() {
- float4 a[42] = v_load(0u);
- float4 b = asfloat(v[3]);
+ float3 a[42] = v_1(0u);
+ float3 b = asfloat(v[3u].xyz);
}
)");
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
index 375c48c..f635e65 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
@@ -130,10 +130,25 @@
// Note, must be called inside a builder insert block (Append, InsertBefore, etc)
core::ir::Value* OffsetToValue(OffsetData offset) {
- core::ir::Value* val = b.Value(u32(offset.byte_offset));
- for (core::ir::Value* expr : offset.byte_offset_expr) {
- val = b.Add(ty.u32(), val, expr)->Result(0);
+ core::ir::Value* val = nullptr;
+
+ // If the offset is zero, skip setting val. This way, we won't add `0 +` and create useless
+ // addition expressions, but if the offset is zero, and there are no expressions, make sure
+ // we return the 0 value.
+ if (offset.byte_offset != 0) {
+ val = b.Value(u32(offset.byte_offset));
+ } else if (offset.byte_offset_expr.IsEmpty()) {
+ return b.Value(0_u);
}
+
+ for (core::ir::Value* expr : offset.byte_offset_expr) {
+ if (!val) {
+ val = expr;
+ } else {
+ val = b.Add(ty.u32(), val, expr)->Result(0);
+ }
+ }
+
return val;
}
@@ -292,10 +307,10 @@
auto* fn = GetLoadFunctionFor(inst, var, m);
return b.Call(fn, byte_idx);
},
- // [&](const core::type::Array* a) {
- // auto* fn = GetLoadFunctionFor(inst, var, a);
- // return b.Call(fn, vec_idx);
- // },
+ [&](const core::type::Array* a) {
+ auto* fn = GetLoadFunctionFor(inst, var, a);
+ return b.Call(fn, byte_idx);
+ },
TINT_ICE_ON_NO_MATCH);
}
@@ -402,6 +417,51 @@
return fn;
});
}
+
+ // Creates a load function for the given `var` and `array` combination. Essentially creates
+ // a function similar to:
+ //
+ // fn custom_load_A(offset: u32) {
+ // A a = A();
+ // u32 i = 0;
+ // loop {
+ // if (i >= A length) {
+ // break;
+ // }
+ // offset = (offset + (i * A->Stride())) / 16
+ // a[i] = cast(v[offset].xyz)
+ // i = i + 1;
+ // }
+ // return a;
+ // }
+ core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
+ core::ir::Var* var,
+ const core::type::Array* arr) {
+ return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, arr}, [&] {
+ auto* start_byte_offset = b.FunctionParam("start_byte_offset", ty.u32());
+ auto* fn = b.Function(arr);
+ fn->SetParams({start_byte_offset});
+
+ b.Append(fn->Block(), [&] {
+ auto* result_arr = b.Var<function>("a", b.Zero(arr));
+
+ auto* count = arr->Count()->As<core::type::ConstantArrayCount>();
+ TINT_ASSERT(count);
+
+ b.LoopRange(ty, 0_u, u32(count->value), 1_u, [&](core::ir::Value* idx) {
+ auto* stride = b.Multiply<u32>(idx, u32(arr->Stride()))->Result(0);
+ OffsetData od{0, {start_byte_offset, stride}};
+ auto* byte_idx = OffsetToValue(od);
+ auto* access = b.Access(ty.ptr<function>(arr->ElemType()), result_arr, idx);
+ b.Store(access, MakeLoad(inst, var, arr->ElemType(), byte_idx));
+ });
+
+ b.Return(fn, b.Load(result_arr));
+ });
+
+ return fn;
+ });
+ }
};
} // namespace
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
index f3455c2..589dc8a 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
@@ -421,36 +421,35 @@
}
%4 = func(%start_byte_offset:u32):mat4x4<f32> {
$B3: {
- %15:u32 = add 0u, %start_byte_offset
- %16:u32 = div %15, 16u
- %17:u32 = mod %15, 16u
- %18:u32 = div %17, 4u
- %19:ptr<uniform, vec4<u32>, read> = access %v, %16
- %20:vec4<u32> = load %19
- %21:vec4<f32> = bitcast %20
- %22:u32 = add 16u, %start_byte_offset
- %23:u32 = div %22, 16u
- %24:u32 = mod %22, 16u
- %25:u32 = div %24, 4u
- %26:ptr<uniform, vec4<u32>, read> = access %v, %23
- %27:vec4<u32> = load %26
- %28:vec4<f32> = bitcast %27
- %29:u32 = add 32u, %start_byte_offset
- %30:u32 = div %29, 16u
- %31:u32 = mod %29, 16u
- %32:u32 = div %31, 4u
- %33:ptr<uniform, vec4<u32>, read> = access %v, %30
- %34:vec4<u32> = load %33
- %35:vec4<f32> = bitcast %34
- %36:u32 = add 48u, %start_byte_offset
- %37:u32 = div %36, 16u
- %38:u32 = mod %36, 16u
- %39:u32 = div %38, 4u
- %40:ptr<uniform, vec4<u32>, read> = access %v, %37
- %41:vec4<u32> = load %40
- %42:vec4<f32> = bitcast %41
- %43:mat4x4<f32> = construct %21, %28, %35, %42
- ret %43
+ %15:u32 = div %start_byte_offset, 16u
+ %16:u32 = mod %start_byte_offset, 16u
+ %17:u32 = div %16, 4u
+ %18:ptr<uniform, vec4<u32>, read> = access %v, %15
+ %19:vec4<u32> = load %18
+ %20:vec4<f32> = bitcast %19
+ %21:u32 = add 16u, %start_byte_offset
+ %22:u32 = div %21, 16u
+ %23:u32 = mod %21, 16u
+ %24:u32 = div %23, 4u
+ %25:ptr<uniform, vec4<u32>, read> = access %v, %22
+ %26:vec4<u32> = load %25
+ %27:vec4<f32> = bitcast %26
+ %28:u32 = add 32u, %start_byte_offset
+ %29:u32 = div %28, 16u
+ %30:u32 = mod %28, 16u
+ %31:u32 = div %30, 4u
+ %32:ptr<uniform, vec4<u32>, read> = access %v, %29
+ %33:vec4<u32> = load %32
+ %34:vec4<f32> = bitcast %33
+ %35:u32 = add 48u, %start_byte_offset
+ %36:u32 = div %35, 16u
+ %37:u32 = mod %35, 16u
+ %38:u32 = div %37, 4u
+ %39:ptr<uniform, vec4<u32>, read> = access %v, %36
+ %40:vec4<u32> = load %39
+ %41:vec4<f32> = bitcast %40
+ %42:mat4x4<f32> = construct %20, %27, %34, %41
+ ret %42
}
}
)";
@@ -458,7 +457,7 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeUniformAccessTest, DISABLED_UniformAccessArray) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessArray) {
auto* var = b.Var<uniform, array<vec3<f32>, 5>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -490,20 +489,22 @@
auto* expect = R"(
$B1: { # root
- %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+ %v:ptr<uniform, array<vec4<u32>, 5>, read> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
$B2: {
%3:array<vec3<f32>, 5> = call %4, 0u
%a:array<vec3<f32>, 5> = let %3
- %6:vec3<u32> = %v.Load3 48u
- %7:vec3<f32> = bitcast %6
- %b:vec3<f32> = let %7
+ %6:ptr<uniform, vec4<u32>, read> = access %v, 3u
+ %7:vec4<u32> = load %6
+ %8:vec3<u32> = swizzle %7, xyz
+ %9:vec3<f32> = bitcast %8
+ %b:vec3<f32> = let %9
ret
}
}
-%4 = func(%offset:u32):array<vec3<f32>, 5> {
+%4 = func(%start_byte_offset:u32):array<vec3<f32>, 5> {
$B3: {
%a_1:ptr<function, array<vec3<f32>, 5>, read_write> = var, array<vec3<f32>, 5>(vec3<f32>(0.0f)) # %a_1: 'a'
loop [i: $B4, b: $B5, c: $B6] { # loop_1
@@ -511,27 +512,32 @@
next_iteration 0u # -> $B5
}
$B5 (%idx:u32): { # body
- %12:bool = gte %idx, 5u
- if %12 [t: $B7] { # if_1
+ %14:bool = gte %idx, 5u
+ if %14 [t: $B7] { # if_1
$B7: { # true
exit_loop # loop_1
}
}
- %13:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
- %14:u32 = mul %idx, 16u
- %15:u32 = add %offset, %14
- %16:vec3<u32> = %v.Load3 %15
- %17:vec3<f32> = bitcast %16
- store %13, %17
+ %15:u32 = mul %idx, 16u
+ %16:u32 = add %start_byte_offset, %15
+ %17:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+ %18:u32 = div %16, 16u
+ %19:u32 = mod %16, 16u
+ %20:u32 = div %19, 4u
+ %21:ptr<uniform, vec4<u32>, read> = access %v, %18
+ %22:vec4<u32> = load %21
+ %23:vec3<u32> = swizzle %22, xyz
+ %24:vec3<f32> = bitcast %23
+ store %17, %24
continue # -> $B6
}
$B6: { # continuing
- %18:u32 = add %idx, 1u
- next_iteration %18 # -> $B5
+ %25:u32 = add %idx, 1u
+ next_iteration %25 # -> $B5
}
}
- %19:array<vec3<f32>, 5> = load %a_1
- ret %19
+ %26:array<vec3<f32>, 5> = load %a_1
+ ret %26
}
}
)";
@@ -539,8 +545,7 @@
EXPECT_EQ(expect, str());
}
-TEST_F(HlslWriterDecomposeUniformAccessTest,
- DISABLED_UniformAccessArrayWhichCanHaveSizesOtherThenFive) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessArrayWhichCanHaveSizesOtherThenFive) {
auto* var = b.Var<uniform, array<vec3<f32>, 42>, core::Access::kRead>("v");
var->SetBindingPoint(0, 0);
@@ -572,20 +577,22 @@
auto* expect = R"(
$B1: { # root
- %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+ %v:ptr<uniform, array<vec4<u32>, 42>, read> = var @binding_point(0, 0)
}
%foo = @fragment func():void {
$B2: {
%3:array<vec3<f32>, 42> = call %4, 0u
%a:array<vec3<f32>, 42> = let %3
- %6:vec3<u32> = %v.Load3 48u
- %7:vec3<f32> = bitcast %6
- %b:vec3<f32> = let %7
+ %6:ptr<uniform, vec4<u32>, read> = access %v, 3u
+ %7:vec4<u32> = load %6
+ %8:vec3<u32> = swizzle %7, xyz
+ %9:vec3<f32> = bitcast %8
+ %b:vec3<f32> = let %9
ret
}
}
-%4 = func(%offset:u32):array<vec3<f32>, 42> {
+%4 = func(%start_byte_offset:u32):array<vec3<f32>, 42> {
$B3: {
%a_1:ptr<function, array<vec3<f32>, 42>, read_write> = var, array<vec3<f32>, 42>(vec3<f32>(0.0f)) # %a_1: 'a'
loop [i: $B4, b: $B5, c: $B6] { # loop_1
@@ -593,27 +600,32 @@
next_iteration 0u # -> $B5
}
$B5 (%idx:u32): { # body
- %12:bool = gte %idx, 42u
- if %12 [t: $B7] { # if_1
+ %14:bool = gte %idx, 42u
+ if %14 [t: $B7] { # if_1
$B7: { # true
exit_loop # loop_1
}
}
- %13:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
- %14:u32 = mul %idx, 16u
- %15:u32 = add %offset, %14
- %16:vec3<u32> = %v.Load3 %15
- %17:vec3<f32> = bitcast %16
- store %13, %17
+ %15:u32 = mul %idx, 16u
+ %16:u32 = add %start_byte_offset, %15
+ %17:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+ %18:u32 = div %16, 16u
+ %19:u32 = mod %16, 16u
+ %20:u32 = div %19, 4u
+ %21:ptr<uniform, vec4<u32>, read> = access %v, %18
+ %22:vec4<u32> = load %21
+ %23:vec3<u32> = swizzle %22, xyz
+ %24:vec3<f32> = bitcast %23
+ store %17, %24
continue # -> $B6
}
$B6: { # continuing
- %18:u32 = add %idx, 1u
- next_iteration %18 # -> $B5
+ %25:u32 = add %idx, 1u
+ next_iteration %25 # -> $B5
}
}
- %19:array<vec3<f32>, 42> = load %a_1
- ret %19
+ %26:array<vec3<f32>, 42> = load %a_1
+ ret %26
}
}
)";