[hlsl] Add array access to uniform loading

This CL adds the ability to load array and array elements to the
decompose uniform access transform.

Bug: 349867642
Change-Id: I816cb9f60f278896c68a8b899ef9466307b415b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/197135
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index f155d44..7185c51 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -951,7 +951,7 @@
   uint4 v[4];
 };
 float4x4 v_1(uint start_byte_offset) {
-  float4 v_2 = asfloat(v[((0u + start_byte_offset) / 16u)]);
+  float4 v_2 = asfloat(v[(start_byte_offset / 16u)]);
   float4 v_3 = asfloat(v[((16u + start_byte_offset) / 16u)]);
   float4 v_4 = asfloat(v[((32u + start_byte_offset) / 16u)]);
   return float4x4(v_2, v_3, v_4, asfloat(v[((48u + start_byte_offset) / 16u)]));
@@ -986,7 +986,7 @@
   uint4 v[2];
 };
 float2x3 v_1(uint start_byte_offset) {
-  float3 v_2 = asfloat(v[((0u + start_byte_offset) / 16u)].xyz);
+  float3 v_2 = asfloat(v[(start_byte_offset / 16u)].xyz);
   return float2x3(v_2, asfloat(v[((16u + start_byte_offset) / 16u)].xyz));
 }
 
