Factor out GetInsertionPoint to transform/utils

This function was copy-pasted in two transforms, and will be used in the
next one I'm writing.

Bug: tint:1080
Change-Id: Ic5ffe68a7e9d00b37722e8f5faff01e9e15fa6b1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85262
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 1646c85..a2541bd 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -489,6 +489,8 @@
     "transform/unshadow.h",
     "transform/unwind_discard_functions.cc",
     "transform/unwind_discard_functions.h",
+    "transform/utils/get_insertion_point.cc",
+    "transform/utils/get_insertion_point.h",
     "transform/utils/hoist_to_decl_before.cc",
     "transform/utils/hoist_to_decl_before.h",
     "transform/var_for_dynamic_index.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 1f5e90c..0a3ce1c 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -377,6 +377,8 @@
   transform/wrap_arrays_in_structs.h
   transform/zero_init_workgroup_memory.cc
   transform/zero_init_workgroup_memory.h
+  transform/utils/get_insertion_point.cc
+  transform/utils/get_insertion_point.h
   transform/utils/hoist_to_decl_before.cc
   transform/utils/hoist_to_decl_before.h
   sem/bool_type.cc
@@ -1039,6 +1041,7 @@
       transform/vertex_pulling_test.cc
       transform/wrap_arrays_in_structs_test.cc
       transform/zero_init_workgroup_memory_test.cc
+      transform/utils/get_insertion_point_test.cc
       transform/utils/hoist_to_decl_before_test.cc
     )
   endif()
diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc
index e1d5ab2..9fd19db 100644
--- a/src/tint/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/transform/promote_side_effects_to_decl.cc
@@ -28,6 +28,7 @@
 #include "src/tint/sem/member_accessor_expression.h"
 #include "src/tint/sem/variable.h"
 #include "src/tint/transform/manager.h"
+#include "src/tint/transform/utils/get_insertion_point.h"
 #include "src/tint/transform/utils/hoist_to_decl_before.h"
 #include "src/tint/utils/scoped_assignment.h"
 
@@ -556,33 +557,11 @@
         });
   }
 
-  // For the input statement, returns the block and statement within that block
-  // to insert before/after.
-  std::pair<const sem::BlockStatement*, const ast::Statement*>
-  GetInsertionPoint(const ast::Statement* stmt) {
-    auto* sem_stmt = sem.Get(stmt);
-    if (sem_stmt) {
-      auto* parent = sem_stmt->Parent();
-      if (auto* block = parent->As<sem::BlockStatement>()) {
-        // Common case, just insert in the current block above the input
-        // statement.
-        return {block, stmt};
-      }
-      if (auto* fl = parent->As<sem::ForLoopStatement>()) {
-        if (fl->Declaration()->initializer == stmt) {
-          // For loop init, insert above the for loop itself.
-          return {fl->Block(), fl->Declaration()};
-        }
-      }
-    }
-    return {};
-  }
-
   // Inserts statements in `stmts` before `stmt`
   void InsertBefore(const ast::StatementList& stmts,
                     const ast::Statement* stmt) {
     if (!stmts.empty()) {
-      auto ip = GetInsertionPoint(stmt);
+      auto ip = utils::GetInsertionPoint(ctx, stmt);
       for (auto* s : stmts) {
         ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s);
       }
diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc
index 80a47b1..b15f3a0 100644
--- a/src/tint/transform/unwind_discard_functions.cc
+++ b/src/tint/transform/unwind_discard_functions.cc
@@ -28,6 +28,7 @@
 #include "src/tint/sem/for_loop_statement.h"
 #include "src/tint/sem/function.h"
 #include "src/tint/sem/if_statement.h"
+#include "src/tint/transform/utils/get_insertion_point.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);
 
@@ -42,46 +43,6 @@
   Symbol module_discard_var_name;   // Use ModuleDiscardVarName() to read
   Symbol module_discard_func_name;  // Use ModuleDiscardFuncName() to read
 
-  // For the input statement, returns the block and statement within that
-  // block to insert before/after.
-  std::pair<const sem::BlockStatement*, const ast::Statement*>
-  GetInsertionPoint(const ast::Statement* stmt) {
-    using RetType =
-        std::pair<const sem::BlockStatement*, const ast::Statement*>;
-
-    if (auto* sem_stmt = sem.Get(stmt)) {
-      auto* parent = sem_stmt->Parent();
-      return Switch(
-          parent,
-          [&](const sem::BlockStatement* block) -> RetType {
-            // Common case, just insert in the current block above the input
-            // statement.
-            return {block, stmt};
-          },
-          [&](const sem::ForLoopStatement* fl) -> RetType {
-            // `stmt` is either the for loop initializer or the continuing
-            // statement of a for-loop.
-            if (fl->Declaration()->initializer == stmt) {
-              // For loop init, insert above the for loop itself.
-              return {fl->Block(), fl->Declaration()};
-            }
-
-            TINT_ICE(Transform, b.Diagnostics())
-                << "cannot insert before or after continuing statement of a "
-                   "for-loop";
-            return {};
-          },
-          [&](Default) -> RetType {
-            TINT_ICE(Transform, b.Diagnostics())
-                << "expected parent of statement to be either a block or for "
-                   "loop";
-            return {};
-          });
-    }
-
-    return {};
-  }
-
   // If `block`'s parent is of type TO, returns pointer to it.
   template <typename TO>
   const TO* ParentAs(const ast::BlockStatement* block) {
@@ -186,7 +147,7 @@
                                              const sem::Expression* sem_expr) {
     auto* expr = sem_expr->Declaration();
 
-    auto ip = GetInsertionPoint(stmt);
+    auto ip = utils::GetInsertionPoint(ctx, stmt);
     auto var_name = b.Sym();
     auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr)));
     ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
