Add ExpandCompoundAssignment transform
This transform converts compound assignment statements into regular
assignments, hoisting LHS expressions and converting for-loops and
else-if statements if necessary.
The vector-component case needs particular care, as we cannot take the
address of a vector component. We need to capture a pointer to the
whole vector and also the component index expression:
// Before
vector_array[foo()][bar()] *= 2.0;
// After:
let _vec = &vector_array[foo()];
let _idx = bar();
(*_vec)[_idx] = (*_vec)[_idx] * 2.0;
Bug: tint:1325
Change-Id: I8b9b31fc9ac4b3697f954100ceb4be24d063bca6
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85282
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 7a5ca93..6c525c8 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -457,6 +457,8 @@
"transform/fold_trivial_single_use_lets.h",
"transform/for_loop_to_loop.cc",
"transform/for_loop_to_loop.h",
+ "transform/expand_compound_assignment.cc",
+ "transform/expand_compound_assignment.h",
"transform/localize_struct_array_assignment.cc",
"transform/localize_struct_array_assignment.h",
"transform/loop_to_for_loop.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 731c19f..207b988 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -335,6 +335,8 @@
transform/localize_struct_array_assignment.h
transform/for_loop_to_loop.cc
transform/for_loop_to_loop.h
+ transform/expand_compound_assignment.cc
+ transform/expand_compound_assignment.h
transform/glsl.cc
transform/glsl.h
transform/loop_to_for_loop.cc
@@ -1026,6 +1028,7 @@
transform/fold_constants_test.cc
transform/fold_trivial_single_use_lets_test.cc
transform/for_loop_to_loop_test.cc
+ transform/expand_compound_assignment.cc
transform/localize_struct_array_assignment_test.cc
transform/loop_to_for_loop_test.cc
transform/module_scope_var_to_entry_point_param_test.cc
diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc
new file mode 100644
index 0000000..7c967d3
--- /dev/null
+++ b/src/tint/transform/expand_compound_assignment.cc
@@ -0,0 +1,149 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/expand_compound_assignment.h"
+
+#include <utility>
+
+#include "src/tint/ast/compound_assignment_statement.h"
+#include "src/tint/program_builder.h"
+#include "src/tint/sem/block_statement.h"
+#include "src/tint/sem/expression.h"
+#include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/sem/statement.h"
+#include "src/tint/transform/utils/hoist_to_decl_before.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::transform::ExpandCompoundAssignment);
+
+namespace tint {
+namespace transform {
+
+ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
+
+ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
+
+bool ExpandCompoundAssignment::ShouldRun(const Program* program,
+ const DataMap&) const {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (node->Is<ast::CompoundAssignmentStatement>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void ExpandCompoundAssignment::Run(CloneContext& ctx,
+ const DataMap&,
+ DataMap&) const {
+ HoistToDeclBefore hoist_to_decl_before(ctx);
+
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
+ auto* sem_assign = ctx.src->Sem().Get(assign);
+
+ // Helper function to create the LHS expression. This will be called twice
+ // when building the non-compound assignment statement, so must not
+ // produce expressions that cause side effects.
+ std::function<const ast::Expression*()> lhs;
+
+ // Helper function to create a variable that is a pointer to `expr`.
+ auto hoist_pointer_to = [&](const ast::Expression* expr) {
+ auto name = ctx.dst->Sym();
+ auto* ptr = ctx.dst->AddressOf(ctx.Clone(expr));
+ auto* decl = ctx.dst->Decl(ctx.dst->Const(name, nullptr, ptr));
+ hoist_to_decl_before.InsertBefore(sem_assign, decl);
+ return name;
+ };
+
+ // Helper function to hoist `expr` to a let declaration.
+ auto hoist_expr_to_let = [&](const ast::Expression* expr) {
+ auto name = ctx.dst->Sym();
+ auto* decl =
+ ctx.dst->Decl(ctx.dst->Const(name, nullptr, ctx.Clone(expr)));
+ hoist_to_decl_before.InsertBefore(sem_assign, decl);
+ return name;
+ };
+
+ // Helper function that returns `true` if the type of `expr` is a vector.
+ auto is_vec = [&](const ast::Expression* expr) {
+ return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
+ };
+
+ // Hoist the LHS expression subtree into local constants to produce a new
+ // LHS that we can evaluate twice.
+ // We need to special case compound assignments to vector components since
+ // we cannot take the address of a vector component.
+ auto* index_accessor = assign->lhs->As<ast::IndexAccessorExpression>();
+ auto* member_accessor = assign->lhs->As<ast::MemberAccessorExpression>();
+ if (assign->lhs->Is<ast::IdentifierExpression>() ||
+ (member_accessor &&
+ member_accessor->structure->Is<ast::IdentifierExpression>())) {
+ // This is the simple case with no side effects, so we can just use the
+ // original LHS expression directly.
+ // Before:
+ // foo.bar += rhs;
+ // After:
+ // foo.bar = foo.bar + rhs;
+ lhs = [&]() { return ctx.Clone(assign->lhs); };
+ } else if (index_accessor && is_vec(index_accessor->object)) {
+ // This is the case for vector component via an array accessor. We need
+ // to capture a pointer to the vector and also the index value.
+ // Before:
+ // v[idx()] += rhs;
+ // After:
+ // let vec_ptr = &v;
+ // let index = idx();
+ // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
+ auto lhs_ptr = hoist_pointer_to(index_accessor->object);
+ auto index = hoist_expr_to_let(index_accessor->index);
+ lhs = [&, lhs_ptr, index]() {
+ return ctx.dst->IndexAccessor(ctx.dst->Deref(lhs_ptr), index);
+ };
+ } else if (member_accessor && is_vec(member_accessor->structure)) {
+ // This is the case for vector component via a member accessor. We just
+ // need to capture a pointer to the vector.
+ // Before:
+ // a[idx()].y += rhs;
+ // After:
+ // let vec_ptr = &a[idx()];
+ // (*vec_ptr).y = (*vec_ptr).y + rhs;
+ auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
+ lhs = [&, lhs_ptr]() {
+ return ctx.dst->MemberAccessor(ctx.dst->Deref(lhs_ptr),
+ ctx.Clone(member_accessor->member));
+ };
+ } else {
+ // For all other statements that may have side-effecting expressions, we
+ // just need to capture a pointer to the whole LHS.
+ // Before:
+ // a[idx()] += rhs;
+ // After:
+ // let lhs_ptr = &a[idx()];
+ // (*lhs_ptr) = (*lhs_ptr) + rhs;
+ auto lhs_ptr = hoist_pointer_to(assign->lhs);
+ lhs = [&, lhs_ptr]() { return ctx.dst->Deref(lhs_ptr); };
+ }
+
+ // Replace the compound assignment with a regular assignment.
+ auto* rhs = ctx.dst->create<ast::BinaryExpression>(
+ assign->op, lhs(), ctx.Clone(assign->rhs));
+ ctx.Replace(assign, ctx.dst->Assign(lhs(), rhs));
+ }
+ }
+ hoist_to_decl_before.Apply();
+ ctx.Clone();
+}
+
+} // namespace transform
+} // namespace tint
diff --git a/src/tint/transform/expand_compound_assignment.h b/src/tint/transform/expand_compound_assignment.h
new file mode 100644
index 0000000..73b0c83
--- /dev/null
+++ b/src/tint/transform/expand_compound_assignment.h
@@ -0,0 +1,68 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
+#define SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
+
+#include "src/tint/transform/transform.h"
+
+namespace tint {
+namespace transform {
+
+/// Converts compound assignment statements to regular assignment statements,
+/// hoisting the LHS expression if necessary.
+///
+/// Before:
+/// ```
+/// a += 1;
+/// vector_array[foo()][bar()] *= 2.0;
+/// ```
+///
+/// After:
+/// ```
+/// a = a + 1;
+/// let _vec = &vector_array[foo()];
+/// let _idx = bar();
+/// (*_vec)[_idx] = (*_vec)[_idx] * 2.0;
+/// ```
+class ExpandCompoundAssignment
+ : public Castable<ExpandCompoundAssignment, Transform> {
+ public:
+ /// Constructor
+ ExpandCompoundAssignment();
+ /// Destructor
+ ~ExpandCompoundAssignment() override;
+
+ /// @param program the program to inspect
+ /// @param data optional extra transform-specific input data
+ /// @returns true if this transform should be run for the given program
+ bool ShouldRun(const Program* program,
+ const DataMap& data = {}) const override;
+
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx,
+ const DataMap& inputs,
+ DataMap& outputs) const override;
+};
+
+} // namespace transform
+} // namespace tint
+
+#endif // SRC_TINT_TRANSFORM_EXPAND_COMPOUND_ASSIGNMENT_H_
diff --git a/src/tint/transform/expand_compound_assignment_test.cc b/src/tint/transform/expand_compound_assignment_test.cc
new file mode 100644
index 0000000..4ad02d9
--- /dev/null
+++ b/src/tint/transform/expand_compound_assignment_test.cc
@@ -0,0 +1,457 @@
+// Copyright 2022 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/transform/expand_compound_assignment.h"
+
+#include <utility>
+
+#include "src/tint/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using ExpandCompoundAssignmentTest = TransformTest;
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ShouldRunHasCompoundAssignment) {
+ auto* src = R"(
+fn foo() {
+ var v : i32;
+ v += 1;
+}
+)";
+
+ EXPECT_TRUE(ShouldRun<ExpandCompoundAssignment>(src));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, Basic) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ v += 1;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ v = (v + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsPointer) {
+ auto* src = R"(
+fn main() {
+ var v : i32;
+ let p = &v;
+ *p += 1;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : i32;
+ let p = &(v);
+ let tint_symbol = &(*(p));
+ *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsStructMember) {
+ auto* src = R"(
+struct S {
+ m : f32,
+}
+
+fn main() {
+ var s : S;
+ s.m += 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ m : f32,
+}
+
+fn main() {
+ var s : S;
+ s.m = (s.m + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsArrayElement) {
+ auto* src = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+ a[1] = 42;
+ return 1;
+}
+
+fn main() {
+ a[idx()] += 1;
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<i32, 4>;
+
+fn idx() -> i32 {
+ a[1] = 42;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(a[idx()]);
+ *(tint_symbol) = (*(tint_symbol) + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_ArrayAccessor) {
+ auto* src = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+ v.y = 42;
+ return 1;
+}
+
+fn main() {
+ v[idx()] += 1;
+}
+)";
+
+ auto* expect = R"(
+var<private> v : vec4<i32>;
+
+fn idx() -> i32 {
+ v.y = 42;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(v);
+ let tint_symbol_1 = idx();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsVectorComponent_MemberAccessor) {
+ auto* src = R"(
+fn main() {
+ var v : vec4<i32>;
+ v.y += 1;
+}
+)";
+
+ auto* expect = R"(
+fn main() {
+ var v : vec4<i32>;
+ v.y = (v.y + 1);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMatrixColumn) {
+ auto* src = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn main() {
+ m[idx()] += 1.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(m[idx()]);
+ *(tint_symbol) = (*(tint_symbol) + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMatrixElement) {
+ auto* src = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx1() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn idx2() -> i32 {
+ m[1].z = 42.0;
+ return 1;
+}
+
+fn main() {
+ m[idx1()][idx2()] += 1.0;
+}
+)";
+
+ auto* expect = R"(
+var<private> m : mat4x4<f32>;
+
+fn idx1() -> i32 {
+ m[0].y = 42.0;
+ return 1;
+}
+
+fn idx2() -> i32 {
+ m[1].z = 42.0;
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(m[idx1()]);
+ let tint_symbol_1 = idx2();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, LhsMultipleSideEffects) {
+ auto* src = R"(
+struct S {
+ a : array<vec4<f32>, 3>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : array<S>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p += 1;
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p *= 3;
+ return 2;
+}
+
+fn idx3() -> i32 {
+ p -= 2;
+ return 1;
+}
+
+fn main() {
+ buffer[idx1()].a[idx2()][idx3()] += 1.0;
+}
+)";
+
+ auto* expect = R"(
+struct S {
+ a : array<vec4<f32>, 3>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : array<S>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn idx3() -> i32 {
+ p = (p - 2);
+ return 1;
+}
+
+fn main() {
+ let tint_symbol = &(buffer[idx1()].a[idx2()]);
+ let tint_symbol_1 = idx3();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1.0);
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ForLoopInit) {
+ auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ for (a[idx1()][idx2()] += 1; ; ) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ let tint_symbol = &(a[idx1()]);
+ let tint_symbol_1 = idx2();
+ for((*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1); ; ) {
+ break;
+ }
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ExpandCompoundAssignmentTest, ForLoopCont) {
+ auto* src = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ for (; ; a[idx1()][idx2()] += 1) {
+ break;
+ }
+}
+)";
+
+ auto* expect = R"(
+var<private> a : array<vec4<i32>, 4>;
+
+var<private> p : i32;
+
+fn idx1() -> i32 {
+ p = (p + 1);
+ return 3;
+}
+
+fn idx2() -> i32 {
+ p = (p * 3);
+ return 2;
+}
+
+fn main() {
+ loop {
+ {
+ break;
+ }
+
+ continuing {
+ let tint_symbol = &(a[idx1()]);
+ let tint_symbol_1 = idx2();
+ (*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
+ }
+ }
+}
+)";
+
+ auto got = Run<ExpandCompoundAssignment>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace transform
+} // namespace tint
diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn
index 8d193da..b6bcdd4 100644
--- a/test/tint/BUILD.gn
+++ b/test/tint/BUILD.gn
@@ -324,6 +324,7 @@
"../../src/tint/transform/fold_constants_test.cc",
"../../src/tint/transform/fold_trivial_single_use_lets_test.cc",
"../../src/tint/transform/for_loop_to_loop_test.cc",
+ "../../src/tint/transform/expand_compound_assignment_test.cc",
"../../src/tint/transform/localize_struct_array_assignment_test.cc",
"../../src/tint/transform/loop_to_for_loop_test.cc",
"../../src/tint/transform/module_scope_var_to_entry_point_param_test.cc",