Rework the FirstIndexOffset transform

So that it transforms more on clone than in-place.

Bug: dawn:548
Bug: tint:390
Change-Id: I0127bc02c4e0e88c924042c491d274363422cc52
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35420
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/module.h b/src/ast/module.h
index 6b08b4f..59a67b0 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -56,6 +56,10 @@
   Module Clone();
 
   /// Clone this module into `ctx->mod` using the provided CloneContext
+  /// The module will be cloned in this order:
+  /// * Constructed types
+  /// * Global variables
+  /// * Functions
   /// @param ctx the clone context
   void Clone(CloneContext* ctx);
 
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index fcca0c6..37014d0 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -25,6 +25,7 @@
 #include "src/ast/builtin_decoration.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/case_statement.h"
+#include "src/ast/clone_context.h"
 #include "src/ast/constructor_expression.h"
 #include "src/ast/decorated_variable.h"
 #include "src/ast/else_statement.h"
@@ -61,6 +62,20 @@
 constexpr char kFirstInstanceName[] = "tint_first_instance_index";
 constexpr char kIndexOffsetPrefix[] = "tint_first_index_offset_";
 
+ast::DecoratedVariable* clone_variable_with_new_name(ast::CloneContext* ctx,
+                                                     ast::DecoratedVariable* in,
+                                                     std::string new_name) {
+  auto* var = ctx->mod->create<ast::Variable>(ctx->Clone(in->source()),
+                                              new_name, in->storage_class(),
+                                              ctx->Clone(in->type()));
+  var->set_is_const(in->is_const());
+  var->set_constructor(ctx->Clone(in->constructor()));
+
+  auto* out = ctx->mod->create<ast::DecoratedVariable>(var);
+  out->set_decorations(ctx->Clone(in->decorations()));
+  return out;
+}
+
 }  // namespace
 
 FirstIndexOffset::FirstIndexOffset(uint32_t binding, uint32_t set)
