Add support for increment/decrement statements
Refactor the ExpandCompoundAssignment transform to handle these
statements, which delivers support for all of the non-WGSL backends.
Fixed: tint:1488
Change-Id: I96cdc31851c61f6d92d296447d0b0637907d5fe5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/86004
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/transform/expand_compound_assignment.cc b/src/tint/transform/expand_compound_assignment.cc
index 7c967d3..3a986a2 100644
--- a/src/tint/transform/expand_compound_assignment.cc
+++ b/src/tint/transform/expand_compound_assignment.cc
@@ -17,6 +17,7 @@
#include <utility>
#include "src/tint/ast/compound_assignment_statement.h"
+#include "src/tint/ast/increment_decrement_statement.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/expression.h"
@@ -36,113 +37,159 @@
bool ExpandCompoundAssignment::ShouldRun(const Program* program,
const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
- if (node->Is<ast::CompoundAssignmentStatement>()) {
+ if (node->IsAnyOf<ast::CompoundAssignmentStatement,
+ ast::IncrementDecrementStatement>()) {
return true;
}
}
return false;
}
+/// Internal class used to collect statement expansions during the transform.
+class State {
+ private:
+ /// The clone context.
+ CloneContext& ctx;
+
+ /// The program builder.
+ ProgramBuilder& b;
+
+ /// The HoistToDeclBefore helper instance.
+ HoistToDeclBefore hoist_to_decl_before;
+
+ public:
+ /// Constructor
+ /// @param context the clone context
+ explicit State(CloneContext& context)
+ : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
+
+ /// Replace `stmt` with a regular assignment statement of the form:
+ /// lhs = lhs op rhs
+ /// The LHS expression will only be evaluated once, and any side effects will
+ /// be hoisted to `let` declarations above the assignment statement.
+ /// @param stmt the statement to replace
+ /// @param lhs the lhs expression from the source statement
+ /// @param rhs the rhs expression in the destination module
+ /// @param op the binary operator
+ void Expand(const ast::Statement* stmt,
+ const ast::Expression* lhs,
+ const ast::Expression* rhs,
+ ast::BinaryOp op) {
+ // Helper function to create the new LHS expression. This will be called
+ // twice when building the non-compound assignment statement, so must
+ // not produce expressions that cause side effects.
+ std::function<const ast::Expression*()> new_lhs;
+
+ // Helper function to create a variable that is a pointer to `expr`.
+ auto hoist_pointer_to = [&](const ast::Expression* expr) {
+ auto name = b.Sym();
+ auto* ptr = b.AddressOf(ctx.Clone(expr));
+ auto* decl = b.Decl(b.Const(name, nullptr, ptr));
+ hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
+ return name;
+ };
+
+ // Helper function to hoist `expr` to a let declaration.
+ auto hoist_expr_to_let = [&](const ast::Expression* expr) {
+ auto name = b.Sym();
+ auto* decl = b.Decl(b.Const(name, nullptr, ctx.Clone(expr)));
+ hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
+ return name;
+ };
+
+ // Helper function that returns `true` if the type of `expr` is a vector.
+ auto is_vec = [&](const ast::Expression* expr) {
+ return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
+ };
+
+ // Hoist the LHS expression subtree into local constants to produce a new
+ // LHS that we can evaluate twice.
+ // We need to special case compound assignments to vector components since
+ // we cannot take the address of a vector component.
+ auto* index_accessor = lhs->As<ast::IndexAccessorExpression>();
+ auto* member_accessor = lhs->As<ast::MemberAccessorExpression>();
+ if (lhs->Is<ast::IdentifierExpression>() ||
+ (member_accessor &&
+ member_accessor->structure->Is<ast::IdentifierExpression>())) {
+ // This is the simple case with no side effects, so we can just use the
+ // original LHS expression directly.
+ // Before:
+ // foo.bar += rhs;
+ // After:
+ // foo.bar = foo.bar + rhs;
+ new_lhs = [&]() { return ctx.Clone(lhs); };
+ } else if (index_accessor && is_vec(index_accessor->object)) {
+ // This is the case for vector component via an array accessor. We need
+ // to capture a pointer to the vector and also the index value.
+ // Before:
+ // v[idx()] += rhs;
+ // After:
+ // let vec_ptr = &v;
+ // let index = idx();
+ // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
+ auto lhs_ptr = hoist_pointer_to(index_accessor->object);
+ auto index = hoist_expr_to_let(index_accessor->index);
+ new_lhs = [&, lhs_ptr, index]() {
+ return b.IndexAccessor(b.Deref(lhs_ptr), index);
+ };
+ } else if (member_accessor && is_vec(member_accessor->structure)) {
+ // This is the case for vector component via a member accessor. We just
+ // need to capture a pointer to the vector.
+ // Before:
+ // a[idx()].y += rhs;
+ // After:
+ // let vec_ptr = &a[idx()];
+ // (*vec_ptr).y = (*vec_ptr).y + rhs;
+ auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
+ new_lhs = [&, lhs_ptr]() {
+ return b.MemberAccessor(b.Deref(lhs_ptr),
+ ctx.Clone(member_accessor->member));
+ };
+ } else {
+ // For all other statements that may have side-effecting expressions, we
+ // just need to capture a pointer to the whole LHS.
+ // Before:
+ // a[idx()] += rhs;
+ // After:
+ // let lhs_ptr = &a[idx()];
+ // (*lhs_ptr) = (*lhs_ptr) + rhs;
+ auto lhs_ptr = hoist_pointer_to(lhs);
+ new_lhs = [&, lhs_ptr]() { return b.Deref(lhs_ptr); };
+ }
+
+ // Replace the statement with a regular assignment statement.
+ auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs);
+ ctx.Replace(stmt, b.Assign(new_lhs(), value));
+ }
+
+ /// Finalize the transformation and clone the module.
+ void Finalize() {
+ hoist_to_decl_before.Apply();
+ ctx.Clone();
+ }
+};
+
void ExpandCompoundAssignment::Run(CloneContext& ctx,
const DataMap&,
DataMap&) const {
- HoistToDeclBefore hoist_to_decl_before(ctx);
-
+ State state(ctx);
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
- auto* sem_assign = ctx.src->Sem().Get(assign);
-
- // Helper function to create the LHS expression. This will be called twice
- // when building the non-compound assignment statement, so must not
- // produce expressions that cause side effects.
- std::function<const ast::Expression*()> lhs;
-
- // Helper function to create a variable that is a pointer to `expr`.
- auto hoist_pointer_to = [&](const ast::Expression* expr) {
- auto name = ctx.dst->Sym();
- auto* ptr = ctx.dst->AddressOf(ctx.Clone(expr));
- auto* decl = ctx.dst->Decl(ctx.dst->Const(name, nullptr, ptr));
- hoist_to_decl_before.InsertBefore(sem_assign, decl);
- return name;
- };
-
- // Helper function to hoist `expr` to a let declaration.
- auto hoist_expr_to_let = [&](const ast::Expression* expr) {
- auto name = ctx.dst->Sym();
- auto* decl =
- ctx.dst->Decl(ctx.dst->Const(name, nullptr, ctx.Clone(expr)));
- hoist_to_decl_before.InsertBefore(sem_assign, decl);
- return name;
- };
-
- // Helper function that returns `true` if the type of `expr` is a vector.
- auto is_vec = [&](const ast::Expression* expr) {
- return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
- };
-
- // Hoist the LHS expression subtree into local constants to produce a new
- // LHS that we can evaluate twice.
- // We need to special case compound assignments to vector components since
- // we cannot take the address of a vector component.
- auto* index_accessor = assign->lhs->As<ast::IndexAccessorExpression>();
- auto* member_accessor = assign->lhs->As<ast::MemberAccessorExpression>();
- if (assign->lhs->Is<ast::IdentifierExpression>() ||
- (member_accessor &&
- member_accessor->structure->Is<ast::IdentifierExpression>())) {
- // This is the simple case with no side effects, so we can just use the
- // original LHS expression directly.
- // Before:
- // foo.bar += rhs;
- // After:
- // foo.bar = foo.bar + rhs;
- lhs = [&]() { return ctx.Clone(assign->lhs); };
- } else if (index_accessor && is_vec(index_accessor->object)) {
- // This is the case for vector component via an array accessor. We need
- // to capture a pointer to the vector and also the index value.
- // Before:
- // v[idx()] += rhs;
- // After:
- // let vec_ptr = &v;
- // let index = idx();
- // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
- auto lhs_ptr = hoist_pointer_to(index_accessor->object);
- auto index = hoist_expr_to_let(index_accessor->index);
- lhs = [&, lhs_ptr, index]() {
- return ctx.dst->IndexAccessor(ctx.dst->Deref(lhs_ptr), index);
- };
- } else if (member_accessor && is_vec(member_accessor->structure)) {
- // This is the case for vector component via a member accessor. We just
- // need to capture a pointer to the vector.
- // Before:
- // a[idx()].y += rhs;
- // After:
- // let vec_ptr = &a[idx()];
- // (*vec_ptr).y = (*vec_ptr).y + rhs;
- auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
- lhs = [&, lhs_ptr]() {
- return ctx.dst->MemberAccessor(ctx.dst->Deref(lhs_ptr),
- ctx.Clone(member_accessor->member));
- };
- } else {
- // For all other statements that may have side-effecting expressions, we
- // just need to capture a pointer to the whole LHS.
- // Before:
- // a[idx()] += rhs;
- // After:
- // let lhs_ptr = &a[idx()];
- // (*lhs_ptr) = (*lhs_ptr) + rhs;
- auto lhs_ptr = hoist_pointer_to(assign->lhs);
- lhs = [&, lhs_ptr]() { return ctx.dst->Deref(lhs_ptr); };
- }
-
- // Replace the compound assignment with a regular assignment.
- auto* rhs = ctx.dst->create<ast::BinaryExpression>(
- assign->op, lhs(), ctx.Clone(assign->rhs));
- ctx.Replace(assign, ctx.dst->Assign(lhs(), rhs));
+ state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
+ } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
+ // For increment/decrement statements, `i++` becomes `i = i + 1`.
+ // TODO(jrprice): Simplify this when we have untyped literals.
+ auto* sem_lhs = ctx.src->Sem().Get(inc_dec->lhs);
+ const ast::IntLiteralExpression* one =
+ sem_lhs->Type()->UnwrapRef()->is_signed_integer_scalar()
+ ? ctx.dst->Expr(1)->As<ast::IntLiteralExpression>()
+ : ctx.dst->Expr(1u)->As<ast::IntLiteralExpression>();
+ auto op =
+ inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract;
+ state.Expand(inc_dec, inc_dec->lhs, one, op);
}
}
- hoist_to_decl_before.Apply();
- ctx.Clone();
+ state.Finalize();
}
} // namespace transform