CloneContext: Don't create named symbols from unnamed
Registering a new Symbol with the NameFor() of the source symbol creates
a new *named* symbol. When mixing these with unnamed symbols we can have
collisions.
Update CloneContext::Clone(Symbol) to properly clone unnamed symbols.
Update (most) the transforms to ctx.Clone() the symbols instead of
registering the names directly.
Fix up the tests where the symbol IDs have changed.
Note: We can still have symbol collisions if a program is authored with
identifiers like 'tint_symbol_3'. This will be fixed up in a later
change.
Change-Id: I0ce559644da3d60e1060f2eef185fa55ae284521
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46866
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 4fa5746..901dd29 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -457,6 +457,7 @@
"type/vector_type.h",
"type/void_type.cc",
"type/void_type.h",
+ "utils/get_or_create.h",
"utils/hash.h",
"utils/math.h",
"utils/unique_vector.h",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 53810cb..8216dba 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -272,6 +272,7 @@
type/vector_type.h
type/void_type.cc
type/void_type.h
+ utils/get_or_create.h
utils/hash.h
utils/math.h
utils/unique_vector.h
@@ -519,6 +520,7 @@
type/vector_type_test.cc
utils/command_test.cc
utils/command.h
+ utils/get_or_create_test.cc
utils/hash_test.cc
utils/math_test.cc
utils/tmpfile_test.cc
diff --git a/src/clone_context.cc b/src/clone_context.cc
index f449b0b..a055f4d 100644
--- a/src/clone_context.cc
+++ b/src/clone_context.cc
@@ -15,6 +15,7 @@
#include "src/clone_context.h"
#include "src/program_builder.h"
+#include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::Cloneable);
@@ -27,11 +28,16 @@
: dst(to), src(from) {}
CloneContext::~CloneContext() = default;
-Symbol CloneContext::Clone(const Symbol& s) const {
- if (symbol_transform_) {
- return symbol_transform_(s);
- }
- return dst->Symbols().Register(src->Symbols().NameFor(s));
+Symbol CloneContext::Clone(Symbol s) {
+ return utils::GetOrCreate(cloned_symbols_, s, [&]() -> Symbol {
+ if (symbol_transform_) {
+ return symbol_transform_(s);
+ }
+ if (!src->Symbols().HasName(s)) {
+ return dst->Symbols().New();
+ }
+ return dst->Symbols().Register(src->Symbols().NameFor(s));
+ });
}
void CloneContext::Clone() {
diff --git a/src/clone_context.h b/src/clone_context.h
index fd7a0ac..34e366b 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -148,7 +148,7 @@
///
/// @param s the Symbol to clone
/// @return the cloned source
- Symbol Clone(const Symbol& s) const;
+ Symbol Clone(Symbol s);
/// Clones each of the elements of the vector `v` into the ProgramBuilder
/// #dst.
@@ -448,6 +448,9 @@
/// A map of object in #src to their cloned equivalent in #dst
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
+ /// A map of symbol in #src to their cloned equivalent in #dst
+ std::unordered_map<Symbol, Symbol> cloned_symbols_;
+
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;
diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc
index b09ccdc..3b30680 100644
--- a/src/clone_context_test.cc
+++ b/src/clone_context_test.cc
@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "gtest/gtest-spi.h"
+#include <unordered_set>
+#include "gtest/gtest-spi.h"
#include "src/program_builder.h"
namespace tint {
@@ -416,6 +417,27 @@
"internal compiler error");
}
+TEST(CloneContext, CloneUnnamedSymbols) {
+ ProgramBuilder builder;
+ Symbol old_a = builder.Symbols().New();
+ Symbol old_b = builder.Symbols().New();
+ Symbol old_c = builder.Symbols().New();
+
+ Program original(std::move(builder));
+
+ ProgramBuilder cloned;
+ CloneContext ctx(&cloned, &original);
+ Symbol new_a = ctx.Clone(old_a);
+ Symbol new_x = cloned.Symbols().New();
+ Symbol new_b = ctx.Clone(old_b);
+ Symbol new_y = cloned.Symbols().New();
+ Symbol new_c = ctx.Clone(old_c);
+ Symbol new_z = cloned.Symbols().New();
+
+ std::unordered_set<Symbol> all{new_a, new_x, new_b, new_y, new_c, new_z};
+ EXPECT_EQ(all.size(), 6u);
+}
+
} // namespace
TINT_INSTANTIATE_TYPEINFO(Node);
diff --git a/src/program_builder.h b/src/program_builder.h
index b063779..857bd4d 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -413,7 +413,7 @@
/// @param subtype the array element type
/// @param n the array size. 0 represents a runtime-array.
/// @return the tint AST type for a array of size `n` of type `T`
- type::Array* array(type::Type* subtype, uint32_t n) const {
+ type::Array* array(type::Type* subtype, uint32_t n = 0) const {
return builder->create<type::Array>(subtype, n, ast::DecorationList{});
}
@@ -490,6 +490,14 @@
// AST helper methods
//////////////////////////////////////////////////////////////////////////////
+ /// @param name the symbol string
+ /// @return a Symbol with the given name
+ Symbol Sym(const std::string& name) { return Symbols().Register(name); }
+
+ /// @param sym the symbol
+ /// @return `sym`
+ Symbol Sym(Symbol sym) { return sym; }
+
/// @param expr the expression
/// @return expr
template <typename T>
@@ -775,13 +783,14 @@
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given name, storage and type
- ast::Variable* Var(const std::string& name,
+ template <typename NAME>
+ ast::Variable* Var(NAME&& name,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
- return create<ast::Variable>(Symbols().Register(name), storage, type, false,
- constructor, decorations);
+ return create<ast::Variable>(Sym(std::forward<NAME>(name)), storage, type,
+ false, constructor, decorations);
}
/// @param source the variable source
@@ -791,58 +800,28 @@
/// @param constructor constructor expression
/// @param decorations variable decorations
/// @returns a `ast::Variable` with the given name, storage and type
+ template <typename NAME>
ast::Variable* Var(const Source& source,
- const std::string& name,
+ NAME&& name,
type::Type* type,
ast::StorageClass storage,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
- return create<ast::Variable>(source, Symbols().Register(name), storage,
+ return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), storage,
type, false, constructor, decorations);
}
- /// @param symbol the variable symbol
- /// @param type the variable type
- /// @param storage the variable storage class
- /// @param constructor constructor expression
- /// @param decorations variable decorations
- /// @returns a `ast::Variable` with the given symbol, storage and type
- ast::Variable* Var(Symbol symbol,
- type::Type* type,
- ast::StorageClass storage,
- ast::Expression* constructor = nullptr,
- ast::DecorationList decorations = {}) {
- return create<ast::Variable>(symbol, storage, type, false, constructor,
- decorations);
- }
-
- /// @param source the variable source
- /// @param symbol the variable symbol
- /// @param type the variable type
- /// @param storage the variable storage class
- /// @param constructor constructor expression
- /// @param decorations variable decorations
- /// @returns a `ast::Variable` with the given symbol, storage and type
- ast::Variable* Var(const Source& source,
- Symbol symbol,
- type::Type* type,
- ast::StorageClass storage,
- ast::Expression* constructor = nullptr,
- ast::DecorationList decorations = {}) {
- return create<ast::Variable>(source, symbol, storage, type, false,
- constructor, decorations);
- }
-
/// @param name the variable name
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given name, storage and type
- ast::Variable* Const(const std::string& name,
+ template <typename NAME>
+ ast::Variable* Const(NAME&& name,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
- return create<ast::Variable>(Symbols().Register(name),
+ return create<ast::Variable>(Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, type, true,
constructor, decorations);
}
@@ -853,46 +832,17 @@
/// @param constructor optional constructor expression
/// @param decorations optional variable decorations
/// @returns a constant `ast::Variable` with the given name, storage and type
+ template <typename NAME>
ast::Variable* Const(const Source& source,
- const std::string& name,
+ NAME&& name,
type::Type* type,
ast::Expression* constructor = nullptr,
ast::DecorationList decorations = {}) {
- return create<ast::Variable>(source, Symbols().Register(name),
+ return create<ast::Variable>(source, Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, type, true,
constructor, decorations);
}
- /// @param symbol the variable symbol
- /// @param type the variable type
- /// @param constructor optional constructor expression
- /// @param decorations optional variable decorations
- /// @returns a constant `ast::Variable` with the given symbol, storage and
- /// type
- ast::Variable* Const(Symbol symbol,
- type::Type* type,
- ast::Expression* constructor = nullptr,
- ast::DecorationList decorations = {}) {
- return create<ast::Variable>(symbol, ast::StorageClass::kNone, type, true,
- constructor, decorations);
- }
-
- /// @param source the variable source
- /// @param symbol the variable symbol
- /// @param type the variable type
- /// @param constructor optional constructor expression
- /// @param decorations optional variable decorations
- /// @returns a constant `ast::Variable` with the given symbol, storage and
- /// type
- ast::Variable* Const(const Source& source,
- Symbol symbol,
- type::Type* type,
- ast::Expression* constructor = nullptr,
- ast::DecorationList decorations = {}) {
- return create<ast::Variable>(source, symbol, ast::StorageClass::kNone, type,
- true, constructor, decorations);
- }
-
/// @param args the arguments to pass to Var()
/// @returns a `ast::Variable` constructed by calling Var() with the arguments
/// of `args`, which is automatically registered as a global variable with the
@@ -966,6 +916,16 @@
Expr(std::forward<RHS>(rhs)));
}
+ /// @param lhs the left hand argument to the division operation
+ /// @param rhs the right hand argument to the division operation
+ /// @returns a `ast::BinaryExpression` dividing `lhs` by `rhs`
+ template <typename LHS, typename RHS>
+ ast::Expression* Div(LHS&& lhs, RHS&& rhs) {
+ return create<ast::BinaryExpression>(ast::BinaryOp::kDivide,
+ Expr(std::forward<LHS>(lhs)),
+ Expr(std::forward<RHS>(rhs)));
+ }
+
/// @param arr the array argument for the array accessor expression
/// @param idx the index argument for the array accessor expression
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
@@ -1027,19 +987,22 @@
/// @param params the function parameters
/// @param type the function return type
/// @param body the function body
- /// @param decorations the function decorations
- /// @param return_type_decorations the function return type decorations
+ /// @param decorations the optional function decorations
+ /// @param return_type_decorations the optional function return type
+ /// decorations
/// @returns the function pointer
+ template <typename NAME>
ast::Function* Func(Source source,
- std::string name,
+ NAME&& name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) {
- auto* func = create<ast::Function>(source, Symbols().Register(name), params,
- type, create<ast::BlockStatement>(body),
- decorations, return_type_decorations);
+ auto* func =
+ create<ast::Function>(source, Sym(std::forward<NAME>(name)), params,
+ type, create<ast::BlockStatement>(body),
+ decorations, return_type_decorations);
AST().AddFunction(func);
return func;
}
@@ -1049,17 +1012,19 @@
/// @param params the function parameters
/// @param type the function return type
/// @param body the function body
- /// @param decorations the function decorations
- /// @param return_type_decorations the function return type decorations
+ /// @param decorations the optional function decorations
+ /// @param return_type_decorations the optional function return type
+ /// decorations
/// @returns the function pointer
- ast::Function* Func(std::string name,
+ template <typename NAME>
+ ast::Function* Func(NAME&& name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
ast::DecorationList decorations = {},
ast::DecorationList return_type_decorations = {}) {
- auto* func = create<ast::Function>(Symbols().Register(name), params, type,
- create<ast::BlockStatement>(body),
+ auto* func = create<ast::Function>(Sym(std::forward<NAME>(name)), params,
+ type, create<ast::BlockStatement>(body),
decorations, return_type_decorations);
AST().AddFunction(func);
return func;
@@ -1113,12 +1078,13 @@
/// @param type the struct member type
/// @param decorations the optional struct member decorations
/// @returns the struct member pointer
+ template <typename NAME>
ast::StructMember* Member(const Source& source,
- const std::string& name,
+ NAME&& name,
type::Type* type,
ast::DecorationList decorations = {}) {
- return create<ast::StructMember>(source, Symbols().Register(name), type,
- std::move(decorations));
+ return create<ast::StructMember>(source, Sym(std::forward<NAME>(name)),
+ type, std::move(decorations));
}
/// Creates a ast::StructMember
@@ -1126,11 +1092,12 @@
/// @param type the struct member type
/// @param decorations the optional struct member decorations
/// @returns the struct member pointer
- ast::StructMember* Member(const std::string& name,
+ template <typename NAME>
+ ast::StructMember* Member(NAME&& name,
type::Type* type,
ast::DecorationList decorations = {}) {
- return create<ast::StructMember>(source_, Symbols().Register(name), type,
- std::move(decorations));
+ return create<ast::StructMember>(source_, Sym(std::forward<NAME>(name)),
+ type, std::move(decorations));
}
/// Creates a ast::StructMember with the given byte offset
@@ -1138,11 +1105,10 @@
/// @param name the struct member name
/// @param type the struct member type
/// @returns the struct member pointer
- ast::StructMember* Member(uint32_t offset,
- const std::string& name,
- type::Type* type) {
+ template <typename NAME>
+ ast::StructMember* Member(uint32_t offset, NAME&& name, type::Type* type) {
return create<ast::StructMember>(
- source_, Symbols().Register(name), type,
+ source_, Sym(std::forward<NAME>(name)), type,
ast::DecorationList{
create<ast::StructMemberOffsetDecoration>(offset),
});
diff --git a/src/symbol_table.cc b/src/symbol_table.cc
index 16f28d2..5d7edba 100644
--- a/src/symbol_table.cc
+++ b/src/symbol_table.cc
@@ -50,6 +50,11 @@
return it != name_to_symbol_.end() ? it->second : Symbol();
}
+bool SymbolTable::HasName(const Symbol symbol) const {
+ auto it = symbol_to_name_.find(symbol);
+ return it != symbol_to_name_.end();
+}
+
std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol);
if (it == symbol_to_name_.end()) {
diff --git a/src/symbol_table.h b/src/symbol_table.h
index 107fe16..b3e1cae 100644
--- a/src/symbol_table.h
+++ b/src/symbol_table.h
@@ -53,6 +53,10 @@
/// @returns the symbol for the name or symbol::kInvalid if not found.
Symbol Get(const std::string& name) const;
+ /// @returns true if the symbol has a name
+ /// @param symbol the symbol to query
+ bool HasName(const Symbol symbol) const;
+
/// Returns the name for the given symbol
/// @param symbol the symbol to retrieve the name for
/// @returns the symbol name or "" if not found
diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc
index 1c52066..b263767 100644
--- a/src/transform/canonicalize_entry_point_io.cc
+++ b/src/transform/canonicalize_entry_point_io.cc
@@ -45,7 +45,7 @@
->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
});
new_struct_members.push_back(
- ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()),
+ ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
}
@@ -70,10 +70,9 @@
auto new_struct_param_symbol = ctx.dst->Symbols().New();
ast::StructMemberList new_struct_members;
for (auto* param : func->params()) {
- auto param_name = ctx.src->Symbols().NameFor(param->symbol());
+ auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type();
- auto func_const_symbol = ctx.dst->Symbols().Register(param_name);
ast::Expression* func_const_initializer = nullptr;
if (auto* struct_ty =
@@ -90,7 +89,7 @@
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
- auto member_name = ctx.src->Symbols().NameFor(member->symbol());
+ auto member_name = ctx.Clone(member->symbol());
new_struct_members.push_back(ctx.dst->Member(
member_name, ctx.Clone(member->type()), new_decorations));
init_values.push_back(
@@ -118,15 +117,15 @@
// Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the new struct parameter.
- auto* func_const = ctx.dst->Const(
- func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
+ auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
+ func_const_initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
- ctx.dst->Expr(func_const_symbol));
+ ctx.dst->Expr(param_name));
}
}
@@ -163,9 +162,9 @@
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
- auto member_name = ctx.src->Symbols().NameFor(member->symbol());
- new_struct_members.push_back(ctx.dst->Member(
- member_name, ctx.Clone(member->type()), new_decorations));
+ new_struct_members.push_back(
+ ctx.dst->Member(ctx.Clone(member->symbol()),
+ ctx.Clone(member->type()), new_decorations));
}
} else {
new_struct_members.push_back(
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index c9da68b..a4a07f2 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -121,7 +121,7 @@
->IsAnyOf<ast::BuiltinDecoration, ast::LocationDecoration>();
});
new_struct_members.push_back(
- ctx.dst->Member(ctx.src->Symbols().NameFor(member->symbol()),
+ ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
}
@@ -215,7 +215,7 @@
}
// Use the same name as the old variable.
- std::string var_name = ctx.src->Symbols().NameFor(var->symbol());
+ auto var_name = ctx.Clone(var->symbol());
// Use `array<u32, 1>` for the new variable.
auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u);
// Create the new variable.
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index efec401..6cfeff7 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -74,7 +74,7 @@
auto* expect = R"(
type myf32 = f32;
-[[location(1)]] var<in> tint_symbol_1 : myf32;
+[[location(1)]] var<in> tint_symbol_2 : myf32;
[[stage(fragment)]]
fn frag_main() -> void {
@@ -95,15 +95,15 @@
)";
auto* expect = R"(
-[[builtin(position)]] var<out> tint_symbol_2 : vec4<f32>;
+[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
-fn tint_symbol_3(tint_symbol_1 : vec4<f32>) -> void {
- tint_symbol_2 = tint_symbol_1;
+fn tint_symbol_2(tint_symbol_3 : vec4<f32>) -> void {
+ tint_symbol_1 = tint_symbol_3;
}
[[stage(vertex)]]
fn vert_main() -> void {
- tint_symbol_3(vec4<f32>(1.0, 2.0, 3.0, 0.0));
+ tint_symbol_2(vec4<f32>(1.0, 2.0, 3.0, 0.0));
return;
}
)";
@@ -127,19 +127,19 @@
auto* expect = R"(
[[location(0)]] var<in> tint_symbol_1 : u32;
-[[location(0)]] var<out> tint_symbol_3 : f32;
+[[location(0)]] var<out> tint_symbol_2 : f32;
-fn tint_symbol_4(tint_symbol_2 : f32) -> void {
- tint_symbol_3 = tint_symbol_2;
+fn tint_symbol_3(tint_symbol_4 : f32) -> void {
+ tint_symbol_2 = tint_symbol_4;
}
[[stage(fragment)]]
fn frag_main() -> void {
if ((tint_symbol_1 > 10u)) {
- tint_symbol_4(0.5);
+ tint_symbol_3(0.5);
return;
}
- tint_symbol_4(1.0);
+ tint_symbol_3(1.0);
return;
}
)";
@@ -165,21 +165,21 @@
auto* expect = R"(
type myf32 = f32;
-[[location(0)]] var<in> tint_symbol_1 : u32;
+[[location(0)]] var<in> tint_symbol_2 : u32;
[[location(0)]] var<out> tint_symbol_3 : myf32;
-fn tint_symbol_5(tint_symbol_2 : myf32) -> void {
- tint_symbol_3 = tint_symbol_2;
+fn tint_symbol_4(tint_symbol_5 : myf32) -> void {
+ tint_symbol_3 = tint_symbol_5;
}
[[stage(fragment)]]
fn frag_main() -> void {
- if ((tint_symbol_1 > 10u)) {
- tint_symbol_5(0.5);
+ if ((tint_symbol_2 > 10u)) {
+ tint_symbol_4(0.5);
return;
}
- tint_symbol_5(1.0);
+ tint_symbol_4(1.0);
return;
}
)";
@@ -214,8 +214,8 @@
[[stage(fragment)]]
fn frag_main() -> void {
- const tint_symbol_6 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5);
- var col : f32 = (tint_symbol_6.coord.x * tint_symbol_6.value);
+ const tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_4, tint_symbol_5);
+ var col : f32 = (tint_symbol_7.coord.x * tint_symbol_7.value);
}
)";
@@ -275,23 +275,23 @@
value : f32;
};
-[[builtin(position)]] var<out> tint_symbol_5 : vec4<f32>;
+[[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>;
-[[location(1)]] var<out> tint_symbol_6 : f32;
+[[location(1)]] var<out> tint_symbol_5 : f32;
-fn tint_symbol_7(tint_symbol_4 : VertexOutput) -> void {
- tint_symbol_5 = tint_symbol_4.pos;
- tint_symbol_6 = tint_symbol_4.value;
+fn tint_symbol_6(tint_symbol_7 : VertexOutput) -> void {
+ tint_symbol_4 = tint_symbol_7.pos;
+ tint_symbol_5 = tint_symbol_7.value;
}
[[stage(vertex)]]
fn vert_main() -> void {
if (false) {
- tint_symbol_7(VertexOutput());
+ tint_symbol_6(VertexOutput());
return;
}
var pos : vec4<f32> = vec4<f32>(1.0, 2.0, 3.0, 0.0);
- tint_symbol_7(VertexOutput(pos, 2.0));
+ tint_symbol_6(VertexOutput(pos, 2.0));
return;
}
)";
@@ -320,16 +320,16 @@
[[location(1)]] var<in> tint_symbol_3 : f32;
-[[location(1)]] var<out> tint_symbol_6 : f32;
+[[location(1)]] var<out> tint_symbol_4 : f32;
-fn tint_symbol_7(tint_symbol_5 : Interface) -> void {
- tint_symbol_6 = tint_symbol_5.value;
+fn tint_symbol_5(tint_symbol_6 : Interface) -> void {
+ tint_symbol_4 = tint_symbol_6.value;
}
[[stage(vertex)]]
fn vert_main() -> void {
- const tint_symbol_4 : Interface = Interface(tint_symbol_3);
- tint_symbol_7(tint_symbol_4);
+ const tint_symbol_8 : Interface = Interface(tint_symbol_3);
+ tint_symbol_5(tint_symbol_8);
return;
}
)";
@@ -361,15 +361,15 @@
value : f32;
};
-[[location(1)]] var<out> tint_symbol_4 : f32;
+[[location(1)]] var<out> tint_symbol_3 : f32;
-fn tint_symbol_5(tint_symbol_3 : Interface) -> void {
- tint_symbol_4 = tint_symbol_3.value;
+fn tint_symbol_4(tint_symbol_5 : Interface) -> void {
+ tint_symbol_3 = tint_symbol_5.value;
}
[[stage(vertex)]]
fn vert_main() -> void {
- tint_symbol_5(Interface(42.0));
+ tint_symbol_4(Interface(42.0));
return;
}
@@ -377,8 +377,8 @@
[[stage(fragment)]]
fn frag_main() -> void {
- const tint_symbol_8 : Interface = Interface(tint_symbol_7);
- var x : f32 = tint_symbol_8.value;
+ const tint_symbol_9 : Interface = Interface(tint_symbol_7);
+ var x : f32 = tint_symbol_9.value;
}
)";
@@ -423,16 +423,16 @@
[[builtin(frag_coord)]] var<in> tint_symbol_6 : vec4<f32>;
-[[location(1)]] var<out> tint_symbol_9 : f32;
+[[location(1)]] var<out> tint_symbol_7 : f32;
-fn tint_symbol_10(tint_symbol_8 : FragmentOutput) -> void {
- tint_symbol_9 = tint_symbol_8.value;
+fn tint_symbol_8(tint_symbol_9 : FragmentOutput) -> void {
+ tint_symbol_7 = tint_symbol_9.value;
}
[[stage(fragment)]]
fn frag_main() -> void {
- const tint_symbol_7 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6);
- tint_symbol_10(FragmentOutput((tint_symbol_7.coord.x * tint_symbol_7.value)));
+ const tint_symbol_11 : FragmentInput = FragmentInput(tint_symbol_5, tint_symbol_6);
+ tint_symbol_8(FragmentOutput((tint_symbol_11.coord.x * tint_symbol_11.value)));
return;
}
)";
@@ -467,8 +467,8 @@
[[builtin(position)]] var<out> tint_symbol_4 : vec4<f32>;
-fn tint_symbol_5(tint_symbol_3 : VertexOutput) -> void {
- tint_symbol_4 = tint_symbol_3.Position;
+fn tint_symbol_5(tint_symbol_6 : VertexOutput) -> void {
+ tint_symbol_4 = tint_symbol_6.Position;
}
[[stage(vertex)]]
@@ -585,19 +585,19 @@
)";
auto* expect = R"(
-[[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
+[[builtin(sample_index)]] var<in> tint_symbol_3 : u32;
-[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
+[[builtin(sample_mask_in)]] var<in> tint_symbol_1 : array<u32, 1>;
-[[builtin(sample_mask_out)]] var<out> tint_symbol_4 : array<u32, 1>;
+[[builtin(sample_mask_out)]] var<out> tint_symbol_2 : array<u32, 1>;
-fn tint_symbol_5(tint_symbol_3 : u32) -> void {
- tint_symbol_4[0] = tint_symbol_3;
+fn tint_symbol_4(tint_symbol_5 : u32) -> void {
+ tint_symbol_2[0] = tint_symbol_5;
}
[[stage(fragment)]]
fn main() -> void {
- tint_symbol_5(tint_symbol_2[0]);
+ tint_symbol_4(tint_symbol_1[0]);
return;
}
)";
diff --git a/src/utils/get_or_create.h b/src/utils/get_or_create.h
new file mode 100644
index 0000000..b5b86ef
--- /dev/null
+++ b/src/utils/get_or_create.h
@@ -0,0 +1,44 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_UTILS_GET_OR_CREATE_H_
+#define SRC_UTILS_GET_OR_CREATE_H_
+
+#include <unordered_map>
+
+namespace tint {
+namespace utils {
+
+/// GetOrCreate is a utility function for lazily adding to an unordered map.
+/// If the map already contains the key `key` then this is returned, otherwise
+/// `create()` is called and the result is added to the map and is returned.
+/// @param map the unordered_map
+/// @param key the map key of the item to query or add
+/// @param create a callable function-like object with the signature `V()`
+/// @return the value of the item with the given key, or the newly created item
+template <typename K, typename V, typename CREATE, typename H>
+V GetOrCreate(std::unordered_map<K, V, H>& map, K key, CREATE&& create) {
+ auto it = map.find(key);
+ if (it != map.end()) {
+ return it->second;
+ }
+ V value = create();
+ map.emplace(key, value);
+ return value;
+}
+
+} // namespace utils
+} // namespace tint
+
+#endif // SRC_UTILS_GET_OR_CREATE_H_
diff --git a/src/utils/get_or_create_test.cc b/src/utils/get_or_create_test.cc
new file mode 100644
index 0000000..ae0c499
--- /dev/null
+++ b/src/utils/get_or_create_test.cc
@@ -0,0 +1,49 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/utils/get_or_create.h"
+
+#include <unordered_map>
+
+#include "gtest/gtest.h"
+
+namespace tint {
+namespace utils {
+namespace {
+
+TEST(GetOrCreateTest, NewKey) {
+ std::unordered_map<int, int> map;
+ EXPECT_EQ(GetOrCreate(map, 1, [&] { return 2; }), 2);
+ EXPECT_EQ(map.size(), 1u);
+ EXPECT_EQ(map[1], 2);
+}
+
+TEST(GetOrCreateTest, ExistingKey) {
+ std::unordered_map<int, int> map;
+ map[1] = 2;
+ bool called = false;
+ EXPECT_EQ(GetOrCreate(map, 1,
+ [&] {
+ called = true;
+ return -2;
+ }),
+ 2);
+ EXPECT_EQ(called, false);
+ EXPECT_EQ(map.size(), 1u);
+ EXPECT_EQ(map[1], 2);
+}
+
+} // namespace
+} // namespace utils
+} // namespace tint
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index fb8fe9a..f9b1d6a 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -124,16 +124,16 @@
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_3 {
+ EXPECT_EQ(result(), R"(struct tint_symbol_1 {
float foo : TEXCOORD0;
};
-struct tint_symbol_5 {
+struct tint_symbol_3 {
float value : SV_Target1;
};
-tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) {
- const float foo = tint_symbol_1.foo;
- return tint_symbol_5(foo);
+tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
+ const float foo = tint_symbol_6.foo;
+ return tint_symbol_3(foo);
}
)");
@@ -157,16 +157,16 @@
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
- EXPECT_EQ(result(), R"(struct tint_symbol_3 {
+ EXPECT_EQ(result(), R"(struct tint_symbol_1 {
float4 coord : SV_Position;
};
-struct tint_symbol_5 {
+struct tint_symbol_3 {
float value : SV_Depth;
};
-tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1) {
- const float4 coord = tint_symbol_1.coord;
- return tint_symbol_5(coord.x);
+tint_symbol_3 frag_main(tint_symbol_1 tint_symbol_6) {
+ const float4 coord = tint_symbol_6.coord;
+ return tint_symbol_3(coord.x);
}
)");
@@ -217,18 +217,18 @@
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
};
-struct tint_symbol_9 {
+struct tint_symbol_7 {
float col1 : TEXCOORD1;
float col2 : TEXCOORD2;
};
tint_symbol_4 vert_main() {
- const Interface tint_symbol_5 = Interface(0.5f, 0.25f);
- return tint_symbol_4(tint_symbol_5.col1, tint_symbol_5.col2);
+ const Interface tint_symbol_6 = Interface(0.5f, 0.25f);
+ return tint_symbol_4(tint_symbol_6.col1, tint_symbol_6.col2);
}
-void frag_main(tint_symbol_9 tint_symbol_7) {
- const Interface colors = Interface(tint_symbol_7.col1, tint_symbol_7.col2);
+void frag_main(tint_symbol_7 tint_symbol_9) {
+ const Interface colors = Interface(tint_symbol_9.col1, tint_symbol_9.col2);
const float r = colors.col1;
const float g = colors.col2;
return;
@@ -281,10 +281,10 @@
EXPECT_EQ(result(), R"(struct VertexOutput {
float4 pos;
};
-struct tint_symbol_3 {
+struct tint_symbol_5 {
float4 pos : SV_Position;
};
-struct tint_symbol_7 {
+struct tint_symbol_8 {
float4 pos : SV_Position;
};
@@ -292,14 +292,14 @@
return VertexOutput(float4(x, x, x, 1.0f));
}
-tint_symbol_3 vert_main1() {
- const VertexOutput tint_symbol_5 = VertexOutput(foo(0.5f));
- return tint_symbol_3(tint_symbol_5.pos);
+tint_symbol_5 vert_main1() {
+ const VertexOutput tint_symbol_7 = VertexOutput(foo(0.5f));
+ return tint_symbol_5(tint_symbol_7.pos);
}
-tint_symbol_7 vert_main2() {
- const VertexOutput tint_symbol_8 = VertexOutput(foo(0.25f));
- return tint_symbol_7(tint_symbol_8.pos);
+tint_symbol_8 vert_main2() {
+ const VertexOutput tint_symbol_10 = VertexOutput(foo(0.25f));
+ return tint_symbol_8(tint_symbol_10.pos);
}
)");
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index 6c4b3b2..78248bd 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -134,9 +134,9 @@
OpEntryPoint Fragment %14 "frag_main" %1 %4
OpExecutionMode %14 OriginUpperLeft
OpName %1 "tint_symbol_1"
-OpName %4 "tint_symbol_3"
-OpName %10 "tint_symbol_4"
-OpName %11 "tint_symbol_2"
+OpName %4 "tint_symbol_2"
+OpName %10 "tint_symbol_3"
+OpName %11 "tint_symbol_4"
OpName %14 "frag_main"
OpDecorate %1 Location 0
OpDecorate %4 Location 0
@@ -220,16 +220,16 @@
OpEntryPoint Fragment %25 "frag_main" %5 %7
OpExecutionMode %25 OriginUpperLeft
OpExecutionMode %25 DepthReplacing
-OpName %1 "tint_symbol_4"
+OpName %1 "tint_symbol_3"
OpName %5 "tint_symbol_7"
-OpName %7 "tint_symbol_10"
+OpName %7 "tint_symbol_8"
OpName %10 "Interface"
OpMemberName %10 0 "value"
-OpName %11 "tint_symbol_5"
-OpName %12 "tint_symbol_3"
+OpName %11 "tint_symbol_4"
+OpName %12 "tint_symbol_5"
OpName %16 "vert_main"
-OpName %22 "tint_symbol_11"
-OpName %23 "tint_symbol_9"
+OpName %22 "tint_symbol_9"
+OpName %23 "tint_symbol_10"
OpName %25 "frag_main"
OpDecorate %1 Location 1
OpDecorate %5 Location 1
diff --git a/test/BUILD.gn b/test/BUILD.gn
index 64dcf38..69f1017 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -220,6 +220,7 @@
"../src/type/vector_type_test.cc",
"../src/utils/command.h",
"../src/utils/command_test.cc",
+ "../src/utils/get_or_create_test.cc",
"../src/utils/hash_test.cc",
"../src/utils/math_test.cc",
"../src/utils/tmpfile.h",