Automatically clone all symbols in CloneContext()

Almost all transforms should clone all symbols before doing any work,
to avoid any newly created symbols clashing with existing symbols the
source program and causing them to be renamed.

The Renamer is the exception to this, and so an optional flag is used
to prevent automatic cloning of symbols for this transform.

Bug: dawn:758
Change-Id: I84527a352825b2eaa43eabe225beb9e0999bf048
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48000
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
diff --git a/src/clone_context.cc b/src/clone_context.cc
index 587e593..9023660 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -26,8 +26,18 @@
 CloneContext::ListTransforms::ListTransforms() = default;
 CloneContext::ListTransforms::~ListTransforms() = default;
 
-CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
-    : dst(to), src(from) {}
+CloneContext::CloneContext(ProgramBuilder* to,
+                           Program const* from,
+                           bool auto_clone_symbols)
+    : dst(to), src(from) {
+  if (auto_clone_symbols) {
+    // Almost all transforms will want to clone all symbols before doing any
+    // work, to avoid any newly created symbols clashing with existing symbols
+    // in the source program and causing them to be renamed.
+    from->Symbols().Foreach([&](Symbol s, const std::string&) { Clone(s); });
+  }
+}
+
 CloneContext::~CloneContext() = default;
 
 Symbol CloneContext::Clone(Symbol s) {
@@ -39,11 +49,6 @@
   });
 }
 
-CloneContext& CloneContext::CloneSymbols() {
-  src->Symbols().Foreach([&](Symbol s, const std::string&) { Clone(s); });
-  return *this;
-}
-
 void CloneContext::Clone() {
   dst->AST().Copy(this, &src->AST());
 }
diff --git a/src/clone_context.h b/src/clone_context.h
index 1243e34..2ff66ca 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -63,7 +63,10 @@
   /// Constructor
   /// @param to the target ProgramBuilder to clone into
   /// @param from the source Program to clone from
-  CloneContext(ProgramBuilder* to, Program const* from);
+  /// @param auto_clone_symbols clone all symbols in `from` before returning
+  CloneContext(ProgramBuilder* to,
+               Program const* from,
+               bool auto_clone_symbols = true);
 
   /// Destructor
   ~CloneContext();
@@ -150,12 +153,6 @@
   /// @return the cloned source
   Symbol Clone(Symbol s);
 
-  /// Clones all the Symbols in `src->Symbols()` into #dst.
-  /// This may be used to ensure that authored symbols are not suffixed with a
-  /// unique identifier if they collide with other symbols.
-  /// @returns this CloneContext so calls can be chained
-  CloneContext& CloneSymbols();
-
   /// Clones each of the elements of the vector `v` into the ProgramBuilder
   /// #dst.
   ///
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index ea6f902..c64ca9f 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -224,7 +224,7 @@
   // N: Node
 
   ProgramBuilder cloned;
-  auto* cloned_root = CloneContext(&cloned, &original)
+  auto* cloned_root = CloneContext(&cloned, &original, false)
                           .ReplaceAll([&](Symbol sym) {
                             auto in = original.Symbols().NameFor(sym);
                             auto out = "transformed<" + in + ">";
@@ -435,7 +435,7 @@
   Program original(std::move(builder));
 
   ProgramBuilder cloned;
-  CloneContext ctx(&cloned, &original);
+  CloneContext ctx(&cloned, &original, false);
   Symbol new_x = cloned.Symbols().New();
   Symbol new_a = ctx.Clone(old_a);
   Symbol new_y = cloned.Symbols().New();
@@ -463,7 +463,7 @@
   Program original(std::move(builder));
 
   ProgramBuilder cloned;
-  CloneContext ctx(&cloned, &original);
+  CloneContext ctx(&cloned, &original, false);
   Symbol new_x = cloned.Symbols().New("a");
   Symbol new_a = ctx.Clone(old_a);
   Symbol new_y = cloned.Symbols().New("b");
@@ -492,7 +492,6 @@
 
   ProgramBuilder cloned;
   CloneContext ctx(&cloned, &original);
-  ctx.CloneSymbols();
   Symbol new_x = cloned.Symbols().New("a");
   Symbol new_a = ctx.Clone(old_a);
   Symbol new_y = cloned.Symbols().New("b");
diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc
index 3f6416b..e66866d 100644
--- a/src/transform/bound_array_accessors.cc
+++ b/src/transform/bound_array_accessors.cc
@@ -29,11 +29,6 @@
 Output BoundArrayAccessors::Run(const Program* in, const DataMap&) {
   ProgramBuilder out;
   CloneContext ctx(&out, in);
-
-  // Start by cloning all the symbols. This ensures that the authored symbols
-  // won't get renamed if they collide with new symbols below.
-  ctx.CloneSymbols();
-
   ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) {
     return Transform(expr, &ctx);
   });
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index 622b6b2..06bb158 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -62,10 +62,6 @@
   ProgramBuilder out;
   CloneContext ctx(&out, in);
 
-  // Start by cloning all the symbols. This ensures that the authored symbols
-  // won't get renamed if they collide with new symbols below.
-  ctx.CloneSymbols();
-
   // Strip entry point IO decorations from struct declarations.
   // TODO(jrprice): This code is duplicated with the SPIR-V transform.
   for (auto* ty : ctx.src->AST().ConstructedTypes()) {
diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc
index d154514..a5c9e09 100644
--- a/src/transform/decompose_storage_access.cc
+++ b/src/transform/decompose_storage_access.cc
@@ -613,10 +613,6 @@
   ProgramBuilder out;
   CloneContext ctx(&out, in);
 
-  // Start by cloning all the symbols. This ensures that the authored symbols
-  // won't get renamed if they collide with new symbols below.
-  ctx.CloneSymbols();
-
   auto& sem = ctx.src->Sem();
 
   State state;
diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc
index 676d400..048d23a 100644
--- a/src/transform/emit_vertex_point_size.cc
+++ b/src/transform/emit_vertex_point_size.cc
@@ -35,10 +35,6 @@
 
   CloneContext ctx(&out, in);
 
-  // Start by cloning all the symbols. This ensures that the authored symbols
-  // won't get renamed if they collide with new symbols below.
-  ctx.CloneSymbols();
-
   Symbol pointsize = out.Symbols().New("tint_pointsize");
 
   // Declare the pointsize builtin output variable.
diff --git a/src/transform/renamer.cc b/src/transform/renamer.cc
index 9ed4743..a5a6856 100644
--- a/src/transform/renamer.cc
+++ b/src/transform/renamer.cc
@@ -848,7 +848,8 @@
 
 Output Renamer::Run(const Program* in, const DataMap&) {
   ProgramBuilder out;
-  CloneContext ctx(&out, in);
+  // Disable auto-cloning of symbols, since we want to rename them.
+  CloneContext ctx(&out, in, false);
 
   // Swizzles and intrinsic calls need to keep their symbols preserved.
   std::unordered_set<ast::IdentifierExpression*> preserve;
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index a34acf3..7b5d1a2 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -414,10 +414,6 @@
 
   CloneContext ctx(&out, in);
 
-  // Start by cloning all the symbols. This ensures that the authored symbols
-  // won't get renamed if they collide with new symbols below.
-  ctx.CloneSymbols();
-
   State state{ctx, cfg};
   state.FindOrInsertVertexIndexIfUsed();
   state.FindOrInsertInstanceIndexIfUsed();
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index cdadf42..c3fa8c4 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -126,10 +126,10 @@
   float value : SV_Target1;
 };
 
-tint_symbol_2 frag_main(tint_symbol_1 tint_symbol_3) {
-  const float foo = tint_symbol_3.foo;
-  const tint_symbol_2 tint_symbol = {foo};
-  return tint_symbol;
+tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
+  const float foo = tint_symbol.foo;
+  const tint_symbol_2 tint_symbol_3 = {foo};
+  return tint_symbol_3;
 }
 
 )");
@@ -160,10 +160,10 @@
   float value : SV_Depth;
 };
 
