| // Copyright 2021 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/transform/inline_pointer_lets.h" |
| |
| #include <memory> |
| #include <unordered_map> |
| #include <utility> |
| |
| #include "src/program_builder.h" |
| #include "src/sem/block_statement.h" |
| #include "src/sem/function.h" |
| #include "src/sem/statement.h" |
| #include "src/sem/variable.h" |
| #include "src/utils/scoped_assignment.h" |
| |
| TINT_INSTANTIATE_TYPEINFO(tint::transform::InlinePointerLets); |
| |
| namespace tint { |
| namespace transform { |
| namespace { |
| |
| /// Traverses the expression `expr` looking for non-literal array indexing |
| /// expressions that would affect the computed address of a pointer expression. |
| /// The function-like argument `cb` is called for each found. |
| /// @param program the program that owns all the expression nodes |
| /// @param expr the expression to traverse |
| /// @param cb a function-like object with the signature |
| /// `void(const ast::Expression*)`, which is called for each array index |
| /// expression |
| template <typename F> |
| void CollectSavedArrayIndices(const Program* program, |
| ast::Expression* expr, |
| F&& cb) { |
| if (auto* a = expr->As<ast::ArrayAccessorExpression>()) { |
| CollectSavedArrayIndices(program, a->array(), cb); |
| |
| if (!a->idx_expr()->Is<ast::ScalarConstructorExpression>()) { |
| cb(a->idx_expr()); |
| } |
| return; |
| } |
| |
| if (auto* m = expr->As<ast::MemberAccessorExpression>()) { |
| CollectSavedArrayIndices(program, m->structure(), cb); |
| return; |
| } |
| |
| if (auto* u = expr->As<ast::UnaryOpExpression>()) { |
| CollectSavedArrayIndices(program, u->expr(), cb); |
| return; |
| } |
| |
| // Note: Other ast::Expression types can be safely ignored as they cannot be |
| // used to generate a reference or pointer. |
| // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers |
| } |
| |
| // PtrLet represents a `let` declaration of a pointer type. |
| struct PtrLet { |
| // A map of ptr-let initializer sub-expression to the name of generated |
| // variable that holds the saved value of this sub-expression, when resolved |
| // at the point of the ptr-let declaration. |
| std::unordered_map<const ast::Expression*, Symbol> saved_vars; |
| }; |
| |
| } // namespace |
| |
| InlinePointerLets::InlinePointerLets() = default; |
| |
| InlinePointerLets::~InlinePointerLets() = default; |
| |
| void InlinePointerLets::Run(CloneContext& ctx, const DataMap&, DataMap&) { |
| // If not null, current_ptr_let is the current PtrLet being operated on. |
| PtrLet* current_ptr_let = nullptr; |
| // A map of the AST `let` variable to the PtrLet |
| std::unordered_map<const ast::Variable*, std::unique_ptr<PtrLet>> ptr_lets; |
| |
| // Register the ast::Expression transform handler. |
| // This performs two different transformations: |
| // * Identifiers that resolve to the pointer-typed `let` declarations are |
| // replaced with the inlined (and recursively transformed) initializer |
| // expression for the `let` declaration. |
| // * Sub-expressions inside the pointer-typed `let` initializer expression |
| // that have been hoisted to a saved variable are replaced with the saved |
| // variable identifier. |
| ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* { |
| if (current_ptr_let) { |
| // We're currently processing the initializer expression of a |
| // pointer-typed `let` declaration. Look to see if we need to swap this |
| // Expression with a saved variable. |
| auto it = current_ptr_let->saved_vars.find(expr); |
| if (it != current_ptr_let->saved_vars.end()) { |
| return ctx.dst->Expr(it->second); |
| } |
| } |
| if (auto* ident = expr->As<ast::IdentifierExpression>()) { |
| if (auto* vu = ctx.src->Sem().Get<sem::VariableUser>(ident)) { |
| auto* var = vu->Variable()->Declaration(); |
| auto it = ptr_lets.find(var); |
| if (it != ptr_lets.end()) { |
| // We've found an identifier that resolves to a `let` declaration. |
| // We need to replace this identifier with the initializer expression |
| // of the `let` declaration. Clone the initializer expression to make |
| // a copy. Note that this will call back into this ReplaceAll() |
| // handler for sub-expressions of the initializer. |
| auto* ptr_let = it->second.get(); |
| // TINT_SCOPED_ASSIGNMENT provides a stack of PtrLet*, this is |
| // required to handle the 'chaining' of inlined `let`s. |
| TINT_SCOPED_ASSIGNMENT(current_ptr_let, ptr_let); |
| return ctx.Clone(var->constructor()); |
| } |
| } |
| } |
| return nullptr; |
| }); |
| |
| // Find all the pointer-typed `let` declarations. |
| // Note that these must be function-scoped, as module-scoped `let`s are not |
| // permitted. |
| for (auto* node : ctx.src->ASTNodes().Objects()) { |
| if (auto* let = node->As<ast::VariableDeclStatement>()) { |
| if (!let->variable()->is_const()) { |
| continue; // Not a `let` declaration. Ignore. |
| } |
| |
| auto* var = ctx.src->Sem().Get(let->variable()); |
| if (!var->Type()->Is<sem::Pointer>()) { |
| continue; // Not a pointer type. Ignore. |
| } |
| |
| // We're dealing with a pointer-typed `let` declaration. |
| auto ptr_let = std::make_unique<PtrLet>(); |
| TINT_SCOPED_ASSIGNMENT(current_ptr_let, ptr_let.get()); |
| |
| auto* block = ctx.src->Sem().Get(let)->Block()->Declaration(); |
| |
| // Scan the initializer expression for array index expressions that need |
| // to be hoist to temporary "saved" variables. |
| CollectSavedArrayIndices( |
| ctx.src, var->Declaration()->constructor(), |
| [&](ast::Expression* idx_expr) { |
| // We have a sub-expression that needs to be saved. |
| // Create a new variable |
| auto saved_name = ctx.dst->Symbols().New( |
| ctx.src->Symbols().NameFor(var->Declaration()->symbol()) + |
| "_save"); |
| auto* saved = ctx.dst->Decl( |
| ctx.dst->Const(saved_name, nullptr, ctx.Clone(idx_expr))); |
| // Place this variable after the pointer typed let. Order here is |
| // important as order-of-operations needs to be preserved. |
| // CollectSavedArrayIndices() visits the LHS of an array accessor |
| // before the index expression. |
| // Note that repeated calls to InsertAfter() with the same `after` |
| // argument will result in nodes to inserted in the order the calls |
| // are made (last call is inserted last). |
| ctx.InsertAfter(block->statements(), let, saved); |
| // Record the substitution of `idx_expr` to the saved variable with |
| // the symbol `saved_name`. This will be used by the ReplaceAll() |
| // handler above. |
| ptr_let->saved_vars.emplace(idx_expr, saved_name); |
| }); |
| |
| // Record the pointer-typed `let` declaration. |
| // This will be used by the ReplaceAll() handler above. |
| ptr_lets.emplace(let->variable(), std::move(ptr_let)); |
| // As the original `let` declaration will be fully inlined, there's no |
| // need for the original declaration to exist. Remove it. |
| RemoveStatement(ctx, let); |
| } |
| } |
| |
| ctx.Clone(); |
| } |
| |
| } // namespace transform |
| } // namespace tint |