tint/uniformity: Fix compound assignment LHS eval

Only evaluate the LHS once, and then manually "load" from the
referenced variable to emulate the desugared implementation. Do the
same for increment/decrement statements.

Fixed: tint:1869
Change-Id: If0dc96bebd52485cfe222ae09305264ffc8b9329
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/123640
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 200221a..7a4d25e 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -557,10 +557,13 @@
                     auto [cf_r, _] = ProcessExpression(cf, a->rhs);
                     return cf_r;
                 }
-                auto [cf_l, v_l, apply] = ProcessLValueExpression(cf, a->lhs);
+                auto [cf_l, v_l, ident] = ProcessLValueExpression(cf, a->lhs);
                 auto [cf_r, v_r] = ProcessExpression(cf_l, a->rhs);
                 v_l->AddEdge(v_r);
-                apply();
+
+                // Update the variable node for the LHS variable.
+                current_function_->variables.Set(ident, v_l);
+
                 return cf_r;
             },
 
@@ -706,18 +709,28 @@
                 // The compound assignment statement `a += b` is equivalent to:
                 //   let p = &a;
                 //   *p = *p + b;
-                // Note: we set load_rule=true when evaluating the LHS, as the resolver does not add
-                // a load node for it.
-                auto [cf1, l1, apply] = ProcessLValueExpression(cf, c->lhs);
-                auto [cf2, v2] = ProcessExpression(cf1, c->lhs, /* load_rule */ true);
-                auto [cf3, v3] = ProcessExpression(cf2, c->rhs);
+
+                // Evaluate the LHS.
+                auto [cf1, l1, ident] = ProcessLValueExpression(cf, c->lhs);
+
+                // Get the current value loaded from the LHS reference before evaluating the RHS.
+                auto* lhs_load = current_function_->variables.Get(ident);
+
+                // Evaluate the RHS.
+                auto [cf2, v2] = ProcessExpression(cf1, c->rhs);
+
+                // Create a node for the resulting value.
                 auto* result = CreateNode({"binary_expr_result"});
                 result->AddEdge(v2);
-                result->AddEdge(v3);
+                if (lhs_load) {
+                    result->AddEdge(lhs_load);
+                }
 
+                // Update the variable node for the LHS variable.
                 l1->AddEdge(result);
