[tint][resolver] Fix short-circuiting of const identifiers
Add asserts to catch bad behavior in the future.
Fixed: tint:1961
Change-Id: I5a64662f678a54ef4b5aa93642f5ef54b3518fd3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/137120
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 4aeced9..3fa71cd 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -1820,16 +1820,18 @@
ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- // Note: Due to short-circuiting, this function is only called if lhs is true, so we could
- // technically only return the value of the rhs.
- return CreateScalar(source, ty, args[0]->ValueAs<bool>() && args[1]->ValueAs<bool>());
+ // Due to short-circuiting, this function is only called if lhs is true, so we only return the
+ // value of the rhs.
+ TINT_ASSERT(Resolver, args[0]->ValueAs<bool>());
+ return CreateScalar(source, ty, args[1]->ValueAs<bool>());
}
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- // Note: Due to short-circuiting, this function is only called if lhs is false, so we could
- // technically only return the value of the rhs.
+ // Due to short-circuiting, this function is only called if lhs is false, so we only only return
+ // the value of the rhs.
+ TINT_ASSERT(Resolver, !args[0]->ValueAs<bool>());
return CreateScalar(source, ty, args[1]->ValueAs<bool>());
}
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 6bf5830..e61599f 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -2239,6 +2239,84 @@
}
////////////////////////////////////////////////
+// Short-Circuit with RHS Variable Access
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSConstDecl) {
+ // const FALSE = false;
+ // const result = FALSE && FALSE;
+ GlobalConst("FALSE", Expr(false));
+ auto* binary = LogicalAnd("FALSE", "FALSE");
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSConstDecl) {
+ // const TRUE = true;
+ // const result = TRUE || TRUE;
+ GlobalConst("TRUE", Expr(true));
+ auto* binary = LogicalOr("TRUE", "TRUE");
+ GlobalConst("result", binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSLetDecl) {
+ // fn f() {
+ // let b = false;
+ // let result = false && b;
+ // }
+ auto* binary = LogicalAnd(false, "b");
+ WrapInFunction(Decl(Let("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSLetDecl) {
+ // fn f() {
+ // let b = false;
+ // let result = true || b;
+ // }
+ auto* binary = LogicalOr(true, "b");
+ WrapInFunction(Decl(Let("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSVarDecl) {
+ // fn f() {
+ // var b = false;
+ // let result = false && b;
+ // }
+ auto* binary = LogicalAnd(false, "b");
+ WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSVarDecl) {
+ // fn f() {
+ // var b = false;
+ // let result = true || b;
+ // }
+ auto* binary = LogicalOr(true, "b");
+ WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), sem::EvaluationStage::kConstant);
+ EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+////////////////////////////////////////////////
// Short-Circuit Swizzle
////////////////////////////////////////////////
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 66791ca..8a57ec8 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -3057,9 +3057,16 @@
return Switch(
resolved_node, //
[&](sem::Variable* variable) -> sem::VariableUser* {
- auto symbol = ident->symbol;
- auto* user =
- builder_->create<sem::VariableUser>(expr, current_statement_, variable);
+ auto stage = variable->Stage();
+ const constant::Value* value = variable->ConstantValue();
+ if (skip_const_eval_.Contains(expr)) {
+ // This expression is short-circuited by an ancestor expression.
+ // Do not const-eval.
+ stage = sem::EvaluationStage::kNotEvaluated;
+ value = nullptr;
+ }
+ auto* user = builder_->create<sem::VariableUser>(expr, stage, current_statement_,
+ value, variable);
if (current_statement_) {
// If identifier is part of a loop continuing block, make sure it
@@ -3073,6 +3080,7 @@
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
+ auto symbol = ident->symbol;
if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
@@ -3118,7 +3126,7 @@
// Note: The spec is currently vague around the rules here. See
// https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when
// resolved.
- std::string desc = "var '" + symbol.Name() + "' ";
+ std::string desc = "var '" + ident->symbol.Name() + "' ";
AddError(desc + "cannot be referenced at module-scope", expr->source);
AddNote(desc + "declared here", variable->Declaration()->source);
return nullptr;
@@ -3368,8 +3376,19 @@
if (!lhs || !rhs) {
return nullptr;
}
- auto* lhs_ty = lhs->Type()->UnwrapRef();
- auto* rhs_ty = rhs->Type()->UnwrapRef();
+
+ // Load arguments if they are references
+ lhs = Load(lhs);
+ if (!lhs) {
+ return nullptr;
+ }
+ rhs = Load(rhs);
+ if (!rhs) {
+ return nullptr;
+ }
+
+ auto* lhs_ty = lhs->Type();
+ auto* rhs_ty = rhs->Type();
auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, stage, expr->source, false);
@@ -3389,16 +3408,6 @@
}
}
- // Load arguments if they are references
- lhs = Load(lhs);
- if (!lhs) {
- return nullptr;
- }
- rhs = Load(rhs);
- if (!rhs) {
- return nullptr;
- }
-
const constant::Value* value = nullptr;
if (skip_const_eval_.Contains(expr)) {
// This expression is short-circuited by an ancestor expression.
diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc
index 2e3cb88..a220a80 100644
--- a/src/tint/sem/variable.cc
+++ b/src/tint/sem/variable.cc
@@ -86,13 +86,15 @@
Parameter::~Parameter() = default;
VariableUser::VariableUser(const ast::IdentifierExpression* declaration,
+ EvaluationStage stage,
Statement* statement,
+ const constant::Value* constant,
sem::Variable* variable)
: Base(declaration,
variable->Type(),
- variable->Stage(),
+ stage,
statement,
- variable->ConstantValue(),
+ constant,
/* has_side_effects */ false),
variable_(variable) {
auto* type = variable->Type();
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 4cc0b1e..7660df3 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -259,10 +259,14 @@
public:
/// Constructor
/// @param declaration the AST identifier node
+ /// @param stage the evaluation stage for an expression of this variable type
/// @param statement the statement that owns this expression
+ /// @param constant the constant value of the expression. May be null
/// @param variable the semantic variable
VariableUser(const ast::IdentifierExpression* declaration,
+ EvaluationStage stage,
Statement* statement,
+ const constant::Value* constant,
sem::Variable* variable);
~VariableUser() override;
diff --git a/test/tint/bug/tint/1961.wgsl b/test/tint/bug/tint/1961.wgsl
new file mode 100644
index 0000000..62097de
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl
@@ -0,0 +1,10 @@
+const TRUE = true;
+const FALSE = false;
+const_assert(true || FALSE);
+const_assert(!(false && true));
+
+fn f() {
+ var x = false;
+ var y = false;
+ if (x && (true || y)) { }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..3fb56d3
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
@@ -0,0 +1,19 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+void f() {
+ bool x = false;
+ bool y = false;
+ bool tint_tmp = x;
+ if (tint_tmp) {
+ bool tint_tmp_1 = true;
+ if (!tint_tmp_1) {
+ tint_tmp_1 = y;
+ }
+ tint_tmp = (tint_tmp_1);
+ }
+ if ((tint_tmp)) {
+ }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..3fb56d3
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
@@ -0,0 +1,19 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+ return;
+}
+
+void f() {
+ bool x = false;
+ bool y = false;
+ bool tint_tmp = x;
+ if (tint_tmp) {
+ bool tint_tmp_1 = true;
+ if (!tint_tmp_1) {
+ tint_tmp_1 = y;
+ }
+ tint_tmp = (tint_tmp_1);
+ }
+ if ((tint_tmp)) {
+ }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.glsl b/test/tint/bug/tint/1961.wgsl.expected.glsl
new file mode 100644
index 0000000..4b62f21
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.glsl
@@ -0,0 +1,21 @@
+#version 310 es
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+ return;
+}
+void f() {
+ bool x = false;
+ bool y = false;
+ bool tint_tmp = x;
+ if (tint_tmp) {
+ bool tint_tmp_1 = true;
+ if (!tint_tmp_1) {
+ tint_tmp_1 = y;
+ }
+ tint_tmp = (tint_tmp_1);
+ }
+ if ((tint_tmp)) {
+ }
+}
+
diff --git a/test/tint/bug/tint/1961.wgsl.expected.msl b/test/tint/bug/tint/1961.wgsl.expected.msl
new file mode 100644
index 0000000..cdc071a
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.msl
@@ -0,0 +1,10 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void f() {
+ bool x = false;
+ bool y = false;
+ if ((x && (true || y))) {
+ }
+}
+
diff --git a/test/tint/bug/tint/1961.wgsl.expected.spvasm b/test/tint/bug/tint/1961.wgsl.expected.spvasm
new file mode 100644
index 0000000..633c7d4
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.spvasm
@@ -0,0 +1,50 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 23
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+ OpExecutionMode %unused_entry_point LocalSize 1 1 1
+ OpName %unused_entry_point "unused_entry_point"
+ OpName %f "f"
+ OpName %x "x"
+ OpName %y "y"
+ %void = OpTypeVoid
+ %1 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %8 = OpConstantNull %bool
+%_ptr_Function_bool = OpTypePointer Function %bool
+ %true = OpConstantTrue %bool
+%unused_entry_point = OpFunction %void None %1
+ %4 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %f = OpFunction %void None %1
+ %6 = OpLabel
+ %x = OpVariable %_ptr_Function_bool Function %8
+ %y = OpVariable %_ptr_Function_bool Function %8
+ OpStore %x %8
+ OpStore %y %8
+ %12 = OpLoad %bool %x
+ OpSelectionMerge %13 None
+ OpBranchConditional %12 %14 %13
+ %14 = OpLabel
+ OpSelectionMerge %16 None
+ OpBranchConditional %true %16 %17
+ %17 = OpLabel
+ %18 = OpLoad %bool %y
+ OpBranch %16
+ %16 = OpLabel
+ %19 = OpPhi %bool %true %14 %18 %17
+ OpBranch %13
+ %13 = OpLabel
+ %20 = OpPhi %bool %12 %6 %19 %16
+ OpSelectionMerge %21 None
+ OpBranchConditional %20 %22 %21
+ %22 = OpLabel
+ OpBranch %21
+ %21 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/bug/tint/1961.wgsl.expected.wgsl b/test/tint/bug/tint/1961.wgsl.expected.wgsl
new file mode 100644
index 0000000..dd8dc1a
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.wgsl
@@ -0,0 +1,14 @@
+const TRUE = true;
+
+const FALSE = false;
+
+const_assert (true || FALSE);
+
+const_assert !((false && true));
+
+fn f() {
+ var x = false;
+ var y = false;
+ if ((x && (true || y))) {
+ }
+}