[tint][resolver] Mark all short-circuited RHS expressions as not-evaluated

Not just those that would otherwise be const-expression.

The rules around this were patchy, and there were situations where if
the RHS was a runtime-expression builtin call, the binary op would be
classed as EvaluationStage::kRuntime, the LHS expression
EvaluationStage::kConstant, and the RHS classed as EvaluationStage::kRuntime, but the RHS sub-expressions were marked EvaluationStage::kNotEvaluated.

This would lead to bad assumptions downstream.

With this change, a short-circuited binary expression with a
const-expression LHS will always become EvaluationStage::kConstant, with the LHS being EvaluationStage::kConstant and the RHS always being EvaluationStage::kNotEvaluated.

Fixed: 341124493
Change-Id: I47d4b7e3ce1a74f74db4a51e816f114aa3bec015
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/190400
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/constant/eval_binary_op_test.cc b/src/tint/lang/core/constant/eval_binary_op_test.cc
index df37060..2f47723 100644
--- a/src/tint/lang/core/constant/eval_binary_op_test.cc
+++ b/src/tint/lang/core/constant/eval_binary_op_test.cc
@@ -27,6 +27,7 @@
 
 #include "src/tint/lang/core/constant/eval_test.h"
 
+#include "src/tint/lang/wgsl/builtin_fn.h"
 #include "src/tint/utils/result/result.h"
 
 #if TINT_BUILD_WGSL_READER
@@ -2420,9 +2421,10 @@
     WrapInFunction(Decl(Var("b", Expr(false))), binary);
 
     EXPECT_TRUE(r()->Resolve()) << r()->error();
-    EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kRuntime);
+    ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), false);
     EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
-    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kRuntime);
+    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
 }
 
 TEST_F(ConstEvalTest, ShortCircuit_Or_RHSVarDecl) {
@@ -2434,9 +2436,40 @@
     WrapInFunction(Decl(Var("b", Expr(false))), binary);
 
     EXPECT_TRUE(r()->Resolve()) << r()->error();
-    EXPECT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kRuntime);
+    ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), true);
     EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
-    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kRuntime);
+    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
+}
+
+TEST_F(ConstEvalTest, ShortCircuit_And_RHSRuntimeBuiltin) {
+    // fn f() {
+    //   var b = false;
+    //   let result = false && any(b);
+    // }
+    auto* binary = LogicalAnd(false, Call(wgsl::BuiltinFn::kAny, "b"));
+    WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), false);
+    EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
+}
+
+TEST_F(ConstEvalTest, ShortCircuit_Or_RHSRuntimeBuiltin) {
+    // fn f() {
+    //   var b = false;
+    //   let result = true || any(b);
+    // }
+    auto* binary = LogicalOr(true, Call(wgsl::BuiltinFn::kAny, "b"));
+    WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    ASSERT_EQ(Sem().Get(binary)->Stage(), core::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().Get(binary)->ConstantValue()->ValueAs<bool>(), true);
+    EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), core::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), core::EvaluationStage::kNotEvaluated);
 }
 
 ////////////////////////////////////////////////
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 56b70e5..2f14dd4 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -1594,7 +1594,7 @@
                     // Mark entire expression tree to not const-evaluate
                     auto r = ast::TraverseExpressions(  //
                         (*binary)->rhs, [&](const ast::Expression* e) {
-                            skip_const_eval_.Add(e);
+                            not_evaluated_.Add(e);
                             return ast::TraverseAction::Descend;
                         });
                     if (!r) {
@@ -1878,7 +1878,7 @@
         return expr;
     }
 
-    auto* load = b.create<sem::Load>(expr, current_statement_);
+    auto* load = b.create<sem::Load>(expr, current_statement_, expr->Stage());
     load->Behaviors() = expr->Behaviors();
     b.Sem().Replace(expr->Declaration(), load);
 
@@ -1909,7 +1909,7 @@
     }
 
     const core::constant::Value* materialized_val = nullptr;
