spirv-reader: handle locations on structure members

Bug: tint:912
Change-Id: Ia179a3152cfcdfed812f1673aaa4dba8b565dadf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56341
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 85ab480..03f45d9 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -980,11 +980,14 @@
     index_prefix.push_back(0);
     for (int i = 0; i < static_cast<int>(members.size()); ++i) {
       index_prefix.back() = i;
+      auto* location = parser_impl_.GetMemberLocation(*struct_type, i);
+      auto* saved_location = SetLocation(decos, location);
       if (!EmitPipelineInput(var_name, var_type, decos, index_prefix,
                              members[i], forced_param_type, params,
                              statements)) {
         return false;
       }
+      SetLocation(decos, saved_location);
     }
     return success();
   }
@@ -1036,6 +1039,12 @@
   statements->push_back(builder_.Assign(store_dest, param_value));
 
   // Increment the location attribute, in case more parameters will follow.
+  IncrementLocation(decos);
+
+  return success();
+}
+
+void FunctionEmitter::IncrementLocation(ast::DecorationList* decos) {
   for (auto*& deco : *decos) {
     if (auto* loc_deco = deco->As<ast::LocationDecoration>()) {
       // Replace this location decoration with a new one with one higher index.
@@ -1044,8 +1053,27 @@
       deco = builder_.Location(loc_deco->source(), loc_deco->value() + 1);
     }
   }
+}
 
-  return success();
+ast::Decoration* FunctionEmitter::SetLocation(
+    ast::DecorationList* decos,
+    ast::Decoration* replacement) {
+  if (!replacement) {
+    return nullptr;
+  }
+  for (auto*& deco : *decos) {
+    if (deco->Is<ast::LocationDecoration>()) {
+      // Replace this location decoration with a new one with one higher index.
+      // The old one doesn't leak because it's kept in the builder's AST node
+      // list.
+      ast::Decoration* result = deco;
+      deco = replacement;
+      return result;
+    }
+  }
+  // The list didn't have a location. Add it.
+  decos->push_back(replacement);
+  return nullptr;
 }
 
 bool FunctionEmitter::EmitPipelineOutput(std::string var_name,
@@ -1095,11 +1123,14 @@
     index_prefix.push_back(0);
     for (int i = 0; i < static_cast<int>(members.size()); ++i) {
       index_prefix.back() = i;
+      auto* location = parser_impl_.GetMemberLocation(*struct_type, i);
+      auto* saved_location = SetLocation(decos, location);
       if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
                               members[i], forced_member_type, return_members,
                               return_exprs)) {
         return false;
       }
+      SetLocation(decos, saved_location);
     }
     return success();
   }
@@ -1150,14 +1181,7 @@
   return_exprs->push_back(load_source);
 
   // Increment the location attribute, in case more parameters will follow.
-  for (auto*& deco : *decos) {
-    if (auto* loc_deco = deco->As<ast::LocationDecoration>()) {
-      // Replace this location decoration with a new one with one higher index.
-      // The old one doesn't leak because it's kept in the builder's AST node
-      // list.
-      deco = builder_.Location(loc_deco->source(), loc_deco->value() + 1);
-    }
-  }
+  IncrementLocation(decos);
 
   return success();
 }
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 3a287a4..06f0450 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -466,6 +466,23 @@
                           ast::StructMemberList* return_members,
                           ast::ExpressionList* return_exprs);
 
+  /// Updates the decoration list, replacing an existing Location decoration
+  /// with another having one higher location value. Does nothing if no
+  /// location decoration exists.
+  /// Assumes the list contains at most one Location decoration.
+  /// @param decos the decoration list to modify
+  void IncrementLocation(ast::DecorationList* decos);
+
+  /// Updates the decoration list, placing a non-null location decoration into
+  /// the list, replacing an existing one if it exists. Does nothing if the
+  /// replacement is nullptr.
+  /// Assumes the list contains at most one Location decoration.
+  /// @param decos the decoration list to modify
+  /// @param replacement the location decoration to place into the list
+  /// @returns the location decoration that was replaced, if one was replaced.
+  ast::Decoration* SetLocation(ast::DecorationList* decos,
+                               ast::Decoration* replacement);
+
   /// Create an ast::BlockStatement representing the body of the function.
   /// This creates the statement stack, which is non-empty for the lifetime
   /// of the function.
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index cf2e9aa..9b07da7 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -2784,7 +2784,6 @@
     OpReturn
     OpFunctionEnd
 )";
-  std::cout << assembly << std::endl;
   auto p = parser(test::Assemble(assembly));
   ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
   auto fe = p->function_emitter(100);
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index db4383a..c5830c5 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -1098,6 +1098,8 @@
         // apply the ReadOnly access control to the containing struct if all
         // the members are non-writable.
         is_non_writable = true;
