Validate that Symbols are all part of the same program

Assert in each AST constructor that symbols belong to the program of the parent.

Bug: tint:709
Change-Id: I82ae9b23c88e89714a44e057a0272f0293385aaf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47624
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index b656d6b..989d5e1 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -38,6 +38,7 @@
       body_(body),
       decorations_(std::move(decorations)),
       return_type_decorations_(std::move(return_type_decorations)) {
+  TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
   TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body, program_id);
   for (auto* param : params_) {
     TINT_ASSERT(param && param->is_const());
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 1ec2fbb..775afaa 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -79,6 +79,16 @@
       "internal compiler error");
 }
 
+TEST_F(FunctionTest, Assert_DifferentProgramID_Symbol) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder b1;
+        ProgramBuilder b2;
+        b1.Func(b2.Sym("func"), VariableList{}, b1.ty.void_(), StatementList{});
+      },
+      "internal compiler error");
+}
+
 TEST_F(FunctionTest, Assert_DifferentProgramID_Param) {
   EXPECT_FATAL_FAILURE(
       {
diff --git a/src/ast/identifier_expression.cc b/src/ast/identifier_expression.cc
index 95784e0..b737fd3 100644
--- a/src/ast/identifier_expression.cc
+++ b/src/ast/identifier_expression.cc
@@ -25,6 +25,7 @@
                                            const Source& source,
                                            Symbol sym)
     : Base(program_id, source), sym_(sym) {
+  TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(sym_, program_id);
   TINT_ASSERT(sym_.IsValid());
 }
 
diff --git a/src/ast/identifier_expression_test.cc b/src/ast/identifier_expression_test.cc
index cea616d..385c923 100644
--- a/src/ast/identifier_expression_test.cc
+++ b/src/ast/identifier_expression_test.cc
@@ -23,12 +23,12 @@
 
 TEST_F(IdentifierExpressionTest, Creation) {
   auto* i = Expr("ident");
-  EXPECT_EQ(i->symbol(), Symbol(1));
+  EXPECT_EQ(i->symbol(), Symbol(1, ID()));
 }
 
 TEST_F(IdentifierExpressionTest, Creation_WithSource) {
   auto* i = Expr(Source{Source::Location{20, 2}}, "ident");
-  EXPECT_EQ(i->symbol(), Symbol(1));
+  EXPECT_EQ(i->symbol(), Symbol(1, ID()));
 
   auto src = i->source();
   EXPECT_EQ(src.range.begin.line, 20u);
@@ -49,6 +49,16 @@
       "internal compiler error");
 }
 
+TEST_F(IdentifierExpressionTest, Assert_DifferentProgramID_Symbol) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder b1;
+        ProgramBuilder b2;
+        b1.Expr(b2.Sym(""));
+      },
+      "internal compiler error");
+}
+
 TEST_F(IdentifierExpressionTest, ToStr) {
   auto* i = Expr("ident");
   EXPECT_EQ(str(i), R"(Identifier[not set]{ident}
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc
index 882d0c1..1f229d0 100644
--- a/src/ast/struct_member.cc
+++ b/src/ast/struct_member.cc
@@ -32,6 +32,7 @@
       decorations_(std::move(decorations)) {
   TINT_ASSERT(type);
   TINT_ASSERT(symbol_.IsValid());
+  TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
   for (auto* deco : decorations_) {
     TINT_ASSERT(deco);
     TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
diff --git a/src/ast/struct_member_test.cc b/src/ast/struct_member_test.cc
index 458e259..d5570a0 100644
--- a/src/ast/struct_member_test.cc
+++ b/src/ast/struct_member_test.cc
@@ -23,7 +23,7 @@
 
 TEST_F(StructMemberTest, Creation) {
   auto* st = Member("a", ty.i32(), {MemberSize(4)});
-  EXPECT_EQ(st->symbol(), Symbol(1));
+  EXPECT_EQ(st->symbol(), Symbol(1, ID()));
   EXPECT_EQ(st->type(), ty.i32());
   EXPECT_EQ(st->decorations().size(), 1u);
   EXPECT_TRUE(st->decorations()[0]->Is<StructMemberSizeDecoration>());
@@ -37,7 +37,7 @@
   auto* st = Member(
       Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 8}}},
       "a", ty.i32());
-  EXPECT_EQ(st->symbol(), Symbol(1));
+  EXPECT_EQ(st->symbol(), Symbol(1, ID()));
   EXPECT_EQ(st->type(), ty.i32());
   EXPECT_EQ(st->decorations().size(), 0u);
   EXPECT_EQ(st->source().range.begin.line, 27u);
@@ -73,6 +73,16 @@
       "internal compiler error");
 }
 
