[tint][resolver] Fix short-circuiting of const identifiers

Add asserts to catch bad behavior in the future.

Fixed: tint:1961
Change-Id: I5a64662f678a54ef4b5aa93642f5ef54b3518fd3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/137120
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 4aeced9..3fa71cd 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -1820,16 +1820,18 @@
 ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty,
                                           utils::VectorRef<const constant::Value*> args,
                                           const Source& source) {
-    // Note: Due to short-circuiting, this function is only called if lhs is true, so we could
-    // technically only return the value of the rhs.
-    return CreateScalar(source, ty, args[0]->ValueAs<bool>() && args[1]->ValueAs<bool>());
+    // Due to short-circuiting, this function is only called if lhs is true, so we only return the
+    // value of the rhs.
+    TINT_ASSERT(Resolver, args[0]->ValueAs<bool>());
+    return CreateScalar(source, ty, args[1]->ValueAs<bool>());
 }
 
 ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
                                          utils::VectorRef<const constant::Value*> args,
                                          const Source& source) {
-    // Note: Due to short-circuiting, this function is only called if lhs is false, so we could
-    // technically only return the value of the rhs.
+    // Due to short-circuiting, this function is only called if lhs is false, so we only only return
+    // the value of the rhs.
+    TINT_ASSERT(Resolver, !args[0]->ValueAs<bool>());
     return CreateScalar(source, ty, args[1]->ValueAs<bool>());
 }
 
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 6bf5830..e61599f 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -2239,6 +2239,84 @@
 }
 
 ////////////////////////////////////////////////
