[spirv-reader] Follow non-canonicalized SPIR-V type for composite extract

Follow the actual SPIR-V type when computing a composite extract
instad of the canonicalized view in the optimizer's type manager.
Do this so we can generate the correct member name for a struct,
rather than using the member name for the other representative
struct type. The optimizer's type canonicalizer is insensitive to
struct member names.

Prompted by tint:213, for which the original case was an
access chain.

Bug: tint:3, tint:213
Change-Id: I8705c7ee655fe47c8b7a3658db524fe18833efdb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27603
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 806748c..3687bf9 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2906,7 +2906,7 @@
               std::move(current_expr.expr),
               std::move(MakeOperand(inst, index).expr));
         }
-        // All vector components are the same type, so follow the first.
+        // All vector components are the same type.
         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
         break;
       case SpvOpTypeMatrix:
@@ -2914,7 +2914,7 @@
         next_expr = std::make_unique<ast::ArrayAccessorExpression>(
             std::move(current_expr.expr),
             std::move(MakeOperand(inst, index).expr));
-        // All matrix components are the same type, so follow the first.
+        // All matrix components are the same type.
         pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
         break;
       case SpvOpTypeArray:
@@ -2988,20 +2988,31 @@
   static const char* swizzles[] = {"x", "y", "z", "w"};
 
   const auto composite = inst.GetSingleWordInOperand(0);