+TEST_F(StructMemberTest, Assert_DifferentProgramID_Symbol) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder b1;
+        ProgramBuilder b2;
+        b1.Member(b2.Sym("a"), b1.ty.i32(), {b1.MemberSize(4)});
+      },
+      "internal compiler error");
+}
+
 TEST_F(StructMemberTest, Assert_DifferentProgramID_Decoration) {
   EXPECT_FATAL_FAILURE(
       {
diff --git a/src/ast/variable.cc b/src/ast/variable.cc
index 92057bc..83e66b2 100644
--- a/src/ast/variable.cc
+++ b/src/ast/variable.cc
@@ -39,6 +39,7 @@
       decorations_(std::move(decorations)),
       declared_storage_class_(declared_storage_class) {
   TINT_ASSERT(symbol_.IsValid());
+  TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(symbol_, program_id);
   // no type means we must have a constructor to infer it
   TINT_ASSERT(declared_type_ || constructor);
   TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(constructor, program_id);
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc
index 9378bf6..c7bfd5f 100644
--- a/src/ast/variable_test.cc
+++ b/src/ast/variable_test.cc
@@ -25,7 +25,7 @@
 TEST_F(VariableTest, Creation) {
   auto* v = Var("my_var", ty.i32(), StorageClass::kFunction);
 
-  EXPECT_EQ(v->symbol(), Symbol(1));
+  EXPECT_EQ(v->symbol(), Symbol(1, ID()));
   EXPECT_EQ(v->declared_storage_class(), StorageClass::kFunction);
   EXPECT_EQ(v->declared_type(), ty.i32());
   EXPECT_EQ(v->source().range.begin.line, 0u);
@@ -39,7 +39,7 @@
       Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 5}}},
       "i", ty.f32(), StorageClass::kPrivate, nullptr, DecorationList{});
 
-  EXPECT_EQ(v->symbol(), Symbol(1));
+  EXPECT_EQ(v->symbol(), Symbol(1, ID()));
   EXPECT_EQ(v->declared_storage_class(), StorageClass::kPrivate);
   EXPECT_EQ(v->declared_type(), ty.f32());
   EXPECT_EQ(v->source().range.begin.line, 27u);
@@ -53,7 +53,7 @@
       Source{Source::Range{Source::Location{27, 4}, Source::Location{27, 7}}},
       "a_var", ty.i32(), StorageClass::kWorkgroup, nullptr, DecorationList{});
 
-  EXPECT_EQ(v->symbol(), Symbol(1));
+  EXPECT_EQ(v->symbol(), Symbol(1, ID()));
   EXPECT_EQ(v->declared_storage_class(), StorageClass::kWorkgroup);
   EXPECT_EQ(v->declared_type(), ty.i32());
   EXPECT_EQ(v->source().range.begin.line, 27u);
@@ -80,6 +80,16 @@
       "internal compiler error");
 }
 