+      } else if (decoration[0] == SpvDecorationLocation) {
+        // Location decorations are handled when emitting the entry point.
       } else {
         auto* ast_member_decoration =
             ConvertMemberDecoration(type_id, member_index, decoration);
@@ -1134,7 +1136,7 @@
   // Now make the struct.
   auto sym = builder_.Symbols().Register(name);
   ast::DecorationList ast_struct_decorations;
-  if (is_block_decorated) {
+  if (is_block_decorated && struct_types_for_buffers_.count(type_id)) {
     ast_struct_decorations.emplace_back(
         create<ast::StructBlockDecoration>(Source{}));
   }
@@ -1208,6 +1210,37 @@
   if (!success_) {
     return false;
   }
+
+  // First record the structure types that should have a `block` decoration
+  // in WGSL. In particular, exclude user-defined pipeline IO in a
+  // block-decorated struct.
+  for (const auto& type_or_value : module_->types_values()) {
+    if (type_or_value.opcode() != SpvOpVariable) {
+      continue;
+    }
+    const auto& var = type_or_value;
+    const auto spirv_storage_class =
+        SpvStorageClass(var.GetSingleWordInOperand(0));
+    if ((spirv_storage_class != SpvStorageClassStorageBuffer) &&
+        (spirv_storage_class != SpvStorageClassUniform)) {
+      continue;
+    }
+    const auto* ptr_type = def_use_mgr_->GetDef(var.type_id());
+    if (ptr_type->opcode() != SpvOpTypePointer) {
+      return Fail() << "OpVariable type expected to be a pointer: "
+                    << var.PrettyPrint();
+    }
+    const auto* store_type =
+        def_use_mgr_->GetDef(ptr_type->GetSingleWordInOperand(1));
+    if (store_type->opcode() == SpvOpTypeStruct) {
+      struct_types_for_buffers_.insert(store_type->result_id());
+    } else {
+      Fail() << "WGSL does not support arrays of buffers: "
+             << var.PrettyPrint();
+    }
+  }
+
+  // Now convert each type.
   for (auto& type_or_const : module_->types_values()) {
     const auto* type = type_mgr_->GetType(type_or_const.result_id());
     if (type == nullptr) {
@@ -2630,6 +2663,22 @@
   return namer_.GetMemberName(where->second, member_index);
 }
 
+ast::Decoration* ParserImpl::GetMemberLocation(const Struct& struct_type,
+                                               int member_index) {
+  auto where = struct_id_for_symbol_.find(struct_type.name);
+  if (where == struct_id_for_symbol_.end()) {
+    Fail() << "no structure type registered for symbol";
+    return nullptr;
+  }
+  const auto type_id = where->second;
+  for (auto& deco : GetDecorationsForMember(type_id, member_index)) {
+    if ((deco.size() == 2) && (deco[0] == SpvDecorationLocation)) {
+      return create<ast::LocationDecoration>(Source{}, deco[1]);
+    }
+  }
+  return nullptr;
+}
+
 WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
 
 WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 3321fd9..b041605 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -379,6 +379,13 @@
   /// @returns the field name
   std::string GetMemberName(const Struct& struct_type, int member_index);
 
+  /// Returns the location decoration, if any on a struct member.
+  /// @param struct_type the parser's structure type.
+  /// @param member_index the member index
+  /// @returns a newly created location node, or nullptr
+  ast::Decoration* GetMemberLocation(const Struct& struct_type,
+                                     int member_index);
+
   /// Creates an AST Variable node for a SPIR-V ID, including any attached
   /// decorations, unless it's an ignorable builtin variable.
   /// @param id the SPIR-V result ID
@@ -765,6 +772,10 @@
   // "NonSemanticInfo." import is ignored.
   std::unordered_set<uint32_t> ignored_imports_;
 
+  // The SPIR-V IDs of structure types that are the store type for buffer
+  // variables, either UBO or SSBO.
+  std::unordered_set<uint32_t> struct_types_for_buffers_;
+
   // Bookkeeping for the gl_Position builtin.
   // In Vulkan SPIR-V, it's the 0 member of the gl_PerVertex structure.
   // But in WGSL we make a module-scope variable:
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 78e1ea9..beeef85 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -6528,7 +6528,7 @@
   EXPECT_EQ(got, expected) << got;
 }
 
-TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct) {
+TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct_LocOnVariable) {
   const std::string assembly = R"(
     OpCapability Shader
     OpMemoryModel Logical Simple
@@ -6993,7 +6993,7 @@
   EXPECT_EQ(got, expected) << got;
 }
 
-TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct) {
+TEST_F(SpvModuleScopeVarParserTest, Output_FlattenStruct_LocOnVariable) {
   const std::string assembly = R"(
     OpCapability Shader
     OpMemoryModel Logical Simple
@@ -7092,6 +7092,192 @@
   EXPECT_EQ(got, expected) << got;
 }
 
+TEST_F(SpvModuleScopeVarParserTest, FlattenStruct_LocOnMembers) {
+  // Block-decorated struct may have its members decorated with Location.
+  const std::string assembly = R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+    OpEntryPoint Vertex %main "main" %1 %2 %3
+
+    OpName %strct "Communicators"
+    OpMemberName %strct 0 "alice"
+    OpMemberName %strct 1 "bob"
+
+    OpMemberDecorate %strct 0 Location 9
+    OpMemberDecorate %strct 1 Location 11
+    OpDecorate %strct Block
+    OpDecorate %2 BuiltIn Position
+
+    %void = OpTypeVoid
+    %voidfn = OpTypeFunction %void
+    %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %strct = OpTypeStruct %float %v4float
+
+    %11 = OpTypePointer Input %strct
+    %13 = OpTypePointer Output %strct
+
+    %1 = OpVariable %11 Input
+    %3 = OpVariable %13 Output
+
+    %12 = OpTypePointer Output %v4float
+    %2 = OpVariable %12 Output
+
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    OpReturn
+    OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+
+  ASSERT_TRUE(p->Parse()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+
+  const auto got = p->program().to_str();
+  const std::string expected = R"(Module{
+  Struct Communicators {
+    StructMember{alice: __f32}
+    StructMember{bob: __vec_4__f32}
+  }
+  Struct main_out {
+    StructMember{[[ BuiltinDecoration{position}
+ ]] x_2_1: __vec_4__f32}
+    StructMember{[[ LocationDecoration{9}
+ ]] x_3_1: __f32}
+    StructMember{[[ LocationDecoration{11}
+ ]] x_3_2: __vec_4__f32}
+  }
+  Variable{
+    x_1
+    private
+    undefined
+    __type_name_Communicators
+  }
+  Variable{
+    x_3
+    private
+    undefined
+    __type_name_Communicators
+  }
+  Variable{
+    x_2
+    private
+    undefined
+    __vec_4__f32
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __type_name_main_out
+  StageDecoration{vertex}
+  (
+    VariableConst{
+      Decorations{
+        LocationDecoration{9}
+      }
+      x_1_param
+      none
+      undefined
+      __f32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{11}
+      }
+      x_1_param_1
+      none
+      undefined
+      __vec_4__f32
+    }
+  )
+  {
+    Assignment{
+      MemberAccessor[not set]{
+        Identifier[not set]{x_1}
+        Identifier[not set]{alice}
+      }
+      Identifier[not set]{x_1_param}
+    }
+    Assignment{
+      MemberAccessor[not set]{
+        Identifier[not set]{x_1}
+        Identifier[not set]{bob}
+      }
+      Identifier[not set]{x_1_param_1}
+    }
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+    Return{
+      {
+        TypeConstructor[not set]{
+          __type_name_main_out
+          Identifier[not set]{x_2}
+          MemberAccessor[not set]{
+            Identifier[not set]{x_3}
+            Identifier[not set]{alice}
+          }
+          MemberAccessor[not set]{
+            Identifier[not set]{x_3}
+            Identifier[not set]{bob}
+          }
+        }
+      }
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, FlattenStruct_LocOnStruct) {
+  const std::string assembly = R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+    OpEntryPoint Vertex %main "main" %1 %2 %3
+
+    OpName %strct "Communicators"
+    OpMemberName %strct 0 "alice"
+    OpMemberName %strct 1 "bob"
+
+    OpDecorate %strct Location 9
+    OpDecorate %strct Block
+    OpDecorate %2 BuiltIn Position
+
+    %void = OpTypeVoid
+    %voidfn = OpTypeFunction %void
+    %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %strct = OpTypeStruct %float %v4float
+
+    %11 = OpTypePointer Input %strct
+    %13 = OpTypePointer Output %strct
+
+    %1 = OpVariable %11 Input
+    %3 = OpVariable %13 Output
+
+    %12 = OpTypePointer Output %v4float
+    %2 = OpVariable %12 Output
+
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    OpReturn
+    OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+
+  // The validator rejects this because Location decorations
+  // can only go on OpVariable or members of a structure type.
+  ASSERT_FALSE(p->Parse()) << p->error() << assembly;
+  EXPECT_THAT(p->error(),
+              HasSubstr("Location decoration can only be applied to a variable "
+                        "or member of a structure type"));
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader