[spirv] Add struct flag for explicit layout decorations

Instead of running a pass in the printer to determine which structs
will need explicit layout decorations, just have the
`ForkExplicitLayoutTypes` transform add a struct flag to signal that
information to the printer, since it already knows which structs will
need the decorations.

This currently requires a `const_cast` due to long-standing tech debt
that causes types to be immutable in IR transforms.

Bug: 42252012
Change-Id: Ifb0aa90a47fb92ec30559f0d76c9b92adc959e7c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/230134
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/disassembler.cc b/src/tint/lang/core/ir/disassembler.cc
index 3d915bf..bb7c3ff 100644
--- a/src/tint/lang/core/ir/disassembler.cc
+++ b/src/tint/lang/core/ir/disassembler.cc
@@ -940,8 +940,15 @@
 void Disassembler::EmitStructDecl(const core::type::Struct* str) {
     out_ << StyleType(str->Name().Name()) << " = " << StyleKeyword("struct") << " "
          << StyleAttribute("@align") << "(" << StyleLiteral(str->Align()) << ")";
-    if (str->StructFlags().Contains(core::type::StructFlag::kBlock)) {
-        out_ << ", " << StyleAttribute("@block");
+    for (auto flag : str->StructFlags()) {
+        switch (flag) {
+            case core::type::kBlock:
+                out_ << ", " << StyleAttribute("@block");
+                break;
+            case core::type::kSpirvExplicitLayout:
+                out_ << ", " << StyleAttribute("@spirv.explicit_layout");
+                break;
+        }
     }
     out_ << " {";
     EmitLine();
diff --git a/src/tint/lang/core/type/struct.h b/src/tint/lang/core/type/struct.h
index 0050d6e..90f8508 100644
--- a/src/tint/lang/core/type/struct.h
+++ b/src/tint/lang/core/type/struct.h
@@ -65,6 +65,8 @@
 enum StructFlag {
     /// The structure is a block-decorated structure (for SPIR-V or GLSL).
     kBlock,
+    /// The structure requires explicit layout decorations for SPIR-V.
+    kSpirvExplicitLayout,
 };
 
 /// An alias to tint::EnumSet<StructFlag>
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index e6cebb4..9e1feea 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -241,9 +241,6 @@
         }
     };
 
-    /// The set of structure types that require explicit layout decorations.
-    Hashset<const core::type::Struct*, 16> requires_layout_decorations_;
-
     /// The map of types to their result IDs.
     Hashmap<const core::type::Type*, uint32_t, 8> types_;
 
@@ -309,9 +306,6 @@
                                                              U32Operand(SpvMemoryModelGLSL450)});
         }
 
-        // Find types that require explicit layout decorations.
-        FindStructuresThatRequireLayoutDecorations();
-
         // Emit module-scope declarations.
         EmitRootBlock(ir_.root_block);
 
@@ -383,38 +377,6 @@
         return SpvBuiltInMax;
     }
 
-    /// Find all structure types that are used in host-shareable address spaces and mark them as
-    /// such so that we know to add explicit layout decorations when we emit them.
-    void FindStructuresThatRequireLayoutDecorations() {
-        // We only look at module-scope variable declarations, since this is where all
-        // host-shareable types are declared.
-        for (auto* decl : *ir_.root_block) {
-            if (auto* var = decl->As<core::ir::Var>()) {
-                auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
-                if (!core::IsHostShareable(ptr->AddressSpace())) {
-                    continue;
-                }
-
-                // Look for arrays and structures at any nesting depth of this type.
-                Vector<const core::type::Type*, 8> type_queue;
-                type_queue.Push(ptr->StoreType());
-                while (!type_queue.IsEmpty()) {
-                    auto* next = type_queue.Pop();
-                    if (auto* str = next->As<core::type::Struct>()) {
-                        // Record this structure as host-shareable and then check its members.
-                        requires_layout_decorations_.Add(str);
-                        for (auto* member : str->Members()) {
-                            type_queue.Push(member->Type());
-                        }
-                    } else if (auto* arr = next->As<core::type::Array>()) {
-                        // Check its element type.
-                        type_queue.Push(arr->ElemType());
-                    }
-                }
-            }
-        }
-    }
-
     /// Get the result ID of the constant `constant`, emitting its instruction if necessary.
     /// @param constant the constant to get the ID for
     /// @returns the result ID of the constant