-    if (!skip_const_eval_.Contains(decl)) {
+    if (!not_evaluated_.Contains(decl)) {
         auto expr_val = expr->ConstantValue();
         if (TINT_UNLIKELY(!expr_val)) {
             ICE(decl->source) << "Materialize(" << decl->TypeInfo().name
@@ -2046,7 +2046,7 @@
 
     const core::constant::Value* val = nullptr;
     auto stage = core::EarliestStage(obj->Stage(), idx->Stage());
-    if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+    if (not_evaluated_.Contains(expr)) {
         stage = core::EvaluationStage::kNotEvaluated;
     } else {
         if (auto* idx_val = idx->ConstantValue()) {
@@ -2139,7 +2139,7 @@
 
         const core::constant::Value* value = nullptr;
         auto stage = core::EarliestStage(overload_stage, args_stage);
-        if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+        if (not_evaluated_.Contains(expr)) {
             stage = core::EvaluationStage::kNotEvaluated;
         }
         if (stage == core::EvaluationStage::kConstant) {
@@ -2165,7 +2165,7 @@
                                const sem::CallTarget* call_target) -> sem::Call* {
         auto stage = args_stage;                       // The evaluation stage of the call
         const core::constant::Value* value = nullptr;  // The constant value for the call
-        if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+        if (not_evaluated_.Contains(expr)) {
             stage = core::EvaluationStage::kNotEvaluated;
         }
         if (stage == core::EvaluationStage::kConstant) {
@@ -2432,7 +2432,7 @@
     // now.
     const core::constant::Value* value = nullptr;
     auto stage = core::EarliestStage(arg_stage, target->Stage());
-    if (stage == core::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
+    if (not_evaluated_.Contains(expr)) {
         stage = core::EvaluationStage::kNotEvaluated;
     }
     if (stage == core::EvaluationStage::kConstant) {
@@ -3109,8 +3109,8 @@
         return nullptr;
     }
 
-    auto stage = skip_const_eval_.Contains(expr) ? core::EvaluationStage::kNotEvaluated
-                                                 : core::EvaluationStage::kRuntime;
+    auto stage = not_evaluated_.Contains(expr) ? core::EvaluationStage::kNotEvaluated
+                                               : core::EvaluationStage::kRuntime;
 
     // TODO(crbug.com/tint/1420): For now, assume all function calls have side effects.
     bool has_side_effects = true;
@@ -3238,7 +3238,7 @@
 
     const core::constant::Value* val = nullptr;
     auto stage = core::EvaluationStage::kConstant;
-    if (skip_const_eval_.Contains(literal)) {
+    if (not_evaluated_.Contains(literal)) {
         stage = core::EvaluationStage::kNotEvaluated;
     }
     if (stage == core::EvaluationStage::kConstant) {
@@ -3292,7 +3292,7 @@
 
                 auto stage = variable->Stage();
                 const core::constant::Value* value = variable->ConstantValue();
-                if (skip_const_eval_.Contains(expr)) {
+                if (not_evaluated_.Contains(expr)) {
                     // This expression is short-circuited by an ancestor expression.
                     // Do not const-eval.
                     stage = core::EvaluationStage::kNotEvaluated;
@@ -3632,7 +3632,7 @@
     }
 
     const core::constant::Value* value = nullptr;
-    if (skip_const_eval_.Contains(expr)) {
+    if (not_evaluated_.Contains(expr)) {
         // This expression is short-circuited by an ancestor expression.
         // Do not const-eval.
         stage = core::EvaluationStage::kNotEvaluated;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index a1abde7..80394c2 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -722,7 +722,7 @@
     uint32_t current_scoping_depth_ = 0;
     Hashset<TypeAndAddressSpace, 8> valid_type_storage_layouts_;
     Hashmap<const ast::Expression*, const ast::BinaryExpression*, 8> logical_binary_lhs_to_parent_;
-    Hashset<const ast::Expression*, 8> skip_const_eval_;
+    Hashset<const ast::Expression*, 8> not_evaluated_;
     Hashmap<const core::type::Type*, size_t, 8> nest_depth_;
     Hashmap<std::pair<core::intrinsic::Overload, wgsl::BuiltinFn>, sem::BuiltinFn*, 64> builtins_;
     Hashmap<core::intrinsic::Overload, sem::ValueConstructor*, 16> constructors_;
diff --git a/src/tint/lang/wgsl/sem/load.cc b/src/tint/lang/wgsl/sem/load.cc
index 1f3c18d..636c24e 100644
--- a/src/tint/lang/wgsl/sem/load.cc
+++ b/src/tint/lang/wgsl/sem/load.cc
@@ -33,10 +33,10 @@
 TINT_INSTANTIATE_TYPEINFO(tint::sem::Load);
 
 namespace tint::sem {
-Load::Load(const ValueExpression* ref, const Statement* statement)
+Load::Load(const ValueExpression* ref, const Statement* statement, core::EvaluationStage stage)
     : Base(/* declaration */ ref->Declaration(),
            /* type */ ref->Type()->UnwrapRef(),
-           /* stage */ core::EvaluationStage::kRuntime,  // Loads can only be runtime
+           /* stage */ stage,
            /* statement */ statement,
            /* constant */ nullptr,
            /* has_side_effects */ ref->HasSideEffects(),
diff --git a/src/tint/lang/wgsl/sem/load.h b/src/tint/lang/wgsl/sem/load.h
index 035682d01..1711a25 100644
--- a/src/tint/lang/wgsl/sem/load.h
+++ b/src/tint/lang/wgsl/sem/load.h
@@ -41,7 +41,8 @@
     /// Constructor
     /// @param reference the reference expression being loaded
     /// @param statement the statement that owns this expression
-    Load(const ValueExpression* reference, const Statement* statement);
+    /// @param stage the earliest evaluation stage for the expression
+    Load(const ValueExpression* reference, const Statement* statement, core::EvaluationStage stage);
 
     /// Destructor
     ~Load() override;
diff --git a/test/tint/bug/chromium/341124493.wgsl b/test/tint/bug/chromium/341124493.wgsl
new file mode 100644
index 0000000..5e6c50e
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl
@@ -0,0 +1,5 @@
+fn F() {
+  var b : bool;
+  if false && select(!b, true, true) {
+  }
+}
diff --git a/test/tint/bug/chromium/341124493.wgsl.expected.dxc.hlsl b/test/tint/bug/chromium/341124493.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..8eba3f0
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl.expected.dxc.hlsl
@@ -0,0 +1,10 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void F() {
+  bool b = false;
+  if (false) {
+  }
+}
diff --git a/test/tint/bug/chromium/341124493.wgsl.expected.fxc.hlsl b/test/tint/bug/chromium/341124493.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..8eba3f0
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl.expected.fxc.hlsl
@@ -0,0 +1,10 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void F() {
+  bool b = false;
+  if (false) {
+  }
+}
diff --git a/test/tint/bug/chromium/341124493.wgsl.expected.glsl b/test/tint/bug/chromium/341124493.wgsl.expected.glsl
new file mode 100644
index 0000000..9ce0b04
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl.expected.glsl
@@ -0,0 +1,12 @@
+#version 310 es
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+  return;
+}
+void F() {
+  bool b = false;
+  if (false) {
+  }
+}
+
diff --git a/test/tint/bug/chromium/341124493.wgsl.expected.msl b/test/tint/bug/chromium/341124493.wgsl.expected.msl
new file mode 100644
index 0000000..da3f5fe
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl.expected.msl
@@ -0,0 +1,9 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void F() {
+  bool b = false;
+  if (false) {
+  }
+}
+
diff --git a/test/tint/bug/chromium/341124493.wgsl.expected.spvasm b/test/tint/bug/chromium/341124493.wgsl.expected.spvasm
new file mode 100644
index 0000000..82db5bb
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl.expected.spvasm
@@ -0,0 +1,31 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 13
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+               OpExecutionMode %unused_entry_point LocalSize 1 1 1
+               OpName %unused_entry_point "unused_entry_point"
+               OpName %F "F"
+               OpName %b "b"
+       %void = OpTypeVoid
+          %1 = OpTypeFunction %void
+       %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+         %10 = OpConstantNull %bool
+%unused_entry_point = OpFunction %void None %1
+          %4 = OpLabel
+               OpReturn
+               OpFunctionEnd
+          %F = OpFunction %void None %1
+          %6 = OpLabel
+          %b = OpVariable %_ptr_Function_bool Function %10
+               OpSelectionMerge %11 None
+               OpBranchConditional %10 %12 %11
+         %12 = OpLabel
+               OpBranch %11
+         %11 = OpLabel
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/chromium/341124493.wgsl.expected.wgsl b/test/tint/bug/chromium/341124493.wgsl.expected.wgsl
new file mode 100644
index 0000000..279caab
--- /dev/null
+++ b/test/tint/bug/chromium/341124493.wgsl.expected.wgsl
@@ -0,0 +1,5 @@
+fn F() {
+  var b : bool;
+  if ((false && select(!(b), true, true))) {
+  }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
index 3fb56d3..66be2a1 100644
--- a/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
+++ b/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
@@ -8,11 +8,7 @@
   bool y = false;
   bool tint_tmp = x;
   if (tint_tmp) {
-    bool tint_tmp_1 = true;
-    if (!tint_tmp_1) {
-      tint_tmp_1 = y;
-    }
-    tint_tmp = (tint_tmp_1);
+    tint_tmp = true;
   }
   if ((tint_tmp)) {
   }
diff --git a/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
index 3fb56d3..66be2a1 100644
--- a/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
+++ b/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
@@ -8,11 +8,7 @@
   bool y = false;
   bool tint_tmp = x;
   if (tint_tmp) {
-    bool tint_tmp_1 = true;
-    if (!tint_tmp_1) {
-      tint_tmp_1 = y;
-    }
-    tint_tmp = (tint_tmp_1);
+    tint_tmp = true;
   }
   if ((tint_tmp)) {
   }
diff --git a/test/tint/bug/tint/1961.wgsl.expected.glsl b/test/tint/bug/tint/1961.wgsl.expected.glsl
index 4b62f21..658f318 100644
--- a/test/tint/bug/tint/1961.wgsl.expected.glsl
+++ b/test/tint/bug/tint/1961.wgsl.expected.glsl
@@ -9,11 +9,7 @@
   bool y = false;
   bool tint_tmp = x;
   if (tint_tmp) {
-    bool tint_tmp_1 = true;
-    if (!tint_tmp_1) {
-      tint_tmp_1 = y;
-    }
-    tint_tmp = (tint_tmp_1);
+    tint_tmp = true;
   }
   if ((tint_tmp)) {
   }
diff --git a/test/tint/bug/tint/1961.wgsl.expected.msl b/test/tint/bug/tint/1961.wgsl.expected.msl
index cdc071a..07c9a52 100644
--- a/test/tint/bug/tint/1961.wgsl.expected.msl
+++ b/test/tint/bug/tint/1961.wgsl.expected.msl
@@ -4,7 +4,7 @@
 void f() {
   bool x = false;
   bool y = false;
-  if ((x && (true || y))) {
+  if ((x && true)) {
   }
 }
 
diff --git a/test/tint/bug/tint/1961.wgsl.expected.spvasm b/test/tint/bug/tint/1961.wgsl.expected.spvasm
index 633c7d4..c206710 100644
--- a/test/tint/bug/tint/1961.wgsl.expected.spvasm
+++ b/test/tint/bug/tint/1961.wgsl.expected.spvasm
@@ -1,7 +1,7 @@
 ; SPIR-V
 ; Version: 1.3
 ; Generator: Google Tint Compiler; 0
-; Bound: 23
+; Bound: 19
 ; Schema: 0
                OpCapability Shader
                OpMemoryModel Logical GLSL450
@@ -31,20 +31,13 @@
                OpSelectionMerge %13 None
                OpBranchConditional %12 %14 %13
          %14 = OpLabel
-               OpSelectionMerge %16 None
-               OpBranchConditional %true %16 %17
-         %17 = OpLabel
-         %18 = OpLoad %bool %y
-               OpBranch %16
-         %16 = OpLabel
-         %19 = OpPhi %bool %true %14 %18 %17
                OpBranch %13
          %13 = OpLabel
-         %20 = OpPhi %bool %12 %6 %19 %16
-               OpSelectionMerge %21 None
-               OpBranchConditional %20 %22 %21
-         %22 = OpLabel
-               OpBranch %21
-         %21 = OpLabel
+         %16 = OpPhi %bool %12 %6 %true %14
+               OpSelectionMerge %17 None
+               OpBranchConditional %16 %18 %17
+         %18 = OpLabel
+               OpBranch %17
+         %17 = OpLabel
                OpReturn
                OpFunctionEnd