@@ -239,7 +200,7 @@
       return HoistAndInsertBefore(stmt, sem_expr);
     }
 
-    auto ip = GetInsertionPoint(stmt);
+    auto ip = utils::GetInsertionPoint(ctx, stmt);
     ctx.InsertAfter(ip.first->Declaration()->statements, ip.second,
                     IfDiscardReturn(stmt));
     return nullptr;  // Don't replace current statement
@@ -269,7 +230,7 @@
       to_insert = b.Assign(var_name, true);
     }
 
-    auto ip = GetInsertionPoint(stmt);
+    auto ip = utils::GetInsertionPoint(ctx, stmt);
     ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
     return Return(stmt);
   }
diff --git a/src/tint/transform/utils/get_insertion_point.cc b/src/tint/transform/utils/get_insertion_point.cc
new file mode 100644
index 0000000..0f00e0c
--- /dev/null
+++ b/src/tint/transform/utils/get_insertion_point.cc
@@ -0,0 +1,58 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/utils/get_insertion_point.h"
+#include "src/tint/debug.h"
+#include "src/tint/diagnostic/diagnostic.h"
+#include "src/tint/sem/for_loop_statement.h"
+
+namespace tint::transform::utils {
+
+InsertionPoint GetInsertionPoint(CloneContext& ctx,
+                                 const ast::Statement* stmt) {
+  auto& sem = ctx.src->Sem();
+  auto& diag = ctx.dst->Diagnostics();
+  using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
+
+  if (auto* sem_stmt = sem.Get(stmt)) {
+    auto* parent = sem_stmt->Parent();
+    return Switch(
+        parent,
+        [&](const sem::BlockStatement* block) -> RetType {
+          // Common case, can insert in the current block above/below the input
+          // statement.
+          return {block, stmt};
+        },
+        [&](const sem::ForLoopStatement* fl) -> RetType {
+          // `stmt` is either the for loop initializer or the continuing
+          // statement of a for-loop.
+          if (fl->Declaration()->initializer == stmt) {
+            // For loop init, can insert above the for loop itself.
+            return {fl->Block(), fl->Declaration()};
+          }
+
+          // Cannot insert before or after continuing statement of a for-loop
+          return {};
+        },
+        [&](Default) -> RetType {
+          TINT_ICE(Transform, diag) << "expected parent of statement to be "
+                                       "either a block or for loop";
+          return {};
+        });
+  }
+
+  return {};
+}
+
+}  // namespace tint::transform::utils
diff --git a/src/tint/transform/utils/get_insertion_point.h b/src/tint/transform/utils/get_insertion_point.h
new file mode 100644
index 0000000..85abcea
--- /dev/null
+++ b/src/tint/transform/utils/get_insertion_point.h
@@ -0,0 +1,40 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
+#define SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+
+namespace tint::transform::utils {
+
+/// InsertionPoint is a pair of the block (`first`) within which, and the
+/// statement (`second`) before or after which to insert.
+using InsertionPoint =
+    std::pair<const sem::BlockStatement*, const ast::Statement*>;
+
+/// For the input statement, returns the block and statement within that
+/// block to insert before/after. If `stmt` is a for-loop continue statement,
+/// the function returns {nullptr, nullptr} as we cannot insert before/after it.
+/// @param ctx the clone context
+/// @param stmt the statement to insert before or after
+/// @return the insertion point
+InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt);
+
+}  // namespace tint::transform::utils
+
+#endif  // SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
diff --git a/src/tint/transform/utils/get_insertion_point_test.cc b/src/tint/transform/utils/get_insertion_point_test.cc
new file mode 100644
index 0000000..48e358e
--- /dev/null
+++ b/src/tint/transform/utils/get_insertion_point_test.cc
@@ -0,0 +1,94 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/debug.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/transform/test_helper.h"
+#include "src/tint/transform/utils/get_insertion_point.h"
+
+namespace tint::transform {
+namespace {
+
+using GetInsertionPointTest = ::testing::Test;
+
+TEST_F(GetInsertionPointTest, Block) {
+  // fn f() {
+  //     var a = 1;
+  // }
+  ProgramBuilder b;
+  auto* expr = b.Expr(1);
+  auto* var = b.Decl(b.Var("a", nullptr, expr));
+  auto* block = b.Block(var);
+  b.Func("f", {}, b.ty.void_(), {block});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  // Can insert in block containing the variable, above or below the input
+  // statement.
+  auto ip = utils::GetInsertionPoint(ctx, var);
+  ASSERT_EQ(ip.first->Declaration(), block);
+  ASSERT_EQ(ip.second, var);
+}
+
+TEST_F(GetInsertionPointTest, ForLoopInit) {
+  // fn f() {
+  //     for(var a = 1; true; ) {
+  //     }
+  // }
+  ProgramBuilder b;
+  auto* expr = b.Expr(1);
+  auto* var = b.Decl(b.Var("a", nullptr, expr));
+  auto* fl = b.For(var, b.Expr(true), {}, b.Block());
+  auto* func_block = b.Block(fl);
+  b.Func("f", {}, b.ty.void_(), {func_block});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  // Can insert in block containing for-loop above the for-loop itself.
+  auto ip = utils::GetInsertionPoint(ctx, var);
+  ASSERT_EQ(ip.first->Declaration(), func_block);
+  ASSERT_EQ(ip.second, fl);
+}
+
+TEST_F(GetInsertionPointTest, ForLoopCont_Invalid) {
+  // fn f() {
+  //     for(; true; var a = 1) {
+  //     }
+  // }
+  ProgramBuilder b;
+  auto* expr = b.Expr(1);
+  auto* var = b.Decl(b.Var("a", nullptr, expr));
+  auto* s = b.For({}, b.Expr(true), var, b.Block());
+  b.Func("f", {}, b.ty.void_(), {s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  // Can't insert before/after for loop continue statement (would ned to be
+  // converted to loop).
+  auto ip = utils::GetInsertionPoint(ctx, var);
+  ASSERT_EQ(ip.first, nullptr);
+  ASSERT_EQ(ip.second, nullptr);
+}
+
+}  // namespace
+}  // namespace tint::transform
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index 7f6ec7e..6bdd527 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -339,6 +339,7 @@
     "../../src/tint/transform/transform_test.cc",
     "../../src/tint/transform/unshadow_test.cc",
     "../../src/tint/transform/unwind_discard_functions_test.cc",
+    "../../src/tint/transform/utils/get_insertion_point_test.cc",
     "../../src/tint/transform/utils/hoist_to_decl_before_test.cc",
     "../../src/tint/transform/var_for_dynamic_index_test.cc",
     "../../src/tint/transform/vectorize_scalar_matrix_constructors_test.cc",