Add ExpandCompoundAssignment transform

This transform converts compound assignment statements into regular
assignments, hoisting LHS expressions and converting for-loops and
else-if statements if necessary.

The vector-component case needs particular care, as we cannot take the
address of a vector component. We need to capture a pointer to the
whole vector and also the component index expression:

// Before
vector_array[foo()][bar()] *= 2.0;

// After:
let _vec = &vector_array[foo()];
let _idx = bar();
(*_vec)[_idx] = (*_vec)[_idx] * 2.0;

Bug: tint:1325
Change-Id: I8b9b31fc9ac4b3697f954100ceb4be24d063bca6
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85282
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 7a5ca93..6c525c8 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -457,6 +457,8 @@
     "transform/fold_trivial_single_use_lets.h",
     "transform/for_loop_to_loop.cc",
     "transform/for_loop_to_loop.h",
+    "transform/expand_compound_assignment.cc",
+    "transform/expand_compound_assignment.h",
     "transform/localize_struct_array_assignment.cc",
     "transform/localize_struct_array_assignment.h",
     "transform/loop_to_for_loop.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 731c19f..207b988 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -335,6 +335,8 @@
   transform/localize_struct_array_assignment.h
   transform/for_loop_to_loop.cc
   transform/for_loop_to_loop.h
+  transform/expand_compound_assignment.cc
+  transform/expand_compound_assignment.h
   transform/glsl.cc
   transform/glsl.h
   transform/loop_to_for_loop.cc
@@ -1026,6 +1028,7 @@
       transform/fold_constants_test.cc
       transform/fold_trivial_single_use_lets_test.cc
       transform/for_loop_to_loop_test.cc
+      transform/expand_compound_assignment.cc
       transform/localize_struct_array_assignment_test.cc
       transform/loop_to_for_loop_test.cc
       transform/module_scope_var_to_entry_point_param_test.cc
diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc
new file mode 100644
index 0000000..7c967d3
--- /dev/null
+++ b/src/tint/transform/expand_compound_assignment.cc
@@ -0,0 +1,149 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/expand_compound_assignment.h"
+
+#include <utility>
+
+#include "src/tint/ast/compound_assignment_statement.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/transform/utils/hoist_to_decl_before.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::transform::ExpandCompoundAssignment);
+
+namespace tint {
+namespace transform {
+
+ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
+
+ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
+
+bool ExpandCompoundAssignment::ShouldRun(const Program* program,
+                                         const DataMap&) const {
+  for (auto* node : program->ASTNodes().Objects()) {
+    if (node->Is<ast::CompoundAssignmentStatement>()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+void ExpandCompoundAssignment::Run(CloneContext& ctx,
+                                   const DataMap&,
+                                   DataMap&) const {
+  HoistToDeclBefore hoist_to_decl_before(ctx);
+
+  for (auto* node : ctx.src->ASTNodes().Objects()) {
+    if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
+      auto* sem_assign = ctx.src->Sem().Get(assign);
+
+      // Helper function to create the LHS expression. This will be called twice
+      // when building the non-compound assignment statement, so must not
+      // produce expressions that cause side effects.
+      std::function<const ast::Expression*()> lhs;
+
+      // Helper function to create a variable that is a pointer to `expr`.
+      auto hoist_pointer_to = [&](const ast::Expression* expr) {
+        auto name = ctx.dst->Sym();
+        auto* ptr = ctx.dst->AddressOf(ctx.Clone(expr));
+        auto* decl = ctx.dst->Decl(ctx.dst->Const(name, nullptr, ptr));
+        hoist_to_decl_before.InsertBefore(sem_assign, decl);
+        return name;
+      };
+
+      // Helper function to hoist `expr` to a let declaration.
+      auto hoist_expr_to_let = [&](const ast::Expression* expr) {
+        auto name = ctx.dst->Sym();
+        auto* decl =
+            ctx.dst->Decl(ctx.dst->Const(name, nullptr, ctx.Clone(expr)));
+        hoist_to_decl_before.InsertBefore(sem_assign, decl);
+        return name;
+      };
+
+      // Helper function that returns `true` if the type of `expr` is a vector.
+      auto is_vec = [&](const ast::Expression* expr) {
+        return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
+      };
+
+      // Hoist the LHS expression subtree into local constants to produce a new
+      // LHS that we can evaluate twice.
+      // We need to special case compound assignments to vector components since
+      // we cannot take the address of a vector component.
+      auto* index_accessor = assign->lhs->As<ast::IndexAccessorExpression>();
+      auto* member_accessor = assign->lhs->As<ast::MemberAccessorExpression>();
+      if (assign->lhs->Is<ast::IdentifierExpression>() ||
+          (member_accessor &&
+           member_accessor->structure->Is<ast::IdentifierExpression>())) {
+        // This is the simple case with no side effects, so we can just use the
+        // original LHS expression directly.
+        // Before:
+        //     foo.bar += rhs;
+        // After:
+        //     foo.bar = foo.bar + rhs;
+        lhs = [&]() { return ctx.Clone(assign->lhs); };
+      } else if (index_accessor && is_vec(index_accessor->object)) {
+        // This is the case for vector component via an array accessor. We need
+        // to capture a pointer to the vector and also the index value.
+        // Before:
+        //     v[idx()] += rhs;
+        // After:
+        //     let vec_ptr = &v;
+        //     let index = idx();
+        //     (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
+        auto lhs_ptr = hoist_pointer_to(index_accessor->object);
+        auto index = hoist_expr_to_let(index_accessor->index);
+        lhs = [&, lhs_ptr, index]() {
+          return ctx.dst->IndexAccessor(ctx.dst->Deref(lhs_ptr), index);
+        };
+      } else if (member_accessor && is_vec(member_accessor->structure)) {
+        // This is the case for vector component via a member accessor. We just
+        // need to capture a pointer to the vector.
+        // Before:
+        //     a[idx()].y += rhs;
+        // After:
+        //     let vec_ptr = &a[idx()];
+        //     (*vec_ptr).y = (*vec_ptr).y + rhs;
+        auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
+        lhs = [&, lhs_ptr]() {
+          return ctx.dst->MemberAccessor(ctx.dst->Deref(lhs_ptr),
+                                         ctx.Clone(member_accessor->member));
+        };
+      } else {
+        // For all other statements that may have side-effecting expressions, we
+        // just need to capture a pointer to the whole LHS.
+        // Before:
+        //     a[idx()] += rhs;
+        // After:
+        //     let lhs_ptr = &a[idx()];
+        //     (*lhs_ptr) = (*lhs_ptr) + rhs;
+        auto lhs_ptr = hoist_pointer_to(assign->lhs);
+        lhs = [&, lhs_ptr]() { return ctx.dst->Deref(lhs_ptr); };
+      }
+
+      // Replace the compound assignment with a regular assignment.
+      auto* rhs = ctx.dst->create<ast::BinaryExpression>(
+          assign->op, lhs(), ctx.Clone(assign->rhs));
+      ctx.Replace(assign, ctx.dst->Assign(lhs(), rhs));
+    }
+  }
+  hoist_to_decl_before.Apply();
+  ctx.Clone();
+}
+
+}  // namespace transform
+}  // namespace tint
diff --git a/src/tint/transform/expand_compound_assignment.h b/src/tint/transform/expand_compound_assignment.h
new file mode 100644
index 0000000..73b0c83
--- /dev/null
+++ b/src/tint/transform/expand_compound_assignment.h
@@ -0,0 +1,68 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
+#define SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
+
+#include "src/tint/transform/transform.h"
+
+namespace tint {
+namespace transform {
+
+/// Converts compound assignment statements to regular assignment statements,
+/// hoisting the LHS expression if necessary.
+///
+/// Before:
+/// ```
+///   a += 1;
+///   vector_array[foo()][bar()] *= 2.0;
+/// ```
+///
+/// After:
+/// ```
+///   a = a + 1;
+///   let _vec = &vector_array[foo()];
+///   let _idx = bar();
+///   (*_vec)[_idx] = (*_vec)[_idx] * 2.0;
+/// ```
+class ExpandCompoundAssignment
+    : public Castable<ExpandCompoundAssignment, Transform> {
+ public:
+  /// Constructor
+  ExpandCompoundAssignment();
+  /// Destructor
+  ~ExpandCompoundAssignment() override;
+
+  /// @param program the program to inspect
+  /// @param data optional extra transform-specific input data
+  /// @returns true if this transform should be run for the given program
+  bool ShouldRun(const Program* program,
+                 const DataMap& data = {}) const override;
+
+ protected:
+  /// Runs the transform using the CloneContext built for transforming a
+  /// program. Run() is responsible for calling Clone() on the CloneContext.
+  /// @param ctx the CloneContext primed with the input program and
+  /// ProgramBuilder
+  /// @param inputs optional extra transform-specific input data
+  /// @param outputs optional extra transform-specific output data
+  void Run(CloneContext& ctx,
+           const DataMap& inputs,
+           DataMap& outputs) const override;
+};
+
+}  // namespace transform
+}  // namespace tint
+
+#endif  // SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
diff --git a/src/tint/transform/expand_compound_assignment_test.cc b/src/tint/transform/expand_compound_assignment_test.cc
new file mode 100644
index 0000000..4ad02d9
--- /dev/null
+++ b/src/tint/transform/expand_compound_assignment_test.cc
@@ -0,0 +1,457 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/expand_compound_assignment.h"
+
+#include <utility>
+
+#include "src/tint/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using ExpandCompoundAssignmentTest = TransformTest;
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunEmptyModule) {
+  auto* src = R"()";
+
+  EXPECT_FALSE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasCompoundAssignment) {
+  auto* src = R"(
+fn foo() {
+  var v : i32;
+  v += 1;
+}
+)";
+
+  EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Basic) {
+  auto* src = R"(
+fn main() {
+  var v : i32;
+  v += 1;
+}
+)";
+
+  auto* expect = R"(
+fn main() {
+  var v : i32;
+  v = (v + 1);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsPointer) {
+  auto* src = R"(
+fn main() {
+  var v : i32;
+  let p = &v;
+  *p += 1;
+}
+)";
+
+  auto* expect = R"(
+fn main() {
+  var v : i32;
+  let p = &(v);
+  let tint_symbol = &(*(p));
+  *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsStructMember) {
+  auto* src = R"(
+struct S {
+  m : f32,
+}
+
+fn main() {
+  var s : S;
+  s.m += 1.0;
+}
+)";
+
+  auto* expect = R"(
+struct S {
+  m : f32,
+}
+
+fn main() {
+  var s : S;
+  s.m = (s.m + 1.0);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsArrayElement) {
+  auto* src = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+  a[1] = 42;
+  return 1;
+}
+
+fn main() {
+  a[idx()] += 1;
+}
+)";
+
+  auto* expect = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+  a[1] = 42;
+  return 1;
+}
+
+fn main() {
+  let tint_symbol = &(a[idx()]);
+  *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_ArrayAccessor) {
+  auto* src = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+  v.y = 42;
+  return 1;
+}
+
+fn main() {
+  v[idx()] += 1;
+}
+)";
+
+  auto* expect = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+  v.y = 42;
+  return 1;
+}
+
+fn main() {
+  let tint_symbol = &(v);
+  let tint_symbol_1 = idx();
+  (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_MemberAccessor) {
+  auto* src = R"(
+fn main() {
+  var v : vec4<i32>;
+  v.y += 1;
+}
+)";
+
+  auto* expect = R"(
+fn main() {
+  var v : vec4<i32>;
+  v.y = (v.y + 1);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMatrixColumn) {
+  auto* src = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx() -> i32 {
+  m[0].y = 42.0;
+  return 1;
+}
+
+fn main() {
+  m[idx()] += 1.0;
+}
+)";
+
+  auto* expect = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx() -> i32 {
+  m[0].y = 42.0;
+  return 1;
+}
+
+fn main() {
+  let tint_symbol = &(m[idx()]);
+  *(tint_symbol) = (*(tint_symbol) + 1.0);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMatrixElement) {
+  auto* src = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx1() -> i32 {
+  m[0].y = 42.0;
+  return 1;
+}
+
+fn idx2() -> i32 {
+  m[1].z = 42.0;
+  return 1;
+}
+
+fn main() {
+  m[idx1()][idx2()] += 1.0;
+}
+)";
+
+  auto* expect = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx1() -> i32 {
+  m[0].y = 42.0;
+  return 1;
+}
+
+fn idx2() -> i32 {
+  m[1].z = 42.0;
+  return 1;
+}
+
+fn main() {
+  let tint_symbol = &(m[idx1()]);
+  let tint_symbol_1 = idx2();
+  (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMultipleSideEffects) {
+  auto* src = R"(
+struct S {
+  a : array<vec4<f32>, 3>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : array<S>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+  p += 1;
+  return 3;
+}
+
+fn idx2() -> i32 {
+  p *= 3;
+  return 2;
+}
+
+fn idx3() -> i32 {
+  p -= 2;
+  return 1;
+}
+
+fn main() {
+  buffer[idx1()].a[idx2()][idx3()] += 1.0;
+}
+)";
+
+  auto* expect = R"(
+struct S {
+  a : array<vec4<f32>, 3>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : array<S>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+  p = (p + 1);
+  return 3;
+}
+
+fn idx2() -> i32 {
+  p = (p * 3);
+  return 2;
+}
+
+fn idx3() -> i32 {
+  p = (p - 2);
+  return 1;
+}
+
+fn main() {
+  let tint_symbol = &(buffer[idx1()].a[idx2()]);
+  let tint_symbol_1 = idx3();
+  (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0);
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ForLoopInit) {
+  auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+  p = (p + 1);
+  return 3;
+}
+
+fn idx2() -> i32 {
+  p = (p * 3);
+  return 2;
+}
+
+fn main() {
+  for (a[idx1()][idx2()] += 1; ; ) {
+    break;
+  }
+}
+)";
+
+  auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+  p = (p + 1);
+  return 3;
+}
+
+fn idx2() -> i32 {
+  p = (p * 3);
+  return 2;
+}
+
+fn main() {
+  let tint_symbol = &(a[idx1()]);
+  let tint_symbol_1 = idx2();
+  for((*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1); ; ) {
+    break;
+  }
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ForLoopCont) {
+  auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+  p = (p + 1);
+  return 3;
+}
+
+fn idx2() -> i32 {
+  p = (p * 3);
+  return 2;
+}
+
+fn main() {
+  for (; ; a[idx1()][idx2()] += 1) {
+    break;
+  }
+}
+)";
+
+  auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+  p = (p + 1);
+  return 3;
+}
+
+fn idx2() -> i32 {
+  p = (p * 3);
+  return 2;
+}
+
+fn main() {
+  loop {
+    {
+      break;
+    }
+
+    continuing {
+      let tint_symbol = &(a[idx1()]);
+      let tint_symbol_1 = idx2();
+      (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+    }
+  }
+}
+)";
+
+  auto got = Run<ExpandCompoundAssignment>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+}  // namespace
+}  // namespace transform
+}  // namespace tint
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index 8d193da..b6bcdd4 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -324,6 +324,7 @@
     "../../src/tint/transform/fold_constants_test.cc",
     "../../src/tint/transform/fold_trivial_single_use_lets_test.cc",
     "../../src/tint/transform/for_loop_to_loop_test.cc",
+    "../../src/tint/transform/expand_compound_assignment_test.cc",
     "../../src/tint/transform/localize_struct_array_assignment_test.cc",
     "../../src/tint/transform/loop_to_for_loop_test.cc",
     "../../src/tint/transform/module_scope_var_to_entry_point_param_test.cc",