spirv-reader: flatten struct pipeline inputs

Bug: tint:912
Change-Id: I01002f4996d3205af06edda092a1c18dcf6213e3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56220
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 43466e5..55c0a27 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -974,6 +974,18 @@
       }
     }
     return success();
+  } else if (auto* struct_type = tip_type->As<Struct>()) {
+    const auto& members = struct_type->members;
+    index_prefix.push_back(0);
+    for (int i = 0; i < static_cast<int>(members.size()); ++i) {
+      index_prefix.back() = i;
+      if (!EmitInputParameter(var_name, var_type, decos, index_prefix,
+                              members[i], forced_param_type, params,
+                              statements)) {
+        return false;
+      }
+    }
+    return success();
   }
 
   const bool is_builtin = ast::HasDecoration<ast::BuiltinDecoration>(*decos);
@@ -1005,6 +1017,11 @@
     } else if (auto* array_type = current_type->As<Array>()) {
       store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
       current_type = array_type->type->UnwrapAlias();
+    } else if (auto* struct_type = current_type->As<Struct>()) {
+      store_dest = builder_.MemberAccessor(
+          store_dest,
+          builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
+      current_type = struct_type->members[index];
     }
   }
 
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 5272aa1..db4383a 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -1144,7 +1144,9 @@
     read_only_struct_types_.insert(ast_struct->name());
   }
   AddTypeDecl(sym, ast_struct);
-  return ty_.Struct(sym, std::move(ast_member_types));
+  const auto* result = ty_.Struct(sym, std::move(ast_member_types));
+  struct_id_for_symbol_[sym] = type_id;
+  return result;
 }
 
 void ParserImpl::AddTypeDecl(Symbol name, ast::TypeDecl* decl) {
@@ -2618,6 +2620,16 @@
   return def_use_mgr_ ? def_use_mgr_->GetDef(id) : nullptr;
 }
 
+std::string ParserImpl::GetMemberName(const Struct& struct_type,
+                                      int member_index) {
+  auto where = struct_id_for_symbol_.find(struct_type.name);
+  if (where == struct_id_for_symbol_.end()) {
+    Fail() << "no structure type registered for symbol";
+    return "";
+  }
+  return namer_.GetMemberName(where->second, member_index);
+}
+
 WorkgroupSizeInfo::WorkgroupSizeInfo() = default;
 
 WorkgroupSizeInfo::~WorkgroupSizeInfo() = default;
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 2b818dd..3321fd9 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -373,6 +373,12 @@
   /// @returns the integer constant for its array size, or nullptr.
   const spvtools::opt::analysis::IntConstant* GetArraySize(uint32_t var_id);
 
+  /// Returns the member name for the struct member.
+  /// @param struct_type the parser's structure type.
+  /// @param member_index the member index
+  /// @returns the field name
+  std::string GetMemberName(const Struct& struct_type, int member_index);
+
   /// Creates an AST Variable node for a SPIR-V ID, including any attached
   /// decorations, unless it's an ignorable builtin variable.
   /// @param id the SPIR-V result ID
@@ -809,6 +815,9 @@
   // adding duplicates.
   std::unordered_set<Symbol> declared_types_;
 
+  // Maps a struct type name to the SPIR-V ID for the structure type.
+  std::unordered_map<Symbol, uint32_t> struct_id_for_symbol_;
+
   /// Maps the SPIR-V ID of a module-scope builtin variable that should be
   /// ignored or type-converted, to its builtin kind.
   /// See also BuiltInPositionInfo which is a separate mechanism for a more
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 684ee82..4243eb5 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -6528,6 +6528,126 @@
   EXPECT_EQ(got, expected) << got;
 }
 
+TEST_F(SpvModuleScopeVarParserTest, Input_FlattenStruct) {
+  const std::string assembly = R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+    OpEntryPoint Vertex %main "main" %1 %2
+
+    OpName %strct "Communicators"
+    OpMemberName %strct 0 "alice"
+    OpMemberName %strct 1 "bob"
+
+    OpDecorate %1 Location 9
+    OpDecorate %2 BuiltIn Position
+
+
+    %void = OpTypeVoid
+    %voidfn = OpTypeFunction %void
+    %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %strct = OpTypeStruct %float %v4float
+
+    %11 = OpTypePointer Input %strct
+
+    %1 = OpVariable %11 Input
+
+    %12 = OpTypePointer Output %v4float
+    %2 = OpVariable %12 Output
+
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    OpReturn
+    OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+
+  ASSERT_TRUE(p->Parse()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+
+  const auto got = p->program().to_str();
+  const std::string expected = R"(Module{
+  Struct Communicators {
+    StructMember{alice: __f32}
+    StructMember{bob: __vec_4__f32}
+  }
+  Struct main_out {
+    StructMember{[[ BuiltinDecoration{position}
+ ]] x_2: __vec_4__f32}
+  }
+  Variable{
+    x_1
+    private
+    undefined
+    __type_name_Communicators
+  }
+  Variable{
+    x_2
+    private
+    undefined
+    __vec_4__f32
+  }
+  Function main_1 -> __void
+  ()
+  {
+    Return{}
+  }
+  Function main -> __type_name_main_out
+  StageDecoration{vertex}
+  (
+    VariableConst{
+      Decorations{
+        LocationDecoration{9}
+      }
+      x_1_param
+      none
+      undefined
+      __f32
+    }
+    VariableConst{
+      Decorations{
+        LocationDecoration{10}
+      }
+      x_1_param_1
+      none
+      undefined
+      __vec_4__f32
+    }
+  )
+  {
+    Assignment{
+      MemberAccessor[not set]{
+        Identifier[not set]{x_1}
+        Identifier[not set]{alice}
+      }
+      Identifier[not set]{x_1_param}
+    }
+    Assignment{
+      MemberAccessor[not set]{
+        Identifier[not set]{x_1}
+        Identifier[not set]{bob}
+      }
+      Identifier[not set]{x_1_param_1}
+    }
+    Call[not set]{
+      Identifier[not set]{main_1}
+      (
+      )
+    }
+    Return{
+      {
+        TypeConstructor[not set]{
+          __type_name_main_out
+          Identifier[not set]{x_2}
+        }
+      }
+    }
+  }
+}
+)";
+  EXPECT_EQ(got, expected) << got;
+}
+
 TEST_F(SpvModuleScopeVarParserTest, Input_FlattenNested) {
   const std::string assembly = R"(
     OpCapability Shader