[hlsl] Enable uniform vector load

Add support for `load_vector_element` to the uniform decompose
transform.

Bug: 349867642
Change-Id: Ie00cd71d82e27d3446b3f63c79986d3fd0b74f69
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/197074
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 3637485..8587d5d 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -789,17 +789,18 @@
 )");
 }
 
-TEST_F(HlslWriterTest, DISABLED_AccessUniformVectorLoad) {
+TEST_F(HlslWriterTest, AccessUniformVectorLoad) {
     auto* var = b.Var<uniform, vec4<f32>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
     b.ir.root_block->Append(var);
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
-        b.Let("a", b.LoadVectorElement(var, 0_u));
-        b.Let("b", b.LoadVectorElement(var, 1_u));
-        b.Let("c", b.LoadVectorElement(var, 2_u));
-        b.Let("d", b.LoadVectorElement(var, 3_u));
+        b.Let("a", b.Load(var));
+        b.Let("b", b.LoadVectorElement(var, 0_u));
+        b.Let("c", b.LoadVectorElement(var, 1_u));
+        b.Let("d", b.LoadVectorElement(var, 2_u));
+        b.Let("e", b.LoadVectorElement(var, 3_u));
         b.Return(func);
     });
 
@@ -808,13 +809,14 @@
 cbuffer cbuffer_v : register(b0) {
   uint4 v[1];
 };