-                apply();
-                return cf3;
+                current_function_->variables.Set(ident, l1);
+
+                return cf2;
             },
 
             [&](const ast::ContinueStatement* c) {
@@ -968,17 +981,25 @@
 
             [&](const ast::IncrementDecrementStatement* i) {
                 // The increment/decrement statement `i++` is equivalent to `i = i + 1`.
-                // Note: we set load_rule=true when evaluating the LHS the first time, as the
-                // resolver does not add a load node for it.
-                auto [cf1, v1] = ProcessExpression(cf, i->lhs, /* load_rule */ true);
-                auto* result = CreateNode({"incdec_result"});
-                result->AddEdge(v1);
-                result->AddEdge(cf1);
 
-                auto [cf2, l2, apply] = ProcessLValueExpression(cf1, i->lhs);
-                l2->AddEdge(result);
-                apply();
-                return cf2;
+                // Evaluate the LHS.
+                auto [cf1, l1, ident] = ProcessLValueExpression(cf, i->lhs);
+
+                // Get the current value loaded from the LHS reference.
+                auto* lhs_load = current_function_->variables.Get(ident);
+
+                // Create a node for the resulting value.
+                auto* result = CreateNode({"incdec_result"});
+                result->AddEdge(cf1);
+                if (lhs_load) {
+                    result->AddEdge(lhs_load);
+                }
+
+                // Update the variable node for the LHS variable.
+                l1->AddEdge(result);
+                current_function_->variables.Set(ident, l1);
+
+                return cf1;
             },
 
             [&](const ast::LoopStatement* l) {
@@ -1384,8 +1405,8 @@
         /// The new value node for an LValue expression
         Node* new_val = nullptr;
 
-        /// Updates the value node of the LValue expression to be #new_val.
-        std::function<void()> apply;
+        /// The root identifier for an LValue expression.
+        const sem::Variable* root_identifier = nullptr;
     };
 
     /// Process an LValue expression.
@@ -1401,13 +1422,11 @@
             [&](const ast::IdentifierExpression* i) {
                 auto* sem = sem_.GetVal(i)->UnwrapLoad()->As<sem::VariableUser>();
                 if (sem->Variable()->Is<sem::GlobalVariable>()) {
-                    return LValue{cf, current_function_->may_be_non_uniform, [] {}};
+                    return LValue{cf, current_function_->may_be_non_uniform, nullptr};
                 } else if (auto* local = sem->Variable()->As<sem::LocalVariable>()) {
                     // Create a new value node for this variable.
                     auto* value = CreateNode({NameFor(i), "_lvalue"});
 
-                    auto apply = [=] { current_function_->variables.Set(local, value); };
-
                     // If i is part of an expression that is a partial reference to a variable (e.g.
                     // index or member access), we link back to the variable's previous value. If
                     // the previous value was non-uniform, a partial assignment will not make it
@@ -1417,7 +1436,7 @@
                         value->AddEdge(old_value);
                     }
 
-                    return LValue{cf, value, apply};
+                    return LValue{cf, value, local};
                 } else {
                     TINT_ICE(Resolver, diagnostics_)
                         << "unknown lvalue identifier expression type: "
@@ -1427,11 +1446,11 @@
             },
 
             [&](const ast::IndexAccessorExpression* i) {
-                auto [cf1, l1, apply] =
+                auto [cf1, l1, root_ident] =
                     ProcessLValueExpression(cf, i->object, /*is_partial_reference*/ true);
                 auto [cf2, v2] = ProcessExpression(cf1, i->index);
                 l1->AddEdge(v2);
-                return LValue{cf2, l1, apply};
+                return LValue{cf2, l1, root_ident};
             },
 
             [&](const ast::MemberAccessorExpression* m) {
@@ -1445,8 +1464,6 @@
                     auto* root_ident = sem_.Get(u)->RootIdentifier();
                     auto* deref = CreateNode({NameFor(root_ident), "_deref"});
 
-                    auto apply = [=] { current_function_->variables.Set(root_ident, deref); };
-
                     if (auto* old_value = current_function_->variables.Get(root_ident)) {
                         // If dereferencing a partial reference or partial pointer, we link back to
                         // the variable's previous value. If the previous value was non-uniform, a
@@ -1455,7 +1472,7 @@
                             deref->AddEdge(old_value);
                         }
                     }
-                    return LValue{cf, deref, apply};
+                    return LValue{cf, deref, root_ident};
                 }
                 return ProcessLValueExpression(cf, u->expr, is_partial_reference);
             },
diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc
index a55d259..fe36da2 100644
--- a/src/tint/resolver/uniformity_test.cc
+++ b/src/tint/resolver/uniformity_test.cc
@@ -7402,6 +7402,128 @@
 )");
 }
 
