[hlsl] Handle access chains of unnamed access chains

This CL adds support for access chains which use other access chains as
their originating variable.

Bug: 349867642
Change-Id: I617a756d4372696de12242c55f049258f0cdf801
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/196316
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/hlsl/writer/access_test.cc b/src/tint/lang/hlsl/writer/access_test.cc
index 8410e2a..8c9b75c 100644
--- a/src/tint/lang/hlsl/writer/access_test.cc
+++ b/src/tint/lang/hlsl/writer/access_test.cc
@@ -628,5 +628,41 @@
 )");
 }
 
+TEST_F(HlslWriterTest, AccessChainFromUnnamedAccessChain) {
+    auto* Inner =
+        ty.Struct(mod.symbols.New("Inner"),
+                  {
+                      {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+                      {mod.symbols.New("d"), ty.u32(), core::type::StructMemberAttributes{}},
+                  });
+    auto* sb = ty.Struct(mod.symbols.New("SB"),
+                         {
+                             {mod.symbols.New("a"), ty.i32(), core::type::StructMemberAttributes{}},
+                             {mod.symbols.New("b"), Inner, core::type::StructMemberAttributes{}},
+                         });
+
+    auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+    var->SetBindingPoint(0, 0);
+    b.ir.root_block->Append(var);
+
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        auto* x = b.Access(ty.ptr(storage, sb, core::Access::kReadWrite), var);
+        auto* y = b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u);
+        b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.u32(), core::Access::kReadWrite),
+                                   y->Result(0), 1_u)));
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
+    EXPECT_EQ(output_.hlsl, R"(
+RWByteAddressBuffer v : register(u0);
+void foo() {
+  uint b = v.Load(8u);
+}
+
+)");
+}
+
 }  // namespace
 }  // namespace tint::hlsl::writer
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
index 8476ff3..089baa8 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access.cc
@@ -106,9 +106,9 @@
                     inst,
                     [&](core::ir::LoadVectorElement* l) { LoadVectorElement(l, var, var_ty); },
                     [&](core::ir::StoreVectorElement* s) { StoreVectorElement(s, var, var_ty); },
-                    [&](core::ir::Store* s) { Store(s); },                 //
-                    [&](core::ir::Load* l) { Load(l, var); },              //
-                    [&](core::ir::Access* a) { Access(a, var, var_ty); },  //
+                    [&](core::ir::Store* s) { Store(s); },                                  //
+                    [&](core::ir::Load* l) { Load(l, var); },                               //
+                    [&](core::ir::Access* a) { Access(a, var, a->Object()->Type(), 0u); },  //
                     TINT_ICE_ON_NO_MATCH);
             }
 
@@ -309,14 +309,17 @@
         });
     }
 
