CloneContext: Add support for transforming symbols
Will be used by a Renamer transform
Change-Id: Ic0e9b69874f51103f0beec7745d32a9f8419e93a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42841
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/clone_context.cc b/src/clone_context.cc
index 1ff9ff2..edc6e8b 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -28,6 +28,9 @@
CloneContext::~CloneContext() = default;
Symbol CloneContext::Clone(const Symbol& s) const {
+ if (symbol_transform_) {
+ return symbol_transform_(s);
+ }
return dst->Symbols().Register(src->Symbols().NameFor(s));
}
@@ -48,4 +51,9 @@
return dst->Diagnostics();
}
+CloneContext::CloneableTransform::CloneableTransform() = default;
+CloneContext::CloneableTransform::CloneableTransform(
+ const CloneableTransform&) = default;
+CloneContext::CloneableTransform::~CloneableTransform() = default;
+
} // namespace tint
diff --git a/src/clone_context.h b/src/clone_context.h
index 31c18c7..94851c7 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -18,6 +18,7 @@
#include <cassert>
#include <functional>
#include <unordered_map>
+#include <utility>
#include <vector>
#include "src/castable.h"
@@ -48,7 +49,18 @@
/// CloneContext holds the state used while cloning AST nodes and types.
class CloneContext {
+ /// ParamTypeIsPtrOf<F, T>::value is true iff the first parameter of
+ /// F is a pointer of (or derives from) type T.
+ template <typename F, typename T>
+ using ParamTypeIsPtrOf = traits::IsTypeOrDerived<
+ typename std::remove_pointer<traits::ParamTypeT<F, 0>>::type,
+ T>;
+
public:
+ /// SymbolTransform is a function that takes a symbol and returns a new
+ /// symbol.
+ using SymbolTransform = std::function<Symbol(Symbol)>;
+
/// Constructor
/// @param to the target ProgramBuilder to clone into
/// @param from the source Program to clone from
@@ -199,10 +211,8 @@
/// `replacer` must be function-like with the signature: `T* (T*)`
/// where `T` is a type deriving from Cloneable.
///
- /// If `replacer` returns a nullptr then Clone() will attempt the next
- /// registered replacer function that matches the object type. If no replacers
- /// match the object type, or all returned nullptr then Clone() will call
- /// `T::Clone()` to clone the object.
+ /// If `replacer` returns a nullptr then Clone() will call `T::Clone()` to
+ /// clone the object.
///
/// Example:
///
@@ -218,6 +228,9 @@
/// ctx.Clone();
/// ```
///
+ /// @warning a single handler can only be registered for any given type.
+ /// Attempting to register two handlers for the same type will result in an
+ /// ICE.
/// @warning The replacement object must be of the correct type for all
/// references of the original object. A type mismatch will result in an
/// assertion in debug builds, and undefined behavior in release builds.
@@ -225,13 +238,44 @@
/// `T* (T*)`, where `T` derives from Cloneable
/// @returns this CloneContext so calls can be chained
template <typename F>
- CloneContext& ReplaceAll(F&& replacer) {
+ traits::EnableIf<ParamTypeIsPtrOf<F, Cloneable>::value, CloneContext>&
+ ReplaceAll(F&& replacer) {
using TPtr = traits::ParamTypeT<F, 0>;
using T = typename std::remove_pointer<TPtr>::type;
- transforms_.emplace_back([=](Cloneable* in) {
- auto* in_as_t = in->As<T>();
- return in_as_t != nullptr ? replacer(in_as_t) : nullptr;
- });
+ for (auto& transform : transforms_) {
+ if (transform.typeinfo->Is(TypeInfo::Of<T>()) ||
+ TypeInfo::Of<T>().Is(*transform.typeinfo)) {
+ TINT_ICE(Diagnostics())
+ << "ReplaceAll() called with a handler for type "
+ << TypeInfo::Of<T>().name
+ << " that is already handled by a handler for type "
+ << transform.typeinfo->name;
+ return *this;
+ }
+ }
+ CloneableTransform transform;
+ transform.typeinfo = &TypeInfo::Of<T>();
+ transform.function = [=](Cloneable* in) { return replacer(in->As<T>()); };
+ transforms_.emplace_back(std::move(transform));
+ return *this;
+ }
+
+ /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
+ /// is called with a Symbol.
+ /// The returned symbol of `replacer` will be used as the replacement for
+ /// all references to the symbol that's being cloned. This returned Symbol
+ /// must be owned by the Program #dst.
+ /// @param replacer a function the signature `Symbol(Symbol)`.
+ /// @warning a SymbolTransform can only be registered once. Attempting to
+ /// register a SymbolTransform more than once will result in an ICE.
+ /// @returns this CloneContext so calls can be chained
+ CloneContext& ReplaceAll(const SymbolTransform& replacer) {
+ if (symbol_transform_) {
+ TINT_ICE(Diagnostics()) << "ReplaceAll(const SymbolTransform&) called "
+ "multiple times on the same CloneContext";
+ return *this;
+ }
+ symbol_transform_ = replacer;
return *this;
}
@@ -276,7 +320,19 @@
Program const* const src;
private:
- using Transform = std::function<Cloneable*(Cloneable*)>;
+ struct CloneableTransform {
+ /// Constructor
+ CloneableTransform();
+ /// Copy constructor
+ /// @param other the CloneableTransform to copy
+ CloneableTransform(const CloneableTransform& other);
+ /// Destructor
+ ~CloneableTransform();
+
+ // TypeInfo of the Cloneable that the transform operates on
+ const TypeInfo* typeinfo;
+ std::function<Cloneable*(Cloneable*)> function;
+ };
CloneContext(const CloneContext&) = delete;
CloneContext& operator=(const CloneContext&) = delete;
@@ -293,11 +349,16 @@
}
// Attempt to clone using the registered replacer functions.
- for (auto& f : transforms_) {
- if (Cloneable* c = f(a)) {
+ auto& typeinfo = a->TypeInfo();
+ for (auto& transform : transforms_) {
+ if (!typeinfo.Is(*transform.typeinfo)) {
+ continue;
+ }
+ if (Cloneable* c = transform.function(a)) {
cloned_.emplace(a, c);
return c;
}
+ break;
}
// No luck, Clone() will have to call T::Clone().
@@ -329,8 +390,11 @@
/// into the target vector/ before cloning and inserting the map-key.
std::unordered_map<Cloneable*, CloneableList> insert_before_;
- /// Transform functions registered with ReplaceAll()
- std::vector<Transform> transforms_;
+ /// Cloneable transform functions registered with ReplaceAll()
+ std::vector<CloneableTransform> transforms_;
+
+ /// Symbol transform registered with ReplaceAll()
+ SymbolTransform symbol_transform_;
};
} // namespace tint
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index 5f38c99..99e1137 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -26,16 +26,16 @@
namespace {
struct Node : public Castable<Node, ast::Node> {
- explicit Node(const Source& source, std::string n) : Base(source), name(n) {}
+ explicit Node(const Source& source, Symbol n) : Base(source), name(n) {}
- std::string name;
+ Symbol name;
Node* a = nullptr;
Node* b = nullptr;
Node* c = nullptr;
std::vector<Node*> vec;
Node* Clone(CloneContext* ctx) const override {
- auto* out = ctx->dst->create<Node>(name);
+ auto* out = ctx->dst->create<Node>(ctx->Clone(name));
out->a = ctx->Clone(a);
out->b = ctx->Clone(b);
out->c = ctx->Clone(c);
@@ -48,10 +48,10 @@
};
struct Replaceable : public Castable<Replaceable, Node> {
- explicit Replaceable(const Source& source, std::string n) : Base(source, n) {}
+ explicit Replaceable(const Source& source, Symbol n) : Base(source, n) {}
};
struct Replacement : public Castable<Replacement, Replaceable> {
- explicit Replacement(const Source& source, std::string n) : Base(source, n) {}
+ explicit Replacement(const Source& source, Symbol n) : Base(source, n) {}
};
struct NotANode : public Castable<NotANode, ast::Node> {
@@ -67,12 +67,15 @@
TEST(CloneContext, Clone) {
ProgramBuilder builder;
- auto* original_root = builder.create<Node>("root");
- original_root->a = builder.create<Node>("a");
- original_root->a->b = builder.create<Node>("a->b");
- original_root->b = builder.create<Node>("b");
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->a->b =
+ builder.create<Node>(builder.Symbols().Register("a->b"));
+ original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
original_root->b->a = original_root->a; // Aliased
- original_root->b->b = builder.create<Node>("b->b");
+ original_root->b->b =
+ builder.create<Node>(builder.Symbols().Register("b->b"));
original_root->c = original_root->b; // Aliased
Program original(std::move(builder));
@@ -106,22 +109,25 @@
EXPECT_NE(cloned_root->b->b, original_root->b->b);
EXPECT_NE(cloned_root->c, original_root->c);
- EXPECT_EQ(cloned_root->name, "root");
- EXPECT_EQ(cloned_root->a->name, "a");
- EXPECT_EQ(cloned_root->a->b->name, "a->b");
- EXPECT_EQ(cloned_root->b->name, "b");
- EXPECT_EQ(cloned_root->b->b->name, "b->b");
+ EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+ EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+ EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("a->b"));
+ EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("b"));
+ EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("b->b"));
EXPECT_EQ(cloned_root->b->a, cloned_root->a); // Aliased
EXPECT_EQ(cloned_root->c, cloned_root->b); // Aliased
}
-TEST(CloneContext, CloneWithReplacements) {
+TEST(CloneContext, CloneWithReplaceAll_Cloneable) {
ProgramBuilder builder;
- auto* original_root = builder.create<Node>("root");
- original_root->a = builder.create<Node>("a");
- original_root->a->b = builder.create<Replaceable>("a->b");
- original_root->b = builder.create<Replaceable>("b");
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->a->b =
+ builder.create<Replaceable>(builder.Symbols().Register("a->b"));
+ original_root->b =
+ builder.create<Replaceable>(builder.Symbols().Register("b"));
original_root->b->a = original_root->a; // Aliased
original_root->c = original_root->b; // Aliased
Program original(std::move(builder));
@@ -141,8 +147,12 @@
CloneContext ctx(&cloned, &original);
ctx.ReplaceAll([&](Replaceable* in) {
- auto* out = cloned.create<Replacement>("replacement:" + in->name);
- out->b = cloned.create<Node>("replacement-child:" + in->name);
+ auto out_name = cloned.Symbols().Register(
+ "replacement:" + original.Symbols().NameFor(in->name));
+ auto b_name = cloned.Symbols().Register(
+ "replacement-child:" + original.Symbols().NameFor(in->name));
+ auto* out = cloned.create<Replacement>(out_name);
+ out->b = cloned.create<Node>(b_name);
out->c = ctx.Clone(in->a);
return out;
});
@@ -181,12 +191,14 @@
EXPECT_NE(cloned_root->b->a, original_root->b->a);
EXPECT_NE(cloned_root->c, original_root->c);
- EXPECT_EQ(cloned_root->name, "root");
- EXPECT_EQ(cloned_root->a->name, "a");
- EXPECT_EQ(cloned_root->a->b->name, "replacement:a->b");
- EXPECT_EQ(cloned_root->a->b->b->name, "replacement-child:a->b");
- EXPECT_EQ(cloned_root->b->name, "replacement:b");
- EXPECT_EQ(cloned_root->b->b->name, "replacement-child:b");
+ EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+ EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+ EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("replacement:a->b"));
+ EXPECT_EQ(cloned_root->a->b->b->name,
+ cloned.Symbols().Get("replacement-child:a->b"));
+ EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement:b"));
+ EXPECT_EQ(cloned_root->b->b->name,
+ cloned.Symbols().Get("replacement-child:b"));
EXPECT_EQ(cloned_root->b->c, cloned_root->a); // Aliased
EXPECT_EQ(cloned_root->c, cloned_root->b); // Aliased
@@ -198,12 +210,53 @@
EXPECT_FALSE(cloned_root->b->b->Is<Replacement>());
}
+TEST(CloneContext, CloneWithReplaceAll_Symbols) {
+ ProgramBuilder builder;
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->a->b =
+ builder.create<Node>(builder.Symbols().Register("a->b"));
+ original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+ original_root->b->a = original_root->a; // Aliased
+ original_root->b->b =
+ builder.create<Node>(builder.Symbols().Register("b->b"));
+ original_root->c = original_root->b; // Aliased
+ Program original(std::move(builder));
+
+ // root
+ // ╭──────────────────┼──────────────────╮
+ // (a) (b) (c)
+ // N <──────┐ N <───────────────┘
+ // ╭────┼────╮ │ ╭────┼────╮
+ // (a) (b) (c) │ (a) (b) (c)
+ // N └───┘ N
+ //
+ // N: Node
+
+ ProgramBuilder cloned;
+ auto* cloned_root = CloneContext(&cloned, &original)
+ .ReplaceAll([&](Symbol sym) {
+ auto in = original.Symbols().NameFor(sym);
+ auto out = "transformed<" + in + ">";
+ return cloned.Symbols().Register(out);
+ })
+ .Clone(original_root);
+
+ EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("transformed<root>"));
+ EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("transformed<a>"));
+ EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("transformed<a->b>"));
+ EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("transformed<b>"));
+ EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("transformed<b->b>"));
+}
+
TEST(CloneContext, CloneWithReplace) {
ProgramBuilder builder;
- auto* original_root = builder.create<Node>("root");
- original_root->a = builder.create<Node>("a");
- original_root->b = builder.create<Node>("b");
- original_root->c = builder.create<Node>("c");
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+ original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
Program original(std::move(builder));
// root
@@ -212,7 +265,8 @@
// Replaced
ProgramBuilder cloned;
- auto* replacement = cloned.create<Node>("replacement");
+ auto* replacement =
+ cloned.create<Node>(cloned.Symbols().Register("replacement"));
auto* cloned_root = CloneContext(&cloned, &original)
.Replace(original_root->b, replacement)
@@ -222,23 +276,24 @@
EXPECT_EQ(cloned_root->b, replacement);
EXPECT_NE(cloned_root->c, replacement);
- EXPECT_EQ(cloned_root->name, "root");
- EXPECT_EQ(cloned_root->a->name, "a");
- EXPECT_EQ(cloned_root->b->name, "replacement");
- EXPECT_EQ(cloned_root->c->name, "c");
+ EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+ EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
+ EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement"));
+ EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
}
TEST(CloneContext, CloneWithInsertBefore) {
ProgramBuilder builder;
- auto* original_root = builder.create<Node>("root");
- original_root->a = builder.create<Node>("a");
- original_root->b = builder.create<Node>("b");
- original_root->c = builder.create<Node>("c");
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
+ original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
original_root->vec = {original_root->a, original_root->b, original_root->c};
Program original(std::move(builder));
ProgramBuilder cloned;
- auto* insertion = cloned.create<Node>("insertion");
+ auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
auto* cloned_root = CloneContext(&cloned, &original)
.InsertBefore(original_root->b, insertion)
@@ -249,21 +304,77 @@
EXPECT_EQ(cloned_root->vec[2], cloned_root->b);
EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
- EXPECT_EQ(cloned_root->name, "root");
- EXPECT_EQ(cloned_root->vec[0]->name, "a");
- EXPECT_EQ(cloned_root->vec[1]->name, "insertion");
- EXPECT_EQ(cloned_root->vec[2]->name, "b");
- EXPECT_EQ(cloned_root->vec[3]->name, "c");
+ EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
+ EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
+ EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("insertion"));
+ EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("b"));
+ EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
+}
+
+TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder cloned;
+ Program original;
+ CloneContext ctx(&cloned, &original);
+ ctx.ReplaceAll([](Node*) { return nullptr; });
+ ctx.ReplaceAll([](Node*) { return nullptr; });
+ },
+ "internal compiler error: ReplaceAll() called with a handler for type "
+ "Node that is already handled by a handler for type Node");
+}
+
+TEST(CloneContext, CloneWithReplaceAll_BaseThenDerived) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder cloned;
+ Program original;
+ CloneContext ctx(&cloned, &original);
+ ctx.ReplaceAll([](Node*) { return nullptr; });
+ ctx.ReplaceAll([](Replaceable*) { return nullptr; });
+ },
+ "internal compiler error: ReplaceAll() called with a handler for type "
+ "Replaceable that is already handled by a handler for type Node");
+}
+
+TEST(CloneContext, CloneWithReplaceAll_DerivedThenBase) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder cloned;
+ Program original;
+ CloneContext ctx(&cloned, &original);
+ ctx.ReplaceAll([](Replaceable*) { return nullptr; });
+ ctx.ReplaceAll([](Node*) { return nullptr; });
+ },
+ "internal compiler error: ReplaceAll() called with a handler for type "
+ "Node that is already handled by a handler for type Replaceable");
+}
+
+TEST(CloneContext, CloneWithReplaceAll_SymbolsTwice) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder cloned;
+ Program original;
+ CloneContext ctx(&cloned, &original);
+ ctx.ReplaceAll([](Symbol s) { return s; });
+ ctx.ReplaceAll([](Symbol s) { return s; });
+ },
+ "internal compiler error: ReplaceAll(const SymbolTransform&) called "
+ "multiple times on the same CloneContext");
}
TEST(CloneContext, CloneWithReplace_WithNotANode) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder builder;
- auto* original_root = builder.create<Node>("root");
- original_root->a = builder.create<Node>("a");
- original_root->b = builder.create<Node>("b");
- original_root->c = builder.create<Node>("c");
+ auto* original_root =
+ builder.create<Node>(builder.Symbols().Register("root"));
+ original_root->a =
+ builder.create<Node>(builder.Symbols().Register("a"));
+ original_root->b =
+ builder.create<Node>(builder.Symbols().Register("b"));
+ original_root->c =
+ builder.create<Node>(builder.Symbols().Register("c"));
Program original(std::move(builder));
// root