[spirv-reader] Add vector shuffle

Use composite-construct from decomposed singly-named operands.

Bug: tint:3
Change-Id: I8536c5f8e87de312460c3d5c6164e090d79bb4a9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23380
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 0635a80..117dcc9 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -15,6 +15,7 @@
 #include "src/reader/spirv/function.h"
 
 #include <algorithm>
+#include <array>
 #include <sstream>
 #include <unordered_map>
 #include <unordered_set>
@@ -49,6 +50,7 @@
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/bool_type.h"
 #include "src/ast/type/u32_type.h"
+#include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/uint_literal.h"
 #include "src/ast/unary_op.h"
@@ -615,6 +617,7 @@
 
   // TODO(dneto): register phis
   // TODO(dneto): register SSA values which need to be hoisted
+  RegisterValuesNeedingNamedDefinition();
 
   if (!EmitFunctionVariables()) {
     return false;
@@ -2394,9 +2397,11 @@
   // Handle combinatorial instructions first.
   auto combinatorial_expr = MaybeEmitCombinatorialValue(inst);
   if (combinatorial_expr.expr != nullptr) {
-    if (def_use_mgr_->NumUses(&inst) == 1) {
-      // If it's used once, then defer emitting the expression until it's
-      // used. Any supporting statements have already been emitted.
+    if ((needs_named_const_def_.count(inst.result_id()) == 0) &&
+        (def_use_mgr_->NumUses(&inst) == 1)) {
+      // If it's used once, and doesn't need a named constant definition,
+      // then defer emitting the expression until it's used. Any supporting
+      // statements have already been emitted.
       singly_used_values_.insert(
           std::make_pair(inst.result_id(), std::move(combinatorial_expr)));
       return success();
@@ -2525,6 +2530,10 @@
     return MakeCompositeExtract(inst);
   }
 
+  if (opcode == SpvOpVectorShuffle) {
+    return MakeVectorShuffle(inst);
+  }
+
   // builtin readonly function
   // glsl.std.450 readonly function
 
@@ -2817,6 +2826,73 @@
       std::make_unique<ast::BoolLiteral>(parser_impl_.BoolType(), false));
 }
 
+TypedExpression FunctionEmitter::MakeVectorShuffle(
+    const spvtools::opt::Instruction& inst) {
+  const auto vec0_id = inst.GetSingleWordInOperand(0);
+  const auto vec1_id = inst.GetSingleWordInOperand(1);
+  const spvtools::opt::Instruction& vec0 = *(def_use_mgr_->GetDef(vec0_id));
+  const spvtools::opt::Instruction& vec1 = *(def_use_mgr_->GetDef(vec1_id));
+  const auto vec0_len =
+      type_mgr_->GetType(vec0.type_id())->AsVector()->element_count();
+  const auto vec1_len =
+      type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
+
+  // Idiomatic vector accessors.
+  const char* swizzles[] = {"x", "y", "z", "w"};
+
+  // Generate an ast::TypeConstructor expression.
+  // Assume the literal indices are valid, and there is a valid number of them.
+  ast::type::VectorType* result_type =
+      parser_impl_.ConvertType(inst.type_id())->AsVector();
+  ast::ExpressionList values;
+  for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
+    const auto index = inst.GetSingleWordInOperand(i);
+    if (index < vec0_len) {
+      assert(index < sizeof(swizzles) / sizeof(swizzles[0]));
+      values.emplace_back(std::make_unique<ast::MemberAccessorExpression>(
+          MakeExpression(vec0_id).expr,
+          std::make_unique<ast::IdentifierExpression>(swizzles[index])));
+    } else if (index < vec0_len + vec1_len) {
+      const auto sub_index = index - vec0_len;
+      assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0]));
+      values.emplace_back(std::make_unique<ast::MemberAccessorExpression>(
+          MakeExpression(vec1_id).expr,
+          std::make_unique<ast::IdentifierExpression>(swizzles[sub_index])));
+    } else if (index == 0xFFFFFFFF) {
+      // By rule, this maps to OpUndef.  Instead, make it zero.
+      values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
+    } else {
+      Fail() << "invalid vectorshuffle ID %" << inst.result_id()
+             << ": index too large: " << index;
+      return {};
+    }
+  }
+  return {result_type, std::make_unique<ast::TypeConstructorExpression>(
+                           result_type, std::move(values))};
+}
+
+void FunctionEmitter::RegisterValuesNeedingNamedDefinition() {
+  for (auto& block : function_) {
+    for (const auto& inst : block) {
+      if (inst.opcode() == SpvOpVectorShuffle) {
+        // We might access the vector operands multiple times. Make sure they
+        // are evaluated only once.
+        for (auto index : std::array<uint32_t, 2>{0, 1}) {
+          auto id = inst.GetSingleWordInOperand(index);
+          if (constant_mgr_->FindDeclaredConstant(id)) {
+            // If it's constant, then avoid making a const definition
+            // in the wrong place; it would be wrong if it didn't
+            // dominate its uses.
+            continue;
+          }
+          // Othewrise, register it.
+          needs_named_const_def_.insert(id);
+        }
+      }
+    }
+  }
+}
+
 }  // namespace spirv
 }  // namespace reader
 }  // namespace tint
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 169a1d1..3d16aa9 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -276,6 +276,13 @@
   /// @returns false if bad nesting has been detected.
   bool FindIfSelectionInternalHeaders();
 