@@ -659,7 +621,7 @@
         for (auto* member : str->Members()) {
             operands.push_back(Type(member->Type()));
 
-            if (requires_layout_decorations_.Contains(str)) {
+            if (str->StructFlags().Contains(core::type::kSpirvExplicitLayout)) {
                 // Generate struct member offset decoration.
                 module_.PushAnnot(spv::Op::OpMemberDecorate,
                                   {operands[0], member->Index(), U32Operand(SpvDecorationOffset),
diff --git a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc
index b391cbc..e3f084b 100644
--- a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc
+++ b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.cc
@@ -180,6 +180,9 @@
         // If no members were forked and the struct itself is not shared with other address spaces,
         // then the original struct can safely be reused.
         if (!must_emit_without_explicit_layout.Contains(original_struct) && !members_were_forked) {
+            // TODO(crbug.com/tint/745): Remove the const_cast.
+            const_cast<core::type::Struct*>(original_struct)
+                ->SetStructFlag(core::type::kSpirvExplicitLayout);
             return nullptr;
         }
 
@@ -190,6 +193,7 @@
                                                    original_struct->Align(),  //
                                                    original_struct->Size(),   //
                                                    original_struct->SizeNoPadding());
+        new_str->SetStructFlag(core::type::kSpirvExplicitLayout);
         for (auto flag : original_struct->StructFlags()) {
             new_str->SetStructFlag(flag);
         }
diff --git a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types_test.cc b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types_test.cc
index 29b08d7..d640d3a 100644
--- a/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types_test.cc
+++ b/src/tint/lang/spirv/writer/raise/fork_explicit_layout_types_test.cc
@@ -81,7 +81,39 @@
     EXPECT_EQ(src, str());
 }
 
-TEST_F(SpirvWriter_ForkExplicitLayoutTypesTest, NoModify_Struct_InHostShareable_NotShared) {
+TEST_F(SpirvWriter_ForkExplicitLayoutTypesTest, NoModify_Array_NotInHostShareable) {
+    auto* wg_buffer = b.Var("wg_buffer", ty.ptr<workgroup, array<u32, 4>>());
+    mod.root_block->Append(wg_buffer);
+
+    auto* func = b.Function("foo", ty.void_());
+    b.Append(func->Block(), [&] {
+        b.Let("let", b.Load(wg_buffer));
+        b.Return(func);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %wg_buffer:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%foo = func():void {
+  $B2: {
+    %3:array<u32, 4> = load %wg_buffer
+    %let:array<u32, 4> = let %3
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    Run(ForkExplicitLayoutTypes);
+
+    EXPECT_EQ(src, str());
+}
+
+// Test that we always add the struct flag to structures that require explicit layout decorations,
+// even if we are not forking them.
+TEST_F(SpirvWriter_ForkExplicitLayoutTypesTest, Struct_InHostShareable_NotShared) {
     auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
                                                                  {mod.symbols.New("a"), ty.u32()},
                                                                  {mod.symbols.New("b"), ty.u32()},
@@ -129,44 +161,38 @@
 )";
     EXPECT_EQ(src, str());
 
-    Run(ForkExplicitLayoutTypes);
-
-    EXPECT_EQ(src, str());
+    auto* expect = R"(
+MyStruct = struct @align(4), @spirv.explicit_layout {
+  a:u32 @offset(0)
+  b:u32 @offset(4)
 }
 
-TEST_F(SpirvWriter_ForkExplicitLayoutTypesTest, NoModify_Array_NotInHostShareable) {
-    auto* wg_buffer = b.Var("wg_buffer", ty.ptr<workgroup, array<u32, 4>>());
-    mod.root_block->Append(wg_buffer);
-
-    auto* func = b.Function("foo", ty.void_());
-    b.Append(func->Block(), [&] {
-        b.Let("let", b.Load(wg_buffer));
-        b.Return(func);
-    });
-
-    auto* src = R"(
 $B1: {  # root
-  %wg_buffer:ptr<workgroup, array<u32, 4>, read_write> = var undef
+  %buffer:ptr<storage, MyStruct, read_write> = var undef @binding_point(0, 0)
 }
 
-%foo = func():void {
+%foo = func(%param:MyStruct):void {
   $B2: {
-    %3:array<u32, 4> = load %wg_buffer
-    %let:array<u32, 4> = let %3
+    ret
+  }
+}
+%foo_1 = func():void {  # %foo_1: 'foo'
+  $B3: {
+    %5:MyStruct = load %buffer
+    %let:MyStruct = let %5
     ret
   }
 }
 )";
-    EXPECT_EQ(src, str());
 
     Run(ForkExplicitLayoutTypes);
 
-    EXPECT_EQ(src, str());
+    EXPECT_EQ(expect, str());
 }
 
 // Test that we always modify arrays that require explicit layout decorations, since the type is
 // used to signal to the printer that layout decorations are required.
-TEST_F(SpirvWriter_ForkExplicitLayoutTypesTest, Array_InHostShareable) {
+TEST_F(SpirvWriter_ForkExplicitLayoutTypesTest, Array_InHostShareable_NotShared) {
     auto* buffer = b.Var("buffer", ty.ptr<storage, array<u32, 4>>());
     buffer->SetBindingPoint(0, 0);
     mod.root_block->Append(buffer);
@@ -217,7 +243,7 @@
   a:u32 @offset(0)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
 }
 
@@ -261,7 +287,7 @@
   a:u32 @offset(0)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
 }
 
@@ -305,7 +331,7 @@
   a:u32 @offset(0)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
 }
 
@@ -349,7 +375,7 @@
   a:u32 @offset(0)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
 }
 
@@ -440,7 +466,7 @@
   a:u32 @offset(0)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
 }
 
@@ -503,7 +529,7 @@
   b:u32 @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
@@ -591,12 +617,12 @@
   b_1:Inner @offset(8)
 }
 
-Inner_tint_explicit_layout = struct @align(4) {
+Inner_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
 
-Outer_tint_explicit_layout = struct @align(4) {
+Outer_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a_1:Inner_tint_explicit_layout @offset(0)
   b_1:Inner_tint_explicit_layout @offset(8)
 }
@@ -682,7 +708,7 @@
   b:u32 @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
@@ -773,12 +799,12 @@
   b_1:Inner @offset(8)
 }
 
-Inner_tint_explicit_layout = struct @align(4) {
+Inner_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
 
-Outer_tint_explicit_layout = struct @align(4) {
+Outer_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a_1:Inner_tint_explicit_layout @offset(0)
   b_1:Inner_tint_explicit_layout @offset(8)
 }
@@ -881,12 +907,12 @@
   b_1:Inner @offset(8)
 }
 
-Inner_tint_explicit_layout = struct @align(4) {
+Inner_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
 
-Outer_tint_explicit_layout = struct @align(4) {
+Outer_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a_1:Inner_tint_explicit_layout @offset(0)
   b_1:Inner_tint_explicit_layout @offset(8)
 }
@@ -995,7 +1021,7 @@
   b:u32 @offset(4)
 }
 
-InnerNotShared = struct @align(4) {
+InnerNotShared = struct @align(4), @spirv.explicit_layout {
   a_1:u32 @offset(0)
   b_1:u32 @offset(4)
 }
@@ -1005,12 +1031,12 @@
   b_2:InnerNotShared @offset(8)
 }
 
-InnerShared_tint_explicit_layout = struct @align(4) {
+InnerShared_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
 
-Outer_tint_explicit_layout = struct @align(4) {
+Outer_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a_2:InnerShared_tint_explicit_layout @offset(0)
   b_2:InnerNotShared @offset(8)
 }
@@ -1101,7 +1127,7 @@
   b:u32 @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
@@ -1195,7 +1221,7 @@
   b:u32 @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
@@ -1285,7 +1311,7 @@
   b:u32 @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
@@ -1372,12 +1398,12 @@
   b_1:Inner @offset(8)
 }
 
-Inner_tint_explicit_layout = struct @align(4) {
+Inner_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
 
-Outer_tint_explicit_layout = struct @align(4) {
+Outer_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a_1:Inner_tint_explicit_layout @offset(0)
   b_1:Inner_tint_explicit_layout @offset(8)
 }
@@ -1479,7 +1505,7 @@
   b:u32 @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
@@ -1610,17 +1636,17 @@
   f:i32 @offset(4)
 }
 
-S_0_tint_explicit_layout = struct @align(4) {
+S_0_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   a:u32 @offset(0)
   b:u32 @offset(4)
 }
 
-S_1_tint_explicit_layout = struct @align(4) {
+S_1_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   c:f32 @offset(0)
   d:f32 @offset(4)
 }
 
-S_2_tint_explicit_layout = struct @align(4) {
+S_2_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   e:i32 @offset(0)
   f:i32 @offset(4)
 }
@@ -2459,7 +2485,7 @@
   arr:array<u32> @offset(4)
 }
 
-MyStruct_tint_explicit_layout = struct @align(4) {
+MyStruct_tint_explicit_layout = struct @align(4), @spirv.explicit_layout {
   i:u32 @offset(0)
   arr:spirv.explicit_layout_array<u32, > @offset(4)
 }