spirv: Handle sample_mask in shader IO transform

This is easy to do while we are processing builtins in the main
transform now that we use wrapper functions.

This is step towards removing the sanitizers completely.

Change-Id: If5472ce552e3cce1e5905916eeffa8fef90461c9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63585
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index 50d1548..91568cb 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -67,6 +67,12 @@
                        ast::InvariantDecoration, ast::LocationDecoration>();
 }
 
+// Returns true if `decos` contains a `sample_mask` builtin.
+bool HasSampleMask(const ast::DecorationList& decos) {
+  auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(decos);
+  return builtin && builtin->value() == ast::Builtin::kSampleMask;
+}
+
 }  // namespace
 
 /// State holds the current transform state for a single entry point.
@@ -166,9 +172,16 @@
 
       // Create the global variable and use its value for the shader input.
       auto var = ctx.dst->Symbols().New(name);
+      ast::Expression* value = ctx.dst->Expr(var);
+      if (HasSampleMask(attributes)) {
+        // Vulkan requires the type of a SampleMask builtin to be an array.
+        // Declare it as array<u32, 1> and then load the first element.
+        type = ctx.dst->ty.array(type, 1);
+        value = ctx.dst->IndexAccessor(value, 0);
+      }
       ctx.dst->Global(var, type, ast::StorageClass::kInput,
                       std::move(attributes));
