transform::Msl: Rename reserved keywords

This change begins the work to move the reserved keyword remapping out of the writer and into the sanitizer transform.

If the transform::Renamer is in use, then these symbols should never have to be remapped - however for debugging purposes it is often nice to be able to emit code that isn't entirely mangled.

The logic in the msl writer will be removed as a followup change

Bug: tint:273
Change-Id: I76af03ff80388a48d9dd80a5b5fdfe21f3c8e7a0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/43982
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 8ec5e25..3f2df43 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -975,6 +975,7 @@
 
 source_set("tint_unittests_spv_writer_src") {
   sources = [
+    "src/transform/msl_test.cc",
     "src/transform/spirv_test.cc",
     "src/writer/spirv/binary_writer_test.cc",
     "src/writer/spirv/builder_accessor_expression_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index b18564c..0093194 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -638,6 +638,7 @@
 
   if(${TINT_BUILD_SPV_WRITER})
     list(APPEND TINT_TEST_SRCS
+      transform/msl_test.cc
       transform/spirv_test.cc
       writer/spirv/binary_writer_test.cc
       writer/spirv/builder_accessor_expression_test.cc
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index bb6d8e9..a089fda 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -20,13 +20,254 @@
 
 namespace tint {
 namespace transform {
+namespace {
+const char* kReservedKeywords[] = {"access",
+                                   "alignas",
+                                   "alignof",
+                                   "and",
+                                   "and_eq",
+                                   "array",
+                                   "array_ref",
+                                   "as_type",
+                                   "asm",
+                                   "atomic",
+                                   "atomic_bool",
+                                   "atomic_int",
+                                   "atomic_uint",
+                                   "auto",
+                                   "bitand",
+                                   "bitor",
+                                   "bool",
+                                   "bool2",
+                                   "bool3",
+                                   "bool4",
+                                   "break",
+                                   "buffer",
+                                   "case",
+                                   "catch",
+                                   "char",
+                                   "char16_t",
+                                   "char2",
+                                   "char3",
+                                   "char32_t",
+                                   "char4",
+                                   "class",
+                                   "compl",
+                                   "const",
+                                   "const_cast",
+                                   "const_reference",
+                                   "constant",
+                                   "constexpr",
+                                   "continue",
+                                   "decltype",
+                                   "default",
+                                   "delete",
+                                   "depth2d",
+                                   "depth2d_array",
+                                   "depth2d_ms",
+                                   "depth2d_ms_array",
+                                   "depthcube",
+                                   "depthcube_array",
+                                   "device",
+                                   "discard_fragment",
+                                   "do",
+                                   "double",
+                                   "dynamic_cast",
+                                   "else",
+                                   "enum",
+                                   "explicit",
+                                   "extern",
+                                   "false",
+                                   "final",
+                                   "float",
+                                   "float2",
+                                   "float2x2",
+                                   "float2x3",
+                                   "float2x4",
+                                   "float3",
+                                   "float3x2",
+                                   "float3x3",
+                                   "float3x4",
+                                   "float4",
+                                   "float4x2",
+                                   "float4x3",
+                                   "float4x4",
+                                   "for",
+                                   "fragment",
+                                   "friend",
+                                   "goto",
+                                   "half",
+                                   "half2",
+                                   "half2x2",
+                                   "half2x3",
+                                   "half2x4",
+                                   "half3",
+                                   "half3x2",
+                                   "half3x3",
+                                   "half3x4",
+                                   "half4",
+                                   "half4x2",
+                                   "half4x3",
+                                   "half4x4",
+                                   "if",
+                                   "imageblock",
+                                   "inline",
+                                   "int",
+                                   "int16_t",
+                                   "int2",
+                                   "int3",
+                                   "int32_t",
+                                   "int4",
+                                   "int64_t",
+                                   "int8_t",
+                                   "kernel",
+                                   "long",
+                                   "long2",
+                                   "long3",
+                                   "long4",
+                                   "main",
+                                   "metal",
+                                   "mutable",
+                                   "namespace",
+                                   "new",
+                                   "noexcept",
+                                   "not",
+                                   "not_eq",
+                                   "nullptr",
+                                   "operator",
+                                   "or",
+                                   "or_eq",
+                                   "override",
+                                   "packed_bool2",
+                                   "packed_bool3",
+                                   "packed_bool4",
+                                   "packed_char2",
+                                   "packed_char3",
+                                   "packed_char4",
+                                   "packed_float2",
+                                   "packed_float3",
+                                   "packed_float4",
+                                   "packed_half2",
+                                   "packed_half3",
+                                   "packed_half4",
+                                   "packed_int2",
+                                   "packed_int3",
+                                   "packed_int4",
+                                   "packed_short2",
+                                   "packed_short3",
+                                   "packed_short4",
+                                   "packed_uchar2",
+                                   "packed_uchar3",
+                                   "packed_uchar4",
+                                   "packed_uint2",
+                                   "packed_uint3",
+                                   "packed_uint4",
+                                   "packed_ushort2",
+                                   "packed_ushort3",
+                                   "packed_ushort4",
+                                   "patch_control_point",
+                                   "private",
+                                   "protected",
+                                   "ptrdiff_t",
+                                   "public",
+                                   "r16snorm",
+                                   "r16unorm",
+                                   "r8unorm",
+                                   "reference",
+                                   "register",
+                                   "reinterpret_cast",
+                                   "return",
+                                   "rg11b10f",
+                                   "rg16snorm",
+                                   "rg16unorm",
+                                   "rg8snorm",
+                                   "rg8unorm",
+                                   "rgb10a2",
+                                   "rgb9e5",
+                                   "rgba16snorm",
+                                   "rgba16unorm",
+                                   "rgba8snorm",
+                                   "rgba8unorm",
+                                   "sampler",
+                                   "short",
+                                   "short2",
+                                   "short3",
+                                   "short4",
+                                   "signed",
+                                   "size_t",
+                                   "sizeof",
+                                   "srgba8unorm",
+                                   "static",
+                                   "static_assert",
+                                   "static_cast",
+                                   "struct",
+                                   "switch",
+                                   "template",
+                                   "texture",
+                                   "texture1d",
+                                   "texture1d_array",
+                                   "texture2d",
+                                   "texture2d_array",
+                                   "texture2d_ms",
+                                   "texture2d_ms_array",
+                                   "texture3d",
+                                   "texture_buffer",
+                                   "texturecube",
+                                   "texturecube_array",
+                                   "this",
+                                   "thread",
+                                   "thread_local",
+                                   "threadgroup",
+                                   "threadgroup_imageblock",
+                                   "throw",
+                                   "true",
+                                   "try",
+                                   "typedef",
+                                   "typeid",
+                                   "typename",
+                                   "uchar",
+                                   "uchar2",
+                                   "uchar3",
+                                   "uchar4",
+                                   "uint",
+                                   "uint16_t",
+                                   "uint2",
+                                   "uint3",
+                                   "uint32_t",
+                                   "uint4",
+                                   "uint64_t",
+                                   "uint8_t",
+                                   "ulong2",
+                                   "ulong3",
+                                   "ulong4",
+                                   "uniform",
+                                   "union",
+                                   "unsigned",
+                                   "ushort",
+                                   "ushort2",
+                                   "ushort3",
+                                   "ushort4",
+                                   "using",
+                                   "vec",
+                                   "vertex",
+                                   "virtual",
+                                   "void",
+                                   "volatile",
+                                   "wchar_t",
+                                   "while",
+                                   "xor",
+                                   "xor_eq"};
+}  // namespace
 
 Msl::Msl() = default;
 Msl::~Msl() = default;
 
 Transform::Output Msl::Run(const Program* in) {
   ProgramBuilder out;
-  CloneContext(&out, in).Clone();
+  CloneContext ctx(&out, in);
+  RenameReservedKeywords(&ctx, kReservedKeywords);
+  ctx.Clone();
+
   return Output{Program(std::move(out))};
 }
 
diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc
new file mode 100644
index 0000000..5390ab2
--- /dev/null
+++ b/src/transform/msl_test.cc
@@ -0,0 +1,334 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/msl.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using MslReservedKeywordTest = TransformTestWithParam<std::string>;
+
+TEST_F(MslReservedKeywordTest, Basic) {
+  auto* src = R"(
+struct class {
+  delete : i32;
+};
+
+[[stage(fragment)]]
+fn main() -> void {
+  var foo : i32;
+  var half : f32;
+  var half1 : f32;
+  var half2 : f32;
+  var _tint_half2 : f32;
+}
+)";
+
+  auto* expect = R"(
+struct _tint_class {
+  _tint_delete : i32;
+};
+
+[[stage(fragment)]]
+fn _tint_main() -> void {
+  var foo : i32;
+  var _tint_half : f32;
+  var half1 : f32;
+  var _tint_half2_0 : f32;
+  var _tint_half2 : f32;
+}
+)";
+
+  auto got = Transform<Msl>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_P(MslReservedKeywordTest, Keywords) {
+  auto keyword = GetParam();
+
+  auto src = R"(
+[[stage(fragment)]]
+fn main() -> void {
+  var )" + keyword +
+             R"( : i32;
+}
+)";
+
+  auto expect = R"(
+[[stage(fragment)]]
+fn _tint_main() -> void {
+  var _tint_)" + keyword +
+                R"( : i32;
+}
+)";
+
+  auto got = Transform<Msl>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+INSTANTIATE_TEST_SUITE_P(MslReservedKeywordTest,
+                         MslReservedKeywordTest,
+                         testing::Values(
+                             // c++14 spec
+                             "alignas",
+                             "alignof",
+                             "and",
+                             "and_eq",
+                             // "asm",  // Also reserved in WGSL
+                             "auto",
+                             "bitand",
+                             "bitor",
+                             // "bool",   // Also used in WGSL
+                             // "break",  // Also used in WGSL
+                             // "case",   // Also used in WGSL
+                             "catch",
+                             "char",
+                             "char16_t",
+                             "char32_t",
+                             "class",
+                             "compl",
+                             // "const",     // Also used in WGSL
+                             "const_cast",
+                             "constexpr",
+                             // "continue",  // Also used in WGSL
+                             "decltype",
+                             // "default",   // Also used in WGSL
+                             "delete",
+                             // "do",  // Also used in WGSL
+                             "double",
+                             "dynamic_cast",
+                             // "else",  // Also used in WGSL
+                             // "enum",  // Also used in WGSL
+                             "explicit",
+                             "extern",
+                             // "false",  // Also used in WGSL
+                             "final",
+                             "float",
+                             // "for",  // Also used in WGSL
+                             "friend",
+                             "goto",
+                             // "if",  // Also used in WGSL
+                             "inline",
+                             "int",
+                             "long",
+                             "mutable",
+                             "namespace",
+                             "new",
+                             "noexcept",
+                             "not",
+                             "not_eq",
+                             "nullptr",
+                             "operator",
+                             "or",
+                             "or_eq",
+                             "override",
+                             // "private",  // Also used in WGSL
+                             "protected",
+                             "public",
+                             "register",
+                             "reinterpret_cast",
+                             // "return",  // Also used in WGSL
+                             "short",
+                             "signed",
+                             "sizeof",
+                             "static",
+                             "static_assert",
+                             "static_cast",
+                             // "struct",  // Also used in WGSL
+                             // "switch",  // Also used in WGSL
+                             "template",
+                             "this",
+                             "thread_local",
+                             "throw",
+                             // "true",  // Also used in WGSL
+                             "try",
+                             // "typedef",  // Also used in WGSL
+                             "typeid",
+                             "typename",
+                             "union",
+                             "unsigned",
+                             "using",
+                             "virtual",
+                             // "void",  // Also used in WGSL
+                             "volatile",
+                             "wchar_t",
+                             "while",
+                             "xor",
+                             "xor_eq",
+
+                             // Metal Spec
+                             "access",
+                             // "array",  // Also used in WGSL
+                             "array_ref",
+                             "as_type",
+                             "atomic",
+                             "atomic_bool",
+                             "atomic_int",
+                             "atomic_uint",
+                             "bool2",
+                             "bool3",
+                             "bool4",
+                             "buffer",
+                             "char2",
+                             "char3",
+                             "char4",
+                             "const_reference",
+                             "constant",
+                             "depth2d",
+                             "depth2d_array",
+                             "depth2d_ms",
+                             "depth2d_ms_array",
+                             "depthcube",
+                             "depthcube_array",
+                             "device",
+                             "discard_fragment",
+                             "float2",
+                             "float2x2",
+                             "float2x3",
+                             "float2x4",
+                             "float3",
+                             "float3x2",
+                             "float3x3",
+                             "float3x4",
+                             "float4",
+                             "float4x2",
+                             "float4x3",
+                             "float4x4",
+                             "fragment",
+                             "half",
+                             "half2",
+                             "half2x2",
+                             "half2x3",
+                             "half2x4",
+                             "half3",
+                             "half3x2",
+                             "half3x3",
+                             "half3x4",
+                             "half4",
+                             "half4x2",
+                             "half4x3",
+                             "half4x4",
+                             "imageblock",
+                             "int16_t",
+                             "int2",
+                             "int3",
+                             "int32_t",
+                             "int4",
+                             "int64_t",
+                             "int8_t",
+                             "kernel",
+                             "long2",
+                             "long3",
+                             "long4",
+                             "main",   // No functions called main
+                             "metal",  // The namespace
+                             "packed_bool2",
+                             "packed_bool3",
+                             "packed_bool4",
+                             "packed_char2",
+                             "packed_char3",
+                             "packed_char4",
+                             "packed_float2",
+                             "packed_float3",
+                             "packed_float4",
+                             "packed_half2",
+                             "packed_half3",
+                             "packed_half4",
+                             "packed_int2",
+                             "packed_int3",
+                             "packed_int4",
+                             "packed_short2",
+                             "packed_short3",
+                             "packed_short4",
+                             "packed_uchar2",
+                             "packed_uchar3",
+                             "packed_uchar4",
+                             "packed_uint2",
+                             "packed_uint3",
+                             "packed_uint4",
+                             "packed_ushort2",
+                             "packed_ushort3",
+                             "packed_ushort4",
+                             "patch_control_point",
+                             "ptrdiff_t",
+                             "r16snorm",
+                             "r16unorm",
+                             // "r8unorm",  // Also used in WGSL
+                             "reference",
+                             "rg11b10f",
+                             "rg16snorm",
+                             "rg16unorm",
+                             // "rg8snorm",  // Also used in WGSL
+                             // "rg8unorm",  // Also used in WGSL
+                             "rgb10a2",
+                             "rgb9e5",
+                             "rgba16snorm",
+                             "rgba16unorm",
+                             // "rgba8snorm",  // Also used in WGSL
+                             // "rgba8unorm",  // Also used in WGSL
+                             // "sampler",  // Also used in WGSL
+                             "short2",
+                             "short3",
+                             "short4",
+                             "size_t",
+                             "srgba8unorm",
+                             "texture",
+                             "texture1d",
+                             "texture1d_array",
+                             "texture2d",
+                             "texture2d_array",
+                             "texture2d_ms",
+                             "texture2d_ms_array",
+                             "texture3d",
+                             "texture_buffer",
+                             "texturecube",
+                             "texturecube_array",
+                             "thread",
+                             "threadgroup",
+                             "threadgroup_imageblock",
+                             "uchar",
+                             "uchar2",
+                             "uchar3",
+                             "uchar4",
+                             "uint",
+                             "uint16_t",
+                             "uint2",
+                             "uint3",
+                             "uint32_t",
+                             "uint4",
+                             "uint64_t",
+                             "uint8_t",
+                             "ulong2",
+                             "ulong3",
+                             "ulong4",
+                             // "uniform",  // Also used in WGSL
+                             "ushort",
+                             "ushort2",
+                             "ushort3",
+                             "ushort4",
+                             "vec",
+                             "vertex"));
+
+}  // namespace
+}  // namespace transform
+}  // namespace tint
diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h
index ae78d51..191f1d8 100644
--- a/src/transform/test_helper.h
+++ b/src/transform/test_helper.h
@@ -32,7 +32,8 @@
 namespace transform {
 
 /// Helper class for testing transforms
-class TransformTest : public testing::Test {
+template <typename BASE>
+class TransformTestBase : public BASE {
  public:
   /// Transforms and returns the WGSL source `in`, transformed using
   /// `transforms`.
@@ -42,8 +43,11 @@
   Transform::Output Transform(
       std::string in,
       std::vector<std::unique_ptr<transform::Transform>> transforms) {
-    Source::File file("test", in);
-    auto program = reader::wgsl::Parse(&file);
+    auto file = std::make_unique<Source::File>("test", in);
+    auto program = reader::wgsl::Parse(file.get());
+
+    // Keep this pointer alive after Transform() returns
+    files_.emplace_back(std::move(file));
 
     if (!program.IsValid()) {
       return Transform::Output(std::move(program));
@@ -108,8 +112,16 @@
     }
     return "\n" + res + "\n";
   }
+
+ private:
+  std::vector<std::unique_ptr<Source::File>> files_;
 };
 
+using TransformTest = TransformTestBase<testing::Test>;
+
+template <typename T>
+using TransformTestWithParam = TransformTestBase<testing::TestWithParam<T>>;
+
 }  // namespace transform
 }  // namespace tint
 
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index 36dbc8f..e0ab4c9 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -14,6 +14,8 @@
 
 #include "src/transform/transform.h"
 