@@ -1019,8 +1019,8 @@
   uint4 v[2];
 };
 float3x2 v_1(uint start_byte_offset) {
-  uint4 v_2 = v[((0u + start_byte_offset) / 16u)];
-  float2 v_3 = asfloat(((((((0u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
+  uint4 v_2 = v[(start_byte_offset / 16u)];
+  float2 v_3 = asfloat((((((start_byte_offset % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
   uint4 v_4 = v[((8u + start_byte_offset) / 16u)];
   float2 v_5 = asfloat(((((((8u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_4.zw) : (v_4.xy)));
   uint4 v_6 = v[((16u + start_byte_offset) / 16u)];
@@ -1056,8 +1056,8 @@
   uint4 v[1];
 };
 float2x2 v_1(uint start_byte_offset) {
-  uint4 v_2 = v[((0u + start_byte_offset) / 16u)];
-  float2 v_3 = asfloat(((((((0u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
+  uint4 v_2 = v[(start_byte_offset / 16u)];
+  float2 v_3 = asfloat((((((start_byte_offset % 16u) / 4u) == 2u)) ? (v_2.zw) : (v_2.xy)));
   uint4 v_4 = v[((8u + start_byte_offset) / 16u)];
   return float2x2(v_3, asfloat(((((((8u + start_byte_offset) % 16u) / 4u) == 2u)) ? (v_4.zw) : (v_4.xy))));
 }
@@ -1071,7 +1071,7 @@
 )");
 }
 
-TEST_F(HlslWriterTest, DISABLED_AccessUniformArray) {
+TEST_F(HlslWriterTest, AccessUniformArray) {
     auto* var = b.Var<uniform, array<vec3<f32>, 5>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
@@ -1088,28 +1088,37 @@
 cbuffer cbuffer_v : register(b0) {
   uint4 v[5];
 };
-
-typedef float4 v_load_ret[5];
-v_load_ret v_load(uint offset) {
-  float4 arr[5] = (float4[5])0;
+typedef float3 ary_ret[5];
+ary_ret v_1(uint start_byte_offset) {
+  float3 a[5] = (float3[5])0;
   {
-    for(uint i = 0u; (i < 5u); i = (i + 1u)) {
-      const uint scalar_offset = ((offset + (i * 16u))) / 4;
-      arr[i] = asfloat(v[scalar_offset / 4]);
+    uint v_2 = 0u;
+    v_2 = 0u;
+    while(true) {
+      uint v_3 = v_2;
+      if ((v_3 >= 5u)) {
+        break;
+      }
+      a[v_3] = asfloat(v[((start_byte_offset + (v_3 * 16u)) / 16u)].xyz);
+      {
+        v_2 = (v_3 + 1u);
+      }
+      continue;
     }
   }
-  return arr;
+  float3 v_4[5] = a;
+  return v_4;
 }
 
 void foo() {
-  float4 a[5] = v_load(0u);
-  float4 b = asfloat(v[3]);
+  float3 a[5] = v_1(0u);
+  float3 b = asfloat(v[3u].xyz);
 }
 
 )");
 }
 
-TEST_F(HlslWriterTest, DISABLED_AccessUniformArrayWhichCanHaveSizesOtherThenFive) {
+TEST_F(HlslWriterTest, AccessUniformArrayWhichCanHaveSizesOtherThenFive) {
     auto* var = b.Var<uniform, array<vec3<f32>, 42>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
@@ -1126,22 +1135,31 @@
 cbuffer cbuffer_v : register(b0) {
   uint4 v[42];
 };
-
-typedef float4 v_load_ret[42];
-v_load_ret v_load(uint offset) {
-  float4 arr[42] = (float4[42])0;
+typedef float3 ary_ret[42];
+ary_ret v_1(uint start_byte_offset) {
+  float3 a[42] = (float3[42])0;
   {
-    for(uint i = 0u; (i < 42u); i = (i + 1u)) {
-      const uint scalar_offset = ((offset + (i * 16u))) / 4;
-      arr[i] = asfloat(v[scalar_offset / 4]);
+    uint v_2 = 0u;
+    v_2 = 0u;
+    while(true) {
+      uint v_3 = v_2;
+      if ((v_3 >= 42u)) {
+        break;
+      }
+      a[v_3] = asfloat(v[((start_byte_offset + (v_3 * 16u)) / 16u)].xyz);
+      {
+        v_2 = (v_3 + 1u);
+      }
+      continue;
     }
   }
-  return arr;
+  float3 v_4[42] = a;
+  return v_4;
 }
 
 void foo() {
-  float4 a[42] = v_load(0u);
-  float4 b = asfloat(v[3]);
+  float3 a[42] = v_1(0u);
+  float3 b = asfloat(v[3u].xyz);
 }
 
 )");
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
index 375c48c..f635e65 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
@@ -130,10 +130,25 @@
 
     // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
     core::ir::Value* OffsetToValue(OffsetData offset) {
-        core::ir::Value* val = b.Value(u32(offset.byte_offset));
-        for (core::ir::Value* expr : offset.byte_offset_expr) {
-            val = b.Add(ty.u32(), val, expr)->Result(0);
+        core::ir::Value* val = nullptr;
+
+        // If the offset is zero, skip setting val. This way, we won't add `0 +` and create useless
+        // addition expressions, but if the offset is zero, and there are no expressions, make sure
+        // we return the 0 value.
+        if (offset.byte_offset != 0) {
+            val = b.Value(u32(offset.byte_offset));
+        } else if (offset.byte_offset_expr.IsEmpty()) {
+            return b.Value(0_u);
         }
+
+        for (core::ir::Value* expr : offset.byte_offset_expr) {
+            if (!val) {
+                val = expr;
+            } else {
+                val = b.Add(ty.u32(), val, expr)->Result(0);
+            }
+        }
+
         return val;
     }
 
@@ -292,10 +307,10 @@
                 auto* fn = GetLoadFunctionFor(inst, var, m);
                 return b.Call(fn, byte_idx);
             },
-            //            [&](const core::type::Array* a) {
-            //                auto* fn = GetLoadFunctionFor(inst, var, a);
-            //                return b.Call(fn, vec_idx);
-            //            },
+            [&](const core::type::Array* a) {
+                auto* fn = GetLoadFunctionFor(inst, var, a);
+                return b.Call(fn, byte_idx);
+            },
             TINT_ICE_ON_NO_MATCH);
     }
 
@@ -402,6 +417,51 @@
             return fn;
         });
     }
+
+    // Creates a load function for the given `var` and `array` combination. Essentially creates
+    // a function similar to:
+    //
+    // fn custom_load_A(offset: u32) {
+    //   A a = A();
+    //   u32 i = 0;
+    //   loop {
+    //     if (i >= A length) {
+    //       break;
+    //     }
+    //     offset = (offset + (i * A->Stride())) / 16
+    //     a[i] = cast(v[offset].xyz)
+    //     i = i + 1;
+    //   }
+    //   return a;
+    // }
+    core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
+                                           core::ir::Var* var,
+                                           const core::type::Array* arr) {
+        return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, arr}, [&] {
+            auto* start_byte_offset = b.FunctionParam("start_byte_offset", ty.u32());
+            auto* fn = b.Function(arr);
+            fn->SetParams({start_byte_offset});
+
+            b.Append(fn->Block(), [&] {
+                auto* result_arr = b.Var<function>("a", b.Zero(arr));
+
+                auto* count = arr->Count()->As<core::type::ConstantArrayCount>();
+                TINT_ASSERT(count);
+
+                b.LoopRange(ty, 0_u, u32(count->value), 1_u, [&](core::ir::Value* idx) {
+                    auto* stride = b.Multiply<u32>(idx, u32(arr->Stride()))->Result(0);
+                    OffsetData od{0, {start_byte_offset, stride}};
+                    auto* byte_idx = OffsetToValue(od);
+                    auto* access = b.Access(ty.ptr<function>(arr->ElemType()), result_arr, idx);
+                    b.Store(access, MakeLoad(inst, var, arr->ElemType(), byte_idx));
+                });
+
+                b.Return(fn, b.Load(result_arr));
+            });
+
+            return fn;
+        });
+    }
 };
 
 }  // namespace
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
index f3455c2..589dc8a 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
@@ -421,36 +421,35 @@
 }
 %4 = func(%start_byte_offset:u32):mat4x4<f32> {
   $B3: {
-    %15:u32 = add 0u, %start_byte_offset
-    %16:u32 = div %15, 16u
-    %17:u32 = mod %15, 16u
-    %18:u32 = div %17, 4u
-    %19:ptr<uniform, vec4<u32>, read> = access %v, %16
-    %20:vec4<u32> = load %19
-    %21:vec4<f32> = bitcast %20
-    %22:u32 = add 16u, %start_byte_offset
-    %23:u32 = div %22, 16u
-    %24:u32 = mod %22, 16u
-    %25:u32 = div %24, 4u
-    %26:ptr<uniform, vec4<u32>, read> = access %v, %23
-    %27:vec4<u32> = load %26
-    %28:vec4<f32> = bitcast %27
-    %29:u32 = add 32u, %start_byte_offset
-    %30:u32 = div %29, 16u
-    %31:u32 = mod %29, 16u
-    %32:u32 = div %31, 4u
-    %33:ptr<uniform, vec4<u32>, read> = access %v, %30
-    %34:vec4<u32> = load %33
-    %35:vec4<f32> = bitcast %34
-    %36:u32 = add 48u, %start_byte_offset
-    %37:u32 = div %36, 16u
-    %38:u32 = mod %36, 16u
-    %39:u32 = div %38, 4u
-    %40:ptr<uniform, vec4<u32>, read> = access %v, %37
-    %41:vec4<u32> = load %40
-    %42:vec4<f32> = bitcast %41
-    %43:mat4x4<f32> = construct %21, %28, %35, %42
-    ret %43
+    %15:u32 = div %start_byte_offset, 16u
+    %16:u32 = mod %start_byte_offset, 16u
+    %17:u32 = div %16, 4u
+    %18:ptr<uniform, vec4<u32>, read> = access %v, %15
+    %19:vec4<u32> = load %18
+    %20:vec4<f32> = bitcast %19
+    %21:u32 = add 16u, %start_byte_offset
+    %22:u32 = div %21, 16u
+    %23:u32 = mod %21, 16u
+    %24:u32 = div %23, 4u
+    %25:ptr<uniform, vec4<u32>, read> = access %v, %22
+    %26:vec4<u32> = load %25
+    %27:vec4<f32> = bitcast %26
+    %28:u32 = add 32u, %start_byte_offset
+    %29:u32 = div %28, 16u
+    %30:u32 = mod %28, 16u
+    %31:u32 = div %30, 4u
+    %32:ptr<uniform, vec4<u32>, read> = access %v, %29
+    %33:vec4<u32> = load %32
+    %34:vec4<f32> = bitcast %33
+    %35:u32 = add 48u, %start_byte_offset
+    %36:u32 = div %35, 16u
+    %37:u32 = mod %35, 16u
+    %38:u32 = div %37, 4u
+    %39:ptr<uniform, vec4<u32>, read> = access %v, %36
+    %40:vec4<u32> = load %39
+    %41:vec4<f32> = bitcast %40
+    %42:mat4x4<f32> = construct %20, %27, %34, %41
+    ret %42
   }
 }
 )";
@@ -458,7 +457,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_F(HlslWriterDecomposeUniformAccessTest, DISABLED_UniformAccessArray) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessArray) {
     auto* var = b.Var<uniform, array<vec3<f32>, 5>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
@@ -490,20 +489,22 @@
 
     auto* expect = R"(
 $B1: {  # root
-  %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+  %v:ptr<uniform, array<vec4<u32>, 5>, read> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
   $B2: {
     %3:array<vec3<f32>, 5> = call %4, 0u
     %a:array<vec3<f32>, 5> = let %3
-    %6:vec3<u32> = %v.Load3 48u
-    %7:vec3<f32> = bitcast %6
-    %b:vec3<f32> = let %7
+    %6:ptr<uniform, vec4<u32>, read> = access %v, 3u
+    %7:vec4<u32> = load %6
+    %8:vec3<u32> = swizzle %7, xyz
+    %9:vec3<f32> = bitcast %8
+    %b:vec3<f32> = let %9
     ret
   }
 }
-%4 = func(%offset:u32):array<vec3<f32>, 5> {
+%4 = func(%start_byte_offset:u32):array<vec3<f32>, 5> {
   $B3: {
     %a_1:ptr<function, array<vec3<f32>, 5>, read_write> = var, array<vec3<f32>, 5>(vec3<f32>(0.0f))  # %a_1: 'a'
     loop [i: $B4, b: $B5, c: $B6] {  # loop_1
@@ -511,27 +512,32 @@
         next_iteration 0u  # -> $B5
       }
       $B5 (%idx:u32): {  # body
-        %12:bool = gte %idx, 5u
-        if %12 [t: $B7] {  # if_1
+        %14:bool = gte %idx, 5u
+        if %14 [t: $B7] {  # if_1
           $B7: {  # true
             exit_loop  # loop_1
           }
         }
-        %13:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
-        %14:u32 = mul %idx, 16u
-        %15:u32 = add %offset, %14
-        %16:vec3<u32> = %v.Load3 %15
-        %17:vec3<f32> = bitcast %16
-        store %13, %17
+        %15:u32 = mul %idx, 16u
+        %16:u32 = add %start_byte_offset, %15
+        %17:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+        %18:u32 = div %16, 16u
+        %19:u32 = mod %16, 16u
+        %20:u32 = div %19, 4u
+        %21:ptr<uniform, vec4<u32>, read> = access %v, %18
+        %22:vec4<u32> = load %21
+        %23:vec3<u32> = swizzle %22, xyz
+        %24:vec3<f32> = bitcast %23
+        store %17, %24
         continue  # -> $B6
       }
       $B6: {  # continuing
-        %18:u32 = add %idx, 1u
-        next_iteration %18  # -> $B5
+        %25:u32 = add %idx, 1u
+        next_iteration %25  # -> $B5
       }
     }
-    %19:array<vec3<f32>, 5> = load %a_1
-    ret %19
+    %26:array<vec3<f32>, 5> = load %a_1
+    ret %26
   }
 }
 )";
@@ -539,8 +545,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_F(HlslWriterDecomposeUniformAccessTest,
-       DISABLED_UniformAccessArrayWhichCanHaveSizesOtherThenFive) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessArrayWhichCanHaveSizesOtherThenFive) {
     auto* var = b.Var<uniform, array<vec3<f32>, 42>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
@@ -572,20 +577,22 @@
 
     auto* expect = R"(
 $B1: {  # root
-  %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+  %v:ptr<uniform, array<vec4<u32>, 42>, read> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
   $B2: {
     %3:array<vec3<f32>, 42> = call %4, 0u
     %a:array<vec3<f32>, 42> = let %3
-    %6:vec3<u32> = %v.Load3 48u
-    %7:vec3<f32> = bitcast %6
-    %b:vec3<f32> = let %7
+    %6:ptr<uniform, vec4<u32>, read> = access %v, 3u
+    %7:vec4<u32> = load %6
+    %8:vec3<u32> = swizzle %7, xyz
+    %9:vec3<f32> = bitcast %8
+    %b:vec3<f32> = let %9
     ret
   }
 }
-%4 = func(%offset:u32):array<vec3<f32>, 42> {
+%4 = func(%start_byte_offset:u32):array<vec3<f32>, 42> {
   $B3: {
     %a_1:ptr<function, array<vec3<f32>, 42>, read_write> = var, array<vec3<f32>, 42>(vec3<f32>(0.0f))  # %a_1: 'a'
     loop [i: $B4, b: $B5, c: $B6] {  # loop_1
@@ -593,27 +600,32 @@
         next_iteration 0u  # -> $B5
       }
       $B5 (%idx:u32): {  # body
-        %12:bool = gte %idx, 42u
-        if %12 [t: $B7] {  # if_1
+        %14:bool = gte %idx, 42u
+        if %14 [t: $B7] {  # if_1
           $B7: {  # true
             exit_loop  # loop_1
           }
         }
-        %13:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
-        %14:u32 = mul %idx, 16u
-        %15:u32 = add %offset, %14
-        %16:vec3<u32> = %v.Load3 %15
-        %17:vec3<f32> = bitcast %16
-        store %13, %17
+        %15:u32 = mul %idx, 16u
+        %16:u32 = add %start_byte_offset, %15
+        %17:ptr<function, vec3<f32>, read_write> = access %a_1, %idx
+        %18:u32 = div %16, 16u
+        %19:u32 = mod %16, 16u
+        %20:u32 = div %19, 4u
+        %21:ptr<uniform, vec4<u32>, read> = access %v, %18
+        %22:vec4<u32> = load %21
+        %23:vec3<u32> = swizzle %22, xyz
+        %24:vec3<f32> = bitcast %23
+        store %17, %24
         continue  # -> $B6
       }
       $B6: {  # continuing
-        %18:u32 = add %idx, 1u
-        next_iteration %18  # -> $B5
+        %25:u32 = add %idx, 1u
+        next_iteration %25  # -> $B5
       }
     }
-    %19:array<vec3<f32>, 42> = load %a_1
-    ret %19
+    %26:array<vec3<f32>, 42> = load %a_1
+    ret %26
   }
 }
 )";