HoistToDeclBefore: add Prepare(const sem::Expression*) function

Used to signal that we plan on hoisting a decl before `before_expr`.
This will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s
if needed.

Bug: tint:1300
Change-Id: I6fed790564f05a9db110866f946af4a66a1311db
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/83101
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index 508b974..ff345eb 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -56,29 +56,39 @@
   /// For-loops that need to be decomposed to loops.
   std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
 
-  /// If statements with 'else if's that need to be decomposed to 'else { if
-  /// }'
+  /// If statements with 'else if's that need to be decomposed to 'else {if}'
   std::unordered_map<const sem::IfStatement*, IfInfo> ifs;
 
-  // Inserts `decl` before `sem_expr`, possibly marking a for-loop to be
-  // converted to a loop, or an else-if to an else { if }.
-  bool InsertBefore(const sem::Expression* sem_expr,
+  // Inserts `decl` before `before_expr`, possibly marking a for-loop to be
+  // converted to a loop, or an else-if to an else { if }. If `decl` is nullptr,
+  // for-loop and else-if conversions are marked, but no hoisting takes place.
+  bool InsertBefore(const sem::Expression* before_expr,
                     const ast::VariableDeclStatement* decl) {
-    auto* sem_stmt = sem_expr->Stmt();
+    auto* sem_stmt = before_expr->Stmt();
     auto* stmt = sem_stmt->Declaration();
 
     if (auto* else_if = sem_stmt->As<sem::ElseStatement>()) {
       // Expression used in 'else if' condition.
       // Need to convert 'else if' to 'else { if }'.
       auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
-      if_info.else_ifs[else_if].cond_decls.push_back(decl);
+
+      // Index the map to convert this else if, even if `decl` is nullptr.
+      auto& decls = if_info.else_ifs[else_if].cond_decls;
+      if (decl) {
+        decls.emplace_back(decl);
+      }
       return true;
     }
 
     if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) {
       // Expression used in for-loop condition.
       // For-loop needs to be decomposed to a loop.
-      loops[fl].cond_decls.emplace_back(decl);
+
+      // Index the map to convert this for-loop, even if `decl` is nullptr.
+      auto& decls = loops[fl].cond_decls;
+      if (decl) {
+        decls.emplace_back(decl);
+      }
       return true;
     }
 
@@ -86,7 +96,9 @@
     if (auto* block = parent->As<sem::BlockStatement>()) {
       // Expression's statement sits in a block. Simple case.
       // Insert the decl before the parent statement
-      ctx.InsertBefore(block->Declaration()->statements, stmt, decl);
+      if (decl) {
+        ctx.InsertBefore(block->Declaration()->statements, stmt, decl);
+      }
       return true;
     }
 
@@ -95,15 +107,22 @@
       if (fl->Declaration()->initializer == stmt) {
         // Expression used in for-loop initializer.
         // Insert the let above the for-loop.
-        ctx.InsertBefore(fl->Block()->Declaration()->statements,
-                         fl->Declaration(), decl);
+        if (decl) {
+          ctx.InsertBefore(fl->Block()->Declaration()->statements,
+                           fl->Declaration(), decl);
+        }
         return true;
       }
 
       if (fl->Declaration()->continuing == stmt) {
         // Expression used in for-loop continuing.
         // For-loop needs to be decomposed to a loop.
-        loops[fl].cont_decls.emplace_back(decl);
+
+        // Index the map to convert this for-loop, even if `decl` is nullptr.
+        auto& decls = loops[fl].cont_decls;
+        if (decl) {
+          decls.emplace_back(decl);
+        }
         return true;
       }
 
@@ -263,10 +282,10 @@
   /// @param as_const hoist to `let` if true, otherwise to `var`
   /// @param decl_name optional name to use for the variable/constant name
   /// @return true on success
-  bool HoistToDeclBefore(const sem::Expression* before_expr,
-                         const ast::Expression* expr,
-                         bool as_const,
-                         const char* decl_name) {
+  bool Add(const sem::Expression* before_expr,
+           const ast::Expression* expr,
+           bool as_const,
+           const char* decl_name) {
     auto name = b.Symbols().New(decl_name);
 
     // Construct the let/var that holds the hoisted expr
@@ -283,6 +302,15 @@
     return true;
   }
 
+  /// Use to signal that we plan on hoisting a decl before `before_expr`. This
+  /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
+  /// needed.
+  /// @param before_expr expression we would hoist a decl before
+  /// @return true on success
+  bool Prepare(const sem::Expression* before_expr) {
+    return InsertBefore(before_expr, nullptr);
+  }
+
   /// Applies any scheduled insertions from previous calls to Add() to
   /// CloneContext. Call this once before ctx.Clone().
   /// @return true on success
@@ -302,7 +330,11 @@
                             const ast::Expression* expr,
                             bool as_const,
                             const char* decl_name) {
-  return state_->HoistToDeclBefore(before_expr, expr, as_const, decl_name);
+  return state_->Add(before_expr, expr, as_const, decl_name);
+}
+
+bool HoistToDeclBefore::Prepare(const sem::Expression* before_expr) {
+  return state_->Prepare(before_expr);
 }
 
 bool HoistToDeclBefore::Apply() {
diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h
index 8f35a09..583896d 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.h
+++ b/src/tint/transform/utils/hoist_to_decl_before.h
@@ -23,8 +23,8 @@
 namespace tint::transform {
 
 /// Utility class that can be used to hoist expressions before other
-/// expressions, possibly converting 'for' loops to 'loop's and 'else if to
-// 'else if'.
+/// expressions, possibly converting 'for-loop's to 'loop's and 'else-if's to
+// 'else {if}'s.
 class HoistToDeclBefore {
  public:
   /// Constructor
@@ -46,6 +46,13 @@
            bool as_const,
            const char* decl_name = "");
 
+  /// Use to signal that we plan on hoisting a decl before `before_expr`. This
+  /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
+  /// needed.
+  /// @param before_expr expression we would hoist a decl before
+  /// @return true on success
+  bool Prepare(const sem::Expression* before_expr);
+
   /// Applies any scheduled insertions from previous calls to Add() to
   /// CloneContext. Call this once before ctx.Clone().
   /// @return true on success
diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc
index 70dc8bc..91e17a5 100644
--- a/src/tint/transform/utils/hoist_to_decl_before_test.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc
@@ -287,5 +287,130 @@
   EXPECT_EQ(expect, str(cloned));
 }
 
+TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCond) {
+  // fn f() {
+  //     var a : bool;
+  //     for(; a; ) {
+  //     }
+  // }
+  ProgramBuilder b;
+  auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+  auto* expr = b.Expr("a");
+  auto* s = b.For({}, expr, {}, b.Block());
+  b.Func("f", {}, b.ty.void_(), {var, s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* sem_expr = ctx.src->Sem().Get(expr);
+  hoistToDeclBefore.Prepare(sem_expr);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn f() {
+  var a : bool;
+  loop {
+    if (!(a)) {
+      break;
+    }
+    {
+    }
+  }
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCont) {
+  // fn f() {
+  //     for(; true; var a = 1) {
+  //     }
+  // }
+  ProgramBuilder b;
+  auto* expr = b.Expr(1);
+  auto* s =
+      b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block());
+  b.Func("f", {}, b.ty.void_(), {s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* sem_expr = ctx.src->Sem().Get(expr);
+  hoistToDeclBefore.Prepare(sem_expr);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn f() {
+  loop {
+    if (!(true)) {
+      break;
+    }
+    {
+    }
+
+    continuing {
+      var a = 1;
+    }
+  }
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Prepare_ElseIf) {
+  // fn f() {
+  //     var a : bool;
+  //     if (true) {
+  //     } else if (a) {
+  //     } else {
+  //     }
+  // }
+  ProgramBuilder b;
+  auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+  auto* expr = b.Expr("a");
+  auto* s = b.If(b.Expr(true), b.Block(),  //
+                 b.Else(expr, b.Block()),  //
+                 b.Else(b.Block()));
+  b.Func("f", {}, b.ty.void_(), {var, s});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* sem_expr = ctx.src->Sem().Get(expr);
+  hoistToDeclBefore.Prepare(sem_expr);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn f() {
+  var a : bool;
+  if (true) {
+  } else {
+    if (a) {
+    } else {
+    }
+  }
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
 }  // namespace
 }  // namespace tint::transform