+TEST_F(VariableTest, Assert_DifferentProgramID_Symbol) {
+  EXPECT_FATAL_FAILURE(
+      {
+        ProgramBuilder b1;
+        ProgramBuilder b2;
+        b1.Var(b2.Sym("x"), b1.ty.f32(), StorageClass::kNone);
+      },
+      "internal compiler error");
+}
+
 TEST_F(VariableTest, Assert_DifferentProgramID_Constructor) {
   EXPECT_FATAL_FAILURE(
       {
diff --git a/src/demangler.cc b/src/demangler.cc
index 66a2955..762db42 100644
--- a/src/demangler.cc
+++ b/src/demangler.cc
@@ -50,7 +50,7 @@
     auto len = end_idx - start_idx;
 
     auto id = str.substr(start_idx, len);
-    Symbol sym(std::stoi(id));
+    Symbol sym(std::stoi(id), symbols.ProgramID());
     out << symbols.NameFor(sym);
 
     pos = end_idx;
diff --git a/src/demangler_test.cc b/src/demangler_test.cc
index f752bf8..e1ccdbf 100644
--- a/src/demangler_test.cc
+++ b/src/demangler_test.cc
@@ -23,7 +23,7 @@
 using DemanglerTest = testing::Test;
 
 TEST_F(DemanglerTest, NoSymbols) {
-  SymbolTable t;
+  SymbolTable t{ProgramID::New()};
   t.Register("sym1");
 
   Demangler d;
@@ -31,7 +31,7 @@
 }
 
 TEST_F(DemanglerTest, Symbol) {
-  SymbolTable t;
+  SymbolTable t{ProgramID::New()};
   t.Register("sym1");
 
   Demangler d;
@@ -39,7 +39,7 @@
 }
 
 TEST_F(DemanglerTest, MultipleSymbols) {
-  SymbolTable t;
+  SymbolTable t{ProgramID::New()};
   t.Register("sym1");
   t.Register("sym2");
 
diff --git a/src/program.h b/src/program.h
index 27cb929..b7e5aaf 100644
--- a/src/program.h
+++ b/src/program.h
@@ -165,7 +165,7 @@
   SemNodeAllocator sem_nodes_;
   ast::Module* ast_ = nullptr;
   semantic::Info sem_;
-  SymbolTable symbols_;
+  SymbolTable symbols_{id_};
   diag::List diagnostics_;
   bool is_valid_ = false;  // Not valid until it is built
   bool moved_ = false;
diff --git a/src/program_builder.h b/src/program_builder.h
index 716c9f9..6919904 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -1435,7 +1435,7 @@
   SemNodeAllocator sem_nodes_;
   ast::Module* ast_;
   semantic::Info sem_;
-  SymbolTable symbols_;
+  SymbolTable symbols_{id_};
   diag::List diagnostics_;
 
   /// The source to use when creating AST nodes without providing a Source as
diff --git a/src/scope_stack_test.cc b/src/scope_stack_test.cc
index c373359..85bd86b 100644
--- a/src/scope_stack_test.cc
+++ b/src/scope_stack_test.cc
@@ -23,7 +23,7 @@
 
 TEST_F(ScopeStackTest, Global) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1);
+  Symbol sym(1, ID());
   s.set_global(sym, 5);
 
   uint32_t val = 0;
@@ -43,7 +43,7 @@
 
 TEST_F(ScopeStackTest, Global_CanNotPop) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1);
+  Symbol sym(1, ID());
   s.set_global(sym, 5);
   s.pop_scope();
 
@@ -54,7 +54,7 @@
 
 TEST_F(ScopeStackTest, Scope) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1);
+  Symbol sym(1, ID());
   s.push_scope();
   s.set(sym, 5);
 
@@ -65,7 +65,7 @@
 
 TEST_F(ScopeStackTest, Get_MissingSymbol) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1);
+  Symbol sym(1, ID());
   uint32_t ret = 0;
   EXPECT_FALSE(s.get(sym, &ret));
   EXPECT_EQ(ret, 0u);
@@ -73,8 +73,8 @@
 
 TEST_F(ScopeStackTest, Has) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1);
