[spirv-reader] Add OpCompositeExtract

Bug: tint:3
Change-Id: I9d8c1cf2545e28ef0ddf89e55ce45ec19c50022a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23161
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 05a9c9e..75276f2 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -45,6 +45,7 @@
 #include "src/ast/sint_literal.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/switch_statement.h"
+#include "src/ast/type/u32_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/uint_literal.h"
 #include "src/ast/unary_op.h"
@@ -2400,6 +2401,10 @@
                           ast_type, std::move(operands))};
   }
 
+  if (opcode == SpvOpCompositeExtract) {
+    return MakeCompositeExtract(inst);
+  }
+
   // builtin readonly function
   // glsl.std.450 readonly function
 
@@ -2462,7 +2467,7 @@
 
   // A SPIR-V access chain is a single instruction with multiple indices
   // walking down into composites.  The Tint AST represents this as
-  // ever-deeper nested indexing expresions. Start off with an expression
+  // ever-deeper nested indexing expressions. Start off with an expression
   // for the base, and then bury that inside nested indexing expressions.
   TypedExpression current_expr(MakeOperand(inst, 0));
 
@@ -2574,6 +2579,113 @@
   return current_expr;
 }
 
+TypedExpression FunctionEmitter::MakeCompositeExtract(
+    const spvtools::opt::Instruction& inst) {
+  // This is structurally similar to creating an access chain, but
+  // the SPIR-V instruction has literal indices instead of IDs for indices.
+
+  // A SPIR-V composite extract is a single instruction with multiple
+  // literal indices walking down into composites. The Tint AST represents
+  // this as ever-deeper nested indexing expressions. Start off with an
+  // expression for the composite, and then bury that inside nested indexing
+  // expressions.
+  TypedExpression current_expr(MakeOperand(inst, 0));
+
+  auto make_index = [](uint32_t literal) {
+    ast::type::U32Type u32;
+    return std::make_unique<ast::ScalarConstructorExpression>(
+        std::make_unique<ast::UintLiteral>(&u32, literal));
+  };
+  static const char* swizzles[] = {"x", "y", "z", "w"};
+
+  const auto composite = inst.GetSingleWordInOperand(0);
+  const auto composite_type_id = def_use_mgr_->GetDef(composite)->type_id();
+  const auto* current_type = type_mgr_->GetType(composite_type_id);
+  const auto num_in_operands = inst.NumInOperands();
+  for (uint32_t index = 1; index < num_in_operands; ++index) {
+    const uint32_t index_val = inst.GetSingleWordInOperand(index);
+    std::unique_ptr<ast::Expression> next_expr;
+    switch (current_type->kind()) {
+      case spvtools::opt::analysis::Type::kVector: {
+        // Try generating a MemberAccessor expression. That result in something
+        // like  "foo.z", which is more idiomatic than "foo[2]".
+        if (current_type->AsVector()->element_count() <= index_val) {
+          Fail() << "CompositeExtract %" << inst.result_id() << " index value "
+                 << index_val << " is out of bounds for vector of "
+                 << current_type->AsVector()->element_count() << " elements";
+          return {};
+        }
+        if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+          Fail() << "internal error: swizzle index " << index_val
+                 << " is too big. Max handled index is "
+                 << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+        }
+        auto letter_index =
+            std::make_unique<ast::IdentifierExpression>(swizzles[index_val]);
+        next_expr = std::make_unique<ast::MemberAccessorExpression>(
+            std::move(current_expr.expr), std::move(letter_index));
+        current_type = current_type->AsVector()->element_type();
+        break;
+      }
+      case spvtools::opt::analysis::Type::kMatrix:
+        // Check bounds
+        if (current_type->AsMatrix()->element_count() <= index_val) {
+          Fail() << "CompositeExtract %" << inst.result_id() << " index value "
+                 << index_val << " is out of bounds for matrix of "
+                 << current_type->AsMatrix()->element_count() << " elements";
+          return {};
+        }
+        if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+          Fail() << "internal error: swizzle index " << index_val
+                 << " is too big. Max handled index is "
+                 << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+        }
+        // Use array syntax.
+        next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+            std::move(current_expr.expr), make_index(index_val));
+        current_type = current_type->AsMatrix()->element_type();
+        break;
+      case spvtools::opt::analysis::Type::kArray:
+        // The array size could be a spec constant, and so it's not always
+        // statically checkable.  Instead, rely on a runtime index clamp
+        // or runtime check to keep this safe.
+        next_expr = std::make_unique<ast::ArrayAccessorExpression>(
+            std::move(current_expr.expr), make_index(index_val));
+        current_type = current_type->AsArray()->element_type();
+        break;
+      case spvtools::opt::analysis::Type::kRuntimeArray:
+        Fail() << "can't do OpCompositeExtract on a runtime array";
+        return {};
+      case spvtools::opt::analysis::Type::kStruct: {
+        if (current_type->AsStruct()->element_types().size() <= index_val) {
+          Fail() << "CompositeExtract %" << inst.result_id() << " index value "
+                 << index_val << " is out of bounds for structure %"
+                 << type_mgr_->GetId(current_type) << " having "
+                 << current_type->AsStruct()->element_types().size()
+                 << " elements";
+          return {};
+        }
+        auto member_access =
+            std::make_unique<ast::IdentifierExpression>(namer_.GetMemberName(
+                type_mgr_->GetId(current_type), uint32_t(index_val)));
+
+        next_expr = std::make_unique<ast::MemberAccessorExpression>(
+            std::move(current_expr.expr), std::move(member_access));
+        current_type = current_type->AsStruct()->element_types()[index_val];
+        break;
+      }
+      default:
+        Fail() << "CompositeExtract with bad type %"
+               << type_mgr_->GetId(current_type) << " " << current_type->str();
+        return {};
+    }
+    current_expr.reset(TypedExpression(
+        parser_impl_.ConvertType(type_mgr_->GetId(current_type)),
+        std::move(next_expr)));
+  }
+  return current_expr;
+}
+
 }  // namespace spirv
 }  // namespace reader
 }  // namespace tint
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 0e92907..686c924 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -433,6 +433,11 @@
   /// @returns an AST expression for the instruction, or nullptr.
   TypedExpression EmitGlslStd450ExtInst(const spvtools::opt::Instruction& inst);
 
