CloneContext: Pass the vector to InsertBefore()
There's usually only ever one vector we want to insert into.
Inserting into *all* vectors that happen to contain the reference object is likely unintended, and is a foot-gun waiting to go off.
Change-Id: I533084ccad102fc998b851fd238fd6bea9299450
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46445
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/clone_context.cc b/src/clone_context.cc
index 06a28e1..f449b0b 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -20,6 +20,9 @@
namespace tint {
+CloneContext::ListTransforms::ListTransforms() = default;
+CloneContext::ListTransforms::~ListTransforms() = default;
+
CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
: dst(to), src(from) {}
CloneContext::~CloneContext() = default;
diff --git a/src/clone_context.h b/src/clone_context.h
index 2853ad4..fd7a0ac 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -15,6 +15,7 @@
#ifndef SRC_CLONE_CONTEXT_H_
#define SRC_CLONE_CONTEXT_H_
+#include <algorithm>
#include <functional>
#include <unordered_map>
#include <utility>
@@ -178,14 +179,29 @@
std::vector<T*> Clone(const std::vector<T*>& v) {
std::vector<T*> out;
out.reserve(v.size());
- for (auto& el : v) {
- auto it = insert_before_.find(el);
- if (it != insert_before_.end()) {
- for (auto insert : it->second) {
- out.emplace_back(CheckedCast<T>(insert));
+
+ auto list_transform_it = list_transforms_.find(&v);
+ if (list_transform_it != list_transforms_.end()) {
+ const auto& transforms = list_transform_it->second;
+ for (auto& el : v) {
+ auto insert_before_it = transforms.insert_before_.find(el);
+ if (insert_before_it != transforms.insert_before_.end()) {
+ for (auto insert : insert_before_it->second) {
+ out.emplace_back(CheckedCast<T>(insert));
+ }
+ }
+ out.emplace_back(Clone(el));
+ auto insert_after_it = transforms.insert_after_.find(el);
+ if (insert_after_it != transforms.insert_after_.end()) {
+ for (auto insert : insert_after_it->second) {
+ out.emplace_back(CheckedCast<T>(insert));
+ }
}
}
- out.emplace_back(Clone(el));
+ } else {
+ for (auto& el : v) {
+ out.emplace_back(Clone(el));
+ }
}
return out;
}
@@ -293,15 +309,46 @@
return *this;
}
- /// Inserts `object` before `before` whenever a vector containing `object` is
- /// cloned.
+ /// Inserts `object` before `before` whenever `vector` is cloned.
+ /// @param vector the vector in #src
/// @param before a pointer to the object in #src
/// @param object a pointer to the object in #dst that will be inserted before
/// any occurrence of the clone of `before`
/// @returns this CloneContext so calls can be chained
- template <typename BEFORE, typename OBJECT>
- CloneContext& InsertBefore(BEFORE* before, OBJECT* object) {
- auto& list = insert_before_[before];
+ template <typename T, typename BEFORE, typename OBJECT>
+ CloneContext& InsertBefore(const std::vector<T>& vector,
+ BEFORE* before,
+ OBJECT* object) {
+ if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
+ TINT_ICE(Diagnostics())
+ << "CloneContext::InsertBefore() vector does not contain before";
+ return *this;
+ }
+
+ auto& transforms = list_transforms_[&vector];
+ auto& list = transforms.insert_before_[before];
+ list.emplace_back(object);
+ return *this;
+ }
+
+ /// Inserts `object` after `after` whenever `vector` is cloned.
+ /// @param vector the vector in #src
+ /// @param after a pointer to the object in #src
+ /// @param object a pointer to the object in #dst that will be inserted after
+ /// any occurrence of the clone of `after`
+ /// @returns this CloneContext so calls can be chained
+ template <typename T, typename AFTER, typename OBJECT>
+ CloneContext& InsertAfter(const std::vector<T>& vector,
+ AFTER* after,
+ OBJECT* object) {
+ if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
+ TINT_ICE(Diagnostics())
+ << "CloneContext::InsertAfter() vector does not contain after";
+ return *this;
+ }
+
+ auto& transforms = list_transforms_[&vector];
+ auto& list = transforms.insert_after_[after];
list.emplace_back(object);
return *this;
}
@@ -380,17 +427,33 @@
/// A vector of Cloneable*
using CloneableList = std::vector<Cloneable*>;
+ // Transformations to be applied to a list (vector)
+ struct ListTransforms {
+ /// Constructor
+ ListTransforms();
+ /// Destructor
+ ~ListTransforms();
+
+ /// A map of object in #src to the list of cloned objects in #dst.
+ /// Clone(const std::vector<T*>& v) will use this to insert the map-value
+ /// list into the target vector before cloning and inserting the map-key.
+ std::unordered_map<const Cloneable*, CloneableList> insert_before_;
+
+ /// A map of object in #src to the list of cloned objects in #dst.
+ /// Clone(const std::vector<T*>& v) will use this to insert the map-value
+ /// list into the target vector after cloning and inserting the map-key.
+ std::unordered_map<const Cloneable*, CloneableList> insert_after_;
+ };
+
/// A map of object in #src to their cloned equivalent in #dst
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
- /// A map of object in #src to the list of cloned objects in #dst.
- /// Clone(const std::vector<T*>& v) will use this to insert the map-value list
- /// into the target vector/ before cloning and inserting the map-key.
- std::unordered_map<const Cloneable*, CloneableList> insert_before_;
-
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;
+ /// Map of std::vector pointer to transforms for that list
+ std::unordered_map<const void*, ListTransforms> list_transforms_;
+
/// Symbol transform registered with ReplaceAll()
SymbolTransform symbol_transform_;
};
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index 372da30..b09ccdc 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -287,9 +287,10 @@
ProgramBuilder cloned;
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
- auto* cloned_root = CloneContext(&cloned, &original)
- .InsertBefore(original_root->b, insertion)
- .Clone(original_root);
+ auto* cloned_root =
+ CloneContext(&cloned, &original)
+ .InsertBefore(original_root->vec, original_root->b, insertion)
+ .Clone(original_root);
EXPECT_EQ(cloned_root->vec.size(), 4u);
EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
@@ -303,6 +304,36 @@
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
}
+TEST(CloneContext, CloneWithInsertAfter) {
+ ProgramBuilder builder;
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+ original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
+ original_root->vec = {original_root->a, original_root->b, original_root->c};
+ Program original(std::move(builder));
+
+ ProgramBuilder cloned;
+ auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
+
+ auto* cloned_root =
+ CloneContext(&cloned, &original)
+ .InsertAfter(original_root->vec, original_root->b, insertion)
+ .Clone(original_root);
+
+ EXPECT_EQ(cloned_root->vec.size(), 4u);
+ EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
+ EXPECT_EQ(cloned_root->vec[1], cloned_root->b);
+ EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
+
+ EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+ EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
+ EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("b"));
+ EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("insertion"));
+ EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
+}
+
TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index ca4af82..1c52066 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -18,6 +18,7 @@
#include "src/program_builder.h"
#include "src/semantic/function.h"
+#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
namespace tint {
@@ -119,7 +120,7 @@
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(
func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
- ctx.InsertBefore(*func->body()->begin(),
+ ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
@@ -134,7 +135,7 @@
ctx.dst->Symbols().New(),
ctx.dst->create<ast::Struct>(new_struct_members,
ast::DecorationList{}));
- ctx.InsertBefore(func, in_struct);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct);
// Create a new function parameter using this struct type.
auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct,
@@ -177,12 +178,13 @@
ctx.dst->Symbols().New(),
ctx.dst->create<ast::Struct>(new_struct_members,
ast::DecorationList{}));
- ctx.InsertBefore(func, out_struct);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct);
new_ret_type = out_struct;
// Replace all return statements.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
+ auto* ret_sem = ctx.src->Sem().Get(ret);
// Reconstruct the return value using the newly created struct.
auto* new_ret_value = ctx.Clone(ret->value());
ast::ExpressionList ret_values;
@@ -193,7 +195,7 @@
auto temp = ctx.dst->Symbols().New();
auto* temp_var = ctx.dst->Decl(
ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value));
- ctx.InsertBefore(ret, temp_var);
+ ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var);
new_ret_value = ctx.dst->Expr(temp);
}
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 21b9bf7..0a08e24 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -96,7 +96,8 @@
auto* dst_ident = ctx.dst->Expr(dst_symbol);
// Insert the constant before the usage
- ctx.InsertBefore(src_stmt, dst_var_decl);
+ ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
+ dst_var_decl);
// Replace the inlined array with a reference to the constant
ctx.Replace(src_init, dst_ident);
}
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 1eb7a6a..d29cfa4 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -21,6 +21,7 @@
#include "src/ast/return_statement.h"
#include "src/program_builder.h"
#include "src/semantic/function.h"
+#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
namespace tint {
@@ -162,13 +163,15 @@
return_func_symbol, ast::VariableList{store_value},
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
ast::DecorationList{}, ast::DecorationList{});
- ctx.InsertBefore(func, return_func);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func);
// Replace all return statements with calls to the output function.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
+ auto* ret_sem = ctx.src->Sem().Get(ret);
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
- ctx.InsertBefore(ret, ctx.dst->create<ast::CallStatement>(call));
+ ctx.InsertBefore(ret_sem->Block()->statements(), ret,
+ ctx.dst->create<ast::CallStatement>(call));
ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
}
}
@@ -247,7 +250,7 @@
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
ast::StorageClass::kInput, nullptr, new_decorations);
- ctx.InsertBefore(func, global_var);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
return global_var_symbol;
}
@@ -269,7 +272,8 @@
// Create a function-scope variable for the struct.
auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
- ctx.InsertBefore(*func->body()->begin(), ctx.dst->WrapInStatement(func_var));
+ ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
+ ctx.dst->WrapInStatement(func_var));
return func_var_symbol;
}
@@ -292,7 +296,7 @@
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
ast::StorageClass::kOutput, nullptr, new_decorations);
- ctx.InsertBefore(func, global_var);
+ ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
// Create the assignment instruction.
ast::Expression* rhs = ctx.dst->Expr(store_value);