+  /// Record the SPIR-V IDs of non-constants that should get a 'const'
+  /// definition in WGSL. This occurs when a SPIR-V instruction might use the
+  /// dynamically computed value only once, but the WGSL code might reference
+  /// it multiple times. For example, this occurs for the vector operands of
+  /// OpVectorShuffle.  Populates |needs_named_const_def_|
+  void RegisterValuesNeedingNamedDefinition();
+
   /// Emits declarations of function variables.
   /// @returns false if emission failed.
   bool EmitFunctionVariables();
@@ -451,6 +458,11 @@
   /// @returns an AST expression for the instruction, or nullptr.
   TypedExpression MakeCompositeExtract(const spvtools::opt::Instruction& inst);
 
+  /// Creates an expression for OpVectorShuffle
+  /// @param inst an OpVectorShuffle instruction.
+  /// @returns an AST expression for the instruction, or nullptr.
+  TypedExpression MakeVectorShuffle(const spvtools::opt::Instruction& inst);
+
   /// Gets the block info for a block ID, if any exists
   /// @param id the SPIR-V ID of the OpLabel instruction starting the block
   /// @returns the block info for the given ID, if it exists, or nullptr
@@ -583,6 +595,8 @@
   std::unordered_set<uint32_t> identifier_values_;
   // Mapping from SPIR-V ID that is used at most once, to its AST expression.
   std::unordered_map<uint32_t, TypedExpression> singly_used_values_;
+  // Set of SPIR-V IDs which should get a named const definition.
+  std::unordered_set<uint32_t> needs_named_const_def_;
 
   // The IDs of basic blocks, in reverse structured post-order (RSPO).
   // This is the output order for the basic blocks.
diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc
index af6d400..8d5cb70 100644
--- a/src/reader/spirv/function_composite_test.cc
+++ b/src/reader/spirv/function_composite_test.cc
@@ -51,6 +51,8 @@
   %float_70 = OpConstant %float 70
 
   %v2uint = OpTypeVector %uint 2
+  %v3uint = OpTypeVector %uint 3
+  %v4uint = OpTypeVector %uint 4
   %v2int = OpTypeVector %int 2
   %v2float = OpTypeVector %float 2
 
@@ -60,6 +62,8 @@
   %s_v2f_u_i = OpTypeStruct %v2float %uint %int
   %a_u_5 = OpTypeArray %uint %uint_5
 