+  /// Creates an expression for OpCompositeExtract
+  /// @param inst an OpCompositeExtract instruction.
+  /// @returns an AST expression for the instruction, or nullptr.
+  TypedExpression MakeCompositeExtract(const spvtools::opt::Instruction& inst);
+
   /// Gets the block info for a block ID, if any exists
   /// @param id the SPIR-V ID of the OpLabel instruction starting the block
   /// @returns the block info for the given ID, if it exists, or nullptr
diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc
index 6030a90..944db61 100644
--- a/src/reader/spirv/function_composite_test.cc
+++ b/src/reader/spirv/function_composite_test.cc
@@ -27,6 +27,7 @@
 namespace spirv {
 namespace {
 
+using ::testing::Eq;
 using ::testing::HasSubstr;
 
 std::string Preamble() {
@@ -54,6 +55,7 @@
   %v2float = OpTypeVector %float 2
 
   %m3v2float = OpTypeMatrix %v2float 3
+  %m3v2float_0 = OpConstantNull %m3v2float
 
   %s_v2f_u_i = OpTypeStruct %v2float %uint %int
   %a_u_5 = OpTypeArray %uint %uint_5
@@ -229,6 +231,283 @@
       << ToString(fe.ast_body());
 }
 
+using SpvParserTest_CompositeExtract = SpvParserTest;
+
+TEST_F(SpvParserTest_CompositeExtract, Vector) {
+  const auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCompositeExtract %float %v2float_50_60 1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_1
+    none
+    __f32
+    {
+      MemberAccessor{
+        TypeConstructor{
+          __vec_2__f32
+          ScalarConstructor{50.000000}
+          ScalarConstructor{60.000000}
+        }
+        Identifier{y}
+      }
+    }
+  })")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Vector_IndexTooBigError) {
+  const auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCompositeExtract %float %v2float_50_60 900
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_FALSE(fe.EmitBody());
+  EXPECT_THAT(p->error(), Eq("CompositeExtract %1 index value 900 is out of "
+                             "bounds for vector of 2 elements"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Matrix) {
+  const auto assembly = Preamble() + R"(
+     %ptr = OpTypePointer Function %m3v2float
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %m3v2float %var
+     %2 = OpCompositeExtract %v2float %1 2
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_2
+    none
+    __vec_2__f32
+    {
+      ArrayAccessor{
+        Identifier{x_1}
+        ScalarConstructor{2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) {
+  const auto assembly = Preamble() + R"(
+     %ptr = OpTypePointer Function %m3v2float
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %m3v2float %var
+     %2 = OpCompositeExtract %v2float %1 3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_FALSE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 3 is out of "
+                             "bounds for matrix of 3 elements"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) {
+  const auto assembly = Preamble() + R"(
+     %ptr = OpTypePointer Function %m3v2float
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %m3v2float %var
+     %2 = OpCompositeExtract %float %1 2 1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_2
+    none
+    __f32
+    {
+      MemberAccessor{
+        ArrayAccessor{
+          Identifier{x_1}
+          ScalarConstructor{2}
+        }
+        Identifier{y}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Array) {
+  const auto assembly = Preamble() + R"(
+     %ptr = OpTypePointer Function %a_u_5
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %a_u_5 %var
+     %2 = OpCompositeExtract %uint %1 3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_2
+    none
+    __u32
+    {
+      ArrayAccessor{
+        Identifier{x_1}
+        ScalarConstructor{3}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) {
+  const auto assembly = Preamble() + R"(
+     %rtarr = OpTypeRuntimeArray %uint
+     %ptr = OpTypePointer Function %rtarr
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %rtarr %var
+     %2 = OpCompositeExtract %uint %1 3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_FALSE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(p->error(), Eq("can't do OpCompositeExtract on a runtime array"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Struct) {
+  const auto assembly = Preamble() + R"(
+     %ptr = OpTypePointer Function %s_v2f_u_i
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %s_v2f_u_i %var
+     %2 = OpCompositeExtract %int %1 2
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_2
+    none
+    __i32
+    {
+      MemberAccessor{
+        Identifier{x_1}
+        Identifier{field2}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) {
+  const auto assembly = Preamble() + R"(
+     %ptr = OpTypePointer Function %s_v2f_u_i
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %s_v2f_u_i %var
+     %2 = OpCompositeExtract %int %1 40
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_FALSE(fe.EmitBody());
+  EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of "
+                             "bounds for structure %23 having 3 elements"));
+}
+
+TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) {
+  const auto assembly = Preamble() + R"(
+     %a_mat = OpTypeArray %m3v2float %uint_3
+     %s = OpTypeStruct %uint %a_mat
+     %ptr = OpTypePointer Function %s
+     %var = OpVariable %ptr Function
+
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpLoad %s %var
+     %2 = OpCompositeExtract %float %1 1 2 0 1
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(
+  Variable{
+    x_2
+    none
+    __f32
+    {
+      MemberAccessor{
+        ArrayAccessor{
+          ArrayAccessor{
+            MemberAccessor{
+              Identifier{x_1}
+              Identifier{field1}
+            }
+            ScalarConstructor{2}
+          }
+          ScalarConstructor{0}
+        }
+        Identifier{y}
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader