spirv-reader: refactor swizzle creation

Change-Id: I6a09756026b7cbc436d5f232be9331255615e8c3
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/34040
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 5d1e5c4..15286e8 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -172,6 +172,8 @@
 
 namespace {
 
+constexpr uint32_t kMaxVectorLen = 4;
+
 // Gets the AST unary opcode for the given SPIR-V opcode, if any
 // @param opcode SPIR-V opcode
 // @param ast_unary_op return parameter
@@ -2874,6 +2876,16 @@
   return {ast_type, call};
 }
 
+ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
+  if (i >= kMaxVectorLen) {
+    Fail() << "vector component index is larger than " << kMaxVectorLen - 1
+           << ": " << i;
+    return nullptr;
+  }
+  const char* names[] = {"x", "y", "z", "w"};
+  return ast_module_.create<ast::IdentifierExpression>(names[i & 3]);
+}
+
 TypedExpression FunctionEmitter::MakeAccessChain(
     const spvtools::opt::Instruction& inst) {
   if (inst.NumInOperands() < 1) {
@@ -2888,7 +2900,6 @@
   // for the base, and then bury that inside nested indexing expressions.
   TypedExpression current_expr(MakeOperand(inst, 0));
   const auto constants = constant_mgr_->GetOperandConstants(&inst);
-  static const char* swizzles[] = {"x", "y", "z", "w"};
 
   const auto base_id = inst.GetSingleWordInOperand(0);
   auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id();
@@ -2981,16 +2992,12 @@
                    << num_elems << " elements";
             return {};
           }
-          if (uint64_t(index_const_val) >=
-              sizeof(swizzles) / sizeof(swizzles[0])) {
+          if (uint64_t(index_const_val) >= kMaxVectorLen) {
             Fail() << "internal error: swizzle index " << index_const_val
-                   << " is too big. Max handled index is "
-                   << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+                   << " is too big. Max handled index is " << kMaxVectorLen - 1;
           }
-          auto* letter_index =
-              create<ast::IdentifierExpression>(swizzles[index_const_val]);
-          next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
-                                                            letter_index);
+          next_expr = create<ast::MemberAccessorExpression>(
+              current_expr.expr, Swizzle(uint32_t(index_const_val)));
         } else {
           // Non-constant index. Use array syntax
           next_expr = create<ast::ArrayAccessorExpression>(
@@ -3072,7 +3079,6 @@
     return create<ast::ScalarConstructorExpression>(
         create<ast::UintLiteral>(&u32, literal));
   };
-  static const char* swizzles[] = {"x", "y", "z", "w"};
 
   const auto composite = inst.GetSingleWordInOperand(0);
   auto current_type_id = def_use_mgr_->GetDef(composite)->type_id();
@@ -3102,15 +3108,12 @@
                  << " elements";
           return {};
         }
-        if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+        if (index_val >= kMaxVectorLen) {
           Fail() << "internal error: swizzle index " << index_val
-                 << " is too big. Max handled index is "
-                 << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+                 << " is too big. Max handled index is " << kMaxVectorLen - 1;
         }
-        auto* letter_index =
-            create<ast::IdentifierExpression>(swizzles[index_val]);
         next_expr = create<ast::MemberAccessorExpression>(current_expr.expr,
-                                                          letter_index);
+                                                          Swizzle(index_val));
         // All vector components are the same type.
         current_type_id = current_type_inst->GetSingleWordInOperand(0);
         break;
@@ -3124,10 +3127,9 @@
                  << " elements";
           return {};
         }