-  Symbol sym2(2);
+  Symbol sym(1, ID());
+  Symbol sym2(2, ID());
   s.set_global(sym2, 3);
   s.push_scope();
   s.set(sym, 5);
@@ -85,7 +85,7 @@
 
 TEST_F(ScopeStackTest, ReturnsScopeBeforeGlobalFirst) {
   ScopeStack<uint32_t> s;
-  Symbol sym(1);
+  Symbol sym(1, ID());
   s.set_global(sym, 3);
   s.push_scope();
   s.set(sym, 5);
diff --git a/src/semantic/sem_call_target.cc b/src/semantic/sem_call_target.cc
index 263c797..5f78a3e 100644
--- a/src/semantic/sem_call_target.cc
+++ b/src/semantic/sem_call_target.cc
@@ -67,7 +67,7 @@
 }
 
 std::ostream& operator<<(std::ostream& out, Parameter parameter) {
-  out << "[type: " << parameter.type->FriendlyName(SymbolTable{})
+  out << "[type: " << parameter.type->FriendlyName(SymbolTable{ProgramID{}})
       << ", usage: " << str(parameter.usage) << "]";
   return out;
 }
diff --git a/src/symbol.cc b/src/symbol.cc
index 2f4338e..13db168 100644
--- a/src/symbol.cc
+++ b/src/symbol.cc
@@ -18,7 +18,8 @@
 
 Symbol::Symbol() = default;
 
-Symbol::Symbol(uint32_t val) : val_(val) {}
+Symbol::Symbol(uint32_t val, tint::ProgramID program_id)
+    : val_(val), program_id_(program_id) {}
 
 Symbol::Symbol(const Symbol& o) = default;
 
@@ -31,6 +32,7 @@
 Symbol& Symbol::operator=(Symbol&& o) = default;
 
 bool Symbol::operator==(const Symbol& other) const {
+  TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(program_id_, other.program_id_);
   return val_ == other.val_;
 }
 
diff --git a/src/symbol.h b/src/symbol.h
index 3a92d6f..286e20d 100644
--- a/src/symbol.h
+++ b/src/symbol.h
@@ -17,6 +17,8 @@
 
 #include <string>
 
+#include "src/program_id.h"
+
 namespace tint {
 
 /// A symbol representing a string in the system
@@ -27,7 +29,8 @@
   Symbol();
   /// Constructor
   /// @param val the symbol value
-  explicit Symbol(uint32_t val);
+  /// @param program_id the identifier of the program that owns this Symbol
+  Symbol(uint32_t val, tint::ProgramID program_id);
   /// Copy constructor
   /// @param o the symbol to copy
   Symbol(const Symbol& o);
@@ -61,10 +64,20 @@
   /// @return the string representation of the symbol
   std::string to_str() const;
 
+  /// @returns the identifier of the Program that owns this symbol.
+  tint::ProgramID ProgramID() const { return program_id_; }
+
  private:
   uint32_t val_ = static_cast<uint32_t>(-1);
+  tint::ProgramID program_id_;
 };
 
