Validate that Symbols are all part of the same program
Assert in each AST constructor that symbols belong to the program of the parent.
Bug: tint:709
Change-Id: I82ae9b23c88e89714a44e057a0272f0293385aaf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47624
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index b656d6b..989d5e1 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -38,6 +38,7 @@
body_(body),
decorations_(std::move(decorations)),
return_type_decorations_(std::move(return_type_decorations)) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body, program_id);
for (auto* param : params_) {
TINT_ASSERT(param && param->is_const());
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 1ec2fbb..775afaa 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -79,6 +79,16 @@
"internal compiler error");
}
+TEST_F(FunctionTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func(b2.Sym("func"), VariableList{}, b1.ty.void_(), StatementList{});
+ },
+ "internal compiler error");
+}
+
TEST_F(FunctionTest, Assert_DifferentProgramID_Param) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/ast/identifier_expression.cc b/src/ast/identifier_expression.cc
index 95784e0..b737fd3 100644
--- a/src/ast/identifier_expression.cc
+++ b/src/ast/identifier_expression.cc
@@ -25,6 +25,7 @@
const Source& source,
Symbol sym)
: Base(program_id, source), sym_(sym) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(sym_, program_id);
TINT_ASSERT(sym_.IsValid());
}
diff --git a/src/ast/identifier_expression_test.cc b/src/ast/identifier_expression_test.cc
index cea616d..385c923 100644
--- a/src/ast/identifier_expression_test.cc
+++ b/src/ast/identifier_expression_test.cc
@@ -23,12 +23,12 @@
TEST_F(IdentifierExpressionTest, Creation) {
auto* i = Expr("ident");
- EXPECT_EQ(i->symbol(), Symbol(1));
+ EXPECT_EQ(i->symbol(), Symbol(1, ID()));
}
TEST_F(IdentifierExpressionTest, Creation_WithSource) {
auto* i = Expr(Source{Source::Location{20, 2}}, "ident");
- EXPECT_EQ(i->symbol(), Symbol(1));
+ EXPECT_EQ(i->symbol(), Symbol(1, ID()));
auto src = i->source();
EXPECT_EQ(src.range.begin.line, 20u);
@@ -49,6 +49,16 @@
"internal compiler error");
}
+TEST_F(IdentifierExpressionTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Expr(b2.Sym(""));
+ },
+ "internal compiler error");
+}
+
TEST_F(IdentifierExpressionTest, ToStr) {
auto* i = Expr("ident");
EXPECT_EQ(str(i), R"(Identifier[not set]{ident}
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc
index 882d0c1..1f229d0 100644
--- a/src/ast/struct_member.cc
+++ b/src/ast/struct_member.cc
@@ -32,6 +32,7 @@
decorations_(std::move(decorations)) {
TINT_ASSERT(type);
TINT_ASSERT(symbol_.IsValid());
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
for (auto* deco : decorations_) {
TINT_ASSERT(deco);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
diff --git a/src/ast/struct_member_test.cc b/src/ast/struct_member_test.cc
index 458e259..d5570a0 100644
--- a/src/ast/struct_member_test.cc
+++ b/src/ast/struct_member_test.cc
@@ -23,7 +23,7 @@
TEST_F(StructMemberTest, Creation) {
auto* st = Member("a", ty.i32(), {MemberSize(4)});
- EXPECT_EQ(st->symbol(), Symbol(1));
+ EXPECT_EQ(st->symbol(), Symbol(1, ID()));
EXPECT_EQ(st->type(), ty.i32());
EXPECT_EQ(st->decorations().size(), 1u);
EXPECT_TRUE(st->decorations()[0]->Is<StructMemberSizeDecoration>());
@@ -37,7 +37,7 @@
auto* st = Member(
Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 8}}},
"a", ty.i32());
- EXPECT_EQ(st->symbol(), Symbol(1));
+ EXPECT_EQ(st->symbol(), Symbol(1, ID()));
EXPECT_EQ(st->type(), ty.i32());
EXPECT_EQ(st->decorations().size(), 0u);
EXPECT_EQ(st->source().range.begin.line, 27u);
@@ -73,6 +73,16 @@
"internal compiler error");
}
+TEST_F(StructMemberTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Member(b2.Sym("a"), b1.ty.i32(), {b1.MemberSize(4)});
+ },
+ "internal compiler error");
+}
+
TEST_F(StructMemberTest, Assert_DifferentProgramID_Decoration) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/ast/variable.cc b/src/ast/variable.cc
index 92057bc..83e66b2 100644
--- a/src/ast/variable.cc
+++ b/src/ast/variable.cc
@@ -39,6 +39,7 @@
decorations_(std::move(decorations)),
declared_storage_class_(declared_storage_class) {
TINT_ASSERT(symbol_.IsValid());
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
// no type means we must have a constructor to infer it
TINT_ASSERT(declared_type_ || constructor);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(constructor, program_id);
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc
index 9378bf6..c7bfd5f 100644
--- a/src/ast/variable_test.cc
+++ b/src/ast/variable_test.cc
@@ -25,7 +25,7 @@
TEST_F(VariableTest, Creation) {
auto* v = Var("my_var", ty.i32(), StorageClass::kFunction);
- EXPECT_EQ(v->symbol(), Symbol(1));
+ EXPECT_EQ(v->symbol(), Symbol(1, ID()));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kFunction);
EXPECT_EQ(v->declared_type(), ty.i32());
EXPECT_EQ(v->source().range.begin.line, 0u);
@@ -39,7 +39,7 @@
Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 5}}},
"i", ty.f32(), StorageClass::kPrivate, nullptr, DecorationList{});
- EXPECT_EQ(v->symbol(), Symbol(1));
+ EXPECT_EQ(v->symbol(), Symbol(1, ID()));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kPrivate);
EXPECT_EQ(v->declared_type(), ty.f32());
EXPECT_EQ(v->source().range.begin.line, 27u);
@@ -53,7 +53,7 @@
Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 7}}},
"a_var", ty.i32(), StorageClass::kWorkgroup, nullptr, DecorationList{});
- EXPECT_EQ(v->symbol(), Symbol(1));
+ EXPECT_EQ(v->symbol(), Symbol(1, ID()));
EXPECT_EQ(v->declared_storage_class(), StorageClass::kWorkgroup);
EXPECT_EQ(v->declared_type(), ty.i32());
EXPECT_EQ(v->source().range.begin.line, 27u);
@@ -80,6 +80,16 @@
"internal compiler error");
}
+TEST_F(VariableTest, Assert_DifferentProgramID_Symbol) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Var(b2.Sym("x"), b1.ty.f32(), StorageClass::kNone);
+ },
+ "internal compiler error");
+}
+
TEST_F(VariableTest, Assert_DifferentProgramID_Constructor) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/demangler.cc b/src/demangler.cc
index 66a2955..762db42 100644
--- a/src/demangler.cc
+++ b/src/demangler.cc
@@ -50,7 +50,7 @@
auto len = end_idx - start_idx;
auto id = str.substr(start_idx, len);
- Symbol sym(std::stoi(id));
+ Symbol sym(std::stoi(id), symbols.ProgramID());
out << symbols.NameFor(sym);
pos = end_idx;
diff --git a/src/demangler_test.cc b/src/demangler_test.cc
index f752bf8..e1ccdbf 100644
--- a/src/demangler_test.cc
+++ b/src/demangler_test.cc
@@ -23,7 +23,7 @@
using DemanglerTest = testing::Test;
TEST_F(DemanglerTest, NoSymbols) {
- SymbolTable t;
+ SymbolTable t{ProgramID::New()};
t.Register("sym1");
Demangler d;
@@ -31,7 +31,7 @@
}
TEST_F(DemanglerTest, Symbol) {
- SymbolTable t;
+ SymbolTable t{ProgramID::New()};
t.Register("sym1");
Demangler d;
@@ -39,7 +39,7 @@
}
TEST_F(DemanglerTest, MultipleSymbols) {
- SymbolTable t;
+ SymbolTable t{ProgramID::New()};
t.Register("sym1");
t.Register("sym2");
diff --git a/src/program.h b/src/program.h
index 27cb929..b7e5aaf 100644
--- a/src/program.h
+++ b/src/program.h
@@ -165,7 +165,7 @@
SemNodeAllocator sem_nodes_;
ast::Module* ast_ = nullptr;
semantic::Info sem_;
- SymbolTable symbols_;
+ SymbolTable symbols_{id_};
diag::List diagnostics_;
bool is_valid_ = false; // Not valid until it is built
bool moved_ = false;
diff --git a/src/program_builder.h b/src/program_builder.h
index 716c9f9..6919904 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -1435,7 +1435,7 @@
SemNodeAllocator sem_nodes_;
ast::Module* ast_;
semantic::Info sem_;
- SymbolTable symbols_;
+ SymbolTable symbols_{id_};
diag::List diagnostics_;
/// The source to use when creating AST nodes without providing a Source as
diff --git a/src/scope_stack_test.cc b/src/scope_stack_test.cc
index c373359..85bd86b 100644
--- a/src/scope_stack_test.cc
+++ b/src/scope_stack_test.cc
@@ -23,7 +23,7 @@
TEST_F(ScopeStackTest, Global) {
ScopeStack<uint32_t> s;
- Symbol sym(1);
+ Symbol sym(1, ID());
s.set_global(sym, 5);
uint32_t val = 0;
@@ -43,7 +43,7 @@
TEST_F(ScopeStackTest, Global_CanNotPop) {
ScopeStack<uint32_t> s;
- Symbol sym(1);
+ Symbol sym(1, ID());
s.set_global(sym, 5);
s.pop_scope();
@@ -54,7 +54,7 @@
TEST_F(ScopeStackTest, Scope) {
ScopeStack<uint32_t> s;
- Symbol sym(1);
+ Symbol sym(1, ID());
s.push_scope();
s.set(sym, 5);
@@ -65,7 +65,7 @@
TEST_F(ScopeStackTest, Get_MissingSymbol) {
ScopeStack<uint32_t> s;
- Symbol sym(1);
+ Symbol sym(1, ID());
uint32_t ret = 0;
EXPECT_FALSE(s.get(sym, &ret));
EXPECT_EQ(ret, 0u);
@@ -73,8 +73,8 @@
TEST_F(ScopeStackTest, Has) {
ScopeStack<uint32_t> s;
- Symbol sym(1);
- Symbol sym2(2);
+ Symbol sym(1, ID());
+ Symbol sym2(2, ID());
s.set_global(sym2, 3);
s.push_scope();
s.set(sym, 5);
@@ -85,7 +85,7 @@
TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) {
ScopeStack<uint32_t> s;
- Symbol sym(1);
+ Symbol sym(1, ID());
s.set_global(sym, 3);
s.push_scope();
s.set(sym, 5);
diff --git a/src/semantic/sem_call_target.cc b/src/semantic/sem_call_target.cc
index 263c797..5f78a3e 100644
--- a/src/semantic/sem_call_target.cc
+++ b/src/semantic/sem_call_target.cc
@@ -67,7 +67,7 @@
}
std::ostream& operator<<(std::ostream& out, Parameter parameter) {
- out << "[type: " << parameter.type->FriendlyName(SymbolTable{})
+ out << "[type: " << parameter.type->FriendlyName(SymbolTable{ProgramID{}})
<< ", usage: " << str(parameter.usage) << "]";
return out;
}
diff --git a/src/symbol.cc b/src/symbol.cc
index 2f4338e..13db168 100644
--- a/src/symbol.cc
+++ b/src/symbol.cc
@@ -18,7 +18,8 @@
Symbol::Symbol() = default;
-Symbol::Symbol(uint32_t val) : val_(val) {}
+Symbol::Symbol(uint32_t val, tint::ProgramID program_id)
+ : val_(val), program_id_(program_id) {}
Symbol::Symbol(const Symbol& o) = default;
@@ -31,6 +32,7 @@
Symbol& Symbol::operator=(Symbol&& o) = default;
bool Symbol::operator==(const Symbol& other) const {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(program_id_, other.program_id_);
return val_ == other.val_;
}
diff --git a/src/symbol.h b/src/symbol.h
index 3a92d6f..286e20d 100644
--- a/src/symbol.h
+++ b/src/symbol.h
@@ -17,6 +17,8 @@
#include <string>
+#include "src/program_id.h"
+
namespace tint {
/// A symbol representing a string in the system
@@ -27,7 +29,8 @@
Symbol();
/// Constructor
/// @param val the symbol value
- explicit Symbol(uint32_t val);
+ /// @param program_id the identifier of the program that owns this Symbol
+ Symbol(uint32_t val, tint::ProgramID program_id);
/// Copy constructor
/// @param o the symbol to copy
Symbol(const Symbol& o);
@@ -61,10 +64,20 @@
/// @return the string representation of the symbol
std::string to_str() const;
+ /// @returns the identifier of the Program that owns this symbol.
+ tint::ProgramID ProgramID() const { return program_id_; }
+
private:
uint32_t val_ = static_cast<uint32_t>(-1);
+ tint::ProgramID program_id_;
};
+/// @param sym the Symbol
+/// @returns the ProgramID that owns the given Symbol
+inline ProgramID ProgramIDOf(Symbol sym) {
+ return sym.IsValid() ? sym.ProgramID() : ProgramID();
+}
+
} // namespace tint
namespace std {
diff --git a/src/symbol_table.cc b/src/symbol_table.cc
index 5be19f3..a667e7d 100644
--- a/src/symbol_table.cc
+++ b/src/symbol_table.cc
@@ -14,9 +14,12 @@
#include "src/symbol_table.h"
+#include "src/debug.h"
+
namespace tint {
-SymbolTable::SymbolTable() = default;
+SymbolTable::SymbolTable(tint::ProgramID program_id)
+ : program_id_(program_id) {}
SymbolTable::SymbolTable(const SymbolTable&) = default;
@@ -36,7 +39,7 @@
if (it != name_to_symbol_.end())
return it->second;
- Symbol sym(next_symbol_);
+ Symbol sym(next_symbol_, program_id_);
++next_symbol_;
name_to_symbol_[name] = sym;
@@ -51,6 +54,7 @@
}
std::string SymbolTable::NameFor(const Symbol symbol) const {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL(program_id_, symbol);
auto it = symbol_to_name_.find(symbol);
if (it == symbol_to_name_.end()) {
return symbol.to_str();
diff --git a/src/symbol_table.h b/src/symbol_table.h
index 61940b9..10248c1 100644
--- a/src/symbol_table.h
+++ b/src/symbol_table.h
@@ -26,7 +26,9 @@
class SymbolTable {
public:
/// Constructor
- SymbolTable();
+ /// @param program_id the identifier of the program that owns this symbol
+ /// table
+ explicit SymbolTable(tint::ProgramID program_id);
/// Copy constructor
SymbolTable(const SymbolTable&);
/// Move Constructor
@@ -76,14 +78,24 @@
}
}
+ /// @returns the identifier of the Program that owns this symbol table.
+ tint::ProgramID ProgramID() const { return program_id_; }
+
private:
// The value to be associated to the next registered symbol table entry.
uint32_t next_symbol_ = 1;
std::unordered_map<Symbol, std::string> symbol_to_name_;
std::unordered_map<std::string, Symbol> name_to_symbol_;
+ tint::ProgramID program_id_;
};
+/// @param symbol_table the SymbolTable
+/// @returns the ProgramID that owns the given SymbolTable
+inline ProgramID ProgramIDOf(const SymbolTable& symbol_table) {
+ return symbol_table.ProgramID();
+}
+
} // namespace tint
#endif // SRC_SYMBOL_TABLE_H_
diff --git a/src/symbol_table_test.cc b/src/symbol_table_test.cc
index f3dda91..11741b8 100644
--- a/src/symbol_table_test.cc
+++ b/src/symbol_table_test.cc
@@ -22,31 +22,36 @@
using SymbolTableTest = testing::Test;
TEST_F(SymbolTableTest, GeneratesSymbolForName) {
- SymbolTable s;
- EXPECT_EQ(Symbol(1), s.Register("name"));
- EXPECT_EQ(Symbol(2), s.Register("another_name"));
+ auto program_id = ProgramID::New();
+ SymbolTable s{program_id};
+ EXPECT_EQ(Symbol(1, program_id), s.Register("name"));
+ EXPECT_EQ(Symbol(2, program_id), s.Register("another_name"));
}
TEST_F(SymbolTableTest, DeduplicatesNames) {
- SymbolTable s;
- EXPECT_EQ(Symbol(1), s.Register("name"));
- EXPECT_EQ(Symbol(2), s.Register("another_name"));
- EXPECT_EQ(Symbol(1), s.Register("name"));
+ auto program_id = ProgramID::New();
+ SymbolTable s{program_id};
+ EXPECT_EQ(Symbol(1, program_id), s.Register("name"));
+ EXPECT_EQ(Symbol(2, program_id), s.Register("another_name"));
+ EXPECT_EQ(Symbol(1, program_id), s.Register("name"));
}
TEST_F(SymbolTableTest, ReturnsNameForSymbol) {
- SymbolTable s;
+ auto program_id = ProgramID::New();
+ SymbolTable s{program_id};
auto sym = s.Register("name");
EXPECT_EQ("name", s.NameFor(sym));
}
TEST_F(SymbolTableTest, ReturnsBlankForMissingSymbol) {
- SymbolTable s;
- EXPECT_EQ("$2", s.NameFor(Symbol(2)));
+ auto program_id = ProgramID::New();
+ SymbolTable s{program_id};
+ EXPECT_EQ("$2", s.NameFor(Symbol(2, program_id)));
}
TEST_F(SymbolTableTest, ReturnsInvalidForBlankString) {
- SymbolTable s;
+ auto program_id = ProgramID::New();
+ SymbolTable s{program_id};
EXPECT_FALSE(s.Register("").IsValid());
}
diff --git a/src/symbol_test.cc b/src/symbol_test.cc
index a69d2cc..3c13a92 100644
--- a/src/symbol_test.cc
+++ b/src/symbol_test.cc
@@ -22,12 +22,12 @@
using SymbolTest = testing::Test;
TEST_F(SymbolTest, ToStr) {
- Symbol sym(1);
+ Symbol sym(1, ProgramID::New());
EXPECT_EQ("$1", sym.to_str());
}
TEST_F(SymbolTest, CopyAssign) {
- Symbol sym1(1);
+ Symbol sym1(1, ProgramID::New());
Symbol sym2;
EXPECT_FALSE(sym2.IsValid());
@@ -37,9 +37,10 @@
}
TEST_F(SymbolTest, Comparison) {
- Symbol sym1(1);
- Symbol sym2(2);
- Symbol sym3(1);
+ auto program_id = ProgramID::New();
+ Symbol sym1(1, program_id);
+ Symbol sym2(2, program_id);
+ Symbol sym3(1, program_id);
EXPECT_TRUE(sym1 == sym3);
EXPECT_FALSE(sym1 == sym2);
diff --git a/src/type/alias_type_test.cc b/src/type/alias_type_test.cc
index 6e4354c..cb7e51e 100644
--- a/src/type/alias_type_test.cc
+++ b/src/type/alias_type_test.cc
@@ -24,7 +24,7 @@
TEST_F(AliasTest, Create) {
auto* a = ty.alias("a_type", ty.u32());
- EXPECT_EQ(a->symbol(), Symbol(1));
+ EXPECT_EQ(a->symbol(), Symbol(1, ID()));
EXPECT_EQ(a->type(), ty.u32());
}
@@ -58,7 +58,7 @@
TEST_F(AliasTest, UnwrapIfNeeded_Alias) {
auto* a = ty.alias("a_type", ty.u32());
- EXPECT_EQ(a->symbol(), Symbol(1));
+ EXPECT_EQ(a->symbol(), Symbol(1, ID()));
EXPECT_EQ(a->type(), ty.u32());
EXPECT_EQ(a->UnwrapIfNeeded(), ty.u32());
EXPECT_EQ(ty.u32()->UnwrapIfNeeded(), ty.u32());
@@ -74,7 +74,7 @@
auto* a = ty.alias("a_type", ty.u32());
auto* aa = ty.alias("aa_type", a);
- EXPECT_EQ(aa->symbol(), Symbol(2));
+ EXPECT_EQ(aa->symbol(), Symbol(2, ID()));
EXPECT_EQ(aa->type(), a);
EXPECT_EQ(aa->UnwrapIfNeeded(), ty.u32());
}
@@ -94,7 +94,7 @@
auto* apaa = ty.alias("paa_type", &paa);
auto* aapaa = ty.alias("aapaa_type", apaa);
- EXPECT_EQ(aapaa->symbol(), Symbol(4));
+ EXPECT_EQ(aapaa->symbol(), Symbol(4, ID()));
EXPECT_EQ(aapaa->type(), apaa);
EXPECT_EQ(aapaa->UnwrapAll(), ty.u32());
}