@@ -69,17 +84,29 @@
 FirstIndexOffset::~FirstIndexOffset() = default;
 
 Transform::Output FirstIndexOffset::Run(ast::Module* in) {
-  Output out;
-  out.module = in->Clone();
-  auto* mod = &out.module;
+  // First do a quick check to see if the transform has already been applied.
+  for (ast::Variable* var : in->global_variables()) {
+    if (auto* dec_var = var->As<ast::DecoratedVariable>()) {
+      if (dec_var->name() == kBufferName) {
+        diag::Diagnostic err;
+        err.message = "First index offset transform has already been applied.";
+        err.severity = diag::Severity::Error;
+        Output out;
+        out.diagnostics.add(std::move(err));
+        return out;
+      }
+    }
+  }
 
   // Running TypeDeterminer as we require local_referenced_builtin_variables()
-  // to be populated
-  TypeDeterminer td(mod);
+  // to be populated. TODO(bclayton) - it should not be necessary to re-run the
+  // type determiner if semantic information is already generated. Remove.
+  TypeDeterminer td(in);
   if (!td.Determine()) {
     diag::Diagnostic err;
     err.severity = diag::Severity::Error;
     err.message = td.error();
+    Output out;
     out.diagnostics.add(std::move(err));
     return out;
   }
@@ -87,51 +114,69 @@
   std::string vertex_index_name;
   std::string instance_index_name;
 
-  for (ast::Variable* var : mod->global_variables()) {
-    if (auto* dec_var = var->As<ast::DecoratedVariable>()) {
-      if (dec_var->name() == kBufferName) {
-        diag::Diagnostic err;
-        err.message = "First index offset transform has already been applied.";
-        err.severity = diag::Severity::Error;
-        out.diagnostics.add(std::move(err));
-        return out;
-      }
+  Output out;
 
-      for (ast::VariableDecoration* dec : dec_var->decorations()) {
-        if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
-          ast::Builtin blt_type = blt_dec->value();
-          if (blt_type == ast::Builtin::kVertexIdx) {
-            vertex_index_name = var->name();
-            var->set_name(kIndexOffsetPrefix + var->name());
-            has_vertex_index_ = true;
-          } else if (blt_type == ast::Builtin::kInstanceIdx) {
-            instance_index_name = var->name();
-            var->set_name(kIndexOffsetPrefix + var->name());
-            has_instance_index_ = true;
-          }
+  // Lazilly construct the UniformBuffer on first call to
+  // maybe_create_buffer_var()
+  ast::Variable* buffer_var = nullptr;
+  auto maybe_create_buffer_var = [&] {
+    if (buffer_var == nullptr) {
+      buffer_var = AddUniformBuffer(&out.module);
+    }
+  };
+
+  // Clone the AST, renaming the kVertexIdx and kInstanceIdx builtins, and add
+  // a CreateFirstIndexOffset() statement to each function that uses one of
+  // these builtins.
+  ast::CloneContext ctx(&out.module);
+  ctx.ReplaceAll([&](ast::DecoratedVariable* var) -> ast::DecoratedVariable* {
+    for (ast::VariableDecoration* dec : var->decorations()) {
+      if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
+        ast::Builtin blt_type = blt_dec->value();
+        if (blt_type == ast::Builtin::kVertexIdx) {
+          vertex_index_name = var->name();
+          has_vertex_index_ = true;
+          return clone_variable_with_new_name(&ctx, var,
+                                              kIndexOffsetPrefix + var->name());
+        } else if (blt_type == ast::Builtin::kInstanceIdx) {
+          instance_index_name = var->name();
+          has_instance_index_ = true;
+          return clone_variable_with_new_name(&ctx, var,
+                                              kIndexOffsetPrefix + var->name());
         }
       }
     }
-  }
+    return nullptr;  // Just clone var
+  });
+  ctx.ReplaceAll(  // Note: This happens in the same pass as the rename above
+                   // which determines the original builtin variable names,
+                   // but this should be fine, as variables are cloned first.
+      [&](ast::Function* func) -> ast::Function* {
+        maybe_create_buffer_var();
+        if (buffer_var == nullptr) {
+          return nullptr;  // no transform need, just clone func
+        }
+        auto* body = ctx.mod->create<ast::BlockStatement>(
+            ctx.Clone(func->body()->source()));
+        for (const auto& data : func->local_referenced_builtin_variables()) {
+          if (data.second->value() == ast::Builtin::kVertexIdx) {
+            body->append(CreateFirstIndexOffset(
+                vertex_index_name, kFirstVertexName, buffer_var, ctx.mod));
+          } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
+            body->append(CreateFirstIndexOffset(
+                instance_index_name, kFirstInstanceName, buffer_var, ctx.mod));
+          }
+        }
+        for (auto* s : *func->body()) {
+          body->append(ctx.Clone(s));
+        }
+        return ctx.mod->create<ast::Function>(
+            ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()),
+            ctx.Clone(func->return_type()), ctx.Clone(body),
+            ctx.Clone(func->decorations()));
+      });
 
-  if (!has_vertex_index_ && !has_instance_index_) {
-    return out;
-  }
-
-  ast::Variable* buffer_var = AddUniformBuffer(mod);
-
-  for (ast::Function* func : mod->functions()) {
-    for (const auto& data : func->local_referenced_builtin_variables()) {
-      if (data.second->value() == ast::Builtin::kVertexIdx) {
-        AddFirstIndexOffset(vertex_index_name, kFirstVertexName, buffer_var,
-                            func, mod);
-      } else if (data.second->value() == ast::Builtin::kInstanceIdx) {
-        AddFirstIndexOffset(instance_index_name, kFirstInstanceName, buffer_var,
-                            func, mod);
-      }
-    }
-  }
-
+  in->Clone(&ctx);
   return out;
 }
 
@@ -187,12 +232,10 @@
   auto* idx_var =
       mod->create<ast::DecoratedVariable>(mod->create<ast::Variable>(
           Source{}, kBufferName, ast::StorageClass::kUniform, struct_type));