+/// @param sym the Symbol
+/// @returns the ProgramID that owns the given Symbol
+inline ProgramID ProgramIDOf(Symbol sym) {
+  return sym.IsValid() ? sym.ProgramID() : ProgramID();
+}
+
 }  // namespace tint
 
 namespace std {
diff --git a/src/symbol_table.cc b/src/symbol_table.cc
index 5be19f3..a667e7d 100644
--- a/src/symbol_table.cc
+++ b/src/symbol_table.cc
@@ -14,9 +14,12 @@
 
 #include "src/symbol_table.h"
 
+#include "src/debug.h"
+
 namespace tint {
 
-SymbolTable::SymbolTable() = default;
+SymbolTable::SymbolTable(tint::ProgramID program_id)
+    : program_id_(program_id) {}
 
 SymbolTable::SymbolTable(const SymbolTable&) = default;
 
@@ -36,7 +39,7 @@
   if (it != name_to_symbol_.end())
     return it->second;
 
-  Symbol sym(next_symbol_);
+  Symbol sym(next_symbol_, program_id_);
   ++next_symbol_;
 
   name_to_symbol_[name] = sym;
@@ -51,6 +54,7 @@
 }
 
 std::string SymbolTable::NameFor(const Symbol symbol) const {
+  TINT_ASSERT_PROGRAM_IDS_EQUAL(program_id_, symbol);
   auto it = symbol_to_name_.find(symbol);
   if (it == symbol_to_name_.end()) {
     return symbol.to_str();
diff --git a/src/symbol_table.h b/src/symbol_table.h
index 61940b9..10248c1 100644
--- a/src/symbol_table.h
+++ b/src/symbol_table.h
@@ -26,7 +26,9 @@
 class SymbolTable {
  public:
   /// Constructor
-  SymbolTable();
+  /// @param program_id the identifier of the program that owns this symbol
+  /// table
+  explicit SymbolTable(tint::ProgramID program_id);
   /// Copy constructor
   SymbolTable(const SymbolTable&);
   /// Move Constructor
@@ -76,14 +78,24 @@
     }
   }
 
+  /// @returns the identifier of the Program that owns this symbol table.
+  tint::ProgramID ProgramID() const { return program_id_; }
+
  private:
   // The value to be associated to the next registered symbol table entry.
   uint32_t next_symbol_ = 1;
 
   std::unordered_map<Symbol, std::string> symbol_to_name_;
   std::unordered_map<std::string, Symbol> name_to_symbol_;
+  tint::ProgramID program_id_;
 };
 
+/// @param symbol_table the SymbolTable
+/// @returns the ProgramID that owns the given SymbolTable
+inline ProgramID ProgramIDOf(const SymbolTable& symbol_table) {
+  return symbol_table.ProgramID();
+}
+
 }  // namespace tint
 
 #endif  // SRC_SYMBOL_TABLE_H_
diff --git a/src/symbol_table_test.cc b/src/symbol_table_test.cc
index f3dda91..11741b8 100644
--- a/src/symbol_table_test.cc
+++ b/src/symbol_table_test.cc
@@ -22,31 +22,36 @@
 using SymbolTableTest = testing::Test;
 
 TEST_F(SymbolTableTest, GeneratesSymbolForName) {
-  SymbolTable s;
-  EXPECT_EQ(Symbol(1), s.Register("name"));
-  EXPECT_EQ(Symbol(2), s.Register("another_name"));
+  auto program_id = ProgramID::New();
+  SymbolTable s{program_id};
+  EXPECT_EQ(Symbol(1, program_id), s.Register("name"));
+  EXPECT_EQ(Symbol(2, program_id), s.Register("another_name"));
 }
 
 TEST_F(SymbolTableTest, DeduplicatesNames) {
-  SymbolTable s;
-  EXPECT_EQ(Symbol(1), s.Register("name"));
-  EXPECT_EQ(Symbol(2), s.Register("another_name"));
-  EXPECT_EQ(Symbol(1), s.Register("name"));
+  auto program_id = ProgramID::New();
+  SymbolTable s{program_id};
+  EXPECT_EQ(Symbol(1, program_id), s.Register("name"));
+  EXPECT_EQ(Symbol(2, program_id), s.Register("another_name"));
+  EXPECT_EQ(Symbol(1, program_id), s.Register("name"));
 }
 
 TEST_F(SymbolTableTest, ReturnsNameForSymbol) {
-  SymbolTable s;
+  auto program_id = ProgramID::New();
+  SymbolTable s{program_id};
   auto sym = s.Register("name");
   EXPECT_EQ("name", s.NameFor(sym));
 }
 
 TEST_F(SymbolTableTest, ReturnsBlankForMissingSymbol) {
-  SymbolTable s;
-  EXPECT_EQ("$2", s.NameFor(Symbol(2)));
+  auto program_id = ProgramID::New();
+  SymbolTable s{program_id};
+  EXPECT_EQ("$2", s.NameFor(Symbol(2, program_id)));
 }
 
 TEST_F(SymbolTableTest, ReturnsInvalidForBlankString) {
-  SymbolTable s;
+  auto program_id = ProgramID::New();
+  SymbolTable s{program_id};
   EXPECT_FALSE(s.Register("").IsValid());
 }
 