-      return ctx.dst->Expr(var);
+      return value;
     } else if (cfg.shader_style == ShaderStyle::kMsl &&
                ast::HasDecoration<ast::BuiltinDecoration>(attributes)) {
       // If this input is a builtin and we are targeting MSL, then add it to the
@@ -303,9 +316,7 @@
   void AddFixedSampleMask() {
     // Check the existing output values for a sample mask builtin.
     for (auto& outval : wrapper_output_values) {
-      auto* builtin =
-          ast::GetDecoration<ast::BuiltinDecoration>(outval.attributes);
-      if (builtin && builtin->value() == ast::Builtin::kSampleMask) {
+      if (HasSampleMask(outval.attributes)) {
         // Combine the authored sample mask with the fixed mask.
         outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask);
         return;
@@ -390,7 +401,7 @@
   }
 
   /// Create and assign the wrapper function's output variables.
-  void CreateOutputVariables() {
+  void CreateSpirvOutputVariables() {
     for (auto& outval : wrapper_output_values) {
       // Disable validation for use of the `output` storage class.
       ast::DecorationList attributes = std::move(outval.attributes);
@@ -400,9 +411,17 @@
 
       // Create the global variable and assign it the output value.
       auto name = ctx.dst->Symbols().New(outval.name);
-      ctx.dst->Global(name, outval.type, ast::StorageClass::kOutput,
+      auto* type = outval.type;
+      ast::Expression* lhs = ctx.dst->Expr(name);
+      if (HasSampleMask(attributes)) {
+        // Vulkan requires the type of a SampleMask builtin to be an array.
+        // Declare it as array<u32, 1> and then store to the first element.
+        type = ctx.dst->ty.array(type, 1);
+        lhs = ctx.dst->IndexAccessor(lhs, 0);
+      }
+      ctx.dst->Global(name, type, ast::StorageClass::kOutput,
                       std::move(attributes));
-      wrapper_body.push_back(ctx.dst->Assign(name, outval.value));
+      wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value));
     }
   }
 
@@ -498,7 +517,7 @@
     // Produce the entry point outputs, if necessary.
     if (!wrapper_output_values.empty()) {
       if (cfg.shader_style == ShaderStyle::kSpirv) {
-        CreateOutputVariables();
+        CreateSpirvOutputVariables();
       } else {
         auto* output_struct = CreateOutputStruct();
         wrapper_ret_type = [&, output_struct] {
diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc
index e4d8413..1dc93b1 100644
--- a/src/transform/canonicalize_entry_point_io_test.cc
+++ b/src/transform/canonicalize_entry_point_io_test.cc
@@ -491,7 +491,7 @@
 
 [[builtin(frag_depth), internal(disable_validation__ignore_storage_class)]] var<out> depth_1 : f32;
 
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> mask_1 : u32;
+[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> mask_1 : array<u32, 1>;
 
 struct FragOutput {
   color : vec4<f32>;
@@ -512,7 +512,7 @@
   let inner_result = frag_main_inner();
   color_1 = inner_result.color;
   depth_1 = inner_result.depth;
-  mask_1 = inner_result.mask;
+  mask_1[0] = inner_result.mask;
 }
 )";
 
@@ -2163,6 +2163,42 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) {
+  auto* src = R"(
+[[stage(fragment)]]
+fn main([[builtin(sample_index)]] sample_index : u32,
+        [[builtin(sample_mask)]] mask_in : u32
+        ) -> [[builtin(sample_mask)]] u32 {
+  return mask_in;
+}
+)";
+
+  auto* expect = R"(
+[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
+
+[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1>;
+
+[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1>;
+
+fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
+  return mask_in;
+}
+
+[[stage(fragment)]]
+fn main() {
+  let inner_result = main_inner(sample_index_1, mask_in_1[0]);
+  value[0] = inner_result;
+}
+)";
+
+  DataMap data;
+  data.Add<CanonicalizeEntryPointIO::Config>(
+      CanonicalizeEntryPointIO::ShaderStyle::kSpirv);
+  auto got = Run<CanonicalizeEntryPointIO>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 }  // namespace
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 6bd220a..76fca09 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -68,61 +68,13 @@
 
   ProgramBuilder builder;
   CloneContext ctx(&builder, &transformedInput.program);
-  HandleSampleMaskBuiltins(ctx);
+  // TODO(jrprice): Move the sanitizer into the backend.
   ctx.Clone();
 
   builder.SetTransformApplied(this);
   return Output{Program(std::move(builder))};
 }
 
-void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const {
-  // Find global variables decorated with [[builtin(sample_mask)]] and
-  // change their type from `u32` to `array<u32, 1>`, as required by Vulkan.
-  //
-  // Before:
-  // ```
-  // [[builtin(sample_mask)]] var<out> mask_out : u32;
-  // fn main() {
-  //   mask_out = 1u;
-  // }
-  // ```
-  // After:
-  // ```
-  // [[builtin(sample_mask)]] var<out> mask_out : array<u32, 1>;
-  // fn main() {
-  //   mask_out[0] = 1u;
-  // }
-  // ```
-
-  for (auto* var : ctx.src->AST().GlobalVariables()) {
-    for (auto* deco : var->decorations()) {
-      if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
-        if (builtin->value() != ast::Builtin::kSampleMask) {
-          continue;
-        }
-
-        // Use the same name as the old variable.
-        auto var_name = ctx.Clone(var->symbol());
-        // Use `array<u32, 1>` for the new variable.
-        auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
-        // Create the new variable.
-        auto* var_arr = ctx.dst->Var(var->source(), var_name, type,
-                                     var->declared_storage_class(), nullptr,
-                                     ctx.Clone(var->decorations()));
-        // Replace the variable with the arrayed version.
-        ctx.Replace(var, var_arr);
-
-        // Replace all uses of the old variable with `var_arr[0]`.
-        for (auto* user : ctx.src->Sem().Get(var)->Users()) {
-          auto* new_ident = ctx.dst->IndexAccessor(
-              ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0));
-          ctx.Replace<ast::Expression>(user->Declaration(), new_ident);
-        }
-      }
-    }
-  }
-}
-
 Spirv::Config::Config(bool emit_vps, bool disable_wi)
     : emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {}
 
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index faf5b96..7e633ab 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -67,10 +67,6 @@
   /// @param data optional extra transform-specific input data
   /// @returns the transformation result
   Output Run(const Program* program, const DataMap& data = {}) override;
-
- private:
-  /// Change type of sample mask builtin variables to single element arrays.
-  void HandleSampleMaskBuiltins(CloneContext& ctx) const;
 };
 
 }  // namespace transform
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index 26ed757..e2d4c3c 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -22,140 +22,7 @@
 
 using SpirvTest = TransformTest;
 
-TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
-  auto* src = R"(
-[[stage(fragment)]]
-fn main([[builtin(sample_index)]] sample_index : u32,
-        [[builtin(sample_mask)]] mask_in : u32
-        ) -> [[builtin(sample_mask)]] u32 {
-  return mask_in;
-}
-)";
-
-  auto* expect = R"(
-[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
-
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
-
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1u>;
-
-fn main_inner(sample_index : u32, mask_in : u32) -> u32 {
-  return mask_in;
-}
-
-[[stage(fragment)]]
-fn main() {
-  let inner_result = main_inner(sample_index_1, mask_in_1[0]);
-  value[0] = inner_result;
-}
-)";
-
-  auto got = Run<Spirv>(src);
-
-  EXPECT_EQ(expect, str(got));
-}
-
-TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) {
-  auto* src = R"(
-fn filter(mask: u32) -> u32 {
-  return (mask & 3u);
-}
-
-fn set_mask(input : u32) -> u32 {
-  return input;
-}
-
-[[stage(fragment)]]
-fn main([[builtin(sample_mask)]] mask_in : u32
-        ) -> [[builtin(sample_mask)]] u32 {
-  return set_mask(filter(mask_in));
-}
-)";
-
-  auto* expect = R"(
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
-
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value : array<u32, 1u>;
-
-fn filter(mask : u32) -> u32 {
-  return (mask & 3u);
-}
-
-fn set_mask(input : u32) -> u32 {
-  return input;
-}
-
-fn main_inner(mask_in : u32) -> u32 {
-  return set_mask(filter(mask_in));
-}
-
-[[stage(fragment)]]
-fn main() {
-  let inner_result = main_inner(mask_in_1[0]);
-  value[0] = inner_result;
-}
-)";
-
-  auto got = Run<Spirv>(src);
-
-  EXPECT_EQ(expect, str(got));
-}
-
-// Test that different transforms within the sanitizer interact correctly.
-TEST_F(SpirvTest, MultipleTransforms) {
-  auto* src = R"(
-[[stage(vertex)]]
-fn vert_main() -> [[builtin(position)]] vec4<f32> {
-  return vec4<f32>();
-}
-
-[[stage(fragment)]]
-fn frag_main([[builtin(sample_index)]] sample_index : u32,
-        [[builtin(sample_mask)]] mask_in : u32)
-        -> [[builtin(sample_mask)]] u32 {
-  return mask_in;
-}
-)";
-
-  auto* expect = R"(
-[[builtin(position), internal(disable_validation__ignore_storage_class)]] var<out> value : vec4<f32>;
-
-[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var<out> vertex_point_size : f32;
-
-[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var<in> sample_index_1 : u32;
-
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<in> mask_in_1 : array<u32, 1u>;
-
-[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var<out> value_1 : array<u32, 1u>;
-
-fn vert_main_inner() -> vec4<f32> {
-  return vec4<f32>();
-}
-
-[[stage(vertex)]]
-fn vert_main() {
-  let inner_result = vert_main_inner();
-  value = inner_result;
-  vertex_point_size = 1.0;
-}
-
-fn frag_main_inner(sample_index : u32, mask_in : u32) -> u32 {
-  return mask_in;
-}
-
-[[stage(fragment)]]
-fn frag_main() {
-  let inner_result_1 = frag_main_inner(sample_index_1, mask_in_1[0]);
-  value_1[0] = inner_result_1;
-}
-)";
-
-  DataMap data;
-  data.Add<Spirv::Config>(true);
-  auto got = Run<Spirv>(src, data);
-
-  EXPECT_EQ(expect, str(got));
-}
+// TODO(jrprice): Remove this file when we remove the sanitizers.
 
 }  // namespace
 }  // namespace transform