[tint][resolver] Fix evaluation stage of function calls
A function call on the RHS of a constant short-circuited expression should be `core::EvaluationStage::kNotEvaluated`.
Change-Id: Idd70f908afb68d599d4e2b46071f7bd813f45fad
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/187690
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc b/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc
index 3a3714a..7873d8f 100644
--- a/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc
+++ b/src/tint/lang/wgsl/resolver/evaluation_stage_test.cc
@@ -29,6 +29,7 @@
#include "gmock/gmock.h"
#include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
+#include "src/tint/utils/containers/slice.h"
namespace tint::resolver {
namespace {
@@ -355,5 +356,58 @@
EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
}
+TEST_F(ResolverEvaluationStageTest, FnCall_Runtime) {
+ // fn f() -> bool { return true; }
+ // let l = false
+ // let result = l && f();
+ Func("f", Empty, ty.bool_(), Vector{Return(true)});
+ auto* let = Let("l", Expr(false));
+ auto* lhs = Expr(let);
+ auto* rhs = Call("f");
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Let("result", binary);
+ WrapInFunction(let, result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), core::EvaluationStage::kRuntime);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverEvaluationStageTest, FnCall_NotEvaluated) {
+ // fn f() -> bool { return true; }
+ // let result = false && f();
+ Func("f", Empty, ty.bool_(), Vector{Return(true)});
+ auto* rhs = Call("f");
+ auto* lhs = Expr(false);
+ auto* binary = LogicalAnd(lhs, rhs);
+ auto* result = Let("result", binary);
+ WrapInFunction(result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+}
+
+TEST_F(ResolverEvaluationStageTest, NestedFnCall_NotEvaluated) {
+ // fn f(b : bool) -> bool { return b; }
+ // let result = false && f(f(f(1 == 0)));
+ Func("f", Vector{Param("b", ty.bool_())}, ty.bool_(), Vector{Return("b")});
+ auto* cmp = Equal(0_i, 1_i);
+ auto* rhs_0 = Call("f", cmp);
+ auto* rhs_1 = Call("f", rhs_0);
+ auto* rhs_2 = Call("f", rhs_1);
+ auto* lhs = Expr(false);
+ auto* binary = LogicalAnd(lhs, rhs_2);
+ auto* result = Let("result", binary);
+ WrapInFunction(result);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(Sem().Get(cmp)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(rhs_0)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(rhs_1)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(rhs_2)->Stage(), core::EvaluationStage::kNotEvaluated);
+ EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index a400979..34d3af3 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -3072,10 +3072,12 @@
return nullptr;
}
+ auto stage = skip_const_eval_.Contains(expr) ? core::EvaluationStage::kNotEvaluated
+ : core::EvaluationStage::kRuntime;
+
// TODO(crbug.com/tint/1420): For now, assume all function calls have side effects.
bool has_side_effects = true;
- auto* call = b.create<sem::Call>(expr, target, core::EvaluationStage::kRuntime, std::move(args),
- current_statement_,
+ auto* call = b.create<sem::Call>(expr, target, stage, std::move(args), current_statement_,
/* constant_value */ nullptr, has_side_effects);
target->AddCallSite(call);