HoistToDeclBefore: Add InsertBefore method

This can be used to insert a new statement before an existing
statement, and will take care of converting for-loop and else-if
statements as necessary.

Change-Id: I5ef20f33cf36bb48ea5dabe1048c9d9b3c61b3ee
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85281
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index ff345eb..05c56bb 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -59,84 +59,6 @@
   /// If statements with 'else if's that need to be decomposed to 'else {if}'
   std::unordered_map<const sem::IfStatement*, IfInfo> ifs;
 
-  // Inserts `decl` before `before_expr`, possibly marking a for-loop to be
-  // converted to a loop, or an else-if to an else { if }. If `decl` is nullptr,
-  // for-loop and else-if conversions are marked, but no hoisting takes place.
-  bool InsertBefore(const sem::Expression* before_expr,
-                    const ast::VariableDeclStatement* decl) {
-    auto* sem_stmt = before_expr->Stmt();
-    auto* stmt = sem_stmt->Declaration();
-
-    if (auto* else_if = sem_stmt->As<sem::ElseStatement>()) {
-      // Expression used in 'else if' condition.
-      // Need to convert 'else if' to 'else { if }'.
-      auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
-
-      // Index the map to convert this else if, even if `decl` is nullptr.
-      auto& decls = if_info.else_ifs[else_if].cond_decls;
-      if (decl) {
-        decls.emplace_back(decl);
-      }
-      return true;
-    }
-
-    if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) {
-      // Expression used in for-loop condition.
-      // For-loop needs to be decomposed to a loop.
-
-      // Index the map to convert this for-loop, even if `decl` is nullptr.
-      auto& decls = loops[fl].cond_decls;
-      if (decl) {
-        decls.emplace_back(decl);
-      }
-      return true;
-    }
-
-    auto* parent = sem_stmt->Parent();  // The statement's parent
-    if (auto* block = parent->As<sem::BlockStatement>()) {
-      // Expression's statement sits in a block. Simple case.
-      // Insert the decl before the parent statement
-      if (decl) {
-        ctx.InsertBefore(block->Declaration()->statements, stmt, decl);
-      }
-      return true;
-    }
-
-    if (auto* fl = parent->As<sem::ForLoopStatement>()) {
-      // Expression is used in a for-loop. These require special care.
-      if (fl->Declaration()->initializer == stmt) {
-        // Expression used in for-loop initializer.
-        // Insert the let above the for-loop.
-        if (decl) {
-          ctx.InsertBefore(fl->Block()->Declaration()->statements,
-                           fl->Declaration(), decl);
-        }
-        return true;
-      }
-
-      if (fl->Declaration()->continuing == stmt) {
-        // Expression used in for-loop continuing.
-        // For-loop needs to be decomposed to a loop.
-
-        // Index the map to convert this for-loop, even if `decl` is nullptr.
-        auto& decls = loops[fl].cont_decls;
-        if (decl) {
-          decls.emplace_back(decl);
-        }
-        return true;
-      }
-
-      TINT_ICE(Transform, b.Diagnostics())
-          << "unhandled use of expression in for-loop";
-      return false;
-    }
-
-    TINT_ICE(Transform, b.Diagnostics())
-        << "unhandled expression parent statement type: "
-        << parent->TypeInfo().name;
-    return false;
-  }
-
   // Converts any for-loops marked for conversion to loops, inserting
   // registered declaration statements before the condition or continuing
   // statement.
@@ -293,7 +215,7 @@
                        : b.Var(name, nullptr, ctx.Clone(expr));
     auto* decl = b.Decl(v);
 
-    if (!InsertBefore(before_expr, decl)) {
+    if (!InsertBefore(before_expr->Stmt(), decl)) {
       return false;
     }
 
@@ -302,13 +224,95 @@
     return true;
   }
 
+  /// Inserts `stmt` before `before_stmt`, possibly marking a for-loop to be
+  /// converted to a loop, or an else-if to an else { if }. If `decl` is
+  /// nullptr, for-loop and else-if conversions are marked, but no hoisting
+  /// takes place.
+  /// @param before_stmt statement to insert `stmt` before
+  /// @param stmt statement to insert
+  /// @return true on success
+  bool InsertBefore(const sem::Statement* before_stmt,
+                    const ast::Statement* stmt) {
+    auto* ip = before_stmt->Declaration();
+
+    if (auto* else_if = before_stmt->As<sem::ElseStatement>()) {
+      // Insertion point is an 'else if' condition.
+      // Need to convert 'else if' to 'else { if }'.
+      auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
+
+      // Index the map to convert this else if, even if `stmt` is nullptr.
+      auto& decls = if_info.else_ifs[else_if].cond_decls;
+      if (stmt) {
+        decls.emplace_back(stmt);
+      }
+      return true;
+    }
+
+    if (auto* fl = before_stmt->As<sem::ForLoopStatement>()) {
+      // Insertion point is a for-loop condition.
+      // For-loop needs to be decomposed to a loop.
+
+      // Index the map to convert this for-loop, even if `stmt` is nullptr.
+      auto& decls = loops[fl].cond_decls;
+      if (stmt) {
+        decls.emplace_back(stmt);
+      }
+      return true;
+    }
+
+    auto* parent = before_stmt->Parent();  // The statement's parent
+    if (auto* block = parent->As<sem::BlockStatement>()) {
+      // Insert point sits in a block. Simple case.
+      // Insert the stmt before the parent statement.
+      if (stmt) {
+        ctx.InsertBefore(block->Declaration()->statements, ip, stmt);
+      }
+      return true;
+    }
+
+    if (auto* fl = parent->As<sem::ForLoopStatement>()) {
+      // Insertion point is a for-loop initializer or continuing statement.
+      // These require special care.
+      if (fl->Declaration()->initializer == ip) {
+        // Insertion point is a for-loop initializer.
+        // Insert the new statement above the for-loop.
+        if (stmt) {
+          ctx.InsertBefore(fl->Block()->Declaration()->statements,
+                           fl->Declaration(), stmt);
+        }
+        return true;
+      }
+
+      if (fl->Declaration()->continuing == ip) {
+        // Insertion point is a for-loop continuing statement.
+        // For-loop needs to be decomposed to a loop.
+
+        // Index the map to convert this for-loop, even if `stmt` is nullptr.
+        auto& decls = loops[fl].cont_decls;
+        if (stmt) {
+          decls.emplace_back(stmt);
+        }
+        return true;
+      }
+
+      TINT_ICE(Transform, b.Diagnostics())
+          << "unhandled use of expression in for-loop";
+      return false;
+    }
+
+    TINT_ICE(Transform, b.Diagnostics())
+        << "unhandled expression parent statement type: "
+        << parent->TypeInfo().name;
+    return false;
+  }
+
   /// Use to signal that we plan on hoisting a decl before `before_expr`. This
   /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
   /// needed.
   /// @param before_expr expression we would hoist a decl before
   /// @return true on success
   bool Prepare(const sem::Expression* before_expr) {
-    return InsertBefore(before_expr, nullptr);
+    return InsertBefore(before_expr->Stmt(), nullptr);
   }
 
   /// Applies any scheduled insertions from previous calls to Add() to
