Factor out GetInsertionPoint to transform/utils
This function was copy-pasted in two transforms, and will be used in the
next one I'm writing.
Bug: tint:1080
Change-Id: Ic5ffe68a7e9d00b37722e8f5faff01e9e15fa6b1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85262
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 1646c85..a2541bd 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -489,6 +489,8 @@
"transform/unshadow.h",
"transform/unwind_discard_functions.cc",
"transform/unwind_discard_functions.h",
+ "transform/utils/get_insertion_point.cc",
+ "transform/utils/get_insertion_point.h",
"transform/utils/hoist_to_decl_before.cc",
"transform/utils/hoist_to_decl_before.h",
"transform/var_for_dynamic_index.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 1f5e90c..0a3ce1c 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -377,6 +377,8 @@
transform/wrap_arrays_in_structs.h
transform/zero_init_workgroup_memory.cc
transform/zero_init_workgroup_memory.h
+ transform/utils/get_insertion_point.cc
+ transform/utils/get_insertion_point.h
transform/utils/hoist_to_decl_before.cc
transform/utils/hoist_to_decl_before.h
sem/bool_type.cc
@@ -1039,6 +1041,7 @@
transform/vertex_pulling_test.cc
transform/wrap_arrays_in_structs_test.cc
transform/zero_init_workgroup_memory_test.cc
+ transform/utils/get_insertion_point_test.cc
transform/utils/hoist_to_decl_before_test.cc
)
endif()
diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc
index e1d5ab2..9fd19db 100644
--- a/src/tint/transform/promote_side_effects_to_decl.cc
+++ b/src/tint/transform/promote_side_effects_to_decl.cc
@@ -28,6 +28,7 @@
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/manager.h"
+#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
#include "src/tint/utils/scoped_assignment.h"
@@ -556,33 +557,11 @@
});
}
- // For the input statement, returns the block and statement within that block
- // to insert before/after.
- std::pair<const sem::BlockStatement*, const ast::Statement*>
- GetInsertionPoint(const ast::Statement* stmt) {
- auto* sem_stmt = sem.Get(stmt);
- if (sem_stmt) {
- auto* parent = sem_stmt->Parent();
- if (auto* block = parent->As<sem::BlockStatement>()) {
- // Common case, just insert in the current block above the input
- // statement.
- return {block, stmt};
- }
- if (auto* fl = parent->As<sem::ForLoopStatement>()) {
- if (fl->Declaration()->initializer == stmt) {
- // For loop init, insert above the for loop itself.
- return {fl->Block(), fl->Declaration()};
- }
- }
- }
- return {};
- }
-
// Inserts statements in `stmts` before `stmt`
void InsertBefore(const ast::StatementList& stmts,
const ast::Statement* stmt) {
if (!stmts.empty()) {
- auto ip = GetInsertionPoint(stmt);
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
for (auto* s : stmts) {
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s);
}
diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc
index 80a47b1..b15f3a0 100644
--- a/src/tint/transform/unwind_discard_functions.cc
+++ b/src/tint/transform/unwind_discard_functions.cc
@@ -28,6 +28,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/if_statement.h"
+#include "src/tint/transform/utils/get_insertion_point.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);
@@ -42,46 +43,6 @@
Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read
Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read
- // For the input statement, returns the block and statement within that
- // block to insert before/after.
- std::pair<const sem::BlockStatement*, const ast::Statement*>
- GetInsertionPoint(const ast::Statement* stmt) {
- using RetType =
- std::pair<const sem::BlockStatement*, const ast::Statement*>;
-
- if (auto* sem_stmt = sem.Get(stmt)) {
- auto* parent = sem_stmt->Parent();
- return Switch(
- parent,
- [&](const sem::BlockStatement* block) -> RetType {
- // Common case, just insert in the current block above the input
- // statement.
- return {block, stmt};
- },
- [&](const sem::ForLoopStatement* fl) -> RetType {
- // `stmt` is either the for loop initializer or the continuing
- // statement of a for-loop.
- if (fl->Declaration()->initializer == stmt) {
- // For loop init, insert above the for loop itself.
- return {fl->Block(), fl->Declaration()};
- }
-
- TINT_ICE(Transform, b.Diagnostics())
- << "cannot insert before or after continuing statement of a "
- "for-loop";
- return {};
- },
- [&](Default) -> RetType {
- TINT_ICE(Transform, b.Diagnostics())
- << "expected parent of statement to be either a block or for "
- "loop";
- return {};
- });
- }
-
- return {};
- }
-
// If `block`'s parent is of type TO, returns pointer to it.
template <typename TO>
const TO* ParentAs(const ast::BlockStatement* block) {
@@ -186,7 +147,7 @@
const sem::Expression* sem_expr) {
auto* expr = sem_expr->Declaration();
- auto ip = GetInsertionPoint(stmt);
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
auto var_name = b.Sym();
auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr)));
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
@@ -239,7 +200,7 @@
return HoistAndInsertBefore(stmt, sem_expr);
}
- auto ip = GetInsertionPoint(stmt);
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
ctx.InsertAfter(ip.first->Declaration()->statements, ip.second,
IfDiscardReturn(stmt));
return nullptr; // Don't replace current statement
@@ -269,7 +230,7 @@
to_insert = b.Assign(var_name, true);
}
- auto ip = GetInsertionPoint(stmt);
+ auto ip = utils::GetInsertionPoint(ctx, stmt);
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
return Return(stmt);
}
diff --git a/src/tint/transform/utils/get_insertion_point.cc b/src/tint/transform/utils/get_insertion_point.cc
new file mode 100644
index 0000000..0f00e0c
--- /dev/null
+++ b/src/tint/transform/utils/get_insertion_point.cc
@@ -0,0 +1,58 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/utils/get_insertion_point.h"
+#include "src/tint/debug.h"
+#include "src/tint/diagnostic/diagnostic.h"
+#include "src/tint/sem/for_loop_statement.h"
+
+namespace tint::transform::utils {
+
+InsertionPoint GetInsertionPoint(CloneContext& ctx,
+ const ast::Statement* stmt) {
+ auto& sem = ctx.src->Sem();
+ auto& diag = ctx.dst->Diagnostics();
+ using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
+
+ if (auto* sem_stmt = sem.Get(stmt)) {
+ auto* parent = sem_stmt->Parent();
+ return Switch(
+ parent,
+ [&](const sem::BlockStatement* block) -> RetType {
+ // Common case, can insert in the current block above/below the input
+ // statement.
+ return {block, stmt};
+ },
+ [&](const sem::ForLoopStatement* fl) -> RetType {
+ // `stmt` is either the for loop initializer or the continuing
+ // statement of a for-loop.
+ if (fl->Declaration()->initializer == stmt) {
+ // For loop init, can insert above the for loop itself.
+ return {fl->Block(), fl->Declaration()};
+ }
+
+ // Cannot insert before or after continuing statement of a for-loop
+ return {};
+ },
+ [&](Default) -> RetType {
+ TINT_ICE(Transform, diag) << "expected parent of statement to be "
+ "either a block or for loop";
+ return {};
+ });
+ }
+
+ return {};
+}
+
+} // namespace tint::transform::utils
diff --git a/src/tint/transform/utils/get_insertion_point.h b/src/tint/transform/utils/get_insertion_point.h
new file mode 100644
index 0000000..85abcea
--- /dev/null
+++ b/src/tint/transform/utils/get_insertion_point.h
@@ -0,0 +1,40 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
+#define SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
+
+#include <utility>
+
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+
+namespace tint::transform::utils {
+
+/// InsertionPoint is a pair of the block (`first`) within which, and the
+/// statement (`second`) before or after which to insert.
+using InsertionPoint =
+ std::pair<const sem::BlockStatement*, const ast::Statement*>;
+
+/// For the input statement, returns the block and statement within that
+/// block to insert before/after. If `stmt` is a for-loop continue statement,
+/// the function returns {nullptr, nullptr} as we cannot insert before/after it.
+/// @param ctx the clone context
+/// @param stmt the statement to insert before or after
+/// @return the insertion point
+InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt);
+
+} // namespace tint::transform::utils
+
+#endif // SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
diff --git a/src/tint/transform/utils/get_insertion_point_test.cc b/src/tint/transform/utils/get_insertion_point_test.cc
new file mode 100644
index 0000000..48e358e
--- /dev/null
+++ b/src/tint/transform/utils/get_insertion_point_test.cc
@@ -0,0 +1,94 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "gtest/gtest-spi.h"
+#include "src/tint/debug.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/transform/test_helper.h"
+#include "src/tint/transform/utils/get_insertion_point.h"
+
+namespace tint::transform {
+namespace {
+
+using GetInsertionPointTest = ::testing::Test;
+
+TEST_F(GetInsertionPointTest, Block) {
+ // fn f() {
+ // var a = 1;
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ auto* block = b.Block(var);
+ b.Func("f", {}, b.ty.void_(), {block});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ // Can insert in block containing the variable, above or below the input
+ // statement.
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first->Declaration(), block);
+ ASSERT_EQ(ip.second, var);
+}
+
+TEST_F(GetInsertionPointTest, ForLoopInit) {
+ // fn f() {
+ // for(var a = 1; true; ) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ auto* fl = b.For(var, b.Expr(true), {}, b.Block());
+ auto* func_block = b.Block(fl);
+ b.Func("f", {}, b.ty.void_(), {func_block});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ // Can insert in block containing for-loop above the for-loop itself.
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first->Declaration(), func_block);
+ ASSERT_EQ(ip.second, fl);
+}
+
+TEST_F(GetInsertionPointTest, ForLoopCont_Invalid) {
+ // fn f() {
+ // for(; true; var a = 1) {
+ // }
+ // }
+ ProgramBuilder b;
+ auto* expr = b.Expr(1);
+ auto* var = b.Decl(b.Var("a", nullptr, expr));
+ auto* s = b.For({}, b.Expr(true), var, b.Block());
+ b.Func("f", {}, b.ty.void_(), {s});
+
+ Program original(std::move(b));
+ ProgramBuilder cloned_b;
+ CloneContext ctx(&cloned_b, &original);
+
+ // Can't insert before/after for loop continue statement (would ned to be
+ // converted to loop).
+ auto ip = utils::GetInsertionPoint(ctx, var);
+ ASSERT_EQ(ip.first, nullptr);
+ ASSERT_EQ(ip.second, nullptr);
+}
+
+} // namespace
+} // namespace tint::transform
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index 7f6ec7e..6bdd527 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -339,6 +339,7 @@
"../../src/tint/transform/transform_test.cc",
"../../src/tint/transform/unshadow_test.cc",
"../../src/tint/transform/unwind_discard_functions_test.cc",
+ "../../src/tint/transform/utils/get_insertion_point_test.cc",
"../../src/tint/transform/utils/hoist_to_decl_before_test.cc",
"../../src/tint/transform/var_for_dynamic_index_test.cc",
"../../src/tint/transform/vectorize_scalar_matrix_constructors_test.cc",