-
-  ast::VariableDecorationList decorations;
-  decorations.push_back(
-      mod->create<ast::BindingDecoration>(binding_, Source{}));
-  decorations.push_back(mod->create<ast::SetDecoration>(set_, Source{}));
-  idx_var->set_decorations(std::move(decorations));
+  idx_var->set_decorations({
+      mod->create<ast::BindingDecoration>(binding_, Source{}),
+      mod->create<ast::SetDecoration>(set_, Source{}),
+  });
 
   mod->AddGlobalVariable(idx_var);
 
@@ -201,11 +244,11 @@
   return idx_var;
 }
 
-void FirstIndexOffset::AddFirstIndexOffset(const std::string& original_name,
-                                           const std::string& field_name,
-                                           ast::Variable* buffer_var,
-                                           ast::Function* func,
-                                           ast::Module* mod) {
+ast::VariableDeclStatement* FirstIndexOffset::CreateFirstIndexOffset(
+    const std::string& original_name,
+    const std::string& field_name,
+    ast::Variable* buffer_var,
+    ast::Module* mod) {
   auto* buffer = mod->create<ast::IdentifierExpression>(buffer_var->name());
   auto* var = mod->create<ast::Variable>(Source{}, original_name,
                                          ast::StorageClass::kNone,
@@ -217,8 +260,7 @@
       mod->create<ast::IdentifierExpression>(kIndexOffsetPrefix + var->name()),
       mod->create<ast::MemberAccessorExpression>(
           buffer, mod->create<ast::IdentifierExpression>(field_name))));
-  func->body()->insert(0,
-                       mod->create<ast::VariableDeclStatement>(std::move(var)));
+  return mod->create<ast::VariableDeclStatement>(var);
 }
 
 }  // namespace transform
diff --git a/src/transform/first_index_offset.h b/src/transform/first_index_offset.h
index 163bfbb..873ffc8 100644
--- a/src/transform/first_index_offset.h
+++ b/src/transform/first_index_offset.h
@@ -18,6 +18,7 @@
 #include <string>
 
 #include "src/ast/module.h"
+#include "src/ast/variable_decl_statement.h"
 #include "src/transform/transform.h"
 
 namespace tint {
@@ -94,12 +95,12 @@
   /// @param original_name the name of the original builtin used in function
   /// @param field_name name of field in firstVertex/Instance buffer
   /// @param buffer_var variable of firstVertex/Instance buffer
-  /// @param func function to modify
-  void AddFirstIndexOffset(const std::string& original_name,
-                           const std::string& field_name,
-                           ast::Variable* buffer_var,
-                           ast::Function* func,
-                           ast::Module* module);
+  /// @param module the target module to contain the new ast nodes
+  ast::VariableDeclStatement* CreateFirstIndexOffset(
+      const std::string& original_name,
+      const std::string& field_name,
+      ast::Variable* buffer_var,
+      ast::Module* module);
 
   uint32_t binding_;
   uint32_t set_;
diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc
index 3e20748..b1dbdd0 100644
--- a/src/transform/first_index_offset_test.cc
+++ b/src/transform/first_index_offset_test.cc
@@ -76,6 +76,8 @@
   struct Builder : public ModuleBuilder {
     void Build() override {
       AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx);
+      AddFunction("test")->body()->append(create<ast::ReturnStatement>(
+          Source{}, create<ast::IdentifierExpression>("vert_idx")));
     }
   };
 
@@ -106,15 +108,16 @@
   ASSERT_FALSE(result.diagnostics.contains_errors())
       << diag::Formatter().format(result.diagnostics);
 
-  EXPECT_EQ("Module{\n}\n", result.module.to_str());
+  auto got = result.module.to_str();
+  auto* expected = "Module{\n}\n";
+  EXPECT_EQ(got, expected);
 }
 
 TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
   struct Builder : public ModuleBuilder {
     void Build() override {
       AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx);
-      ast::Function* func = AddFunction("test");
-      func->body()->append(create<ast::ReturnStatement>(
+      AddFunction("test")->body()->append(create<ast::ReturnStatement>(
           Source{}, create<ast::IdentifierExpression>("vert_idx")));
     }
   };
@@ -131,7 +134,9 @@
   ASSERT_FALSE(result.diagnostics.contains_errors())
       << diag::Formatter().format(result.diagnostics);
 
-  EXPECT_EQ(R"(Module{
+  auto got = result.module.to_str();
+  auto* expected =
+      R"(Module{
   TintFirstIndexOffsetData Struct{
     [[block]]
     StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32}
@@ -180,14 +185,16 @@
     }
   }
 }
