ast: Validate that ASTs are all part of the same program
Assert in each AST constructor that child nodes belong to the program of the parent.
Bug: tint:709
Change-Id: Icc89b69691d099e358ff632a0ca6fd7943cb0193
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47623
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/array_accessor_expression.cc b/src/ast/array_accessor_expression.cc
index 212350b..c8d64d8 100644
--- a/src/ast/array_accessor_expression.cc
+++ b/src/ast/array_accessor_expression.cc
@@ -27,7 +27,9 @@
Expression* idx_expr)
: Base(program_id, source), array_(array), idx_expr_(idx_expr) {
TINT_ASSERT(array_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(array_, program_id);
TINT_ASSERT(idx_expr_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(idx_expr_, program_id);
}
ArrayAccessorExpression::ArrayAccessorExpression(ArrayAccessorExpression&&) =
diff --git a/src/ast/array_accessor_expression_test.cc b/src/ast/array_accessor_expression_test.cc
index 154cc69..ec2ad4b 100644
--- a/src/ast/array_accessor_expression_test.cc
+++ b/src/ast/array_accessor_expression_test.cc
@@ -49,7 +49,7 @@
EXPECT_TRUE(exp->Is<ArrayAccessorExpression>());
}
-TEST_F(ArrayAccessorExpressionTest, Assert_NullArray) {
+TEST_F(ArrayAccessorExpressionTest, Assert_Null_Array) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -58,7 +58,7 @@
"internal compiler error");
}
-TEST_F(ArrayAccessorExpressionTest, Assert_NullIndex) {
+TEST_F(ArrayAccessorExpressionTest, Assert_Null_Index) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -67,6 +67,26 @@
"internal compiler error");
}
+TEST_F(ArrayAccessorExpressionTest, Assert_DifferentProgramID_Array) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ArrayAccessorExpression>(b2.Expr("arr"), b1.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(ArrayAccessorExpressionTest, Assert_DifferentProgramID_Index) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ArrayAccessorExpression>(b1.Expr("arr"), b2.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
TEST_F(ArrayAccessorExpressionTest, ToStr) {
auto* ary = Expr("ary");
auto* idx = Expr("idx");
diff --git a/src/ast/assignment_statement.cc b/src/ast/assignment_statement.cc
index 68c2cc9..a83b3a3 100644
--- a/src/ast/assignment_statement.cc
+++ b/src/ast/assignment_statement.cc
@@ -27,7 +27,9 @@
Expression* rhs)
: Base(program_id, source), lhs_(lhs), rhs_(rhs) {
TINT_ASSERT(lhs_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(lhs_, program_id);
TINT_ASSERT(rhs_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(rhs_, program_id);
}
AssignmentStatement::AssignmentStatement(AssignmentStatement&&) = default;
diff --git a/src/ast/assignment_statement_test.cc b/src/ast/assignment_statement_test.cc
index 9ac7882..80d324e 100644
--- a/src/ast/assignment_statement_test.cc
+++ b/src/ast/assignment_statement_test.cc
@@ -51,7 +51,7 @@
EXPECT_TRUE(stmt->Is<AssignmentStatement>());
}
-TEST_F(AssignmentStatementTest, Assert_NullLHS) {
+TEST_F(AssignmentStatementTest, Assert_Null_LHS) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -60,7 +60,7 @@
"internal compiler error");
}
-TEST_F(AssignmentStatementTest, Assert_NullRHS) {
+TEST_F(AssignmentStatementTest, Assert_Null_RHS) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -69,6 +69,26 @@
"internal compiler error");
}
+TEST_F(AssignmentStatementTest, Assert_DifferentProgramID_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<AssignmentStatement>(b2.Expr("lhs"), b1.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(AssignmentStatementTest, Assert_DifferentProgramID_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<AssignmentStatement>(b1.Expr("lhs"), b2.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
TEST_F(AssignmentStatementTest, ToStr) {
auto* lhs = Expr("lhs");
auto* rhs = Expr("rhs");
diff --git a/src/ast/binary_expression.cc b/src/ast/binary_expression.cc
index 63a3320..967fc62 100644
--- a/src/ast/binary_expression.cc
+++ b/src/ast/binary_expression.cc
@@ -28,7 +28,9 @@
Expression* rhs)
: Base(program_id, source), op_(op), lhs_(lhs), rhs_(rhs) {
TINT_ASSERT(lhs_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(lhs_, program_id);
TINT_ASSERT(rhs_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(rhs_, program_id);
TINT_ASSERT(op_ != BinaryOp::kNone);
}
diff --git a/src/ast/binary_expression_test.cc b/src/ast/binary_expression_test.cc
index 8109e41..d5ef079 100644
--- a/src/ast/binary_expression_test.cc
+++ b/src/ast/binary_expression_test.cc
@@ -50,7 +50,7 @@
EXPECT_TRUE(r->Is<BinaryExpression>());
}
-TEST_F(BinaryExpressionTest, IsValid_Null_LHS) {
+TEST_F(BinaryExpressionTest, Assert_Null_LHS) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -59,7 +59,7 @@
"internal compiler error");
}
-TEST_F(BinaryExpressionTest, IsValid_Null_RHS) {
+TEST_F(BinaryExpressionTest, Assert_Null_RHS) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -68,6 +68,28 @@
"internal compiler error");
}
+TEST_F(BinaryExpressionTest, Assert_DifferentProgramID_LHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BinaryExpression>(BinaryOp::kEqual, b2.Expr("lhs"),
+ b1.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(BinaryExpressionTest, Assert_DifferentProgramID_RHS) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BinaryExpression>(BinaryOp::kEqual, b1.Expr("lhs"),
+ b2.Expr("rhs"));
+ },
+ "internal compiler error");
+}
+
TEST_F(BinaryExpressionTest, ToStr) {
auto* lhs = Expr("lhs");
auto* rhs = Expr("rhs");
diff --git a/src/ast/bitcast_expression.cc b/src/ast/bitcast_expression.cc
index c117c3c..3cfde4d 100644
--- a/src/ast/bitcast_expression.cc
+++ b/src/ast/bitcast_expression.cc
@@ -28,6 +28,7 @@
: Base(program_id, source), type_(type), expr_(expr) {
TINT_ASSERT(type_);
TINT_ASSERT(expr_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(expr, program_id);
}
BitcastExpression::BitcastExpression(BitcastExpression&&) = default;
diff --git a/src/ast/bitcast_expression_test.cc b/src/ast/bitcast_expression_test.cc
index 754fa36..573721c 100644
--- a/src/ast/bitcast_expression_test.cc
+++ b/src/ast/bitcast_expression_test.cc
@@ -48,7 +48,7 @@
EXPECT_TRUE(exp->Is<BitcastExpression>());
}
-TEST_F(BitcastExpressionTest, Assert_NullType) {
+TEST_F(BitcastExpressionTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -57,7 +57,7 @@
"internal compiler error");
}
-TEST_F(BitcastExpressionTest, Assert_NullExpr) {
+TEST_F(BitcastExpressionTest, Assert_Null_Expr) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -66,6 +66,16 @@
"internal compiler error");
}
+TEST_F(BitcastExpressionTest, Assert_DifferentProgramID_Expr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BitcastExpression>(b1.ty.f32(), b2.Expr("idx"));
+ },
+ "internal compiler error");
+}
+
TEST_F(BitcastExpressionTest, ToStr) {
auto* expr = Expr("expr");
diff --git a/src/ast/block_statement.cc b/src/ast/block_statement.cc
index 38fdbbf..cae5e03 100644
--- a/src/ast/block_statement.cc
+++ b/src/ast/block_statement.cc
@@ -27,6 +27,7 @@
: Base(program_id, source), statements_(std::move(statements)) {
for (auto* stmt : *this) {
TINT_ASSERT(stmt);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(stmt, program_id);
}
}
diff --git a/src/ast/block_statement_test.cc b/src/ast/block_statement_test.cc
index 4a5b2dd..6cca315 100644
--- a/src/ast/block_statement_test.cc
+++ b/src/ast/block_statement_test.cc
@@ -46,7 +46,7 @@
EXPECT_TRUE(b->Is<BlockStatement>());
}
-TEST_F(BlockStatementTest, Assert_NullStatement) {
+TEST_F(BlockStatementTest, Assert_Null_Statement) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -55,6 +55,17 @@
"internal compiler error");
}
+TEST_F(BlockStatementTest, Assert_DifferentProgramID_Statement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<BlockStatement>(
+ ast::StatementList{b2.create<DiscardStatement>()});
+ },
+ "internal compiler error");
+}
+
TEST_F(BlockStatementTest, ToStr) {
auto* b = create<BlockStatement>(ast::StatementList{
create<DiscardStatement>(),
diff --git a/src/ast/call_expression.cc b/src/ast/call_expression.cc
index 26908eb..47cb1f5 100644
--- a/src/ast/call_expression.cc
+++ b/src/ast/call_expression.cc
@@ -27,8 +27,10 @@
ExpressionList params)
: Base(program_id, source), func_(func), params_(params) {
TINT_ASSERT(func_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(func_, program_id);
for (auto* param : params_) {
TINT_ASSERT(param);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(param, program_id);
}
}
diff --git a/src/ast/call_expression_test.cc b/src/ast/call_expression_test.cc
index 4040153..35299eb 100644
--- a/src/ast/call_expression_test.cc
+++ b/src/ast/call_expression_test.cc
@@ -51,7 +51,7 @@
EXPECT_TRUE(stmt->Is<CallExpression>());
}
-TEST_F(CallExpressionTest, Assert_NullFunction) {
+TEST_F(CallExpressionTest, Assert_Null_Function) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -60,7 +60,7 @@
"internal compiler error");
}
-TEST_F(CallExpressionTest, Assert_NullParam) {
+TEST_F(CallExpressionTest, Assert_Null_Param) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -73,6 +73,27 @@
"internal compiler error");
}
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Function) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CallExpression>(b2.Expr("func"), ExpressionList{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CallExpression>(b1.Expr("func"),
+ ExpressionList{b2.Expr("param1")});
+ },
+ "internal compiler error");
+}
+
TEST_F(CallExpressionTest, ToStr_NoParams) {
auto* func = Expr("func");
auto* stmt = create<CallExpression>(func, ExpressionList{});
diff --git a/src/ast/call_statement.cc b/src/ast/call_statement.cc
index 4432e1b..a6a838d 100644
--- a/src/ast/call_statement.cc
+++ b/src/ast/call_statement.cc
@@ -26,6 +26,7 @@
CallExpression* call)
: Base(program_id, source), call_(call) {
TINT_ASSERT(call_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(call_, program_id);
}
CallStatement::CallStatement(CallStatement&&) = default;
diff --git a/src/ast/call_statement_test.cc b/src/ast/call_statement_test.cc
index 092dc54..1112ba4 100644
--- a/src/ast/call_statement_test.cc
+++ b/src/ast/call_statement_test.cc
@@ -35,7 +35,7 @@
EXPECT_TRUE(c->Is<CallStatement>());
}
-TEST_F(CallStatementTest, Assert_NullCall) {
+TEST_F(CallStatementTest, Assert_Null_Call) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -44,6 +44,17 @@
"internal compiler error");
}
+TEST_F(CallStatementTest, Assert_DifferentProgramID_Call) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CallStatement>(
+ b2.create<CallExpression>(b2.Expr("func"), ExpressionList{}));
+ },
+ "internal compiler error");
+}
+
TEST_F(CallStatementTest, ToStr) {
auto* c = create<CallStatement>(
create<CallExpression>(Expr("func"), ExpressionList{}));
diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc
index eeba6c9..23e1564 100644
--- a/src/ast/case_statement.cc
+++ b/src/ast/case_statement.cc
@@ -27,6 +27,11 @@
BlockStatement* body)
: Base(program_id, source), selectors_(selectors), body_(body) {
TINT_ASSERT(body_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body_, program_id);
+ for (auto* selector : selectors) {
+ TINT_ASSERT(selector);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(selector, program_id);
+ }
}
CaseStatement::CaseStatement(CaseStatement&&) = default;
diff --git a/src/ast/case_statement_test.cc b/src/ast/case_statement_test.cc
index 6e64f39..57bfa72 100644
--- a/src/ast/case_statement_test.cc
+++ b/src/ast/case_statement_test.cc
@@ -90,7 +90,7 @@
EXPECT_TRUE(c->Is<CaseStatement>());
}
-TEST_F(CaseStatementTest, Assert_NullBody) {
+TEST_F(CaseStatementTest, Assert_Null_Body) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -99,6 +99,39 @@
"internal compiler error");
}
+TEST_F(CaseStatementTest, Assert_Null_Selector) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<CaseStatement>(CaseSelectorList{nullptr},
+ b.create<BlockStatement>(StatementList{}));
+ },
+ "internal compiler error");
+}
+
+TEST_F(CaseStatementTest, Assert_DifferentProgramID_Call) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CaseStatement>(CaseSelectorList{},
+ b2.create<BlockStatement>(StatementList{}));
+ },
+ "internal compiler error");
+}
+
+TEST_F(CaseStatementTest, Assert_DifferentProgramID_Selector) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CaseStatement>(
+ CaseSelectorList{b2.create<SintLiteral>(b2.ty.i32(), 2)},
+ b1.create<BlockStatement>(StatementList{}));
+ },
+ "internal compiler error");
+}
+
TEST_F(CaseStatementTest, ToStr_WithSelectors_i32) {
CaseSelectorList b;
b.push_back(create<SintLiteral>(ty.i32(), -2));
diff --git a/src/ast/else_statement.cc b/src/ast/else_statement.cc
index 38955c4..7df7e74 100644
--- a/src/ast/else_statement.cc
+++ b/src/ast/else_statement.cc
@@ -27,6 +27,8 @@
BlockStatement* body)
: Base(program_id, source), condition_(condition), body_(body) {
TINT_ASSERT(body_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body_, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(condition_, program_id);
}
ElseStatement::ElseStatement(ElseStatement&&) = default;
diff --git a/src/ast/else_statement_test.cc b/src/ast/else_statement_test.cc
index 175b09c..5036e80 100644
--- a/src/ast/else_statement_test.cc
+++ b/src/ast/else_statement_test.cc
@@ -38,32 +38,29 @@
TEST_F(ElseStatementTest, Creation_WithSource) {
auto* e = create<ElseStatement>(Source{Source::Location{20, 2}}, Expr(true),
- create<BlockStatement>(StatementList{}));
+ Block());
auto src = e->source();
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
}
TEST_F(ElseStatementTest, IsElse) {
- auto* e =
- create<ElseStatement>(nullptr, create<BlockStatement>(StatementList{}));
+ auto* e = create<ElseStatement>(nullptr, Block());
EXPECT_TRUE(e->Is<ElseStatement>());
}
TEST_F(ElseStatementTest, HasCondition) {
auto* cond = Expr(true);
- auto* e =
- create<ElseStatement>(cond, create<BlockStatement>(StatementList{}));
+ auto* e = create<ElseStatement>(cond, Block());
EXPECT_TRUE(e->HasCondition());
}
TEST_F(ElseStatementTest, HasContition_NullCondition) {
- auto* e =
- create<ElseStatement>(nullptr, create<BlockStatement>(StatementList{}));
+ auto* e = create<ElseStatement>(nullptr, Block());
EXPECT_FALSE(e->HasCondition());
}
-TEST_F(ElseStatementTest, Assert_NullBody) {
+TEST_F(ElseStatementTest, Assert_Null_Body) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -72,6 +69,26 @@
"internal compiler error");
}
+TEST_F(ElseStatementTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ElseStatement>(b2.Expr(true), b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(ElseStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ElseStatement>(b1.Expr(true), b2.Block());
+ },
+ "internal compiler error");
+}
+
TEST_F(ElseStatementTest, ToStr) {
auto* cond = Expr(true);
auto* body = create<BlockStatement>(StatementList{
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 8efbb64..b656d6b 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -38,11 +38,19 @@
body_(body),
decorations_(std::move(decorations)),
return_type_decorations_(std::move(return_type_decorations)) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body, program_id);
for (auto* param : params_) {
TINT_ASSERT(param && param->is_const());
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(param, program_id);
}
TINT_ASSERT(symbol_.IsValid());
TINT_ASSERT(return_type_);
+ for (auto* deco : decorations_) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
+ }
+ for (auto* deco : return_type_decorations_) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
+ }
}
Function::Function(Function&&) = default;
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index a9bdf0e..1ec2fbb 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -57,7 +57,7 @@
"internal compiler error");
}
-TEST_F(FunctionTest, Assert_NullReturnType) {
+TEST_F(FunctionTest, Assert_Null_ReturnType) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -66,7 +66,7 @@
"internal compiler error");
}
-TEST_F(FunctionTest, Assert_NullParam) {
+TEST_F(FunctionTest, Assert_Null_Param) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -79,6 +79,44 @@
"internal compiler error");
}
+TEST_F(FunctionTest, Assert_DifferentProgramID_Param) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func", VariableList{b2.Param("var", b2.ty.i32())},
+ b1.ty.void_(), StatementList{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_Deco) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
+ DecorationList{
+ b2.create<WorkgroupDecoration>(2, 4, 6),
+ });
+ },
+ "internal compiler error");
+}
+
+TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnDeco) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
+ DecorationList{},
+ DecorationList{
+ b2.create<WorkgroupDecoration>(2, 4, 6),
+ });
+ },
+ "internal compiler error");
+}
+
TEST_F(FunctionTest, Assert_NonConstParam) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/ast/if_statement.cc b/src/ast/if_statement.cc
index 3080ca4..d7c00c5 100644
--- a/src/ast/if_statement.cc
+++ b/src/ast/if_statement.cc
@@ -31,9 +31,12 @@
body_(body),
else_statements_(std::move(else_stmts)) {
TINT_ASSERT(condition_);
- TINT_ASSERT(body);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(condition_, program_id);
+ TINT_ASSERT(body_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body_, program_id);
for (auto* el : else_statements_) {
TINT_ASSERT(el);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(el, program_id);
}
}
diff --git a/src/ast/if_statement_test.cc b/src/ast/if_statement_test.cc
index dcc3342..c19ec8b 100644
--- a/src/ast/if_statement_test.cc
+++ b/src/ast/if_statement_test.cc
@@ -26,9 +26,8 @@
TEST_F(IfStatementTest, Creation) {
auto* cond = Expr("cond");
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
- auto* stmt = create<IfStatement>(Source{Source::Location{20, 2}}, cond, body,
+ auto* stmt = create<IfStatement>(Source{Source::Location{20, 2}}, cond,
+ Block(create<DiscardStatement>()),
ElseStatementList{});
auto src = stmt->source();
EXPECT_EQ(src.range.begin.line, 20u);
@@ -36,22 +35,20 @@
}
TEST_F(IfStatementTest, IsIf) {
- auto* stmt = create<IfStatement>(
- Expr(true), create<BlockStatement>(StatementList{}), ElseStatementList{});
+ auto* stmt = create<IfStatement>(Expr(true), Block(), ElseStatementList{});
EXPECT_TRUE(stmt->Is<IfStatement>());
}
-TEST_F(IfStatementTest, Assert_NullCondition) {
+TEST_F(IfStatementTest, Assert_Null_Condition) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- auto* body = b.create<BlockStatement>(StatementList{});
- b.create<IfStatement>(nullptr, body, ElseStatementList{});
+ b.create<IfStatement>(nullptr, b.Block(), ElseStatementList{});
},
"internal compiler error");
}
-TEST_F(IfStatementTest, Assert_NullBody) {
+TEST_F(IfStatementTest, Assert_Null_Body) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -60,7 +57,7 @@
"internal compiler error");
}
-TEST_F(IfStatementTest, Assert_NullElseStatement) {
+TEST_F(IfStatementTest, Assert_Null_ElseStatement) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -70,11 +67,44 @@
"internal compiler error");
}
+TEST_F(IfStatementTest, Assert_DifferentProgramID_Cond) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<IfStatement>(b2.Expr(true), b1.Block(), ElseStatementList{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<IfStatement>(b1.Expr(true), b2.Block(), ElseStatementList{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(IfStatementTest, Assert_DifferentProgramID_ElseStatement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<IfStatement>(
+ b1.Expr(true), b1.Block(),
+ ElseStatementList{
+ b2.create<ElseStatement>(b2.Expr("ident"), b2.Block()),
+ });
+ },
+ "internal compiler error");
+}
+
TEST_F(IfStatementTest, ToStr) {
auto* cond = Expr("cond");
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
- auto* stmt = create<IfStatement>(cond, body, ElseStatementList{});
+ auto* stmt = create<IfStatement>(cond, Block(create<DiscardStatement>()),
+ ElseStatementList{});
EXPECT_EQ(str(stmt), R"(If{
(
@@ -89,12 +119,10 @@
TEST_F(IfStatementTest, ToStr_WithElseStatements) {
auto* cond = Expr("cond");
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
- auto* else_if_body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
- auto* else_body = create<BlockStatement>(
- StatementList{create<DiscardStatement>(), create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
+ auto* else_if_body = Block(create<DiscardStatement>());
+ auto* else_body =
+ Block(create<DiscardStatement>(), create<DiscardStatement>());
auto* stmt = create<IfStatement>(
cond, body,
ElseStatementList{
diff --git a/src/ast/loop_statement.cc b/src/ast/loop_statement.cc
index 10e5208..281907d 100644
--- a/src/ast/loop_statement.cc
+++ b/src/ast/loop_statement.cc
@@ -27,6 +27,8 @@
BlockStatement* continuing)
: Base(program_id, source), body_(body), continuing_(continuing) {
TINT_ASSERT(body_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(body_, program_id);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(continuing_, program_id);
}
LoopStatement::LoopStatement(LoopStatement&&) = default;
diff --git a/src/ast/loop_statement_test.cc b/src/ast/loop_statement_test.cc
index f1e42b1..548515e 100644
--- a/src/ast/loop_statement_test.cc
+++ b/src/ast/loop_statement_test.cc
@@ -26,12 +26,10 @@
using LoopStatementTest = TestHelper;
TEST_F(LoopStatementTest, Creation) {
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
auto* b = body->last();
- auto* continuing =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* continuing = Block(create<DiscardStatement>());
auto* l = create<LoopStatement>(body, continuing);
ASSERT_EQ(l->body()->size(), 1u);
@@ -41,11 +39,9 @@
}
TEST_F(LoopStatementTest, Creation_WithSource) {
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
- auto* continuing =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* continuing = Block(create<DiscardStatement>());
auto* l =
create<LoopStatement>(Source{Source::Location{20, 2}}, body, continuing);
@@ -55,31 +51,27 @@
}
TEST_F(LoopStatementTest, IsLoop) {
- auto* l = create<LoopStatement>(create<BlockStatement>(StatementList{}),
- create<BlockStatement>(StatementList{}));
+ auto* l = create<LoopStatement>(Block(), Block());
EXPECT_TRUE(l->Is<LoopStatement>());
}
TEST_F(LoopStatementTest, HasContinuing_WithoutContinuing) {
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
auto* l = create<LoopStatement>(body, nullptr);
EXPECT_FALSE(l->has_continuing());
}
TEST_F(LoopStatementTest, HasContinuing_WithContinuing) {
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
- auto* continuing =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* continuing = Block(create<DiscardStatement>());
auto* l = create<LoopStatement>(body, continuing);
EXPECT_TRUE(l->has_continuing());
}
-TEST_F(LoopStatementTest, Assert_NullBody) {
+TEST_F(LoopStatementTest, Assert_Null_Body) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -88,9 +80,28 @@
"internal compiler error");
}
+TEST_F(LoopStatementTest, Assert_DifferentProgramID_Body) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<LoopStatement>(b2.Block(), b1.Block());
+ },
+ "internal compiler error");
+}
+
+TEST_F(LoopStatementTest, Assert_DifferentProgramID_Continuing) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<LoopStatement>(b1.Block(), b2.Block());
+ },
+ "internal compiler error");
+}
+
TEST_F(LoopStatementTest, ToStr) {
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
auto* l = create<LoopStatement>(body, nullptr);
EXPECT_EQ(str(l), R"(Loop{
@@ -100,11 +111,9 @@
}
TEST_F(LoopStatementTest, ToStr_WithContinuing) {
- auto* body =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* body = Block(create<DiscardStatement>());
- auto* continuing =
- create<BlockStatement>(StatementList{create<DiscardStatement>()});
+ auto* continuing = Block(create<DiscardStatement>());
auto* l = create<LoopStatement>(body, continuing);
EXPECT_EQ(str(l), R"(Loop{
diff --git a/src/ast/member_accessor_expression.cc b/src/ast/member_accessor_expression.cc
index 862d6f0..83bb539 100644
--- a/src/ast/member_accessor_expression.cc
+++ b/src/ast/member_accessor_expression.cc
@@ -26,8 +26,10 @@
Expression* structure,
IdentifierExpression* member)
: Base(program_id, source), struct_(structure), member_(member) {
- TINT_ASSERT(structure);
- TINT_ASSERT(member);
+ TINT_ASSERT(struct_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(struct_, program_id);
+ TINT_ASSERT(member_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(member_, program_id);
}
MemberAccessorExpression::MemberAccessorExpression(MemberAccessorExpression&&) =
diff --git a/src/ast/member_accessor_expression_test.cc b/src/ast/member_accessor_expression_test.cc
index 44c30d1..023ad0f 100644
--- a/src/ast/member_accessor_expression_test.cc
+++ b/src/ast/member_accessor_expression_test.cc
@@ -44,7 +44,7 @@
EXPECT_TRUE(stmt->Is<MemberAccessorExpression>());
}
-TEST_F(MemberAccessorExpressionTest, Assert_NullStruct) {
+TEST_F(MemberAccessorExpressionTest, Assert_Null_Struct) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -53,7 +53,7 @@
"internal compiler error");
}
-TEST_F(MemberAccessorExpressionTest, Assert_NullMember) {
+TEST_F(MemberAccessorExpressionTest, Assert_Null_Member) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -62,6 +62,28 @@
"internal compiler error");
}
+TEST_F(MemberAccessorExpressionTest, Assert_DifferentProgramID_Struct) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<MemberAccessorExpression>(b2.Expr("structure"),
+ b1.Expr("member"));
+ },
+ "internal compiler error");
+}
+
+TEST_F(MemberAccessorExpressionTest, Assert_DifferentProgramID_Member) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<MemberAccessorExpression>(b1.Expr("structure"),
+ b2.Expr("member"));
+ },
+ "internal compiler error");
+}
+
TEST_F(MemberAccessorExpressionTest, ToStr) {
auto* stmt =
create<MemberAccessorExpression>(Expr("structure"), Expr("member"));
diff --git a/src/ast/module.h b/src/ast/module.h
index efca69e..780988e 100644
--- a/src/ast/module.h
+++ b/src/ast/module.h
@@ -53,6 +53,7 @@
/// @param var the variable to add
void AddGlobalVariable(ast::Variable* var) {
TINT_ASSERT(var);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(var, program_id());
global_variables_.push_back(var);
global_declarations_.push_back(var);
}
@@ -81,6 +82,7 @@
/// @param func the function to add
void AddFunction(ast::Function* func) {
TINT_ASSERT(func);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(func, program_id());
functions_.push_back(func);
global_declarations_.push_back(func);
}
diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc
index 568790d..e53b9d8 100644
--- a/src/ast/module_test.cc
+++ b/src/ast/module_test.cc
@@ -73,6 +73,29 @@
"internal compiler error");
}
+TEST_F(ModuleTest, Assert_DifferentProgramID_Function) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.AST().AddFunction(b2.create<ast::Function>(
+ b2.Symbols().Register("func"), VariableList{}, b2.ty.f32(),
+ b2.Block(), DecorationList{}, DecorationList{}));
+ },
+ "internal compiler error");
+}
+
+TEST_F(ModuleTest, Assert_DifferentProgramID_GlobalVariable) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.AST().AddGlobalVariable(
+ b2.Var("var", b2.ty.i32(), ast::StorageClass::kPrivate));
+ },
+ "internal compiler error");
+}
+
TEST_F(ModuleTest, Assert_Null_Function) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/ast/node.h b/src/ast/node.h
index 827f1dc..0d4cf1e 100644
--- a/src/ast/node.h
+++ b/src/ast/node.h
@@ -78,6 +78,13 @@
};
} // namespace ast
+
+/// @param node a pointer to an AST node
+/// @returns the ProgramID of the given AST node.
+inline ProgramID ProgramIDOf(ast::Node* node) {
+ return node ? node->program_id() : ProgramID();
+}
+
} // namespace tint
#endif // SRC_AST_NODE_H_
diff --git a/src/ast/return_statement.cc b/src/ast/return_statement.cc
index a38aac0..797141b 100644
--- a/src/ast/return_statement.cc
+++ b/src/ast/return_statement.cc
@@ -27,7 +27,9 @@
ReturnStatement::ReturnStatement(ProgramID program_id,
const Source& source,
Expression* value)
- : Base(program_id, source), value_(value) {}
+ : Base(program_id, source), value_(value) {
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(value_, program_id);
+}
ReturnStatement::ReturnStatement(ReturnStatement&&) = default;
diff --git a/src/ast/return_statement_test.cc b/src/ast/return_statement_test.cc
index d55e360..89bfd21 100644
--- a/src/ast/return_statement_test.cc
+++ b/src/ast/return_statement_test.cc
@@ -14,6 +14,7 @@
#include "src/ast/return_statement.h"
+#include "gtest/gtest-spi.h"
#include "src/ast/test_helper.h"
namespace tint {
@@ -69,6 +70,16 @@
)");
}
+TEST_F(ReturnStatementTest, Assert_DifferentProgramID_Expr) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ReturnStatement>(b2.Expr(true));
+ },
+ "internal compiler error");
+}
+
} // namespace
} // namespace ast
} // namespace tint
diff --git a/src/ast/scalar_constructor_expression.cc b/src/ast/scalar_constructor_expression.cc
index 01413aa..5e74bd0 100644
--- a/src/ast/scalar_constructor_expression.cc
+++ b/src/ast/scalar_constructor_expression.cc
@@ -26,6 +26,7 @@
Literal* literal)
: Base(program_id, source), literal_(literal) {
TINT_ASSERT(literal);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(literal, program_id);
}
ScalarConstructorExpression::ScalarConstructorExpression(
diff --git a/src/ast/scalar_constructor_expression_test.cc b/src/ast/scalar_constructor_expression_test.cc
index 8baa906..ceac061 100644
--- a/src/ast/scalar_constructor_expression_test.cc
+++ b/src/ast/scalar_constructor_expression_test.cc
@@ -40,6 +40,17 @@
)");
}
+TEST_F(ScalarConstructorExpressionTest, Assert_DifferentProgramID_Literal) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<ScalarConstructorExpression>(
+ b2.create<BoolLiteral>(b2.ty.bool_(), true));
+ },
+ "internal compiler error");
+}
+
} // namespace
} // namespace ast
} // namespace tint
diff --git a/src/ast/struct.cc b/src/ast/struct.cc
index eb6d6c2..77abd2d 100644
--- a/src/ast/struct.cc
+++ b/src/ast/struct.cc
@@ -31,9 +31,11 @@
decorations_(std::move(decorations)) {
for (auto* mem : members_) {
TINT_ASSERT(mem);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(mem, program_id);
}
for (auto* deco : decorations_) {
TINT_ASSERT(deco);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
}
}
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc
index a36326e..882d0c1 100644
--- a/src/ast/struct_member.cc
+++ b/src/ast/struct_member.cc
@@ -34,6 +34,7 @@
TINT_ASSERT(symbol_.IsValid());
for (auto* deco : decorations_) {
TINT_ASSERT(deco);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(deco, program_id);
}
}
diff --git a/src/ast/struct_member_test.cc b/src/ast/struct_member_test.cc
index c760c85..458e259 100644
--- a/src/ast/struct_member_test.cc
+++ b/src/ast/struct_member_test.cc
@@ -46,7 +46,7 @@
EXPECT_EQ(st->source().range.end.column, 8u);
}
-TEST_F(StructMemberTest, Assert_EmptySymbol) {
+TEST_F(StructMemberTest, Assert_Empty_Symbol) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -55,7 +55,7 @@
"internal compiler error");
}
-TEST_F(StructMemberTest, Assert_NullType) {
+TEST_F(StructMemberTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -64,7 +64,7 @@
"internal compiler error");
}
-TEST_F(StructMemberTest, Assert_NullDecoration) {
+TEST_F(StructMemberTest, Assert_Null_Decoration) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -73,6 +73,16 @@
"internal compiler error");
}
+TEST_F(StructMemberTest, Assert_DifferentProgramID_Decoration) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Member("a", b1.ty.i32(), {b2.MemberSize(4)});
+ },
+ "internal compiler error");
+}
+
TEST_F(StructMemberTest, ToStr) {
auto* st = Member("a", ty.i32(), {MemberSize(4)});
EXPECT_EQ(str(st), "StructMember{[[ size 4 ]] a: __i32}\n");
diff --git a/src/ast/struct_test.cc b/src/ast/struct_test.cc
index 8676966..b7f9b1f 100644
--- a/src/ast/struct_test.cc
+++ b/src/ast/struct_test.cc
@@ -83,6 +83,28 @@
"internal compiler error");
}
+TEST_F(StructTest, Assert_DifferentProgramID_StructMember) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<Struct>(StructMemberList{b2.Member("a", b2.ty.i32())},
+ DecorationList{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(StructTest, Assert_DifferentProgramID_Decoration) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<Struct>(StructMemberList{b1.Member("a", b1.ty.i32())},
+ DecorationList{b2.create<StructBlockDecoration>()});
+ },
+ "internal compiler error");
+}
+
TEST_F(StructTest, ToStr) {
DecorationList decos;
decos.push_back(create<StructBlockDecoration>());
diff --git a/src/ast/switch_statement.cc b/src/ast/switch_statement.cc
index 136236b..876ca1e 100644
--- a/src/ast/switch_statement.cc
+++ b/src/ast/switch_statement.cc
@@ -27,8 +27,10 @@
CaseStatementList body)
: Base(program_id, source), condition_(condition), body_(body) {
TINT_ASSERT(condition_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(condition_, program_id);
for (auto* stmt : body_) {
TINT_ASSERT(stmt);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(stmt, program_id);
}
}
diff --git a/src/ast/switch_statement_test.cc b/src/ast/switch_statement_test.cc
index 4373c2d..c1a7b56 100644
--- a/src/ast/switch_statement_test.cc
+++ b/src/ast/switch_statement_test.cc
@@ -29,8 +29,7 @@
auto* ident = Expr("ident");
CaseStatementList body;
- auto* case_stmt =
- create<CaseStatement>(lit, create<BlockStatement>(StatementList{}));
+ auto* case_stmt = create<CaseStatement>(lit, Block());
body.push_back(case_stmt);
auto* stmt = create<SwitchStatement>(ident, body);
@@ -55,8 +54,7 @@
auto* ident = Expr("ident");
CaseStatementList body;
- body.push_back(
- create<CaseStatement>(lit, create<BlockStatement>(StatementList{})));
+ body.push_back(create<CaseStatement>(lit, Block()));
auto* stmt = create<SwitchStatement>(ident, body);
EXPECT_TRUE(stmt->Is<SwitchStatement>());
@@ -68,8 +66,7 @@
ProgramBuilder b;
CaseStatementList cases;
cases.push_back(
- b.create<CaseStatement>(CaseSelectorList{b.Literal(1)},
- b.create<BlockStatement>(StatementList{})));
+ b.create<CaseStatement>(CaseSelectorList{b.Literal(1)}, b.Block()));
b.create<SwitchStatement>(nullptr, cases);
},
"internal compiler error");
@@ -84,6 +81,38 @@
"internal compiler error");
}
+TEST_F(SwitchStatementTest, Assert_DifferentProgramID_Condition) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<SwitchStatement>(b2.Expr(true), CaseStatementList{
+ b1.create<CaseStatement>(
+ CaseSelectorList{
+ b1.Literal(1),
+ },
+ b1.Block()),
+ });
+ },
+ "internal compiler error");
+}
+
+TEST_F(SwitchStatementTest, Assert_DifferentProgramID_CaseStatement) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<SwitchStatement>(b1.Expr(true), CaseStatementList{
+ b2.create<CaseStatement>(
+ CaseSelectorList{
+ b2.Literal(1),
+ },
+ b2.Block()),
+ });
+ },
+ "internal compiler error");
+}
+
TEST_F(SwitchStatementTest, ToStr_Empty) {
auto* ident = Expr("ident");
@@ -102,8 +131,7 @@
auto* ident = Expr("ident");
CaseStatementList body;
- body.push_back(
- create<CaseStatement>(lit, create<BlockStatement>(StatementList{})));
+ body.push_back(create<CaseStatement>(lit, Block()));
auto* stmt = create<SwitchStatement>(ident, body);
EXPECT_EQ(str(stmt), R"(Switch{
diff --git a/src/ast/type_constructor_expression.cc b/src/ast/type_constructor_expression.cc
index 2cf6ea3..6ae0772 100644
--- a/src/ast/type_constructor_expression.cc
+++ b/src/ast/type_constructor_expression.cc
@@ -26,9 +26,10 @@
type::Type* type,
ExpressionList values)
: Base(program_id, source), type_(type), values_(std::move(values)) {
- TINT_ASSERT(type);
+ TINT_ASSERT(type_);
for (auto* val : values_) {
TINT_ASSERT(val);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(val, program_id);
}
}
diff --git a/src/ast/type_constructor_expression_test.cc b/src/ast/type_constructor_expression_test.cc
index b6c0e28..335385f 100644
--- a/src/ast/type_constructor_expression_test.cc
+++ b/src/ast/type_constructor_expression_test.cc
@@ -50,7 +50,7 @@
EXPECT_TRUE(t->Is<TypeConstructorExpression>());
}
-TEST_F(TypeConstructorExpressionTest, Assert_NullType) {
+TEST_F(TypeConstructorExpressionTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -59,7 +59,7 @@
"internal compiler error");
}
-TEST_F(TypeConstructorExpressionTest, Assert_NullValue) {
+TEST_F(TypeConstructorExpressionTest, Assert_Null_Value) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -69,6 +69,17 @@
"internal compiler error");
}
+TEST_F(TypeConstructorExpressionTest, Assert_DifferentProgramID_Value) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<TypeConstructorExpression>(b1.ty.i32(),
+ ExpressionList{b2.Expr(1)});
+ },
+ "internal compiler error");
+}
+
TEST_F(TypeConstructorExpressionTest, ToStr) {
type::Vector vec(ty.f32(), 3);
ExpressionList expr;
diff --git a/src/ast/unary_op_expression.cc b/src/ast/unary_op_expression.cc
index 5f97221..3009c74 100644
--- a/src/ast/unary_op_expression.cc
+++ b/src/ast/unary_op_expression.cc
@@ -27,6 +27,7 @@
Expression* expr)
: Base(program_id, source), op_(op), expr_(expr) {
TINT_ASSERT(expr_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(expr_, program_id);
}
UnaryOpExpression::UnaryOpExpression(UnaryOpExpression&&) = default;
diff --git a/src/ast/unary_op_expression_test.cc b/src/ast/unary_op_expression_test.cc
index 2f1242a..6f8a253 100644
--- a/src/ast/unary_op_expression_test.cc
+++ b/src/ast/unary_op_expression_test.cc
@@ -46,7 +46,7 @@
EXPECT_TRUE(u->Is<UnaryOpExpression>());
}
-TEST_F(UnaryOpExpressionTest, Assert_NullExpression) {
+TEST_F(UnaryOpExpressionTest, Assert_Null_Expression) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -55,6 +55,16 @@
"internal compiler error");
}
+TEST_F(UnaryOpExpressionTest, Assert_DifferentProgramID_Expression) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<UnaryOpExpression>(UnaryOp::kNot, b2.Expr(true));
+ },
+ "internal compiler error");
+}
+
TEST_F(UnaryOpExpressionTest, ToStr) {
auto* ident = Expr("ident");
auto* u = create<UnaryOpExpression>(UnaryOp::kNot, ident);
diff --git a/src/ast/variable.cc b/src/ast/variable.cc
index 13a50c8..92057bc 100644
--- a/src/ast/variable.cc
+++ b/src/ast/variable.cc
@@ -41,6 +41,7 @@
TINT_ASSERT(symbol_.IsValid());
// no type means we must have a constructor to infer it
TINT_ASSERT(declared_type_ || constructor);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(constructor, program_id);
}
Variable::Variable(Variable&&) = default;
diff --git a/src/ast/variable_decl_statement.cc b/src/ast/variable_decl_statement.cc
index f082e46..e636fb5 100644
--- a/src/ast/variable_decl_statement.cc
+++ b/src/ast/variable_decl_statement.cc
@@ -26,6 +26,7 @@
Variable* variable)
: Base(program_id, source), variable_(variable) {
TINT_ASSERT(variable_);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(variable_, program_id);
}
VariableDeclStatement::VariableDeclStatement(VariableDeclStatement&&) = default;
diff --git a/src/ast/variable_decl_statement_test.cc b/src/ast/variable_decl_statement_test.cc
index 3447ca2..37776e2 100644
--- a/src/ast/variable_decl_statement_test.cc
+++ b/src/ast/variable_decl_statement_test.cc
@@ -47,7 +47,7 @@
EXPECT_TRUE(stmt->Is<VariableDeclStatement>());
}
-TEST_F(VariableDeclStatementTest, Assert_NullVariable) {
+TEST_F(VariableDeclStatementTest, Assert_Null_Variable) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -56,6 +56,17 @@
"internal compiler error");
}
+TEST_F(VariableDeclStatementTest, Assert_DifferentProgramID_Variable) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<VariableDeclStatement>(
+ b2.Var("a", b2.ty.f32(), StorageClass::kNone));
+ },
+ "internal compiler error");
+}
+
TEST_F(VariableDeclStatementTest, ToStr) {
auto* var = Var("a", ty.f32(), StorageClass::kNone);
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc
index 929e922..9378bf6 100644
--- a/src/ast/variable_test.cc
+++ b/src/ast/variable_test.cc
@@ -71,7 +71,7 @@
"internal compiler error");
}
-TEST_F(VariableTest, Assert_NullType) {
+TEST_F(VariableTest, Assert_Null_Type) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
@@ -80,6 +80,16 @@
"internal compiler error");
}
+TEST_F(VariableTest, Assert_DifferentProgramID_Constructor) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.Var("x", b1.ty.f32(), StorageClass::kNone, b2.Expr(1.2f));
+ },
+ "internal compiler error");
+}
+
TEST_F(VariableTest, to_str) {
auto* v = Var("my_var", ty.f32(), StorageClass::kFunction);
EXPECT_EQ(str(v), R"(Variable{
diff --git a/src/program_id.h b/src/program_id.h
index 9b7240a..f6e0f36 100644
--- a/src/program_id.h
+++ b/src/program_id.h
@@ -16,9 +16,19 @@
#define SRC_PROGRAM_ID_H_
#include <stdint.h>
+#include <iostream>
+#include <utility>
+
+#include "src/debug.h"
namespace tint {
+/// If 1 then checks are enabled that AST nodes are not leaked from one program
+/// to another.
+/// TODO(bclayton): We'll want to disable this in production builds. For now we
+/// always check.
+#define TINT_CHECK_FOR_CROSS_PROGRAM_LEAKS 1
+
/// A ProgramID is a unique identifier of a Program.
/// ProgramID can be used to ensure that objects referenced by the Program are
/// owned exclusively by that Program and have accidentally not leaked from
@@ -28,7 +38,7 @@
/// Constructor
ProgramID();
- /// @returns a new ProgramID
+ /// @returns a new. globally unique ProgramID
static ProgramID New();
/// Equality operator
@@ -41,12 +51,85 @@
/// @returns true if the ProgramIDs are not equal
bool operator!=(const ProgramID& rhs) const { return val != rhs.val; }
+ /// @returns the numerical identifier value
+ uint32_t Value() const { return val; }
+
+ /// @returns true if this ProgramID is valid
+ operator bool() const { return val != 0; }
+
private:
explicit ProgramID(uint32_t);
uint32_t val = 0;
};
+/// A simple pass-through function for ProgramID. Intended to be overloaded for
+/// other types.
+/// @param id a ProgramID
+/// @returns id. Simple pass-through function
+inline ProgramID ProgramIDOf(ProgramID id) {
+ return id;
+}
+
+/// Writes the ProgramID to the std::ostream.
+/// @param out the std::ostream to write to
+/// @param id the program identifier to write
+/// @returns out so calls can be chained
+inline std::ostream& operator<<(std::ostream& out, ProgramID id) {
+ out << "Program<" << id.Value() << ">";
+ return out;
+}
+
+namespace detail {
+
+/// AssertProgramIDsEqual is called by TINT_ASSERT_PROGRAM_IDS_EQUAL() and
+/// TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID() to assert that the ProgramIDs of
+/// `a` and `b` are equal.
+template <typename A, typename B>
+void AssertProgramIDsEqual(A&& a,
+ B&& b,
+ bool if_valid,
+ const char* msg,
+ const char* file,
+ size_t line) {
+ auto a_id = ProgramIDOf(std::forward<A>(a));
+ auto b_id = ProgramIDOf(std::forward<B>(b));
+ if (a_id == b_id) {
+ return; // matched
+ }
+ if (if_valid && (!a_id || !b_id)) {
+ return; // a or b were not valid
+ }
+ diag::List diagnostics;
+ tint::InternalCompilerError(file, line, diagnostics) << msg;
+}
+
+} // namespace detail
+
+/// TINT_ASSERT_PROGRAM_IDS_EQUAL(A, B) is a macro that asserts that the program
+/// identifiers for A and B are equal.
+///
+/// TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(A, B) is a macro that asserts that
+/// the program identifiers for A and B are equal, if both A and B have valid
+/// program identifiers.
+#if TINT_CHECK_FOR_CROSS_PROGRAM_LEAKS
+#define TINT_ASSERT_PROGRAM_IDS_EQUAL(a, b) \
+ detail::AssertProgramIDsEqual( \
+ a, b, false, "TINT_ASSERT_PROGRAM_IDS_EQUAL(" #a ", " #b ")", __FILE__, \
+ __LINE__)
+#define TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(a, b) \
+ detail::AssertProgramIDsEqual( \
+ a, b, true, "TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(" #a ", " #b ")", \
+ __FILE__, __LINE__)
+#else
+#define TINT_ASSERT_PROGRAM_IDS_EQUAL(a, b) \
+ do { \
+ } while (false)
+#define TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(a, b) \
+ do { \
+ } while (false)
+#endif
+
} // namespace tint
#endif // SRC_PROGRAM_ID_H_