-    void Access(core::ir::Access* a, core::ir::Var* var, const core::type::Pointer*) {
-        const core::type::Type* obj = a->Object()->Type();
-        auto* view = obj->As<core::type::MemoryView>();
-        TINT_ASSERT(view);
+    void Access(core::ir::Access* a,
+                core::ir::Var* var,
+                const core::type::Type* obj,
+                uint32_t byte_offset) {
+        // Note, because we recurse through the `access` helper, the object passed in isn't
+        // necessarily the originating `var` object, but maybe a partially resolved access chain
+        // object.
+        if (auto* view = obj->As<core::type::MemoryView>()) {
+            obj = view->StoreType();
+        }
 
-        obj = view->StoreType();
-
-        uint32_t byte_offset = 0;
         for (auto* idx_value : a->Indices()) {
             auto* cnst = idx_value->As<core::ir::Constant>();
 
@@ -362,8 +365,11 @@
                 [&](core::ir::Let*) {
                     // TODO(dsinclair): handle let
                 },
-                [&](core::ir::Access*) {
-                    // TODO(dsinclair): Handle access
+                [&](core::ir::Access* sub_access) {
+                    // Treat an access chain of the access chain as a continuation of the outer
+                    // chain. Pass through the object we stopped at and the current byte_offset and
+                    // then restart the access chain replacement for the new access chain.
+                    Access(sub_access, var, obj, byte_offset);
                 },
 
                 [&](core::ir::LoadVectorElement* lve) {
diff --git a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
index da0b618..db557f4 100644
--- a/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
+++ b/src/tint/lang/hlsl/writer/raise/decompose_memory_access_test.cc
@@ -177,11 +177,12 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_F(HlslWriterDecomposeMemoryAccessTest, DISABLED_AccessChainFromUnnamedAccessChain) {
+TEST_F(HlslWriterDecomposeMemoryAccessTest, AccessChainFromUnnamedAccessChain) {
     auto* Inner =
         ty.Struct(mod.symbols.New("Inner"),
                   {
                       {mod.symbols.New("c"), ty.f32(), core::type::StructMemberAttributes{}},
+                      {mod.symbols.New("d"), ty.u32(), core::type::StructMemberAttributes{}},
                   });
     auto* sb = ty.Struct(mod.symbols.New("SB"),
                          {
@@ -190,20 +191,22 @@
                          });
 
     auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+    var->SetBindingPoint(0, 0);
     b.ir.root_block->Append(var);
 
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
         auto* x = b.Access(ty.ptr(storage, sb, core::Access::kReadWrite), var);
         auto* y = b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u);
-        b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
-                                   y->Result(0), 0_u)));
+        b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.u32(), core::Access::kReadWrite),
+                                   y->Result(0), 1_u)));
         b.Return(func);
     });
 
     auto* src = R"(
 Inner = struct @align(4) {
   c:f32 @offset(0)
+  d:u32 @offset(4)
 }
 
 SB = struct @align(4) {
@@ -212,16 +215,16 @@
 }
 
 $B1: {  # root
-  %v:ptr<storage, SB, read_write> = var
+  %v:ptr<storage, SB, read_write> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
   $B2: {
     %3:ptr<storage, SB, read_write> = access %v
     %4:ptr<storage, Inner, read_write> = access %3, 1u
-    %5:ptr<storage, f32, read_write> = access %4, 0u
-    %6:f32 = load %5
-    %b:f32 = let %6
+    %5:ptr<storage, u32, read_write> = access %4, 1u
+    %6:u32 = load %5
+    %b:u32 = let %6
     ret
   }
 }
@@ -229,20 +232,25 @@
     ASSERT_EQ(src, str());
 
     auto* expect = R"(
-SB = struct @align(16) {
+Inner = struct @align(4) {
+  c:f32 @offset(0)
+  d:u32 @offset(4)
+}
+
+SB = struct @align(4) {
   a:i32 @offset(0)
-  b:vec3<f32> @offset(16)
+  b:Inner @offset(4)
 }
 
 $B1: {  # root
-  %v:hlsl.byte_address_buffer<read_write> = var
+  %v:hlsl.byte_address_buffer<read_write> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
   $B2: {
-    %3:vec3<u32> = %v.Load3 16u
-    %a:vec3<f32> = bitcast %3
-    %b:f32 = %a 1u
+    %3:u32 = %v.Load 8u
+    %4:u32 = bitcast %3
+    %b:u32 = let %4
     ret
   }
 }
@@ -265,13 +273,16 @@
                          });
 
     auto* var = b.Var("v", storage, sb, core::Access::kReadWrite);
+    var->SetBindingPoint(0, 0);
     b.ir.root_block->Append(var);
 
     auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
     b.Append(func->Block(), [&] {
-        auto* a = b.Let("a", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), var, 1_u));
-        b.Let("b", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
-                                   a->Result(0), 0_u)));
+        auto* x = b.Let("x", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), var));
+        auto* y = b.Let(
+            "y", b.Access(ty.ptr(storage, Inner, core::Access::kReadWrite), x->Result(0), 1_u));
+        b.Let("z", b.Load(b.Access(ty.ptr(storage, ty.f32(), core::Access::kReadWrite),
+                                   y->Result(0), 0_u)));
         b.Return(func);
     });
 
@@ -286,7 +297,7 @@
 }
 
 $B1: {  # root
-  %v:ptr<storage, SB, read_write> = var
+  %v:ptr<storage, SB, read_write> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {
@@ -309,7 +320,7 @@
 }
 
 $B1: {  # root
-  %v:hlsl.byte_address_buffer<read_write> = var
+  %v:hlsl.byte_address_buffer<read_write> = var @binding_point(0, 0)
 }
 
 %foo = @fragment func():void {