Validate that Symbols are all part of the same program Assert in each AST constructor that symbols belong to the program of the parent. Bug: tint:709 Change-Id: I82ae9b23c88e89714a44e057a0272f0293385aaf Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47624 Commit-Queue: Ben Clayton <bclayton@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: James Price <jrprice@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc index b656d6b..989d5e1 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc
@@ -38,6 +38,7 @@ body_(body), decorations_(std::move(decorations)), return_type_decorations_(std::move(return_type_decorations)) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body, program_id); for (auto* param : params_) { TINT_ASSERT(param && param->is_const());
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index 1ec2fbb..775afaa 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc
@@ -79,6 +79,16 @@ "internal compiler error"); } +TEST_F(FunctionTest, Assert_DifferentProgramID_Symbol) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b1; + ProgramBuilder b2; + b1.Func(b2.Sym("func"), VariableList{}, b1.ty.void_(), StatementList{}); + }, + "internal compiler error"); +} + TEST_F(FunctionTest, Assert_DifferentProgramID_Param) { EXPECT_FATAL_FAILURE( {
diff --git a/src/ast/identifier_expression.cc b/src/ast/identifier_expression.cc index 95784e0..b737fd3 100644 --- a/src/ast/identifier_expression.cc +++ b/src/ast/identifier_expression.cc
@@ -25,6 +25,7 @@ const Source& source, Symbol sym) : Base(program_id, source), sym_(sym) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(sym_, program_id); TINT_ASSERT(sym_.IsValid()); }
diff --git a/src/ast/identifier_expression_test.cc b/src/ast/identifier_expression_test.cc index cea616d..385c923 100644 --- a/src/ast/identifier_expression_test.cc +++ b/src/ast/identifier_expression_test.cc
@@ -23,12 +23,12 @@ TEST_F(IdentifierExpressionTest, Creation) { auto* i = Expr("ident"); - EXPECT_EQ(i->symbol(), Symbol(1)); + EXPECT_EQ(i->symbol(), Symbol(1, ID())); } TEST_F(IdentifierExpressionTest, Creation_WithSource) { auto* i = Expr(Source{Source::Location{20, 2}}, "ident"); - EXPECT_EQ(i->symbol(), Symbol(1)); + EXPECT_EQ(i->symbol(), Symbol(1, ID())); auto src = i->source(); EXPECT_EQ(src.range.begin.line, 20u); @@ -49,6 +49,16 @@ "internal compiler error"); } +TEST_F(IdentifierExpressionTest, Assert_DifferentProgramID_Symbol) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b1; + ProgramBuilder b2; + b1.Expr(b2.Sym("")); + }, + "internal compiler error"); +} + TEST_F(IdentifierExpressionTest, ToStr) { auto* i = Expr("ident"); EXPECT_EQ(str(i), R"(Identifier[not set]{ident}
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc index 882d0c1..1f229d0 100644 --- a/src/ast/struct_member.cc +++ b/src/ast/struct_member.cc
@@ -32,6 +32,7 @@ decorations_(std::move(decorations)) { TINT_ASSERT(type); TINT_ASSERT(symbol_.IsValid()); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id); for (auto* deco : decorations_) { TINT_ASSERT(deco); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
diff --git a/src/ast/struct_member_test.cc b/src/ast/struct_member_test.cc index 458e259..d5570a0 100644 --- a/src/ast/struct_member_test.cc +++ b/src/ast/struct_member_test.cc
@@ -23,7 +23,7 @@ TEST_F(StructMemberTest, Creation) { auto* st = Member("a", ty.i32(), {MemberSize(4)}); - EXPECT_EQ(st->symbol(), Symbol(1)); + EXPECT_EQ(st->symbol(), Symbol(1, ID())); EXPECT_EQ(st->type(), ty.i32()); EXPECT_EQ(st->decorations().size(), 1u); EXPECT_TRUE(st->decorations()[0]->Is<StructMemberSizeDecoration>()); @@ -37,7 +37,7 @@ auto* st = Member( Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 8}}}, "a", ty.i32()); - EXPECT_EQ(st->symbol(), Symbol(1)); + EXPECT_EQ(st->symbol(), Symbol(1, ID())); EXPECT_EQ(st->type(), ty.i32()); EXPECT_EQ(st->decorations().size(), 0u); EXPECT_EQ(st->source().range.begin.line, 27u); @@ -73,6 +73,16 @@ "internal compiler error"); } +TEST_F(StructMemberTest, Assert_DifferentProgramID_Symbol) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b1; + ProgramBuilder b2; + b1.Member(b2.Sym("a"), b1.ty.i32(), {b1.MemberSize(4)}); + }, + "internal compiler error"); +} + TEST_F(StructMemberTest, Assert_DifferentProgramID_Decoration) { EXPECT_FATAL_FAILURE( {
diff --git a/src/ast/variable.cc b/src/ast/variable.cc index 92057bc..83e66b2 100644 --- a/src/ast/variable.cc +++ b/src/ast/variable.cc
@@ -39,6 +39,7 @@ decorations_(std::move(decorations)), declared_storage_class_(declared_storage_class) { TINT_ASSERT(symbol_.IsValid()); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id); // no type means we must have a constructor to infer it TINT_ASSERT(declared_type_ || constructor); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(constructor, program_id);
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc index 9378bf6..c7bfd5f 100644 --- a/src/ast/variable_test.cc +++ b/src/ast/variable_test.cc
@@ -25,7 +25,7 @@ TEST_F(VariableTest, Creation) { auto* v = Var("my_var", ty.i32(), StorageClass::kFunction); - EXPECT_EQ(v->symbol(), Symbol(1)); + EXPECT_EQ(v->symbol(), Symbol(1, ID())); EXPECT_EQ(v->declared_storage_class(), StorageClass::kFunction); EXPECT_EQ(v->declared_type(), ty.i32()); EXPECT_EQ(v->source().range.begin.line, 0u); @@ -39,7 +39,7 @@ Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 5}}}, "i", ty.f32(), StorageClass::kPrivate, nullptr, DecorationList{}); - EXPECT_EQ(v->symbol(), Symbol(1)); + EXPECT_EQ(v->symbol(), Symbol(1, ID())); EXPECT_EQ(v->declared_storage_class(), StorageClass::kPrivate); EXPECT_EQ(v->declared_type(), ty.f32()); EXPECT_EQ(v->source().range.begin.line, 27u); @@ -53,7 +53,7 @@ Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 7}}}, "a_var", ty.i32(), StorageClass::kWorkgroup, nullptr, DecorationList{}); - EXPECT_EQ(v->symbol(), Symbol(1)); + EXPECT_EQ(v->symbol(), Symbol(1, ID())); EXPECT_EQ(v->declared_storage_class(), StorageClass::kWorkgroup); EXPECT_EQ(v->declared_type(), ty.i32()); EXPECT_EQ(v->source().range.begin.line, 27u); @@ -80,6 +80,16 @@ "internal compiler error"); } +TEST_F(VariableTest, Assert_DifferentProgramID_Symbol) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b1; + ProgramBuilder b2; + b1.Var(b2.Sym("x"), b1.ty.f32(), StorageClass::kNone); + }, + "internal compiler error"); +} + TEST_F(VariableTest, Assert_DifferentProgramID_Constructor) { EXPECT_FATAL_FAILURE( {
diff --git a/src/demangler.cc b/src/demangler.cc index 66a2955..762db42 100644 --- a/src/demangler.cc +++ b/src/demangler.cc
@@ -50,7 +50,7 @@ auto len = end_idx - start_idx; auto id = str.substr(start_idx, len); - Symbol sym(std::stoi(id)); + Symbol sym(std::stoi(id), symbols.ProgramID()); out << symbols.NameFor(sym); pos = end_idx;
diff --git a/src/demangler_test.cc b/src/demangler_test.cc index f752bf8..e1ccdbf 100644 --- a/src/demangler_test.cc +++ b/src/demangler_test.cc
@@ -23,7 +23,7 @@ using DemanglerTest = testing::Test; TEST_F(DemanglerTest, NoSymbols) { - SymbolTable t; + SymbolTable t{ProgramID::New()}; t.Register("sym1"); Demangler d; @@ -31,7 +31,7 @@ } TEST_F(DemanglerTest, Symbol) { - SymbolTable t; + SymbolTable t{ProgramID::New()}; t.Register("sym1"); Demangler d; @@ -39,7 +39,7 @@ } TEST_F(DemanglerTest, MultipleSymbols) { - SymbolTable t; + SymbolTable t{ProgramID::New()}; t.Register("sym1"); t.Register("sym2");
diff --git a/src/program.h b/src/program.h index 27cb929..b7e5aaf 100644 --- a/src/program.h +++ b/src/program.h
@@ -165,7 +165,7 @@ SemNodeAllocator sem_nodes_; ast::Module* ast_ = nullptr; semantic::Info sem_; - SymbolTable symbols_; + SymbolTable symbols_{id_}; diag::List diagnostics_; bool is_valid_ = false; // Not valid until it is built bool moved_ = false;
diff --git a/src/program_builder.h b/src/program_builder.h index 716c9f9..6919904 100644 --- a/src/program_builder.h +++ b/src/program_builder.h
@@ -1435,7 +1435,7 @@ SemNodeAllocator sem_nodes_; ast::Module* ast_; semantic::Info sem_; - SymbolTable symbols_; + SymbolTable symbols_{id_}; diag::List diagnostics_; /// The source to use when creating AST nodes without providing a Source as
diff --git a/src/scope_stack_test.cc b/src/scope_stack_test.cc index c373359..85bd86b 100644 --- a/src/scope_stack_test.cc +++ b/src/scope_stack_test.cc
@@ -23,7 +23,7 @@ TEST_F(ScopeStackTest, Global) { ScopeStack<uint32_t> s; - Symbol sym(1); + Symbol sym(1, ID()); s.set_global(sym, 5); uint32_t val = 0; @@ -43,7 +43,7 @@ TEST_F(ScopeStackTest, Global_CanNotPop) { ScopeStack<uint32_t> s; - Symbol sym(1); + Symbol sym(1, ID()); s.set_global(sym, 5); s.pop_scope(); @@ -54,7 +54,7 @@ TEST_F(ScopeStackTest, Scope) { ScopeStack<uint32_t> s; - Symbol sym(1); + Symbol sym(1, ID()); s.push_scope(); s.set(sym, 5); @@ -65,7 +65,7 @@ TEST_F(ScopeStackTest, Get_MissingSymbol) { ScopeStack<uint32_t> s; - Symbol sym(1); + Symbol sym(1, ID()); uint32_t ret = 0; EXPECT_FALSE(s.get(sym, &ret)); EXPECT_EQ(ret, 0u); @@ -73,8 +73,8 @@ TEST_F(ScopeStackTest, Has) { ScopeStack<uint32_t> s; - Symbol sym(1); - Symbol sym2(2); + Symbol sym(1, ID()); + Symbol sym2(2, ID()); s.set_global(sym2, 3); s.push_scope(); s.set(sym, 5); @@ -85,7 +85,7 @@ TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) { ScopeStack<uint32_t> s; - Symbol sym(1); + Symbol sym(1, ID()); s.set_global(sym, 3); s.push_scope(); s.set(sym, 5);
diff --git a/src/semantic/sem_call_target.cc b/src/semantic/sem_call_target.cc index 263c797..5f78a3e 100644 --- a/src/semantic/sem_call_target.cc +++ b/src/semantic/sem_call_target.cc
@@ -67,7 +67,7 @@ } std::ostream& operator<<(std::ostream& out, Parameter parameter) { - out << "[type: " << parameter.type->FriendlyName(SymbolTable{}) + out << "[type: " << parameter.type->FriendlyName(SymbolTable{ProgramID{}}) << ", usage: " << str(parameter.usage) << "]"; return out; }
diff --git a/src/symbol.cc b/src/symbol.cc index 2f4338e..13db168 100644 --- a/src/symbol.cc +++ b/src/symbol.cc
@@ -18,7 +18,8 @@ Symbol::Symbol() = default; -Symbol::Symbol(uint32_t val) : val_(val) {} +Symbol::Symbol(uint32_t val, tint::ProgramID program_id) + : val_(val), program_id_(program_id) {} Symbol::Symbol(const Symbol& o) = default; @@ -31,6 +32,7 @@ Symbol& Symbol::operator=(Symbol&& o) = default; bool Symbol::operator==(const Symbol& other) const { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(program_id_, other.program_id_); return val_ == other.val_; }
diff --git a/src/symbol.h b/src/symbol.h index 3a92d6f..286e20d 100644 --- a/src/symbol.h +++ b/src/symbol.h
@@ -17,6 +17,8 @@ #include <string> +#include "src/program_id.h" + namespace tint { /// A symbol representing a string in the system @@ -27,7 +29,8 @@ Symbol(); /// Constructor /// @param val the symbol value - explicit Symbol(uint32_t val); + /// @param program_id the identifier of the program that owns this Symbol + Symbol(uint32_t val, tint::ProgramID program_id); /// Copy constructor /// @param o the symbol to copy Symbol(const Symbol& o); @@ -61,10 +64,20 @@ /// @return the string representation of the symbol std::string to_str() const; + /// @returns the identifier of the Program that owns this symbol. + tint::ProgramID ProgramID() const { return program_id_; } + private: uint32_t val_ = static_cast<uint32_t>(-1); + tint::ProgramID program_id_; }; +/// @param sym the Symbol +/// @returns the ProgramID that owns the given Symbol +inline ProgramID ProgramIDOf(Symbol sym) { + return sym.IsValid() ? sym.ProgramID() : ProgramID(); +} + } // namespace tint namespace std {
diff --git a/src/symbol_table.cc b/src/symbol_table.cc index 5be19f3..a667e7d 100644 --- a/src/symbol_table.cc +++ b/src/symbol_table.cc
@@ -14,9 +14,12 @@ #include "src/symbol_table.h" +#include "src/debug.h" + namespace tint { -SymbolTable::SymbolTable() = default; +SymbolTable::SymbolTable(tint::ProgramID program_id) + : program_id_(program_id) {} SymbolTable::SymbolTable(const SymbolTable&) = default; @@ -36,7 +39,7 @@ if (it != name_to_symbol_.end()) return it->second; - Symbol sym(next_symbol_); + Symbol sym(next_symbol_, program_id_); ++next_symbol_; name_to_symbol_[name] = sym; @@ -51,6 +54,7 @@ } std::string SymbolTable::NameFor(const Symbol symbol) const { + TINT_ASSERT_PROGRAM_IDS_EQUAL(program_id_, symbol); auto it = symbol_to_name_.find(symbol); if (it == symbol_to_name_.end()) { return symbol.to_str();
diff --git a/src/symbol_table.h b/src/symbol_table.h index 61940b9..10248c1 100644 --- a/src/symbol_table.h +++ b/src/symbol_table.h
@@ -26,7 +26,9 @@ class SymbolTable { public: /// Constructor - SymbolTable(); + /// @param program_id the identifier of the program that owns this symbol + /// table + explicit SymbolTable(tint::ProgramID program_id); /// Copy constructor SymbolTable(const SymbolTable&); /// Move Constructor @@ -76,14 +78,24 @@ } } + /// @returns the identifier of the Program that owns this symbol table. + tint::ProgramID ProgramID() const { return program_id_; } + private: // The value to be associated to the next registered symbol table entry. uint32_t next_symbol_ = 1; std::unordered_map<Symbol, std::string> symbol_to_name_; std::unordered_map<std::string, Symbol> name_to_symbol_; + tint::ProgramID program_id_; }; +/// @param symbol_table the SymbolTable +/// @returns the ProgramID that owns the given SymbolTable +inline ProgramID ProgramIDOf(const SymbolTable& symbol_table) { + return symbol_table.ProgramID(); +} + } // namespace tint #endif // SRC_SYMBOL_TABLE_H_
diff --git a/src/symbol_table_test.cc b/src/symbol_table_test.cc index f3dda91..11741b8 100644 --- a/src/symbol_table_test.cc +++ b/src/symbol_table_test.cc
@@ -22,31 +22,36 @@ using SymbolTableTest = testing::Test; TEST_F(SymbolTableTest, GeneratesSymbolForName) { - SymbolTable s; - EXPECT_EQ(Symbol(1), s.Register("name")); - EXPECT_EQ(Symbol(2), s.Register("another_name")); + auto program_id = ProgramID::New(); + SymbolTable s{program_id}; + EXPECT_EQ(Symbol(1, program_id), s.Register("name")); + EXPECT_EQ(Symbol(2, program_id), s.Register("another_name")); } TEST_F(SymbolTableTest, DeduplicatesNames) { - SymbolTable s; - EXPECT_EQ(Symbol(1), s.Register("name")); - EXPECT_EQ(Symbol(2), s.Register("another_name")); - EXPECT_EQ(Symbol(1), s.Register("name")); + auto program_id = ProgramID::New(); + SymbolTable s{program_id}; + EXPECT_EQ(Symbol(1, program_id), s.Register("name")); + EXPECT_EQ(Symbol(2, program_id), s.Register("another_name")); + EXPECT_EQ(Symbol(1, program_id), s.Register("name")); } TEST_F(SymbolTableTest, ReturnsNameForSymbol) { - SymbolTable s; + auto program_id = ProgramID::New(); + SymbolTable s{program_id}; auto sym = s.Register("name"); EXPECT_EQ("name", s.NameFor(sym)); } TEST_F(SymbolTableTest, ReturnsBlankForMissingSymbol) { - SymbolTable s; - EXPECT_EQ("$2", s.NameFor(Symbol(2))); + auto program_id = ProgramID::New(); + SymbolTable s{program_id}; + EXPECT_EQ("$2", s.NameFor(Symbol(2, program_id))); } TEST_F(SymbolTableTest, ReturnsInvalidForBlankString) { - SymbolTable s; + auto program_id = ProgramID::New(); + SymbolTable s{program_id}; EXPECT_FALSE(s.Register("").IsValid()); }
diff --git a/src/symbol_test.cc b/src/symbol_test.cc index a69d2cc..3c13a92 100644 --- a/src/symbol_test.cc +++ b/src/symbol_test.cc
@@ -22,12 +22,12 @@ using SymbolTest = testing::Test; TEST_F(SymbolTest, ToStr) { - Symbol sym(1); + Symbol sym(1, ProgramID::New()); EXPECT_EQ("$1", sym.to_str()); } TEST_F(SymbolTest, CopyAssign) { - Symbol sym1(1); + Symbol sym1(1, ProgramID::New()); Symbol sym2; EXPECT_FALSE(sym2.IsValid()); @@ -37,9 +37,10 @@ } TEST_F(SymbolTest, Comparison) { - Symbol sym1(1); - Symbol sym2(2); - Symbol sym3(1); + auto program_id = ProgramID::New(); + Symbol sym1(1, program_id); + Symbol sym2(2, program_id); + Symbol sym3(1, program_id); EXPECT_TRUE(sym1 == sym3); EXPECT_FALSE(sym1 == sym2);
diff --git a/src/type/alias_type_test.cc b/src/type/alias_type_test.cc index 6e4354c..cb7e51e 100644 --- a/src/type/alias_type_test.cc +++ b/src/type/alias_type_test.cc
@@ -24,7 +24,7 @@ TEST_F(AliasTest, Create) { auto* a = ty.alias("a_type", ty.u32()); - EXPECT_EQ(a->symbol(), Symbol(1)); + EXPECT_EQ(a->symbol(), Symbol(1, ID())); EXPECT_EQ(a->type(), ty.u32()); } @@ -58,7 +58,7 @@ TEST_F(AliasTest, UnwrapIfNeeded_Alias) { auto* a = ty.alias("a_type", ty.u32()); - EXPECT_EQ(a->symbol(), Symbol(1)); + EXPECT_EQ(a->symbol(), Symbol(1, ID())); EXPECT_EQ(a->type(), ty.u32()); EXPECT_EQ(a->UnwrapIfNeeded(), ty.u32()); EXPECT_EQ(ty.u32()->UnwrapIfNeeded(), ty.u32()); @@ -74,7 +74,7 @@ auto* a = ty.alias("a_type", ty.u32()); auto* aa = ty.alias("aa_type", a); - EXPECT_EQ(aa->symbol(), Symbol(2)); + EXPECT_EQ(aa->symbol(), Symbol(2, ID())); EXPECT_EQ(aa->type(), a); EXPECT_EQ(aa->UnwrapIfNeeded(), ty.u32()); } @@ -94,7 +94,7 @@ auto* apaa = ty.alias("paa_type", &paa); auto* aapaa = ty.alias("aapaa_type", apaa); - EXPECT_EQ(aapaa->symbol(), Symbol(4)); + EXPECT_EQ(aapaa->symbol(), Symbol(4, ID())); EXPECT_EQ(aapaa->type(), apaa); EXPECT_EQ(aapaa->UnwrapAll(), ty.u32()); }