-        if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) {
+        if (index_val >= kMaxVectorLen) {
           Fail() << "internal error: swizzle index " << index_val
-                 << " is too big. Max handled index is "
-                 << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1);
+                 << " is too big. Max handled index is " << kMaxVectorLen - 1;
         }
         // Use array syntax.
         next_expr = create<ast::ArrayAccessorExpression>(current_expr.expr,
@@ -3197,7 +3199,6 @@
       type_mgr_->GetType(vec1.type_id())->AsVector()->element_count();
 
   // Idiomatic vector accessors.
-  const char* swizzles[] = {"x", "y", "z", "w"};
 
   // Generate an ast::TypeConstructor expression.
   // Assume the literal indices are valid, and there is a valid number of them.
@@ -3207,16 +3208,13 @@
   for (uint32_t i = 2; i < inst.NumInOperands(); ++i) {
     const auto index = inst.GetSingleWordInOperand(i);
     if (index < vec0_len) {
-      assert(index < sizeof(swizzles) / sizeof(swizzles[0]));
       values.emplace_back(create<ast::MemberAccessorExpression>(
-          MakeExpression(vec0_id).expr,
-          create<ast::IdentifierExpression>(swizzles[index])));
+          MakeExpression(vec0_id).expr, Swizzle(index)));
     } else if (index < vec0_len + vec1_len) {
       const auto sub_index = index - vec0_len;
-      assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0]));
+      assert(sub_index < kMaxVectorLen);
       values.emplace_back(create<ast::MemberAccessorExpression>(
-          MakeExpression(vec1_id).expr,
-          create<ast::IdentifierExpression>(swizzles[sub_index])));
+          MakeExpression(vec1_id).expr, Swizzle(sub_index)));
     } else if (index == 0xFFFFFFFF) {
       // By rule, this maps to OpUndef.  Instead, make it zero.
       values.emplace_back(parser_impl_.MakeNullValue(result_type->type()));
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index cfebe80..f5c3872 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -672,6 +672,13 @@
   /// @returns the associated loop construct, or nullptr
   const Construct* SiblingLoopConstruct(const Construct* c) const;
 
+  /// Returns an identifier expression for the swizzle name of the given
+  /// index into a vector.  Emits an error and returns nullptr if the
+  /// index is out of range, i.e. 4 or higher.
+  /// @param i index of the subcomponent
+  /// @returns the identifier expression for the @p i'th component
+  ast::IdentifierExpression* Swizzle(uint32_t i);
+
  private:
   /// @returns the store type for the OpVariable instruction, or
   /// null on failure.
diff --git a/src/reader/spirv/function_misc_test.cc b/src/reader/spirv/function_misc_test.cc
index f5d1bd9..e215fdd 100644
--- a/src/reader/spirv/function_misc_test.cc
+++ b/src/reader/spirv/function_misc_test.cc
@@ -16,6 +16,7 @@
 #include <vector>
 
 #include "gmock/gmock.h"
+#include "src/ast/identifier_expression.h"
 #include "src/reader/spirv/function.h"
 #include "src/reader/spirv/parser_impl.h"
 #include "src/reader/spirv/parser_impl_test_helper.h"
@@ -295,6 +296,53 @@
 )")) << ToString(fe.ast_body());
 }
 
+// Test swizzle generation.
+
+struct SwizzleCase {
+  uint32_t index;
+  std::string expected_expr;
+  std::string expected_error;
+};
+using SpvParserSwizzleTest =
+    SpvParserTestBase<::testing::TestWithParam<SwizzleCase>>;
+
+TEST_P(SpvParserSwizzleTest, Sample) {
+  // We need a function so we can get a FunctionEmitter.
+  const auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %entry = OpLabel
+     OpReturn
+     OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+
+  auto* result = fe.Swizzle(GetParam().index);
+  if (GetParam().expected_error.empty()) {
+    EXPECT_TRUE(fe.success());
+    ASSERT_NE(result, nullptr);
+    std::ostringstream ss;
+    result->to_str(ss, 0);
+    EXPECT_THAT(ss.str(), Eq(GetParam().expected_expr));
+  } else {
+    EXPECT_EQ(result, nullptr);
+    EXPECT_FALSE(fe.success());
+    EXPECT_THAT(p->error(), Eq(GetParam().expected_error));
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    ValidIndex,
+    SpvParserSwizzleTest,
+    ::testing::ValuesIn(std::vector<SwizzleCase>{
+        {0, "Identifier[not set]{x}\n", ""},
+        {1, "Identifier[not set]{y}\n", ""},
+        {2, "Identifier[not set]{z}\n", ""},
+        {3, "Identifier[not set]{w}\n", ""},
+        {4, "", "vector component index is larger than 3: 4"},
+        {99999, "", "vector component index is larger than 3: 99999"}}));
+
 // TODO(dneto): OpSizeof : requires Kernel (OpenCL)
 
 }  // namespace