diff --git a/src/symbol_test.cc b/src/symbol_test.cc
index a69d2cc..3c13a92 100644
--- a/src/symbol_test.cc
+++ b/src/symbol_test.cc
@@ -22,12 +22,12 @@
 using SymbolTest = testing::Test;
 
 TEST_F(SymbolTest, ToStr) {
-  Symbol sym(1);
+  Symbol sym(1, ProgramID::New());
   EXPECT_EQ("$1", sym.to_str());
 }
 
 TEST_F(SymbolTest, CopyAssign) {
-  Symbol sym1(1);
+  Symbol sym1(1, ProgramID::New());
   Symbol sym2;
 
   EXPECT_FALSE(sym2.IsValid());
@@ -37,9 +37,10 @@
 }
 
 TEST_F(SymbolTest, Comparison) {
-  Symbol sym1(1);
-  Symbol sym2(2);
-  Symbol sym3(1);
+  auto program_id = ProgramID::New();
+  Symbol sym1(1, program_id);
+  Symbol sym2(2, program_id);
+  Symbol sym3(1, program_id);
 
   EXPECT_TRUE(sym1 == sym3);
   EXPECT_FALSE(sym1 == sym2);
diff --git a/src/type/alias_type_test.cc b/src/type/alias_type_test.cc
index 6e4354c..cb7e51e 100644
--- a/src/type/alias_type_test.cc
+++ b/src/type/alias_type_test.cc
@@ -24,7 +24,7 @@
 
 TEST_F(AliasTest, Create) {
   auto* a = ty.alias("a_type", ty.u32());
-  EXPECT_EQ(a->symbol(), Symbol(1));
+  EXPECT_EQ(a->symbol(), Symbol(1, ID()));
   EXPECT_EQ(a->type(), ty.u32());
 }
 
@@ -58,7 +58,7 @@
 
 TEST_F(AliasTest, UnwrapIfNeeded_Alias) {
   auto* a = ty.alias("a_type", ty.u32());
-  EXPECT_EQ(a->symbol(), Symbol(1));
+  EXPECT_EQ(a->symbol(), Symbol(1, ID()));
   EXPECT_EQ(a->type(), ty.u32());
   EXPECT_EQ(a->UnwrapIfNeeded(), ty.u32());
   EXPECT_EQ(ty.u32()->UnwrapIfNeeded(), ty.u32());
@@ -74,7 +74,7 @@
   auto* a = ty.alias("a_type", ty.u32());
   auto* aa = ty.alias("aa_type", a);
 
-  EXPECT_EQ(aa->symbol(), Symbol(2));
+  EXPECT_EQ(aa->symbol(), Symbol(2, ID()));
   EXPECT_EQ(aa->type(), a);
   EXPECT_EQ(aa->UnwrapIfNeeded(), ty.u32());
 }
@@ -94,7 +94,7 @@
   auto* apaa = ty.alias("paa_type", &paa);
   auto* aapaa = ty.alias("aapaa_type", apaa);
 
-  EXPECT_EQ(aapaa->symbol(), Symbol(4));
+  EXPECT_EQ(aapaa->symbol(), Symbol(4, ID()));
   EXPECT_EQ(aapaa->type(), apaa);
   EXPECT_EQ(aapaa->UnwrapAll(), ty.u32());
 }