-)",
-            result.module.to_str());
+)";
+  EXPECT_EQ(got, expected);
 }
 
 TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
   struct Builder : public ModuleBuilder {
     void Build() override {
       AddBuiltinInput("inst_idx", ast::Builtin::kInstanceIdx);
+      AddFunction("test")->body()->append(create<ast::ReturnStatement>(
+          Source{}, create<ast::IdentifierExpression>("inst_idx")));
     }
   };
 
@@ -202,7 +209,9 @@
 
   ASSERT_FALSE(result.diagnostics.contains_errors())
       << diag::Formatter().format(result.diagnostics);
-  EXPECT_EQ(R"(Module{
+
+  auto got = result.module.to_str();
+  auto* expected = R"(Module{
   TintFirstIndexOffsetData Struct{
     [[block]]
     StructMember{[[ offset 0 ]] tint_first_instance_index: __u32}
@@ -224,9 +233,35 @@
     uniform
     __struct_TintFirstIndexOffsetData
   }
+  Function test -> __u32
+  ()
+  {
+    VariableDeclStatement{
+      VariableConst{
+        inst_idx
+        none
+        __u32
+        {
+          Binary[__u32]{
+            Identifier[__ptr_in__u32]{tint_first_index_offset_inst_idx}
+            add
+            MemberAccessor[__ptr_uniform__u32]{
+              Identifier[__ptr_uniform__struct_TintFirstIndexOffsetData]{tint_first_index_data}
+              Identifier[not set]{tint_first_instance_index}
+            }
+          }
+        }
+      }
+    }
+    Return{
+      {
+        Identifier[__u32]{inst_idx}
+      }
+    }
+  }
 }
-)",
-            result.module.to_str());
+)";
+  EXPECT_EQ(got, expected);
 }
 
 TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
@@ -234,6 +269,8 @@
     void Build() override {
       AddBuiltinInput("inst_idx", ast::Builtin::kInstanceIdx);
       AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx);
+      AddFunction("test")->body()->append(
+          create<ast::ReturnStatement>(Source{}, Expr(1u)));
     }
   };
 
@@ -251,7 +288,9 @@
 
   ASSERT_FALSE(result.diagnostics.contains_errors())
       << diag::Formatter().format(result.diagnostics);
-  EXPECT_EQ(R"(Module{
+
+  auto got = result.module.to_str();
+  auto* expected = R"(Module{
   TintFirstIndexOffsetData Struct{
     [[block]]
     StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32}
@@ -282,9 +321,18 @@
     uniform
     __struct_TintFirstIndexOffsetData
   }
+  Function test -> __u32
+  ()
+  {
+    Return{
+      {
+        ScalarConstructor[__u32]{1}
+      }
+    }
+  }
 }
-)",
-            result.module.to_str());
+)";
+  EXPECT_EQ(got, expected);
 
   EXPECT_TRUE(transform_ptr->HasVertexIndex());
   EXPECT_EQ(transform_ptr->GetFirstVertexOffset(), 0u);
@@ -321,7 +369,9 @@
 
   ASSERT_FALSE(result.diagnostics.contains_errors())
       << diag::Formatter().format(result.diagnostics);
-  EXPECT_EQ(R"(Module{
+
+  auto got = result.module.to_str();
+  auto* expected = R"(Module{
   TintFirstIndexOffsetData Struct{
     [[block]]
     StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32}
@@ -383,8 +433,8 @@
     }
   }
 }
-)",
-            result.module.to_str());
+)";
+  EXPECT_EQ(got, expected);
 }
 
 }  // namespace