@@ -333,6 +337,11 @@
   return state_->Add(before_expr, expr, as_const, decl_name);
 }
 
+bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
+                                     const ast::Statement* stmt) {
+  return state_->InsertBefore(before_stmt, stmt);
+}
+
 bool HoistToDeclBefore::Prepare(const sem::Expression* before_expr) {
   return state_->Prepare(before_expr);
 }
diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h
index 583896d..2d94f52 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.h
+++ b/src/tint/transform/utils/hoist_to_decl_before.h
@@ -46,6 +46,14 @@
            bool as_const,
            const char* decl_name = "");
 
+  /// Inserts `stmt` before `before_stmt`, possibly converting 'for-loop's to
+  /// 'loop's if necessary.
+  /// @param before_stmt statement to insert `stmt` before
+  /// @param stmt statement to insert
+  /// @return true on success
+  bool InsertBefore(const sem::Statement* before_stmt,
+                    const ast::Statement* stmt);
+
   /// Use to signal that we plan on hoisting a decl before `before_expr`. This
   /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
   /// needed.
diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc
index 91e17a5..589eb9a 100644
--- a/src/tint/transform/utils/hoist_to_decl_before_test.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc
@@ -16,6 +16,8 @@
 
 #include "gtest/gtest-spi.h"
 #include "src/tint/program_builder.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/statement.h"
 #include "src/tint/transform/test_helper.h"
 #include "src/tint/transform/utils/hoist_to_decl_before.h"
 
@@ -412,5 +414,185 @@
   EXPECT_EQ(expect, str(cloned));
 }
 
+TEST_F(HoistToDeclBeforeTest, InsertBefore_Block) {
+  // fn foo() {
+  // }
+  // fn f() {
+  //     var a = 1;
+  // }
+  ProgramBuilder b;
+  b.Func("foo", {}, b.ty.void_(), {});
+  auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+  b.Func("f", {}, b.ty.void_(), {var});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* before_stmt = ctx.src->Sem().Get(var);
+  auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+  hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+  foo();
+  var a = 1;
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopInit) {
+  // fn foo() {
+  // }
+  // fn f() {
+  //     for(var a = 1; true;) {
+  //     }
+  // }
+  ProgramBuilder b;
+  b.Func("foo", {}, b.ty.void_(), {});
+  auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+  auto* s = b.For(var, b.Expr(true), {}, b.Block());
+  b.Func("f", {}, b.ty.void_(), {s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* before_stmt = ctx.src->Sem().Get(var);
+  auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+  hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+  foo();
+  for(var a = 1; true; ) {
+  }
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
+  // fn foo() {
+  // }
+  // fn f() {
+  //     var a = 1;
+  //     for(; true; a+=1) {
+  //     }
+  // }
+  ProgramBuilder b;
+  b.Func("foo", {}, b.ty.void_(), {});
+  auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+  auto* cont = b.CompoundAssign("a", b.Expr(1), ast::BinaryOp::kAdd);
+  auto* s = b.For({}, b.Expr(true), cont, b.Block());
+  b.Func("f", {}, b.ty.void_(), {var, s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
+  auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+  hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+  var a = 1;
+  loop {
+    if (!(true)) {
+      break;
+    }
+    {
+    }
+
+    continuing {
+      foo();
+      a += 1;
+    }
+  }
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ElseIf) {
+  // fn foo() {
+  // }
+  // fn f() {
+  //     var a : bool;
+  //     if (true) {
+  //     } else if (a) {
+  //     } else {
+  //     }
+  // }
+  ProgramBuilder b;
+  b.Func("foo", {}, b.ty.void_(), {});
+  auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+  auto* elseif = b.Else(b.Expr("a"), b.Block());
+  auto* s = b.If(b.Expr(true), b.Block(),  //
+                 elseif,                   //
+                 b.Else(b.Block()));
+  b.Func("f", {}, b.ty.void_(), {var, s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* before_stmt = ctx.src->Sem().Get(elseif);
+  auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+  hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+  var a : bool;
+  if (true) {
+  } else {
+    foo();
+    if (a) {
+    } else {
+    }
+  }
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
 }  // namespace
 }  // namespace tint::transform