+// Short-Circuit with RHS Variable Access
+////////////////////////////////////////////////
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSConstDecl) {
+    // const FALSE = false;
+    // const result = FALSE && FALSE;
+    GlobalConst("FALSE", Expr(false));
+    auto* binary = LogicalAnd("FALSE", "FALSE");
+    GlobalConst("result", binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSConstDecl) {
+    // const TRUE = true;
+    // const result = TRUE || TRUE;
+    GlobalConst("TRUE", Expr(true));
+    auto* binary = LogicalOr("TRUE", "TRUE");
+    GlobalConst("result", binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSLetDecl) {
+    // fn f() {
+    //   let b = false;
+    //   let result = false && b;
+    // }
+    auto* binary = LogicalAnd(false, "b");
+    WrapInFunction(Decl(Let("b", Expr(false))), binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    ValidateAnd(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSLetDecl) {
+    // fn f() {
+    //   let b = false;
+    //   let result = true || b;
+    // }
+    auto* binary = LogicalOr(true, "b");
+    WrapInFunction(Decl(Let("b", Expr(false))), binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    ValidateOr(Sem(), binary);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_And_RHSVarDecl) {
+    // fn f() {
+    //   var b = false;
+    //   let result = false && b;
+    // }
+    auto* binary = LogicalAnd(false, "b");
+    WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+    EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), sem::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+TEST_F(ResolverConstEvalTest, ShortCircuit_Or_RHSVarDecl) {
+    // fn f() {
+    //   var b = false;
+    //   let result = true || b;
+    // }
+    auto* binary = LogicalOr(true, "b");
+    WrapInFunction(Decl(Var("b", Expr(false))), binary);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+    EXPECT_EQ(Sem().Get(binary)->Stage(), sem::EvaluationStage::kRuntime);
+    EXPECT_EQ(Sem().GetVal(binary->lhs)->Stage(), sem::EvaluationStage::kConstant);
+    EXPECT_EQ(Sem().GetVal(binary->rhs)->Stage(), sem::EvaluationStage::kRuntime);
+}
+
+////////////////////////////////////////////////
 // Short-Circuit Swizzle
 ////////////////////////////////////////////////
 
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 66791ca..8a57ec8 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -3057,9 +3057,16 @@
         return Switch(
             resolved_node,  //
             [&](sem::Variable* variable) -> sem::VariableUser* {
-                auto symbol = ident->symbol;
-                auto* user =
-                    builder_->create<sem::VariableUser>(expr, current_statement_, variable);
+                auto stage = variable->Stage();
+                const constant::Value* value = variable->ConstantValue();
+                if (skip_const_eval_.Contains(expr)) {
+                    // This expression is short-circuited by an ancestor expression.
+                    // Do not const-eval.
+                    stage = sem::EvaluationStage::kNotEvaluated;
+                    value = nullptr;
+                }
+                auto* user = builder_->create<sem::VariableUser>(expr, stage, current_statement_,
+                                                                 value, variable);
 
                 if (current_statement_) {
                     // If identifier is part of a loop continuing block, make sure it
@@ -3073,6 +3080,7 @@
                         if (loop_block->FirstContinue()) {
                             // If our identifier is in loop_block->decls, make sure its index is
                             // less than first_continue
+                            auto symbol = ident->symbol;
                             if (auto decl = loop_block->Decls().Find(symbol)) {
                                 if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
                                     AddError("continue statement bypasses declaration of '" +
@@ -3118,7 +3126,7 @@
                     // Note: The spec is currently vague around the rules here. See
                     // https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when
                     // resolved.
-                    std::string desc = "var '" + symbol.Name() + "' ";
+                    std::string desc = "var '" + ident->symbol.Name() + "' ";
                     AddError(desc + "cannot be referenced at module-scope", expr->source);
                     AddNote(desc + "declared here", variable->Declaration()->source);
                     return nullptr;
@@ -3368,8 +3376,19 @@
     if (!lhs || !rhs) {
         return nullptr;
     }
-    auto* lhs_ty = lhs->Type()->UnwrapRef();
-    auto* rhs_ty = rhs->Type()->UnwrapRef();
+
+    // Load arguments if they are references
+    lhs = Load(lhs);
+    if (!lhs) {
+        return nullptr;
+    }
+    rhs = Load(rhs);
+    if (!rhs) {
+        return nullptr;
+    }
+
+    auto* lhs_ty = lhs->Type();
+    auto* rhs_ty = rhs->Type();
 
     auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
     auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, stage, expr->source, false);
@@ -3389,16 +3408,6 @@
         }
     }
 
-    // Load arguments if they are references
-    lhs = Load(lhs);
-    if (!lhs) {
-        return nullptr;
-    }
-    rhs = Load(rhs);
-    if (!rhs) {
-        return nullptr;
-    }
-
     const constant::Value* value = nullptr;
     if (skip_const_eval_.Contains(expr)) {
         // This expression is short-circuited by an ancestor expression.
diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc
index 2e3cb88..a220a80 100644
--- a/src/tint/sem/variable.cc
+++ b/src/tint/sem/variable.cc
@@ -86,13 +86,15 @@
 Parameter::~Parameter() = default;
 
 VariableUser::VariableUser(const ast::IdentifierExpression* declaration,
+                           EvaluationStage stage,
                            Statement* statement,
+                           const constant::Value* constant,
                            sem::Variable* variable)
     : Base(declaration,
            variable->Type(),
-           variable->Stage(),
+           stage,
            statement,
-           variable->ConstantValue(),
+           constant,
            /* has_side_effects */ false),
       variable_(variable) {
     auto* type = variable->Type();
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 4cc0b1e..7660df3 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -259,10 +259,14 @@
   public:
     /// Constructor
     /// @param declaration the AST identifier node
+    /// @param stage the evaluation stage for an expression of this variable type
     /// @param statement the statement that owns this expression
+    /// @param constant the constant value of the expression. May be null
     /// @param variable the semantic variable
     VariableUser(const ast::IdentifierExpression* declaration,
+                 EvaluationStage stage,
                  Statement* statement,
+                 const constant::Value* constant,
                  sem::Variable* variable);
     ~VariableUser() override;
 
diff --git a/test/tint/bug/tint/1961.wgsl b/test/tint/bug/tint/1961.wgsl
new file mode 100644
index 0000000..62097de
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl
@@ -0,0 +1,10 @@
+const TRUE = true;
+const FALSE = false;
+const_assert(true || FALSE);
+const_assert(!(false && true));
+
+fn f() {
+  var x = false;
+  var y = false;
+  if (x && (true || y)) { }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..3fb56d3
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.dxc.hlsl
@@ -0,0 +1,19 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void f() {
+  bool x = false;
+  bool y = false;
+  bool tint_tmp = x;
+  if (tint_tmp) {
+    bool tint_tmp_1 = true;
+    if (!tint_tmp_1) {
+      tint_tmp_1 = y;
+    }
+    tint_tmp = (tint_tmp_1);
+  }
+  if ((tint_tmp)) {
+  }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..3fb56d3
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.fxc.hlsl
@@ -0,0 +1,19 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void f() {
+  bool x = false;
+  bool y = false;
+  bool tint_tmp = x;
+  if (tint_tmp) {
+    bool tint_tmp_1 = true;
+    if (!tint_tmp_1) {
+      tint_tmp_1 = y;
+    }
+    tint_tmp = (tint_tmp_1);
+  }
+  if ((tint_tmp)) {
+  }
+}
diff --git a/test/tint/bug/tint/1961.wgsl.expected.glsl b/test/tint/bug/tint/1961.wgsl.expected.glsl
new file mode 100644
index 0000000..4b62f21
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.glsl
@@ -0,0 +1,21 @@
+#version 310 es
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void unused_entry_point() {
+  return;
+}
+void f() {
+  bool x = false;
+  bool y = false;
+  bool tint_tmp = x;
+  if (tint_tmp) {
+    bool tint_tmp_1 = true;
+    if (!tint_tmp_1) {
+      tint_tmp_1 = y;
+    }
+    tint_tmp = (tint_tmp_1);
+  }
+  if ((tint_tmp)) {
+  }
+}
+
diff --git a/test/tint/bug/tint/1961.wgsl.expected.msl b/test/tint/bug/tint/1961.wgsl.expected.msl
new file mode 100644
index 0000000..cdc071a
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.msl
@@ -0,0 +1,10 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void f() {
+  bool x = false;
+  bool y = false;
+  if ((x && (true || y))) {
+  }
+}
+
diff --git a/test/tint/bug/tint/1961.wgsl.expected.spvasm b/test/tint/bug/tint/1961.wgsl.expected.spvasm
new file mode 100644
index 0000000..633c7d4
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.spvasm
@@ -0,0 +1,50 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 23
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+               OpExecutionMode %unused_entry_point LocalSize 1 1 1
+               OpName %unused_entry_point "unused_entry_point"
+               OpName %f "f"
+               OpName %x "x"
+               OpName %y "y"
+       %void = OpTypeVoid
+          %1 = OpTypeFunction %void
+       %bool = OpTypeBool
+          %8 = OpConstantNull %bool
+%_ptr_Function_bool = OpTypePointer Function %bool
+       %true = OpConstantTrue %bool
+%unused_entry_point = OpFunction %void None %1
+          %4 = OpLabel
+               OpReturn
+               OpFunctionEnd
+          %f = OpFunction %void None %1
+          %6 = OpLabel
+          %x = OpVariable %_ptr_Function_bool Function %8
+          %y = OpVariable %_ptr_Function_bool Function %8
+               OpStore %x %8
+               OpStore %y %8
+         %12 = OpLoad %bool %x
+               OpSelectionMerge %13 None
+               OpBranchConditional %12 %14 %13
+         %14 = OpLabel
+               OpSelectionMerge %16 None
+               OpBranchConditional %true %16 %17
+         %17 = OpLabel
+         %18 = OpLoad %bool %y
+               OpBranch %16
+         %16 = OpLabel
+         %19 = OpPhi %bool %true %14 %18 %17
+               OpBranch %13
+         %13 = OpLabel
+         %20 = OpPhi %bool %12 %6 %19 %16
+               OpSelectionMerge %21 None
+               OpBranchConditional %20 %22 %21
+         %22 = OpLabel
+               OpBranch %21
+         %21 = OpLabel
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/tint/1961.wgsl.expected.wgsl b/test/tint/bug/tint/1961.wgsl.expected.wgsl
new file mode 100644
index 0000000..dd8dc1a
--- /dev/null
+++ b/test/tint/bug/tint/1961.wgsl.expected.wgsl
@@ -0,0 +1,14 @@
+const TRUE = true;
+
+const FALSE = false;
+
+const_assert (true || FALSE);
+
+const_assert !((false && true));
+
+fn f() {
+  var x = false;
+  var y = false;
+  if ((x && (true || y))) {
+  }
+}