-
-void m() {
-  float a = asfloat(v[0].x);
-  float b = asfloat(v[0].y);
-  float c = asfloat(v[0].z);
-  float d = asfloat(v[0].w);
+void foo() {
+  float4 a = asfloat(v[0u]);
+  float b = asfloat(v[0u][0u]);
+  float c = asfloat(v[0u][1u]);
+  float d = asfloat(v[0u][2u]);
+  float e = asfloat(v[0u][3u]);
 }
+
 )");
 }
 
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
index 4087e53..607f3f6 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access.cc
@@ -86,40 +86,31 @@
         for (auto* var : var_worklist) {
             auto* result = var->Result(0);
 
-            // Find all the usages of the `var` which is loading.
-            Vector<core::ir::Instruction*, 4> usage_worklist;
-            for (auto& usage : result->Usages()) {
-                Switch(
-                    usage->instruction,
-                    [&](core::ir::LoadVectorElement* lve) { usage_worklist.Push(lve); },
-                    [&](core::ir::Load* ld) { usage_worklist.Push(ld); },
-                    [&](core::ir::Access* a) { usage_worklist.Push(a); },
-                    [&](core::ir::Let* l) { usage_worklist.Push(l); },  //
-                    TINT_ICE_ON_NO_MATCH);
-            }
-
+            auto usage_worklist = result->Usages().Vector();
             auto* var_ty = result->Type()->As<core::type::Pointer>();
             while (!usage_worklist.IsEmpty()) {
-                auto* inst = usage_worklist.Pop();
+                auto usage = usage_worklist.Pop();
+                auto* inst = usage.instruction;
+
                 // Load instructions can be destroyed by the replacing access function
                 if (!inst->Alive()) {
                     continue;
                 }
 
                 Switch(
-                    inst,
-                    [&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
-                    [&](core::ir::Load* l) { Load(l, var); },  //
-                    [&](core::ir::Access* a) {
-                        OffsetData offset;
-                        Access(a, var, a->Object()->Type(), &offset);
-                    },
+                    inst,  //
+                    [&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var); },
+                    [&](core::ir::Load* l) { Load(l, var); },
+                    //                    [&](core::ir::Access* a) {
+                    //                        OffsetData offset;
+                    //                        Access(a, var, a->Object()->Type(), &offset);
+                    //                    },
                     [&](core::ir::Let* let) {
                         // The `let` is, essentially, an alias for the `var` as it's assigned
                         // directly. Gather all the `let` usages into our worklist, and then replace
                         // the `let` with the `var` itself.
-                        for (auto& usage : let->Result(0)->Usages()) {
-                            usage_worklist.Push(usage->instruction);
+                        for (auto& use : let->Result(0)->Usages()) {
+                            usage_worklist.Push(use);
                         }
                         let->Result(0)->ReplaceAllUsesWith(result);
                         let->Destroy();
@@ -134,369 +125,26 @@
         }
     }
 
-    struct OffsetData {
-        uint32_t byte_offset = 0;
-        Vector<core::ir::Value*, 4> expr{};
-    };
-
-    // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
-    void CalculateVectorIndex(core::ir::Value* v,
-                              const core::type::Type* store_ty,
-                              OffsetData* offset) {
-        auto elm_size = store_ty->DeepestElement()->Size();
-        tint::Switch(
-            v,  //
-            [&](core::ir::Constant* idx_value) {
-                offset->byte_offset += idx_value->Value()->ValueAs<uint32_t>() * elm_size;
-            },
-            [&](core::ir::Value* val) {
-                offset->expr.Push(
-                    b.Multiply(ty.u32(), b.Convert(ty.u32(), val), u32(elm_size))->Result(0));
-            },
-            TINT_ICE_ON_NO_MATCH);
-    }
-
-    // Note, must be called inside a builder insert block (Append, InsertBefore, etc)
-    core::ir::Value* OffsetToValue(OffsetData offset) {
-        core::ir::Value* val = b.Value(u32(offset.byte_offset));
-        for (core::ir::Value* expr : offset.expr) {
-            val = b.Add(ty.u32(), val, expr)->Result(0);
-        }
-        return val;
-    }
-
-    // Creates the appropriate load instructions for the given result type.
-    core::ir::Call* MakeLoad(core::ir::Instruction* inst,
-                             core::ir::Var* var,
-                             const core::type::Type* result_ty,
-                             core::ir::Value* offset) {
-        if (result_ty->is_numeric_scalar_or_vector()) {
-            return MakeScalarOrVectorLoad(var, result_ty, offset);
-        }
-
-        return tint::Switch(
-            result_ty,  //
-            [&](const core::type::Struct* s) {
-                auto* fn = GetLoadFunctionFor(inst, var, s);
-                return b.Call(fn, offset);
-            },
-            [&](const core::type::Matrix* m) {
-                auto* fn = GetLoadFunctionFor(inst, var, m);
-                return b.Call(fn, offset);
-            },  //
-            [&](const core::type::Array* a) {
-                auto* fn = GetLoadFunctionFor(inst, var, a);
-                return b.Call(fn, offset);
-            },  //
-            TINT_ICE_ON_NO_MATCH);
-    }
-
-    // Creates a `v.Load{2,3,4} offset` call based on the provided type. The load returns a
-    // `u32` or vector of `u32` and then a `bitcast` is done to get back to the desired type.
-    //
-    // This only works for `u32`, `i32`, `f16`, `f32` and the vector sizes of those types.
-    //
-    // The `f16` type is special in that `f16` uses a templated load in HLSL `Load<float16_t>`
-    // and returns the correct type, so there is no bitcast.
-    core::ir::Call* MakeScalarOrVectorLoad(core::ir::Var* var,
-                                           const core::type::Type* result_ty,
-                                           core::ir::Value* offset) {
-        bool is_f16 = result_ty->DeepestElement()->Is<core::type::F16>();
-
-        const core::type::Type* load_ty = ty.u32();
-        // An `f16` load returns an `f16` instead of a `u32`
-        if (is_f16) {
-            load_ty = ty.f16();
-        }
-
-        auto fn = is_f16 ? BuiltinFn::kLoadF16 : BuiltinFn::kLoad;
-        if (auto* v = result_ty->As<core::type::Vector>()) {
-            load_ty = ty.vec(load_ty, v->Width());
-            switch (v->Width()) {
-                case 2:
-                    fn = is_f16 ? BuiltinFn::kLoad2F16 : BuiltinFn::kLoad2;
-                    break;
-                case 3:
-                    fn = is_f16 ? BuiltinFn::kLoad3F16 : BuiltinFn::kLoad3;
-                    break;
-                case 4:
-                    fn = is_f16 ? BuiltinFn::kLoad4F16 : BuiltinFn::kLoad4;
-                    break;
-                default:
-                    TINT_UNREACHABLE();
-            }
-        }
-
-        auto* builtin = b.MemberCall<hlsl::ir::MemberBuiltinCall>(load_ty, fn, var, offset);
-        core::ir::Call* res = nullptr;
-
-        // Do not bitcast the `f16` conversions as they need to be a templated Load instruction
-        if (is_f16) {
-            res = builtin;
-        } else {
-            res = b.Bitcast(result_ty, builtin->Result(0));
-        }
-        return res;
-    }
-
-    // Creates a load function for the given `var` and `struct` combination. Essentially creates
-    // a function similar to:
-    //
-    // fn custom_load_S(offset: u32) {
-    //   let a = <load S member 0>(offset + member 0 offset);
-    //   let b = <load S member 1>(offset + member 1 offset);
-    //   ...
-    //   let z = <load S member last>(offset + member last offset);
-    //   return S(a, b, ..., z);
-    // }
-    core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
-                                           core::ir::Var* var,
-                                           const core::type::Struct* s) {
-        return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, s}, [&] {
-            auto* p = b.FunctionParam("offset", ty.u32());
-            auto* fn = b.Function(s);
-            fn->SetParams({p});
-
-            b.Append(fn->Block(), [&] {
-                Vector<core::ir::Value*, 4> values;
-                for (const auto* mem : s->Members()) {
-                    values.Push(MakeLoad(inst, var, mem->Type(),
-                                         b.Add<u32>(p, u32(mem->Offset()))->Result(0))
-                                    ->Result(0));
-                }
-
-                b.Return(fn, b.Construct(s, values));
-            });
-
-            return fn;
-        });
-    }
-
-    // Creates a load function for the given `var` and `matrix` combination. Essentially creates
-    // a function similar to:
-    //
-    // fn custom_load_M(offset: u32) {
-    //   let a = <load M column 1>(offset + (1 * ColumnStride));
-    //   let b = <load M column 2>(offset + (2 * ColumnStride));
-    //   ...
-    //   let z = <load M column last>(offset + (last * ColumnStride));
-    //   return M(a, b, ... z);
-    // }
-    core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
-                                           core::ir::Var* var,
-                                           const core::type::Matrix* mat) {
-        return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, mat}, [&] {
-            auto* p = b.FunctionParam("offset", ty.u32());
-            auto* fn = b.Function(mat);
-            fn->SetParams({p});
-
-            b.Append(fn->Block(), [&] {
-                Vector<core::ir::Value*, 4> values;
-                for (size_t i = 0; i < mat->columns(); ++i) {
-                    auto* add = b.Add<u32>(p, u32(i * mat->ColumnStride()));
-                    auto* load = MakeLoad(inst, var, mat->ColumnType(), add->Result(0));
-                    values.Push(load->Result(0));
-                }
-
-                b.Return(fn, b.Construct(mat, values));
-            });
-
-            return fn;
-        });
-    }
-
-    // Creates a load function for the given `var` and `array` combination. Essentially creates
-    // a function similar to:
-    //
-    // fn custom_load_A(offset: u32) {
-    //   A a = A();
-    //   u32 i = 0;
-    //   loop {
-    //     if (i >= A length) {
-    //       break;
-    //     }
-    //     a[i] = <load array type>(offset + (i * A->Stride()));
-    //     i = i + 1;
-    //   }
-    //   return a;
-    // }
-    core::ir::Function* GetLoadFunctionFor(core::ir::Instruction* inst,
-                                           core::ir::Var* var,
-                                           const core::type::Array* arr) {
-        return var_and_type_to_load_fn_.GetOrAdd(VarTypePair{var, arr}, [&] {
-            auto* p = b.FunctionParam("offset", ty.u32());
-            auto* fn = b.Function(arr);
-            fn->SetParams({p});
-
-            b.Append(fn->Block(), [&] {
-                auto* result_arr = b.Var<function>("a", b.Zero(arr));
-
-                auto* count = arr->Count()->As<core::type::ConstantArrayCount>();
-                TINT_ASSERT(count);
-
-                b.LoopRange(ty, 0_u, u32(count->value), 1_u, [&](core::ir::Value* idx) {
-                    auto* access = b.Access(ty.ptr<function>(arr->ElemType()), result_arr, idx);
-                    auto* stride = b.Multiply<u32>(idx, u32(arr->Stride()));
-                    auto* byte_offset = b.Add<u32>(p, stride);
-                    b.Store(access, MakeLoad(inst, var, arr->ElemType(), byte_offset->Result(0)));
-                });
-
-                b.Return(fn, b.Load(result_arr));
-            });
-
-            return fn;
-        });
-    }
-
-    void InsertLoad(core::ir::Var* var, core::ir::Instruction* inst, OffsetData offset) {
-        b.InsertBefore(inst, [&] {
-            auto* call =
-                MakeLoad(inst, var, inst->Result(0)->Type()->UnwrapPtr(), OffsetToValue(offset));
-            inst->Result(0)->ReplaceAllUsesWith(call->Result(0));
-        });
-        inst->Destroy();
-    }
-
-    void Access(core::ir::Access* a,
-                core::ir::Var* var,
-                const core::type::Type* obj,
-                OffsetData* offset) {
-        TINT_ASSERT(offset);
-
-        // Note, because we recurse through the `access` helper, the object passed in isn't
-        // necessarily the originating `var` object, but maybe a partially resolved access chain
-        // object.
-        if (auto* view = obj->As<core::type::MemoryView>()) {
-            obj = view->StoreType();
-        }
-
-        const core::type::StructMember* last_member = nullptr;
-        for (auto* idx_value : a->Indices()) {
-            uint32_t size = 0;
-            auto* cur_obj = obj;
-            tint::Switch(
-                obj,  //
-                [&](const core::type::Vector* v) {
-                    size = v->type()->Size();
-                    obj = v->type();
-                },
-                [&](const core::type::Matrix* m) {
-                    size = m->type()->Size() * m->rows();
-                    obj = m->ColumnType();
-                },
-                [&](const core::type::Array* ary) {
-                    size = ary->Stride();
-                    obj = ary->ElemType();
-                },
-                [&](const core::type::Struct* s) {
-                    auto* cnst = idx_value->As<core::ir::Constant>();
-
-                    // A struct index must be a constant
-                    TINT_ASSERT(cnst);
-
-                    uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
-                    last_member = s->Members()[idx];
-                    obj = last_member->Type();
-                },
-                TINT_ICE_ON_NO_MATCH);
-
-            tint::Switch(
-                idx_value,  //
-                [&](core::ir::Constant* cnst) {
-                    uint32_t idx = cnst->Value()->ValueAs<uint32_t>();
-                    tint::Switch(
-                        cur_obj,  //
-                        [&](const core::type::Vector*) { offset->byte_offset += size * idx; },
-                        [&](const core::type::Matrix*) { offset->byte_offset += size * idx; },
-                        [&](const core::type::Array*) { offset->byte_offset += size * idx; },
-                        [&](const core::type::Struct* s) {
-                            auto* mem = s->Members()[idx];
-                            offset->byte_offset += mem->Offset();
-                        },
-                        TINT_ICE_ON_NO_MATCH);
-                },
-                [&](core::ir::Value* val) {
-                    b.InsertBefore(a, [&] {
-                        offset->expr.Push(
-                            b.Multiply(ty.u32(), u32(size), b.Convert(ty.u32(), val))->Result(0));
-                    });
-                },  //
-                TINT_ICE_ON_NO_MATCH);
-        }
-
-        // Copy the usages into a vector so we can remove items from the hashset.
-        auto usages = a->Result(0)->Usages().Vector();
-        while (!usages.IsEmpty()) {
-            auto usage = usages.Pop();
-            tint::Switch(
-                usage.instruction,
-                [&](core::ir::Let* let) {
-                    // The `let` is essentially an alias to the `access`. So, add the `let`
-                    // usages into the usage worklist, and replace the let with the access chain
-                    // directly.
-                    for (auto& u : let->Result(0)->Usages()) {
-                        usages.Push(u);
-                    }
-                    let->Result(0)->ReplaceAllUsesWith(a->Result(0));
-                    let->Destroy();
-                },
-                [&](core::ir::Access* sub_access) {
-                    // Treat an access chain of the access chain as a continuation of the outer
-                    // chain. Pass through the object we stopped at and the current byte_offset
-                    // and then restart the access chain replacement for the new access chain.
-                    Access(sub_access, var, obj, offset);
-                },
-
-                [&](core::ir::LoadVectorElement* lve) {
-                    a->Result(0)->RemoveUsage(usage);
-
-                    b.InsertBefore(lve, [&] { CalculateVectorIndex(lve->Index(), obj, offset); });
-                    InsertLoad(var, lve, *offset);
-                },
-                [&](core::ir::Load* ld) {
-                    a->Result(0)->RemoveUsage(usage);
-                    InsertLoad(var, ld, *offset);
-                },
-                TINT_ICE_ON_NO_MATCH);
-        }
-
-        a->Destroy();
-    }
-
-    // This should _only_ be handling a `var` parameter as any `access` parameters would have
-    // been replaced by the `access` being converted.
     void Load(core::ir::Load* ld, core::ir::Var* var) {
-        auto* result = ld->From()->As<core::ir::InstructionResult>();
-        TINT_ASSERT(result);
-
-        auto* inst = result->Instruction()->As<core::ir::Var>();
-        TINT_ASSERT(inst);
-
-        const core::type::Type* result_ty = inst->Result(0)->Type()->UnwrapPtr();
-
         b.InsertBefore(ld, [&] {
-            auto* call = MakeLoad(ld, var, result_ty, b.Value(0_u));
-            ld->Result(0)->ReplaceAllUsesWith(call->Result(0));
+            auto* access = b.Access(ty.ptr(uniform, ty.vec4<u32>()), var, 0_u);
+            auto* load = b.Load(access);
+            auto* bitcast = b.Bitcast(ld->Result(0)->Type(), load);
+            ld->Result(0)->ReplaceAllUsesWith(bitcast->Result(0));
         });
         ld->Destroy();
     }
 
-    // Converts to:
-    //
-    // %1:u32 = v.Load 0u
-    // %b:f32 = bitcast %1
-    void LoadVectorElement(core::ir::LoadVectorElement* lve,
-                           core::ir::Var* var,
-                           const core::type::Pointer* var_ty) {
+    // A direct vector load on the `var` means the `var` is a vector. Replace the `var` usage` with
+    // an access chain into the `var` array `0` element a `load_vector_element` to retrieve the item
+    // and a `bitcast` to the correct type.
+    void LoadVectorElement(core::ir::LoadVectorElement* lve, core::ir::Var* var) {
         b.InsertBefore(lve, [&] {
-            OffsetData offset{};
-            CalculateVectorIndex(lve->Index(), var_ty->StoreType(), &offset);
-
-            auto* result =
-                MakeScalarOrVectorLoad(var, lve->Result(0)->Type(), OffsetToValue(offset));
-            lve->Result(0)->ReplaceAllUsesWith(result->Result(0));
+            auto* access = b.Access(ty.ptr(uniform, ty.vec4<u32>()), var, 0_u);
+            auto* load = b.LoadVectorElement(access, lve->Index());
+            auto* bitcast = b.Bitcast(lve->Result(0)->Type(), load);
+            lve->Result(0)->ReplaceAllUsesWith(bitcast->Result(0));
         });
-
         lve->Destroy();
     }
 };
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
index ac7d72d..270c194 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_uniform_access_test.cc
@@ -228,7 +228,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_F(HlslWriterDecomposeUniformAccessTest, DISABLED_UniformAccessVectorLoad) {
+TEST_F(HlslWriterDecomposeUniformAccessTest, UniformAccessVectorLoad) {
     auto* var = b.Var<uniform, vec4<f32>, core::Access::kRead>("v");
     var->SetBindingPoint(0, 0);
 
@@ -268,26 +268,31 @@
 
     auto* expect = R"(
 $B1: {  # root
-  %v:hlsl.byte_address_buffer<read> = var @binding_point(0, 0)
+  %v:ptr<uniform, array<vec4<u32>, 1>, read> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
   $B2: {
-    %3:vec4<u32> = %v.Load4 0u
-    %4:vec4<f32> = bitcast %3
-    %a:vec4<f32> = let %4
-    %6:u32 = %v.Load 0u
-    %7:f32 = bitcast %6
-    %b:f32 = let %7
-    %9:u32 = %v.Load 4u
-    %10:f32 = bitcast %9
-    %c:f32 = let %10
-    %12:u32 = %v.Load 8u
+    %3:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %4:vec4<u32> = load %3
+    %5:vec4<f32> = bitcast %4
+    %a:vec4<f32> = let %5
+    %7:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %8:u32 = load_vector_element %7, 0u
+    %9:f32 = bitcast %8
+    %b:f32 = let %9
+    %11:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %12:u32 = load_vector_element %11, 1u
     %13:f32 = bitcast %12
-    %d:f32 = let %13
-    %15:u32 = %v.Load 12u
-    %16:f32 = bitcast %15
-    %e:f32 = let %16
+    %c:f32 = let %13
+    %15:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %16:u32 = load_vector_element %15, 2u
+    %17:f32 = bitcast %16
+    %d:f32 = let %17
+    %19:ptr<uniform, vec4<u32>, read> = access %v, 0u
+    %20:u32 = load_vector_element %19, 3u
+    %21:f32 = bitcast %20
+    %e:f32 = let %21
     ret
   }
 }