spirv-reader: vertex_index always has u32 store-type

Fixed: tint:483
Change-Id: Ie26941ab751425dfbc0924ea21bee32dc0f92527
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/40623
Auto-Submit: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 41ac1a4..58fa33f 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2035,6 +2035,11 @@
       Fail() << "unhandled use of a pointer to the SampleId builtin, with ID: "
              << id;
       return {};
+    case SkipReason::kVertexIndexBuiltinPointer:
+      Fail()
+          << "unhandled use of a pointer to the VertexIndex builtin, with ID: "
+          << id;
+      return {};
     case SkipReason::kSampleMaskInBuiltinPointer:
       Fail()
           << "unhandled use of a pointer to the SampleMask builtin, with ID: "
@@ -3036,14 +3041,19 @@
       // Memory accesses must be issued in SPIR-V program order.
       // So represent a load by a new const definition.
       const auto ptr_id = inst.GetSingleWordInOperand(0);
-      switch (GetSkipReason(ptr_id)) {
+      const auto skip_reason = GetSkipReason(ptr_id);
+      switch (skip_reason) {
         case SkipReason::kPointSizeBuiltinPointer:
           GetDefInfo(inst.result_id())->skip =
               SkipReason::kPointSizeBuiltinValue;
           return true;
-        case SkipReason::kSampleIdBuiltinPointer: {
+        case SkipReason::kSampleIdBuiltinPointer:
+        case SkipReason::kVertexIndexBuiltinPointer: {
           // The SPIR-V variable is i32, but WGSL requires u32.
-          auto var_id = parser_impl_.IdForSpecialBuiltIn(SpvBuiltInSampleId);
+          auto var_id = parser_impl_.IdForSpecialBuiltIn(
+              (skip_reason == SkipReason::kSampleIdBuiltinPointer)
+                  ? SpvBuiltInSampleId
+                  : SpvBuiltInVertexIndex);
           auto name = namer_.Name(var_id);
           ast::Expression* id_expr = create<ast::IdentifierExpression>(
               Source{}, builder_.Symbols().Register(name));
@@ -3712,6 +3722,9 @@
       case SpvBuiltInSampleId:
         def->skip = SkipReason::kSampleIdBuiltinPointer;
         break;
+      case SpvBuiltInVertexIndex:
+        def->skip = SkipReason::kVertexIndexBuiltinPointer;
+        break;
       case SpvBuiltInSampleMask: {
         // Distinguish between input and output variable.
         const auto storage_class =
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index d58c66f..27ab12a 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -229,6 +229,10 @@
   /// variable.  Don't generate its address.
   kSampleIdBuiltinPointer,
 
+  /// `kVertexIndexBuiltinPointer`: the value is a pointer to the VertexIndex
+  /// builtin variable.  Don't generate its address.
+  kVertexIndexBuiltinPointer,
+
   /// `kSampleMaskInBuiltinPointer`: the value is a pointer to the SampleMaskIn
   /// builtin input variable.  Don't generate its address.
   kSampleMaskInBuiltinPointer,
@@ -344,6 +348,9 @@
     case SkipReason::kSampleIdBuiltinPointer:
       o << " skip:sampleid_pointer";
       break;
+    case SkipReason::kVertexIndexBuiltinPointer:
+      o << " skip:vertexindex_pointer";
+      break;
     case SkipReason::kSampleMaskInBuiltinPointer:
       o << " skip:samplemaskin_pointer";
       break;
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 8b8556f..e4a212e 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -1293,7 +1293,8 @@
           special_builtins_[id] = spv_builtin;
           return nullptr;
         case SpvBuiltInSampleId:
-          // The SPIR-V variable might is likely to be signed (because GLSL
+        case SpvBuiltInVertexIndex:
+          // The SPIR-V variable is likely to be signed (because GLSL
           // requires signed), but WGSL requires unsigned.  Handle specially
           // so we always perform the conversion at load and store.
           if (auto* forced_type = unsigned_type_for_[type]) {
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index 21ed781..4431509 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -192,6 +192,10 @@
 }
 
 TEST_F(SpvModuleScopeVarParserTest, BuiltinVertexIndex) {
+  // This is the simple case for the vertex_index builtin,
+  // where the SPIR-V uses the same store type as in WGSL.
+  // See later for tests where the SPIR-V store type is signed
+  // integer, as in GLSL.
   auto p = parser(test::Assemble(R"(
     OpDecorate %52 BuiltIn VertexIndex
     %uint = OpTypeInt 32 0
@@ -2965,6 +2969,386 @@
       << module_str;
 }
 
+// Returns the start of a shader for testing VertexIndex,
+// parameterized by store type of %int or %uint
+std::string VertexIndexPreamble(std::string store_type) {
+  return R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+    OpEntryPoint Vertex %main "main" %1
+    OpDecorate %1 BuiltIn VertexIndex
+    %void = OpTypeVoid
+    %voidfn = OpTypeFunction %void
+    %float = OpTypeFloat 32
+    %uint = OpTypeInt 32 0
+    %int = OpTypeInt 32 1
+    %ptr_ty = OpTypePointer Input )" +
+         store_type + R"(
+    %1 = OpVariable %ptr_ty Input
+)";
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_I32_Load_Direct) {
+  const std::string assembly = VertexIndexPreamble("%int") + R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %2 = OpLoad %int %1
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct body
+  EXPECT_THAT(module_str, HasSubstr(R"(
+    VariableDeclStatement{
+      VariableConst{
+        x_2
+        none
+        __i32
+        {
+          TypeConstructor[not set]{
+            __i32
+            Identifier[not set]{x_1}
+          }
+        }
+      }
+    })"))
+      << module_str;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_I32_Load_CopyObject) {
+  const std::string assembly = VertexIndexPreamble("%int") + R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %copy_ptr = OpCopyObject %ptr_ty %1
+    %2 = OpLoad %int %copy_ptr
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct body
+  EXPECT_THAT(module_str, HasSubstr(R"(
+    VariableDeclStatement{
+      VariableConst{
+        x_2
+        none
+        __i32
+        {
+          TypeConstructor[not set]{
+            __i32
+            Identifier[not set]{x_1}
+          }
+        }
+      }
+    })"))
+      << module_str;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_I32_Load_AccessChain) {
+  const std::string assembly = VertexIndexPreamble("%int") + R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %copy_ptr = OpAccessChain %ptr_ty %1
+    %2 = OpLoad %int %copy_ptr
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct body
+  EXPECT_THAT(module_str, HasSubstr(R"(
+    VariableDeclStatement{
+      VariableConst{
+        x_2
+        none
+        __i32
+        {
+          TypeConstructor[not set]{
+            __i32
+            Identifier[not set]{x_1}
+          }
+        }
+      }
+    })"))
+      << module_str;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_I32_FunctParam) {
+  const std::string assembly = VertexIndexPreamble("%int") + R"(
+    %helper_ty = OpTypeFunction %int %ptr_ty
+    %helper = OpFunction %int None %helper_ty
+    %param = OpFunctionParameter %ptr_ty
+    %helper_entry = OpLabel
+    %3 = OpLoad %int %param
+    OpReturnValue %3
+    OpFunctionEnd
+
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %result = OpFunctionCall %int %helper %1
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  // TODO(dneto): We can handle this if we make a shadow variable and mutate
+  // the parameter type.
+  ASSERT_FALSE(p->BuildAndParseInternalModule());
+  EXPECT_THAT(
+      p->error(),
+      HasSubstr(
+          "unhandled use of a pointer to the VertexIndex builtin, with ID: 1"));
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_U32_Load_Direct) {
+  const std::string assembly = VertexIndexPreamble("%uint") + R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %2 = OpLoad %uint %1
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct body
+  EXPECT_THAT(module_str, HasSubstr(R"(
+    VariableDeclStatement{
+      VariableConst{
+        x_2
+        none
+        __u32
+        {
+          Identifier[not set]{x_1}
+        }
+      }
+    })"))
+      << module_str;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_U32_Load_CopyObject) {
+  const std::string assembly = VertexIndexPreamble("%uint") + R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %copy_ptr = OpCopyObject %ptr_ty %1
+    %2 = OpLoad %uint %copy_ptr
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct body
+  EXPECT_THAT(module_str, HasSubstr(R"(
+    VariableDeclStatement{
+      VariableConst{
+        x_11
+        none
+        __ptr_in__u32
+        {
+          Identifier[not set]{x_1}
+        }
+      }
+    }
+    VariableDeclStatement{
+      VariableConst{
+        x_2
+        none
+        __u32
+        {
+          Identifier[not set]{x_11}
+        }
+      }
+    })"))
+      << module_str;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_U32_Load_AccessChain) {
+  const std::string assembly = VertexIndexPreamble("%uint") + R"(
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %copy_ptr = OpAccessChain %ptr_ty %1
+    %2 = OpLoad %uint %copy_ptr
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct body
+  EXPECT_THAT(module_str, HasSubstr(R"(
+    VariableDeclStatement{
+      VariableConst{
+        x_2
+        none
+        __u32
+        {
+          Identifier[not set]{x_1}
+        }
+      }
+    })"))
+      << module_str;
+}
+
+TEST_F(SpvModuleScopeVarParserTest, VertexIndex_U32_FunctParam) {
+  const std::string assembly = VertexIndexPreamble("%uint") + R"(
+    %helper_ty = OpTypeFunction %uint %ptr_ty
+    %helper = OpFunction %uint None %helper_ty
+    %param = OpFunctionParameter %ptr_ty
+    %helper_entry = OpLabel
+    %3 = OpLoad %uint %param
+    OpReturnValue %3
+    OpFunctionEnd
+
+    %main = OpFunction %void None %voidfn
+    %entry = OpLabel
+    %result = OpFunctionCall %uint %helper %1
+    OpReturn
+    OpFunctionEnd
+ )";
+  auto p = parser(test::Assemble(assembly));
+  // TODO(dneto): We can handle this if we make a shadow variable and mutate
+  // the parameter type.
+  ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly;
+  EXPECT_TRUE(p->error().empty());
+  const auto module_str = p->program().to_str();
+  // Correct declaration
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Variable{
+    Decorations{
+      BuiltinDecoration{vertex_index}
+    }
+    x_1
+    in
+    __u32
+  })"));
+
+  // Correct bodies
+  EXPECT_THAT(module_str, HasSubstr(R"(
+  Function x_11 -> __u32
+  (
+    VariableConst{
+      x_12
+      none
+      __ptr_in__u32
+    }
+  )
+  {
+    VariableDeclStatement{
+      VariableConst{
+        x_3
+        none
+        __u32
+        {
+          Identifier[not set]{x_12}
+        }
+      }
+    }
+    Return{
+      {
+        Identifier[not set]{x_3}
+      }
+    }
+  }
+  Function main -> __void
+  StageDecoration{vertex}
+  ()
+  {
+    VariableDeclStatement{
+      VariableConst{
+        x_15
+        none
+        __u32
+        {
+          Call[not set]{
+            Identifier[not set]{x_11}
+            (
+              Identifier[not set]{x_1}
+            )
+          }
+        }
+      }
+    }
+    Return{}
+  }
+})")) << module_str;
+}
+
 // TODO(dneto): Test passing pointer to SampleMask as function parameter,
 // both input case and output case.