+  %v2uint_3_4 = OpConstantComposite %v2uint %uint_3 %uint_4
+  %v2uint_4_3 = OpConstantComposite %v2uint %uint_4 %uint_3
   %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
   %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
   %v2float_70_70 = OpConstantComposite %v2float %float_70 %float_70
@@ -464,7 +468,7 @@
   FunctionEmitter fe(p, *spirv_function(100));
   EXPECT_FALSE(fe.EmitBody());
   EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of "
-                             "bounds for structure %23 having 3 elements"));
+                             "bounds for structure %25 having 3 elements"));
 }
 
 TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) {
@@ -584,6 +588,160 @@
 })")) << ToString(fe.ast_body());
 }
 
+using SpvParserTest_VectorShuffle = SpvParserTest;
+
+TEST_F(SpvParserTest_VectorShuffle, FunctionScopeOperands_UseBoth) {
+  // Note that variables are generated for the vector operands.
+  const auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %v2uint %v2uint_3_4
+     %2 = OpIAdd %v2uint %v2uint_4_3 %v2uint_3_4
+     %10 = OpVectorShuffle %v4uint %1 %2 3 2 1 0
+     OpReturn
+     OpFunctionEnd
+)";
+
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __vec_4__u32
+    {
+      TypeConstructor{
+        __vec_4__u32
+        MemberAccessor{
+          Identifier{x_2}
+          Identifier{y}
+        }
+        MemberAccessor{
+          Identifier{x_2}
+          Identifier{x}
+        }
+        MemberAccessor{
+          Identifier{x_1}
+          Identifier{y}
+        }
+        MemberAccessor{
+          Identifier{x_1}
+          Identifier{x}
+        }
+      }
+    }
+  }
+})")) << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_VectorShuffle, ConstantOperands_UseBoth) {
+  const auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %10 = OpVectorShuffle %v4uint %v2uint_3_4 %v2uint_4_3 3 2 1 0
+     OpReturn
+     OpFunctionEnd
+)";
+
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __vec_4__u32
+    {
+      TypeConstructor{
+        __vec_4__u32
+        MemberAccessor{
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{4}
+            ScalarConstructor{3}
+          }
+          Identifier{y}
+        }
+        MemberAccessor{
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{4}
+            ScalarConstructor{3}
+          }
+          Identifier{x}
+        }
+        MemberAccessor{
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{3}
+            ScalarConstructor{4}
+          }
+          Identifier{y}
+        }
+        MemberAccessor{
+          TypeConstructor{
+            __vec_2__u32
+            ScalarConstructor{3}
+            ScalarConstructor{4}
+          }
+          Identifier{x}
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_VectorShuffle, ConstantOperands_AllOnesMapToNull) {
+  const auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %1 = OpCopyObject %v2uint %v2uint_4_3
+     %10 = OpVectorShuffle %v2uint %1 %1 0xFFFFFFFF 1
+     OpReturn
+     OpFunctionEnd
+)";
+
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
+    x_10
+    none
+    __vec_2__u32
+    {
+      TypeConstructor{
+        __vec_2__u32
+        ScalarConstructor{0}
+        MemberAccessor{
+          Identifier{x_1}
+          Identifier{y}
+        }
+      }
+    }
+  })"))
+      << ToString(fe.ast_body());
+}
+
+TEST_F(SpvParserTest_VectorShuffle, IndexTooBig_IsError) {
+  const auto assembly = Preamble() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     %10 = OpVectorShuffle %v4uint %v2uint_3_4 %v2uint_4_3 9 2 1 0
+     OpReturn
+     OpFunctionEnd
+)";
+
+  auto* p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
+  FunctionEmitter fe(p, *spirv_function(100));
+  EXPECT_FALSE(fe.EmitBody()) << p->error();
+  EXPECT_THAT(p->error(),
+              Eq("invalid vectorshuffle ID %10: index too large: 9"));
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader