[spirv-writer] Add support for sample index/mask builtins

Use a sanitizing transform to convert scalar `sample_mask_{in,out}`
variables to single element arrays.

Add the `SampleRateShading` capability if the `sample_index` builtin
is used.

Bug: tint:372
Change-Id: Id7280e3ddb21e0a098d83587d123c97e3c34fa1b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41662
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 03c2274..0d68e41 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -355,8 +355,6 @@
     "src/clone_context.h",
     "src/demangler.cc",
     "src/demangler.h",
-    "src/intrinsic_table.cc",
-    "src/intrinsic_table.h",
     "src/diagnostic/diagnostic.cc",
     "src/diagnostic/diagnostic.h",
     "src/diagnostic/formatter.cc",
@@ -369,6 +367,8 @@
     "src/inspector/inspector.h",
     "src/inspector/scalar.cc",
     "src/inspector/scalar.h",
+    "src/intrinsic_table.cc",
+    "src/intrinsic_table.h",
     "src/namer.cc",
     "src/namer.h",
     "src/program.cc",
@@ -953,6 +953,7 @@
 
 source_set("tint_unittests_spv_writer_src") {
   sources = [
+    "src/transform/spirv_test.cc",
     "src/writer/spirv/binary_writer_test.cc",
     "src/writer/spirv/builder_accessor_expression_test.cc",
     "src/writer/spirv/builder_assign_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index a42d8e3..d1648b9 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -616,6 +616,7 @@
 
   if(${TINT_BUILD_SPV_WRITER})
     list(APPEND TINT_TEST_SRCS
+      transform/spirv_test.cc
       writer/spirv/binary_writer_test.cc
       writer/spirv/builder_accessor_expression_test.cc
       writer/spirv/builder_assign_test.cc
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 07797d9..0f3d7fb 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -14,9 +14,13 @@
 
 #include "src/transform/spirv.h"
 
+#include <string>
 #include <utility>
 
+#include "src/ast/variable.h"
 #include "src/program_builder.h"
+#include "src/semantic/expression.h"
+#include "src/semantic/variable.h"
 
 namespace tint {
 namespace transform {
@@ -26,9 +30,60 @@
 
 Transform::Output Spirv::Run(const Program* in) {
   ProgramBuilder out;
-  CloneContext(&out, in).Clone();
+  CloneContext ctx(&out, in);
+  HandleSampleMaskBuiltins(ctx);
+  ctx.Clone();
   return Output{Program(std::move(out))};
 }
 
+void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
+  // Find global variables decorated with [[builtin(sample_mask_{in,out})]] and
+  // change their type from `u32` to `array<u32, 1>`, as required by Vulkan.
+  //
+  // Before:
+  // ```
+  // [[builtin(sample_mask_out)]] var<out> mask_out : u32;
+  // fn main() -> void {
+  //   mask_out = 1u;
+  // }
+  // ```
+  // After:
+  // ```
+  // [[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+  // fn main() -> void {
+  //   mask_out[0] = 1u;
+  // }
+  // ```
+
+  for (auto* var : ctx.src->AST().GlobalVariables()) {
+    for (auto* deco : var->decorations()) {
+      if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
+        if (builtin->value() != ast::Builtin::kSampleMaskIn &&
+            builtin->value() != ast::Builtin::kSampleMaskOut) {
+          continue;
+        }
+
+        // Use the same name as the old variable.
+        std::string var_name = ctx.src->Symbols().NameFor(var->symbol());
+        // Use `array<u32, 1>` for the new variable.
+        auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
+        // Create the new variable.
+        auto* var_arr =
+            ctx.dst->Var(var->source(), var_name, var->declared_storage_class(),
+                         type, nullptr, ctx.Clone(var->decorations()));
+        // Replace the variable with the arrayed version.
+        ctx.Replace(var, var_arr);
+
+        // Replace all uses of the old variable with `var_arr[0]`.
+        for (auto* user : ctx.src->Sem().Get(var)->Users()) {
+          auto* new_ident = ctx.dst->IndexAccessor(
+              ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0));
+          ctx.Replace<ast::Expression>(user->Declaration(), new_ident);
+        }
+      }
+    }
+  }
+}
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index 3644445..5f79d81 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -18,6 +18,10 @@
 #include "src/transform/transform.h"
 
 namespace tint {
+
+// Forward declarations
+class CloneContext;
+
 namespace transform {
 
 /// Spirv is a transform used to sanitize a Program for use with the Spirv
@@ -36,6 +40,10 @@
   /// @param program the source program to transform
   /// @returns the transformation result
   Output Run(const Program* program) override;
+
+ private:
+  /// Change type of sample mask builtin variables to single element arrays.
+  void HandleSampleMaskBuiltins(CloneContext& ctx) const;
 };
 
 }  // namespace transform
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
new file mode 100644
index 0000000..53bdc52
--- /dev/null
+++ b/src/transform/spirv_test.cc
@@ -0,0 +1,107 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/spirv.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using SpirvTest = TransformTest;
+
+TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
+  auto* src = R"(
+[[builtin(sample_index)]] var<in> sample_index : u32;
+
+[[builtin(sample_mask_in)]] var<in> mask_in : u32;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : u32;
+
+[[stage(fragment)]]
+fn main() -> void {
+  mask_out = mask_in;
+}
+)";
+
+  auto* expect = R"(
+[[builtin(sample_index)]] var<in> sample_index : u32;
+
+[[builtin(sample_mask_in)]] var<in> mask_in : array<u32, 1>;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+
+[[stage(fragment)]]
+fn main() -> void {
+  mask_out[0] = mask_in[0];
+}
+)";
+
+  auto got = Transform<Spirv>(src);
+
+  EXPECT_EQ(expect, got);
+}
+
+TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) {
+  auto* src = R"(
+[[builtin(sample_mask_in)]] var<in> mask_in : u32;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : u32;
+
+fn filter(mask: u32) -> u32 {
+  return (mask & 3u);
+}
+
+fn set_mask(input : u32) -> void {
+  mask_out = input;
+}
+
+[[stage(fragment)]]
+fn main() -> void {
+  set_mask(filter(mask_in));
+}
+)";
+
+  auto* expect = R"(
+[[builtin(sample_mask_in)]] var<in> mask_in : array<u32, 1>;
+
+[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+
+fn filter(mask : u32) -> u32 {
+  return (mask & 3u);
+}
+
+fn set_mask(input : u32) -> void {
+  mask_out[0] = input;
+}
+
+[[stage(fragment)]]
+fn main() -> void {
+  set_mask(filter(mask_in[0]));
+}
+)";
+
+  auto got = Transform<Spirv>(src);
+
+  EXPECT_EQ(expect, got);
+}
+
+}  // namespace
+}  // namespace transform
+}  // namespace tint
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 6b545e3..bd75938 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -3172,7 +3172,7 @@
   return SpvStorageClassMax;
 }
 
-SpvBuiltIn Builder::ConvertBuiltin(ast::Builtin builtin) const {
+SpvBuiltIn Builder::ConvertBuiltin(ast::Builtin builtin) {
   switch (builtin) {
     case ast::Builtin::kPosition:
       return SpvBuiltInPosition;
@@ -3194,9 +3194,13 @@
       return SpvBuiltInGlobalInvocationId;
     case ast::Builtin::kPointSize:
       return SpvBuiltInPointSize;
-    case ast::Builtin::kSampleIndex:    // TODO(dneto)
-    case ast::Builtin::kSampleMaskIn:   // TODO(dneto)
-    case ast::Builtin::kSampleMaskOut:  // TODO(dneto)
+    case ast::Builtin::kSampleIndex:
+      push_capability(SpvCapabilitySampleRateShading);
+      return SpvBuiltInSampleId;
+    case ast::Builtin::kSampleMaskIn:
+      return SpvBuiltInSampleMask;
+    case ast::Builtin::kSampleMaskOut:
+      return SpvBuiltInSampleMask;
     case ast::Builtin::kNone:
       break;
   }
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index b1d38b3..bba8bd5 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -216,10 +216,10 @@
   /// @param klass the storage class to convert
   /// @returns the SPIR-V storage class or SpvStorageClassMax on error.
   SpvStorageClass ConvertStorageClass(ast::StorageClass klass) const;
-  /// Converts a builtin to a SPIR-V builtin
+  /// Converts a builtin to a SPIR-V builtin and pushes a capability if needed.
   /// @param builtin the builtin to convert
   /// @returns the SPIR-V builtin or SpvBuiltInMax on error.
-  SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const;
+  SpvBuiltIn ConvertBuiltin(ast::Builtin builtin);
 
   /// Generates a label for the given id. Emits an error and returns false if
   /// we're currently outside a function.
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 4a9517b..c39c281 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -24,6 +24,7 @@
 #include "src/ast/group_decoration.h"
 #include "src/ast/location_decoration.h"
 #include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/stage_decoration.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/struct.h"
 #include "src/ast/type_constructor_expression.h"
@@ -401,7 +402,10 @@
         BuiltinData{ast::Builtin::kLocalInvocationIndex,
                     SpvBuiltInLocalInvocationIndex},
         BuiltinData{ast::Builtin::kGlobalInvocationId,
-                    SpvBuiltInGlobalInvocationId}));
+                    SpvBuiltInGlobalInvocationId},
+        BuiltinData{ast::Builtin::kSampleIndex, SpvBuiltInSampleId},
+        BuiltinData{ast::Builtin::kSampleMaskIn, SpvBuiltInSampleMask},
+        BuiltinData{ast::Builtin::kSampleMaskOut, SpvBuiltInSampleMask}));
 
 TEST_F(BuilderTest, GlobalVar_DeclReadOnly) {
   // struct A {
@@ -629,6 +633,95 @@
 )");
 }
 
+TEST_F(BuilderTest, SampleIndex) {
+  auto* var =
+      Global("sample_index", ast::StorageClass::kInput, ty.u32(), nullptr,
+             ast::VariableDecorationList{
+                 create<ast::BuiltinDecoration>(ast::Builtin::kSampleIndex),
+             });
+
+  spirv::Builder& b = Build();
+
+  EXPECT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.capabilities()),
+            "OpCapability SampleRateShading\n");
+  EXPECT_EQ(DumpInstructions(b.annots()), "OpDecorate %1 BuiltIn SampleId\n");
+  EXPECT_EQ(DumpInstructions(b.types()),
+            "%3 = OpTypeInt 32 0\n"
+            "%2 = OpTypePointer Input %3\n"
+            "%1 = OpVariable %2 Input\n");
+}
+
+TEST_F(BuilderTest, SampleMask) {
+  // Input:
+  // [[builtin(sample_mask_in)]] var<in> mask_in : u32;
+  // [[builtin(sample_mask_out)]] var<out> mask_out : u32;
+  // [[stage(fragment)]]
+  // fn main() -> void {
+  //   mask_out = mask_in;
+  // }
+
+  // After sanitization:
+  // [[builtin(sample_mask_in)]] var<in> mask_in : array<u32, 1>;
+  // [[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
+  // [[stage(fragment)]]
+  // fn main() -> void {
+  //   mask_out[0] = mask_in[0];
+  // }
+
+  Global("mask_in", ast::StorageClass::kInput, ty.u32(), nullptr,
+         ast::VariableDecorationList{
+             create<ast::BuiltinDecoration>(ast::Builtin::kSampleMaskIn),
+         });
+  Global("mask_out", ast::StorageClass::kOutput, ty.u32(), nullptr,
+         ast::VariableDecorationList{
+             create<ast::BuiltinDecoration>(ast::Builtin::kSampleMaskOut),
+         });
+  Func("main", ast::VariableList{}, ty.void_(),
+       ast::StatementList{
+           create<ast::AssignmentStatement>(Expr("mask_out"), Expr("mask_in")),
+       },
+       ast::FunctionDecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kCompute),
+       });
+
+  spirv::Builder& b = SanitizeAndBuild();
+
+  ASSERT_TRUE(b.Build());
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %11 "main" %6 %1
+OpExecutionMode %11 LocalSize 1 1 1
+OpName %1 "mask_in"
+OpName %6 "mask_out"
+OpName %11 "main"
+OpDecorate %1 BuiltIn SampleMask
+OpDecorate %6 BuiltIn SampleMask
+%4 = OpTypeInt 32 0
+%5 = OpConstant %4 1
+%3 = OpTypeArray %4 %5
+%2 = OpTypePointer Input %3
+%1 = OpVariable %2 Input
+%7 = OpTypePointer Output %3
+%8 = OpConstantNull %3
+%6 = OpVariable %7 Output %8
+%10 = OpTypeVoid
+%9 = OpTypeFunction %10
+%13 = OpTypeInt 32 1
+%14 = OpConstant %13 0
+%15 = OpTypePointer Output %4
+%17 = OpTypePointer Input %4
+%11 = OpFunction %10 None %9
+%12 = OpLabel
+%16 = OpAccessChain %15 %6 %14
+%18 = OpAccessChain %17 %1 %14
+%19 = OpLoad %4 %18
+OpStore %16 %19
+OpReturn
+OpFunctionEnd
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/src/writer/spirv/test_helper.h b/src/writer/spirv/test_helper.h
index aafc15d..ba5e51a 100644
--- a/src/writer/spirv/test_helper.h
+++ b/src/writer/spirv/test_helper.h
@@ -24,6 +24,7 @@
 #include "src/ast/module.h"
 #include "src/diagnostic/formatter.h"
 #include "src/program_builder.h"
+#include "src/transform/spirv.h"
 #include "src/type_determiner.h"
 #include "src/writer/spirv/binary_writer.h"
 #include "src/writer/spirv/builder.h"
@@ -60,6 +61,34 @@
     return *spirv_builder;
   }
 
+  /// Builds the program, runs the program through the transform::Spirv
+  /// sanitizer and returns a spirv::Builder from the sanitized program.
+  /// @note The spirv::Builder is only built once. Multiple calls to Build()
+  /// will return the same spirv::Builder without rebuilding.
+  /// @return the built spirv::Builder
+  spirv::Builder& SanitizeAndBuild() {
+    if (spirv_builder) {
+      return *spirv_builder;
+    }
+    [&]() {
+      ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
+                             << diag::Formatter().format(Diagnostics());
+    }();
+    program = std::make_unique<Program>(std::move(*this));
+    [&]() {
+      ASSERT_TRUE(program->IsValid())
+          << diag::Formatter().format(program->Diagnostics());
+    }();
+    auto result = transform::Spirv().Run(program.get());
+    [&]() {
+      ASSERT_FALSE(result.diagnostics.contains_errors())
+          << diag::Formatter().format(result.diagnostics);
+    }();
+    *program = std::move(result.program);
+    spirv_builder = std::make_unique<spirv::Builder>(program.get());
+    return *spirv_builder;
+  }
+
   /// Validate passes the generated SPIR-V of the builder `b` to the SPIR-V
   /// Tools Validator. If the validator finds problems the test will fail.
   /// @param b the spirv::Builder containing the built SPIR-V module