HoistToDeclBefore: Add InsertBefore method
This can be used to insert a new statement before an existing
statement, and will take care of converting for-loop and else-if
statements as necessary.
Change-Id: I5ef20f33cf36bb48ea5dabe1048c9d9b3c61b3ee
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85281
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc
index ff345eb..05c56bb 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before.cc
@@ -59,84 +59,6 @@
/// If statements with 'else if's that need to be decomposed to 'else {if}'
std::unordered_map<const sem::IfStatement*, IfInfo> ifs;
- // Inserts `decl` before `before_expr`, possibly marking a for-loop to be
- // converted to a loop, or an else-if to an else { if }. If `decl` is nullptr,
- // for-loop and else-if conversions are marked, but no hoisting takes place.
- bool InsertBefore(const sem::Expression* before_expr,
- const ast::VariableDeclStatement* decl) {
- auto* sem_stmt = before_expr->Stmt();
- auto* stmt = sem_stmt->Declaration();
-
- if (auto* else_if = sem_stmt->As<sem::ElseStatement>()) {
- // Expression used in 'else if' condition.
- // Need to convert 'else if' to 'else { if }'.
- auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
-
- // Index the map to convert this else if, even if `decl` is nullptr.
- auto& decls = if_info.else_ifs[else_if].cond_decls;
- if (decl) {
- decls.emplace_back(decl);
- }
- return true;
- }
-
- if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) {
- // Expression used in for-loop condition.
- // For-loop needs to be decomposed to a loop.
-
- // Index the map to convert this for-loop, even if `decl` is nullptr.
- auto& decls = loops[fl].cond_decls;
- if (decl) {
- decls.emplace_back(decl);
- }
- return true;
- }
-
- auto* parent = sem_stmt->Parent(); // The statement's parent
- if (auto* block = parent->As<sem::BlockStatement>()) {
- // Expression's statement sits in a block. Simple case.
- // Insert the decl before the parent statement
- if (decl) {
- ctx.InsertBefore(block->Declaration()->statements, stmt, decl);
- }
- return true;
- }
-
- if (auto* fl = parent->As<sem::ForLoopStatement>()) {
- // Expression is used in a for-loop. These require special care.
- if (fl->Declaration()->initializer == stmt) {
- // Expression used in for-loop initializer.
- // Insert the let above the for-loop.
- if (decl) {
- ctx.InsertBefore(fl->Block()->Declaration()->statements,
- fl->Declaration(), decl);
- }
- return true;
- }
-
- if (fl->Declaration()->continuing == stmt) {
- // Expression used in for-loop continuing.
- // For-loop needs to be decomposed to a loop.
-
- // Index the map to convert this for-loop, even if `decl` is nullptr.
- auto& decls = loops[fl].cont_decls;
- if (decl) {
- decls.emplace_back(decl);
- }
- return true;
- }
-
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled use of expression in for-loop";
- return false;
- }
-
- TINT_ICE(Transform, b.Diagnostics())
- << "unhandled expression parent statement type: "
- << parent->TypeInfo().name;
- return false;
- }
-
// Converts any for-loops marked for conversion to loops, inserting
// registered declaration statements before the condition or continuing
// statement.
@@ -293,7 +215,7 @@
: b.Var(name, nullptr, ctx.Clone(expr));
auto* decl = b.Decl(v);
- if (!InsertBefore(before_expr, decl)) {
+ if (!InsertBefore(before_expr->Stmt(), decl)) {
return false;
}
@@ -302,13 +224,95 @@
return true;
}
+ /// Inserts `stmt` before `before_stmt`, possibly marking a for-loop to be
+ /// converted to a loop, or an else-if to an else { if }. If `decl` is
+ /// nullptr, for-loop and else-if conversions are marked, but no hoisting
+ /// takes place.
+ /// @param before_stmt statement to insert `stmt` before
+ /// @param stmt statement to insert
+ /// @return true on success
+ bool InsertBefore(const sem::Statement* before_stmt,
+ const ast::Statement* stmt) {
+ auto* ip = before_stmt->Declaration();
+
+ if (auto* else_if = before_stmt->As<sem::ElseStatement>()) {
+ // Insertion point is an 'else if' condition.
+ // Need to convert 'else if' to 'else { if }'.
+ auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
+
+ // Index the map to convert this else if, even if `stmt` is nullptr.
+ auto& decls = if_info.else_ifs[else_if].cond_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ if (auto* fl = before_stmt->As<sem::ForLoopStatement>()) {
+ // Insertion point is a for-loop condition.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to convert this for-loop, even if `stmt` is nullptr.
+ auto& decls = loops[fl].cond_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ auto* parent = before_stmt->Parent(); // The statement's parent
+ if (auto* block = parent->As<sem::BlockStatement>()) {
+ // Insert point sits in a block. Simple case.
+ // Insert the stmt before the parent statement.
+ if (stmt) {
+ ctx.InsertBefore(block->Declaration()->statements, ip, stmt);
+ }
+ return true;
+ }
+
+ if (auto* fl = parent->As<sem::ForLoopStatement>()) {
+ // Insertion point is a for-loop initializer or continuing statement.
+ // These require special care.
+ if (fl->Declaration()->initializer == ip) {
+ // Insertion point is a for-loop initializer.
+ // Insert the new statement above the for-loop.
+ if (stmt) {
+ ctx.InsertBefore(fl->Block()->Declaration()->statements,
+ fl->Declaration(), stmt);
+ }
+ return true;
+ }
+
+ if (fl->Declaration()->continuing == ip) {
+ // Insertion point is a for-loop continuing statement.
+ // For-loop needs to be decomposed to a loop.
+
+ // Index the map to convert this for-loop, even if `stmt` is nullptr.
+ auto& decls = loops[fl].cont_decls;
+ if (stmt) {
+ decls.emplace_back(stmt);
+ }
+ return true;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled use of expression in for-loop";
+ return false;
+ }
+
+ TINT_ICE(Transform, b.Diagnostics())
+ << "unhandled expression parent statement type: "
+ << parent->TypeInfo().name;
+ return false;
+ }
+
/// Use to signal that we plan on hoisting a decl before `before_expr`. This
/// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
/// needed.
/// @param before_expr expression we would hoist a decl before
/// @return true on success
bool Prepare(const sem::Expression* before_expr) {
- return InsertBefore(before_expr, nullptr);
+ return InsertBefore(before_expr->Stmt(), nullptr);
}
/// Applies any scheduled insertions from previous calls to Add() to
@@ -333,6 +337,11 @@
return state_->Add(before_expr, expr, as_const, decl_name);
}
+bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
+ const ast::Statement* stmt) {
+ return state_->InsertBefore(before_stmt, stmt);
+}
+
bool HoistToDeclBefore::Prepare(const sem::Expression* before_expr) {
return state_->Prepare(before_expr);
}
diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h
index 583896d..2d94f52 100644
--- a/src/tint/transform/utils/hoist_to_decl_before.h
+++ b/src/tint/transform/utils/hoist_to_decl_before.h
@@ -46,6 +46,14 @@
bool as_const,
const char* decl_name = "");
+ /// Inserts `stmt` before `before_stmt`, possibly converting 'for-loop's to
+ /// 'loop's if necessary.
+ /// @param before_stmt statement to insert `stmt` before
+ /// @param stmt statement to insert
+ /// @return true on success
+ bool InsertBefore(const sem::Statement* before_stmt,
+ const ast::Statement* stmt);
+
/// Use to signal that we plan on hoisting a decl before `before_expr`. This
/// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if
/// needed.
diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc
index 91e17a5..589eb9a 100644
--- a/src/tint/transform/utils/hoist_to_decl_before_test.cc
+++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc
@@ -16,6 +16,8 @@
#include "gtest/gtest-spi.h"
#include "src/tint/program_builder.h"
+#include "src/tint/sem/if_statement.h"
+#include "src/tint/sem/statement.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
@@ -412,5 +414,185 @@
EXPECT_EQ(expect, str(cloned));
}
+TEST_F(HoistToDeclBeforeTest, InsertBefore_Block) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1;
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+ b.Func("f", {}, b.ty.void_(), {var});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ foo();
+ var a = 1;
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopInit) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // for(var a = 1; true;) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+ auto* s = b.For(var, b.Expr(true), {}, b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(var);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ foo();
+ for(var a = 1; true; ) {
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a = 1;
+ // for(; true; a+=1) {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", nullptr, b.Expr(1)));
+ auto* cont = b.CompoundAssign("a", b.Expr(1), ast::BinaryOp::kAdd);
+ auto* s = b.For({}, b.Expr(true), cont, b.Block());
+ b.Func("f", {}, b.ty.void_(), {var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a = 1;
+ loop {
+ if (!(true)) {
+ break;
+ }
+ {
+ }
+
+ continuing {
+ foo();
+ a += 1;
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, InsertBefore_ElseIf) {
+ // fn foo() {
+ // }
+ // fn f() {
+ // var a : bool;
+ // if (true) {
+ // } else if (a) {
+ // } else {
+ // }
+ // }
+ ProgramBuilder b;
+ b.Func("foo", {}, b.ty.void_(), {});
+ auto* var = b.Decl(b.Var("a", b.ty.bool_()));
+ auto* elseif = b.Else(b.Expr("a"), b.Block());
+ auto* s = b.If(b.Expr(true), b.Block(), //
+ elseif, //
+ b.Else(b.Block()));
+ b.Func("f", {}, b.ty.void_(), {var, s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ HoistToDeclBefore hoistToDeclBefore(ctx);
+ auto* before_stmt = ctx.src->Sem().Get(elseif);
+ auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
+ hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
+ hoistToDeclBefore.Apply();
+
+ ctx.Clone();
+ Program cloned(std::move(cloned_b));
+
+ auto* expect = R"(
+fn foo() {
+}
+
+fn f() {
+ var a : bool;
+ if (true) {
+ } else {
+ foo();
+ if (a) {
+ } else {
+ }
+ }
+}
+)";
+
+ EXPECT_EQ(expect, str(cloned));
+}
+
} // namespace
} // namespace tint::transform