+#include <algorithm>
+
 #include "src/ast/block_statement.h"
 #include "src/ast/function.h"
 #include "src/clone_context.h"
@@ -63,5 +65,23 @@
                                          body, decos);
 }
 
+void Transform::RenameReservedKeywords(CloneContext* ctx,
+                                       const char* names[],
+                                       size_t count) {
+  ctx->ReplaceAll([=](Symbol in) {
+    auto name_in = ctx->src->Symbols().NameFor(in);
+    if (!std::binary_search(names, names + count, name_in)) {
+      return ctx->dst->Symbols().Register(name_in);
+    }
+    // Create a new unique name
+    auto base_name = "_tint_" + name_in;
+    auto name_out = base_name;
+    for (int i = 0; ctx->src->Symbols().Get(name_out).IsValid(); i++) {
+      name_out = base_name + "_" + std::to_string(i);
+    }
+    return ctx->dst->Symbols().Register(name_out);
+  });
+}
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/transform.h b/src/transform/transform.h
index 47a950a..c82bcf6 100644
--- a/src/transform/transform.h
+++ b/src/transform/transform.h
@@ -154,6 +154,25 @@
       CloneContext* ctx,
       ast::Function* in,
       ast::StatementList statements);
+
+  /// Registers a symbol renamer on `ctx` for any symbol that is found in the
+  /// list of reserved identifiers.
+  /// @param ctx the clone context
+  /// @param names the lexicographically sorted list of reserved identifiers
+  /// @param count the number of identifiers in the array `names`
+  static void RenameReservedKeywords(CloneContext* ctx,
+                                     const char* names[],
+                                     size_t count);
+
+  /// Registers a symbol renamer on `ctx` for any symbol that is found in the
+  /// list of reserved identifiers.
+  /// @param ctx the clone context
+  /// @param names the lexicographically sorted list of reserved identifiers
+  template <size_t N>
+  static void RenameReservedKeywords(CloneContext* ctx,
+                                     const char* (&names)[N]) {
+    RenameReservedKeywords(ctx, names, N);
+  }
 };
 
 }  // namespace transform