-tint_symbol_2 frag_main(tint_symbol_1 tint_symbol_3) {
-  const float4 coord = tint_symbol_3.coord;
-  const tint_symbol_2 tint_symbol = {coord.x};
-  return tint_symbol;
+tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
+  const float4 coord = tint_symbol.coord;
+  const tint_symbol_2 tint_symbol_3 = {coord.x};
+  return tint_symbol_3;
 }
 
 )");
@@ -210,7 +210,7 @@
   float col1;
   float col2;
 };
-struct tint_symbol_1 {
+struct tint_symbol {
   float col1 : TEXCOORD1;
   float col2 : TEXCOORD2;
 };
@@ -219,10 +219,10 @@
   float col2 : TEXCOORD2;
 };
 
-tint_symbol_1 vert_main() {
-  const Interface tint_symbol_1_1 = {0.5f, 0.25f};
-  const tint_symbol_1 tint_symbol = {tint_symbol_1_1.col1, tint_symbol_1_1.col2};
-  return tint_symbol;
+tint_symbol vert_main() {
+  const Interface tint_symbol_1 = {0.5f, 0.25f};
+  const tint_symbol tint_symbol_4 = {tint_symbol_1.col1, tint_symbol_1.col2};
+  return tint_symbol_4;
 }
 
 void frag_main(tint_symbol_3 tint_symbol_2) {
@@ -278,28 +278,28 @@
   EXPECT_EQ(result(), R"(struct VertexOutput {
   float4 pos;
 };
-struct tint_symbol_2 {
+struct tint_symbol {
   float4 pos : SV_Position;
 };
-struct tint_symbol_2_1 {
+struct tint_symbol_2 {
   float4 pos : SV_Position;
 };
 
 VertexOutput foo(float x) {
-  const VertexOutput tint_symbol = {float4(x, x, x, 1.0f)};
-  return tint_symbol;
+  const VertexOutput tint_symbol_4 = {float4(x, x, x, 1.0f)};
+  return tint_symbol_4;
 }
 
-tint_symbol_2 vert_main1() {
-  const VertexOutput tint_symbol_1_1 = {foo(0.5f)};
-  const tint_symbol_2 tint_symbol_1 = {tint_symbol_1_1.pos};
-  return tint_symbol_1;
+tint_symbol vert_main1() {
+  const VertexOutput tint_symbol_1 = {foo(0.5f)};
+  const tint_symbol tint_symbol_5 = {tint_symbol_1.pos};
+  return tint_symbol_5;
 }
 
-tint_symbol_2_1 vert_main2() {
-  const VertexOutput tint_symbol_3_1 = {foo(0.25f)};
-  const tint_symbol_2_1 tint_symbol_3 = {tint_symbol_3_1.pos};
-  return tint_symbol_3;
+tint_symbol_2 vert_main2() {
+  const VertexOutput tint_symbol_3 = {foo(0.25f)};
+  const tint_symbol_2 tint_symbol_6 = {tint_symbol_3.pos};
+  return tint_symbol_6;
 }
 
 )");