[spirv-reader] Follow non-canonicalized SPIR-V type for access chains

Follow the actual SPIR-V type when computing an access chain expression,
instead of the canonicalized view in the optimizer's type manager.
Do this so we can generate the correct member name for a struct,
rather than using the member name for the other representative
struct type. The optimizer's type canonicalizer is insensitive to
struct member names.

Fixes tint:213

Bug: tint:3, tint:213
Change-Id: I88ec42a4cb049b011a59d5522e4cb39bc181a4fb
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27602
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 2578e06..806748c 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2851,30 +2851,43 @@
     }
   }
 
-  const auto* ptr_type = type_mgr_->GetType(ptr_ty_id);
-  if (!ptr_type || !ptr_type->AsPointer()) {
+  const auto* ptr_type_inst = def_use_mgr_->GetDef(ptr_ty_id);
+  if (!ptr_type_inst || (ptr_type_inst->opcode() != SpvOpTypePointer)) {
     Fail() << "Access chain %" << inst.result_id()
            << " base pointer is not of pointer type";
     return {};
   }
-  SpvStorageClass storage_class = ptr_type->AsPointer()->storage_class();
-  const auto* pointee_type = ptr_type->AsPointer()->pointee_type();
+  SpvStorageClass storage_class =
+      static_cast<SpvStorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
+  uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+
+  // Build up a nested expression for the access chain by walking down the type
+  // hierarchy, maintaining |pointee_type_id| as the SPIR-V ID of the type of
+  // the object pointed to after processing the previous indices.
   for (uint32_t index = first_index; index < num_in_operands; ++index) {
     const auto* index_const =
         constants[index] ? constants[index]->AsIntConstant() : nullptr;
     const int64_t index_const_val =
         index_const ? index_const->GetSignExtendedValue() : 0;
     std::unique_ptr<ast::Expression> next_expr;
-    switch (pointee_type->kind()) {
-      case spvtools::opt::analysis::Type::kVector:
+
+    const auto* pointee_type_inst = def_use_mgr_->GetDef(pointee_type_id);
+    if (!pointee_type_inst) {
+      Fail() << "pointee type %" << pointee_type_id
+             << " is invalid after following " << (index - first_index)
+             << " indices: " << inst.PrettyPrint();
+      return {};
+    }
+    switch (pointee_type_inst->opcode()) {
+      case SpvOpTypeVector:
         if (index_const) {
-          // Try generating a MemberAccessor expression.
-          if (index_const_val < 0 ||
-              pointee_type->AsVector()->element_count() <= index_const_val) {
+          // Try generating a MemberAccessor expression
+          const auto num_elems = pointee_type_inst->GetSingleWordInOperand(1);
+          if (index_const_val < 0 || num_elems <= index_const_val) {
             Fail() << "Access chain %" << inst.result_id() << " index %"
                    << inst.GetSingleWordInOperand(index) << " value "
                    << index_const_val << " is out of bounds for vector of "
-                   << pointee_type->AsVector()->element_count() << " elements";
+                   << num_elems << " elements";
             return {};
           }
           if (uint64_t(index_const_val) >=
@@ -2893,61 +2906,58 @@
               std::move(current_expr.expr),
               std::move(MakeOperand(inst, index).expr));
         }
-        pointee_type = pointee_type->AsVector()->element_type();
+        // All vector components are the same type, so follow the first.
+        pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
         break;
-      case spvtools::opt::analysis::Type::kMatrix:
+      case SpvOpTypeMatrix:
         // Use array syntax.
         next_expr = std::make_unique<ast::ArrayAccessorExpression>(
             std::move(current_expr.expr),
             std::move(MakeOperand(inst, index).expr));
-        pointee_type = pointee_type->AsMatrix()->element_type();
+        // All matrix components are the same type, so follow the first.
+        pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
         break;
-      case spvtools::opt::analysis::Type::kArray:
+      case SpvOpTypeArray:
         next_expr = std::make_unique<ast::ArrayAccessorExpression>(
             std::move(current_expr.expr),
             std::move(MakeOperand(inst, index).expr));
-        pointee_type = pointee_type->AsArray()->element_type();
+        pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
         break;
-      case spvtools::opt::analysis::Type::kRuntimeArray:
+      case SpvOpTypeRuntimeArray:
         next_expr = std::make_unique<ast::ArrayAccessorExpression>(
             std::move(current_expr.expr),
             std::move(MakeOperand(inst, index).expr));
-        pointee_type = pointee_type->AsRuntimeArray()->element_type();
+        pointee_type_id = pointee_type_inst->GetSingleWordInOperand(0);
         break;
-      case spvtools::opt::analysis::Type::kStruct: {
+      case SpvOpTypeStruct: {
         if (!index_const) {
           Fail() << "Access chain %" << inst.result_id() << " index %"
                  << inst.GetSingleWordInOperand(index)
                  << " is a non-constant index into a structure %"
-                 << type_mgr_->GetId(pointee_type);
+                 << pointee_type_id;
           return {};
         }
-        if ((index_const_val < 0) ||
-            pointee_type->AsStruct()->element_types().size() <=
-                uint64_t(index_const_val)) {
+        const auto num_members = pointee_type_inst->NumInOperands();
+        if ((index_const_val < 0) || num_members <= uint64_t(index_const_val)) {
           Fail() << "Access chain %" << inst.result_id() << " index value "
                  << index_const_val << " is out of bounds for structure %"
-                 << type_mgr_->GetId(pointee_type) << " having "
-                 << pointee_type->AsStruct()->element_types().size()
-                 << " elements";
+                 << pointee_type_id << " having " << num_members << " members";
           return {};
         }
-        auto member_access =
-            std::make_unique<ast::IdentifierExpression>(namer_.GetMemberName(
-                type_mgr_->GetId(pointee_type), uint32_t(index_const_val)));
+        auto member_access = std::make_unique<ast::IdentifierExpression>(
+            namer_.GetMemberName(pointee_type_id, uint32_t(index_const_val)));
 
         next_expr = std::make_unique<ast::MemberAccessorExpression>(
             std::move(current_expr.expr), std::move(member_access));
-        pointee_type =
-            pointee_type->AsStruct()->element_types()[index_const_val];
+        pointee_type_id = pointee_type_inst->GetSingleWordInOperand(
+            static_cast<uint32_t>(index_const_val));
         break;
       }
       default:
-        Fail() << "Access chain with unknown pointee type %"
-               << type_mgr_->GetId(pointee_type) << " " << pointee_type->str();
+        Fail() << "Access chain with unknown or invalid pointee type %"
+               << pointee_type_id << ": " << pointee_type_inst->PrettyPrint();
         return {};
     }
-    const auto pointee_type_id = type_mgr_->GetId(pointee_type);
     const auto pointer_type_id =
         type_mgr_->FindPointerToType(pointee_type_id, storage_class);
     auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id);
diff --git a/src/reader/spirv/function_memory_test.cc b/src/reader/spirv/function_memory_test.cc
index 149b304..5d4a436 100644
--- a/src/reader/spirv/function_memory_test.cc
+++ b/src/reader/spirv/function_memory_test.cc
@@ -347,7 +347,7 @@
     Identifier{z}
   }
   ScalarConstructor{42}
-})"));
+})")) << ToString(fe.ast_body());
 }
 
 TEST_F(SpvParserTest, EmitStatement_AccessChain_VectorConstOutOfBounds) {
@@ -535,6 +535,61 @@
 })"));
 }
 
+TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct_DifferOnlyMemberName) {
+  // The spirv-opt internal representation will map both structs to the
+  // same canonicalized type, because it doesn't care about member names.
+  // But we care about member names when producing a member-access expression.
+  // crbug.com/tint/213
+  const std::string assembly = R"(
+     OpName %1 "myvar"
+     OpName %10 "myvar2"
+     OpMemberName %strct 1 "age"
+     OpMemberName %strct2 1 "ancientness"
+     %void = OpTypeVoid
+     %voidfn = OpTypeFunction %void
+     %float = OpTypeFloat 32
+     %float_42 = OpConstant %float 42
+     %float_420 = OpConstant %float 420
+     %strct = OpTypeStruct %float %float
+     %strct2 = OpTypeStruct %float %float
+     %elem_ty = OpTypePointer Workgroup %float
+     %var_ty = OpTypePointer Workgroup %strct
+     %var2_ty = OpTypePointer Workgroup %strct2
+     %uint = OpTypeInt 32 0
+     %uint_1 = OpConstant %uint 1
+
+     %1 = OpVariable %var_ty Workgroup
+     %10 = OpVariable %var2_ty Workgroup
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %2 = OpAccessChain %elem_ty %1 %uint_1
+     OpStore %2 %float_42
+     %20 = OpAccessChain %elem_ty %10 %uint_1
+     OpStore %20 %float_420
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
+      << assembly << p->error();
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody());
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Assignment{
+  MemberAccessor{
+    Identifier{myvar}
+    Identifier{age}
+  }
+  ScalarConstructor{42.000000}
+}
+Assignment{
+  MemberAccessor{
+    Identifier{myvar2}
+    Identifier{ancientness}
+  }
+  ScalarConstructor{420.000000}
+})")) << ToString(fe.ast_body());
+}
+
 TEST_F(SpvParserTest, EmitStatement_AccessChain_StructNonConstIndex) {
   const std::string assembly = R"(
      OpName %1 "myvar"
@@ -597,7 +652,7 @@
   FunctionEmitter fe(p, *spirv_function(100));
   EXPECT_FALSE(fe.EmitBody());
   EXPECT_THAT(p->error(), Eq("Access chain %2 index value 99 is out of bounds "
-                             "for structure %55 having 2 elements"));
+                             "for structure %55 having 2 members"));
 }
 
 TEST_F(SpvParserTest, EmitStatement_AccessChain_Struct_RuntimeArray) {
@@ -705,7 +760,8 @@
   FunctionEmitter fe(p, *spirv_function(100));
   EXPECT_FALSE(fe.EmitBody());
   EXPECT_THAT(p->error(),
-              HasSubstr("Access chain with unknown pointee type %60 void"));
+              HasSubstr("Access chain with unknown or invalid pointee type "
+                        "%60: %60 = OpTypePointer Workgroup %55"));
 }
 
 std::string OldStorageBufferPreamble() {