+TEST_F(UniformityAnalysisTest, CompoundAssignment_Global) {
+    // Use compound assignment on a global variable.
+    // Tests that we do not assume there is always a variable node for the LHS, but we still process
+    // the expression.
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> rw : i32;
+
+var<private> v : array<i32, 4>;
+
+fn bar(p : ptr<function, i32>) -> i32 {
+  if (*p == 0) {
+    workgroupBarrier();
+  }
+  return 0;
+}
+
+fn foo() {
+  var f = rw;
+  v[bar(&f)] += 1;
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:8:5 error: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
+
+test:7:3 note: control flow depends on possibly non-uniform value
+  if (*p == 0) {
+  ^^
+
+test:7:8 note: parameter 'p' of 'bar' may be non-uniform
+  if (*p == 0) {
+       ^
+
+test:15:9 note: possibly non-uniform value passed via pointer here
+  v[bar(&f)] += 1;
+        ^
+
+test:14:11 note: reading from read_write storage buffer 'rw' may result in a non-uniform value
+  var f = rw;
+          ^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, IncDec_StillNonUniform) {
+    // Use increment on a variable that is already non-uniform.
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> rw : i32;
+
+fn foo() {
+  var v = rw;
+  v++;
+  if (v == 0) {
+    workgroupBarrier();
+  }
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:8:5 error: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
+
+test:7:3 note: control flow depends on possibly non-uniform value
+  if (v == 0) {
+  ^^
+
+test:5:11 note: reading from read_write storage buffer 'rw' may result in a non-uniform value
+  var v = rw;
+          ^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, IncDec_Global) {
+    // Use increment on a global variable.
+    // Tests that we do not assume there is always a variable node for the LHS, but we still process
+    // the expression.
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> rw : i32;
+
+var<private> v : array<i32, 4>;
+
+fn bar(p : ptr<function, i32>) -> i32 {
+  if (*p == 0) {
+    workgroupBarrier();
+  }
+  return 0;
+}
+
+fn foo() {
+  var f = rw;
+  v[bar(&f)]++;
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:8:5 error: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
+
+test:7:3 note: control flow depends on possibly non-uniform value
+  if (*p == 0) {
+  ^^
+
+test:7:8 note: parameter 'p' of 'bar' may be non-uniform
+  if (*p == 0) {
+       ^
+
+test:15:9 note: possibly non-uniform value passed via pointer here
+  v[bar(&f)]++;
+        ^
+
+test:14:11 note: reading from read_write storage buffer 'rw' may result in a non-uniform value
+  var f = rw;
+          ^^
+)");
+}
+
 TEST_F(UniformityAnalysisTest, ShortCircuiting_UniformLHS) {
     std::string src = R"(
 @group(0) @binding(0) var<storage, read> uniform_global : i32;
@@ -8649,5 +8771,108 @@
 )");
 }
 
+TEST_F(UniformityAnalysisTest, CompoundAssignmentEval_RHS_Makes_LHS_NonUniform_After_Load) {
+    // Test that the LHS is loaded from before the RHS makes is evaluated.
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) -> i32 {
+  *p = non_uniform;
+  return 0;
+}
+
+fn foo() {
+  var i = 0;
+  var arr : array<i32, 4>;
+  i += arr[bar(&i)];
+  if (i == 0) {
+    workgroupBarrier();
+  }
+}
+)";
+
+    RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, CompoundAssignmentEval_RHS_Makes_LHS_Uniform_After_Load) {
+    // Test that the LHS is loaded from before the RHS makes is evaluated.
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) -> i32 {
+  *p = 0;
+  return 0;
+}
+
+fn foo() {
+  var i = non_uniform;
+  var arr : array<i32, 4>;
+  i += arr[bar(&i)];
+  if (i == 0) {
+    workgroupBarrier();
+  }
+}
+)";
+
+    RunTest(src, false);
+    EXPECT_EQ(error_,
+              R"(test:14:5 error: 'workgroupBarrier' must only be called from uniform control flow
+    workgroupBarrier();
+    ^^^^^^^^^^^^^^^^
+
+test:13:3 note: control flow depends on possibly non-uniform value
+  if (i == 0) {
+  ^^
+
+test:10:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
+  var i = non_uniform;
+          ^^^^^^^^^^^
+)");
+}
+
+TEST_F(UniformityAnalysisTest, CompoundAssignmentEval_LHS_OnlyOnce) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) -> i32 {
+  if (*p == 0) {
+    workgroupBarrier();
+  }
+  *p = non_uniform;
+  return 0;
+}
+
+fn foo(){
+  var f : i32 = 0;
+  var arr : array<i32, 4>;
+  arr[bar(&f)] += 1;
+}
+)";
+
+    RunTest(src, true);
+}
+
+TEST_F(UniformityAnalysisTest, IncDec_LHS_OnlyOnce) {
+    std::string src = R"(
+@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
+
+fn bar(p : ptr<function, i32>) -> i32 {
+  if (*p == 0) {
+    workgroupBarrier();
+  }
+  *p = non_uniform;
+  return 0;
+}
+
+fn foo(){
+  var f : i32 = 0;
+  var arr : array<i32, 4>;
+  arr[bar(&f)]++;
+}
+)";
+
+    RunTest(src, true);
+}
+
 }  // namespace
 }  // namespace tint::resolver