blob: 85b9bf69ed6f45c4569cc7c530192c0944c21dca [file] [log] [blame]
James Priceb9b6e692022-03-31 22:30:10 +00001// Copyright 2022 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/tint/transform/expand_compound_assignment.h"
16
17#include <utility>
18
19#include "src/tint/ast/compound_assignment_statement.h"
James Priced68d3a92022-04-07 13:42:45 +000020#include "src/tint/ast/increment_decrement_statement.h"
James Priceb9b6e692022-03-31 22:30:10 +000021#include "src/tint/program_builder.h"
22#include "src/tint/sem/block_statement.h"
23#include "src/tint/sem/expression.h"
24#include "src/tint/sem/for_loop_statement.h"
25#include "src/tint/sem/statement.h"
26#include "src/tint/transform/utils/hoist_to_decl_before.h"
27
28TINT_INSTANTIATE_TYPEINFO(tint::transform::ExpandCompoundAssignment);
29
Ben Clayton0ce9ab02022-05-05 20:23:40 +000030using namespace tint::number_suffixes; // NOLINT
31
dan sinclairb5599d32022-04-07 16:55:14 +000032namespace tint::transform {
James Priceb9b6e692022-03-31 22:30:10 +000033
34ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
35
36ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
37
dan sinclair41e4d9a2022-05-01 14:40:55 +000038bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const {
39 for (auto* node : program->ASTNodes().Objects()) {
40 if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
41 return true;
42 }
James Priceb9b6e692022-03-31 22:30:10 +000043 }
dan sinclair41e4d9a2022-05-01 14:40:55 +000044 return false;
James Priceb9b6e692022-03-31 22:30:10 +000045}
46
James Price33563dc2022-06-16 08:28:22 +000047namespace {
48
James Priced68d3a92022-04-07 13:42:45 +000049/// Internal class used to collect statement expansions during the transform.
50class State {
dan sinclair41e4d9a2022-05-01 14:40:55 +000051 private:
52 /// The clone context.
53 CloneContext& ctx;
James Priced68d3a92022-04-07 13:42:45 +000054
dan sinclair41e4d9a2022-05-01 14:40:55 +000055 /// The program builder.
56 ProgramBuilder& b;
James Priced68d3a92022-04-07 13:42:45 +000057
dan sinclair41e4d9a2022-05-01 14:40:55 +000058 /// The HoistToDeclBefore helper instance.
59 HoistToDeclBefore hoist_to_decl_before;
James Priced68d3a92022-04-07 13:42:45 +000060
dan sinclair41e4d9a2022-05-01 14:40:55 +000061 public:
62 /// Constructor
63 /// @param context the clone context
64 explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
James Priced68d3a92022-04-07 13:42:45 +000065
dan sinclair41e4d9a2022-05-01 14:40:55 +000066 /// Replace `stmt` with a regular assignment statement of the form:
67 /// lhs = lhs op rhs
68 /// The LHS expression will only be evaluated once, and any side effects will
69 /// be hoisted to `let` declarations above the assignment statement.
70 /// @param stmt the statement to replace
71 /// @param lhs the lhs expression from the source statement
72 /// @param rhs the rhs expression in the destination module
73 /// @param op the binary operator
74 void Expand(const ast::Statement* stmt,
75 const ast::Expression* lhs,
76 const ast::Expression* rhs,
77 ast::BinaryOp op) {
78 // Helper function to create the new LHS expression. This will be called
79 // twice when building the non-compound assignment statement, so must
80 // not produce expressions that cause side effects.
81 std::function<const ast::Expression*()> new_lhs;
James Priced68d3a92022-04-07 13:42:45 +000082
dan sinclair41e4d9a2022-05-01 14:40:55 +000083 // Helper function to create a variable that is a pointer to `expr`.
84 auto hoist_pointer_to = [&](const ast::Expression* expr) {
85 auto name = b.Sym();
86 auto* ptr = b.AddressOf(ctx.Clone(expr));
Ben Clayton58794ae2022-08-19 17:28:53 +000087 auto* decl = b.Decl(b.Let(name, ptr));
dan sinclair41e4d9a2022-05-01 14:40:55 +000088 hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
89 return name;
90 };
James Priced68d3a92022-04-07 13:42:45 +000091
dan sinclair41e4d9a2022-05-01 14:40:55 +000092 // Helper function to hoist `expr` to a let declaration.
93 auto hoist_expr_to_let = [&](const ast::Expression* expr) {
94 auto name = b.Sym();
Ben Clayton58794ae2022-08-19 17:28:53 +000095 auto* decl = b.Decl(b.Let(name, ctx.Clone(expr)));
dan sinclair41e4d9a2022-05-01 14:40:55 +000096 hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
97 return name;
98 };
James Priced68d3a92022-04-07 13:42:45 +000099
dan sinclair41e4d9a2022-05-01 14:40:55 +0000100 // Helper function that returns `true` if the type of `expr` is a vector.
101 auto is_vec = [&](const ast::Expression* expr) {
102 return ctx.src->Sem().Get(expr)->Type()->UnwrapRef()->Is<sem::Vector>();
103 };
James Priced68d3a92022-04-07 13:42:45 +0000104
dan sinclair41e4d9a2022-05-01 14:40:55 +0000105 // Hoist the LHS expression subtree into local constants to produce a new
106 // LHS that we can evaluate twice.
107 // We need to special case compound assignments to vector components since
108 // we cannot take the address of a vector component.
109 auto* index_accessor = lhs->As<ast::IndexAccessorExpression>();
110 auto* member_accessor = lhs->As<ast::MemberAccessorExpression>();
111 if (lhs->Is<ast::IdentifierExpression>() ||
112 (member_accessor && member_accessor->structure->Is<ast::IdentifierExpression>())) {
113 // This is the simple case with no side effects, so we can just use the
114 // original LHS expression directly.
115 // Before:
116 // foo.bar += rhs;
117 // After:
118 // foo.bar = foo.bar + rhs;
119 new_lhs = [&]() { return ctx.Clone(lhs); };
120 } else if (index_accessor && is_vec(index_accessor->object)) {
121 // This is the case for vector component via an array accessor. We need
122 // to capture a pointer to the vector and also the index value.
123 // Before:
124 // v[idx()] += rhs;
125 // After:
126 // let vec_ptr = &v;
127 // let index = idx();
128 // (*vec_ptr)[index] = (*vec_ptr)[index] + rhs;
129 auto lhs_ptr = hoist_pointer_to(index_accessor->object);
130 auto index = hoist_expr_to_let(index_accessor->index);
131 new_lhs = [&, lhs_ptr, index]() { return b.IndexAccessor(b.Deref(lhs_ptr), index); };
132 } else if (member_accessor && is_vec(member_accessor->structure)) {
133 // This is the case for vector component via a member accessor. We just
134 // need to capture a pointer to the vector.
135 // Before:
136 // a[idx()].y += rhs;
137 // After:
138 // let vec_ptr = &a[idx()];
139 // (*vec_ptr).y = (*vec_ptr).y + rhs;
140 auto lhs_ptr = hoist_pointer_to(member_accessor->structure);
141 new_lhs = [&, lhs_ptr]() {
142 return b.MemberAccessor(b.Deref(lhs_ptr), ctx.Clone(member_accessor->member));
143 };
144 } else {
145 // For all other statements that may have side-effecting expressions, we
146 // just need to capture a pointer to the whole LHS.
147 // Before:
148 // a[idx()] += rhs;
149 // After:
150 // let lhs_ptr = &a[idx()];
151 // (*lhs_ptr) = (*lhs_ptr) + rhs;
152 auto lhs_ptr = hoist_pointer_to(lhs);
153 new_lhs = [&, lhs_ptr]() { return b.Deref(lhs_ptr); };
154 }
155
156 // Replace the statement with a regular assignment statement.
157 auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs);
158 ctx.Replace(stmt, b.Assign(new_lhs(), value));
James Priced68d3a92022-04-07 13:42:45 +0000159 }
160
dan sinclair41e4d9a2022-05-01 14:40:55 +0000161 /// Finalize the transformation and clone the module.
162 void Finalize() {
163 hoist_to_decl_before.Apply();
164 ctx.Clone();
165 }
James Priced68d3a92022-04-07 13:42:45 +0000166};
167
James Price33563dc2022-06-16 08:28:22 +0000168} // namespace
169
dan sinclair41e4d9a2022-05-01 14:40:55 +0000170void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
171 State state(ctx);
172 for (auto* node : ctx.src->ASTNodes().Objects()) {
173 if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
174 state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
175 } else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
176 // For increment/decrement statements, `i++` becomes `i = i + 1`.
dan sinclair41e4d9a2022-05-01 14:40:55 +0000177 auto op = inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract;
James Pricec7f7ca32022-06-18 14:22:15 +0000178 state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
dan sinclair41e4d9a2022-05-01 14:40:55 +0000179 }
James Priceb9b6e692022-03-31 22:30:10 +0000180 }
dan sinclair41e4d9a2022-05-01 14:40:55 +0000181 state.Finalize();
James Priceb9b6e692022-03-31 22:30:10 +0000182}
183
dan sinclairb5599d32022-04-07 16:55:14 +0000184} // namespace tint::transform