-  const auto composite_type_id = def_use_mgr_->GetDef(composite)->type_id();
-  const auto* current_type = type_mgr_->GetType(composite_type_id);
+  auto current_type_id = def_use_mgr_->GetDef(composite)->type_id();
+  // Build up a nested expression for the access chain by walking down the type
+  // hierarchy, maintaining |current_type_id| as the SPIR-V ID of the type of
+  // the object pointed to after processing the previous indices.
   const auto num_in_operands = inst.NumInOperands();
   for (uint32_t index = 1; index < num_in_operands; ++index) {
     const uint32_t index_val = inst.GetSingleWordInOperand(index);
+
+    const auto* current_type_inst = def_use_mgr_->GetDef(current_type_id);
+    if (!current_type_inst) {
+      Fail() << "composite type %" << current_type_id
+             << " is invalid after following " << (index - 1)
+             << " indices: " << inst.PrettyPrint();
+      return {};
+    }
     std::unique_ptr<ast::Expression> next_expr;
-    switch (current_type->kind()) {
-      case spvtools::opt::analysis::Type::kVector: {
+    switch (current_type_inst->opcode()) {
+      case SpvOpTypeVector: {
         // Try generating a MemberAccessor expression. That result in something
         // like  "foo.z", which is more idiomatic than "foo[2]".
-        if (current_type->AsVector()->element_count() <= index_val) {
+        const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
+        if (num_elems <= index_val) {
           Fail() << "CompositeExtract %" << inst.result_id() << " index value "
-                 << index_val << " is out of bounds for vector of "
-                 << current_type->AsVector()->element_count() << " elements";
+                 << index_val << " is out of bounds for vector of " << num_elems
+                 << " elements";
           return {};
         }
         if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
@@ -3013,15 +3024,17 @@
             std::make_unique<ast::IdentifierExpression>(swizzles[index_val]);
         next_expr = std::make_unique<ast::MemberAccessorExpression>(
             std::move(current_expr.expr), std::move(letter_index));
-        current_type = current_type->AsVector()->element_type();
+        // All vector components are the same type.
+        current_type_id = current_type_inst->GetSingleWordInOperand(0);
         break;
       }
-      case spvtools::opt::analysis::Type::kMatrix:
+      case SpvOpTypeMatrix: {
         // Check bounds
-        if (current_type->AsMatrix()->element_count() <= index_val) {
+        const auto num_elems = current_type_inst->GetSingleWordInOperand(1);
+        if (num_elems <= index_val) {
           Fail() << "CompositeExtract %" << inst.result_id() << " index value "
-                 << index_val << " is out of bounds for matrix of "
-                 << current_type->AsMatrix()->element_count() << " elements";
+                 << index_val << " is out of bounds for matrix of " << num_elems
+                 << " elements";
           return {};
         }
         if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
@@ -3032,45 +3045,44 @@
         // Use array syntax.
         next_expr = std::make_unique<ast::ArrayAccessorExpression>(
             std::move(current_expr.expr), make_index(index_val));
-        current_type = current_type->AsMatrix()->element_type();
+        // All matrix components are the same type.
+        current_type_id = current_type_inst->GetSingleWordInOperand(0);
         break;
-      case spvtools::opt::analysis::Type::kArray:
+      }
+      case SpvOpTypeArray:
         // The array size could be a spec constant, and so it's not always
         // statically checkable.  Instead, rely on a runtime index clamp
         // or runtime check to keep this safe.
         next_expr = std::make_unique<ast::ArrayAccessorExpression>(
             std::move(current_expr.expr), make_index(index_val));
-        current_type = current_type->AsArray()->element_type();
+        current_type_id = current_type_inst->GetSingleWordInOperand(0);
         break;
-      case spvtools::opt::analysis::Type::kRuntimeArray:
+      case SpvOpTypeRuntimeArray:
         Fail() << "can't do OpCompositeExtract on a runtime array";
         return {};
-      case spvtools::opt::analysis::Type::kStruct: {
-        if (current_type->AsStruct()->element_types().size() <= index_val) {
+      case SpvOpTypeStruct: {
+        const auto num_members = current_type_inst->NumInOperands();
+        if (num_members <= index_val) {
           Fail() << "CompositeExtract %" << inst.result_id() << " index value "
                  << index_val << " is out of bounds for structure %"
-                 << type_mgr_->GetId(current_type) << " having "
-                 << current_type->AsStruct()->element_types().size()
-                 << " elements";
+                 << current_type_id << " having " << num_members << " members";
           return {};
         }
-        auto member_access =
-            std::make_unique<ast::IdentifierExpression>(namer_.GetMemberName(
-                type_mgr_->GetId(current_type), uint32_t(index_val)));
+        auto member_access = std::make_unique<ast::IdentifierExpression>(
+            namer_.GetMemberName(current_type_id, uint32_t(index_val)));
 
         next_expr = std::make_unique<ast::MemberAccessorExpression>(
             std::move(current_expr.expr), std::move(member_access));
-        current_type = current_type->AsStruct()->element_types()[index_val];
+        current_type_id = current_type_inst->GetSingleWordInOperand(index_val);
         break;
       }
       default:
-        Fail() << "CompositeExtract with bad type %"
-               << type_mgr_->GetId(current_type) << " " << current_type->str();
+        Fail() << "CompositeExtract with bad type %" << current_type_id << ": "
+               << current_type_inst->PrettyPrint();
         return {};
     }
     current_expr.reset(TypedExpression(
-        parser_impl_.ConvertType(type_mgr_->GetId(current_type)),
-        std::move(next_expr)));
+        parser_impl_.ConvertType(current_type_id), std::move(next_expr)));
   }
   return current_expr;
 }
diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc
index b399701..f9ffd29 100644
--- a/src/reader/spirv/function_composite_test.cc
+++ b/src/reader/spirv/function_composite_test.cc
@@ -451,6 +451,61 @@
       << ToString(fe.ast_body());
 }
 
+TEST_F(SpvParserTest_CompositeExtract, Struct_DifferOnlyInMemberName) {
+  const auto assembly =
+      R"(
+      OpMemberName %s0 0 "algo"
+      OpMemberName %s1 0 "rithm"
+)" + Preamble() +
+      R"(
+     %s0 = OpTypeStruct %uint
+     %s1 = OpTypeStruct %uint
+     %ptr0 = OpTypePointer Function %s0
+     %ptr1 = OpTypePointer Function %s1
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %var0 = OpVariable %ptr0 Function
+     %var1 = OpVariable %ptr1 Function
+     %1 = OpLoad %s0 %var0
+     %2 = OpCompositeExtract %uint %1 0
+     %3 = OpLoad %s1 %var1
+     %4 = OpCompositeExtract %uint %3 0
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_2
+    none
+    __u32
+    {
+      MemberAccessor{
+        Identifier{x_1}
+        Identifier{algo}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_4
+    none
+    __u32
+    {
+      MemberAccessor{
+        Identifier{x_3}
+        Identifier{rithm}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
 TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {
   const auto assembly = Preamble() + R"(
      %ptr = OpTypePointer Function %s_v2f_u_i
@@ -468,7 +523,7 @@
   FunctionEmitter fe(p, *spirv_function(100));
   EXPECT_FALSE(fe.EmitBody());
   EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of "
-                             "bounds for structure %25 having 3 elements"));
+                             "bounds for structure %25 having 3 members"));
 }
 
 TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) {