| // Copyright 2021 The Tint Authors. | 
 | // | 
 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
 | // you may not use this file except in compliance with the License. | 
 | // You may obtain a copy of the License at | 
 | // | 
 | //     http://www.apache.org/licenses/LICENSE-2.0 | 
 | // | 
 | // Unless required by applicable law or agreed to in writing, software | 
 | // distributed under the License is distributed on an "AS IS" BASIS, | 
 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | // See the License for the specific language governing permissions and | 
 | // limitations under the License. | 
 |  | 
 | #include "src/transform/inline_pointer_lets.h" | 
 |  | 
 | #include <memory> | 
 | #include <unordered_map> | 
 | #include <utility> | 
 |  | 
 | #include "src/program_builder.h" | 
 | #include "src/sem/block_statement.h" | 
 | #include "src/sem/function.h" | 
 | #include "src/sem/statement.h" | 
 | #include "src/sem/variable.h" | 
 | #include "src/utils/scoped_assignment.h" | 
 |  | 
 | TINT_INSTANTIATE_TYPEINFO(tint::transform::InlinePointerLets); | 
 |  | 
 | namespace tint { | 
 | namespace transform { | 
 | namespace { | 
 |  | 
 | /// Traverses the expression `expr` looking for non-literal array indexing | 
 | /// expressions that would affect the computed address of a pointer expression. | 
 | /// The function-like argument `cb` is called for each found. | 
 | /// @param program the program that owns all the expression nodes | 
 | /// @param expr the expression to traverse | 
 | /// @param cb a function-like object with the signature | 
 | /// `void(const ast::Expression*)`, which is called for each array index | 
 | /// expression | 
 | template <typename F> | 
 | void CollectSavedArrayIndices(const Program* program, | 
 |                               ast::Expression* expr, | 
 |                               F&& cb) { | 
 |   if (auto* a = expr->As<ast::ArrayAccessorExpression>()) { | 
 |     CollectSavedArrayIndices(program, a->array(), cb); | 
 |  | 
 |     if (!a->idx_expr()->Is<ast::ScalarConstructorExpression>()) { | 
 |       cb(a->idx_expr()); | 
 |     } | 
 |     return; | 
 |   } | 
 |  | 
 |   if (auto* m = expr->As<ast::MemberAccessorExpression>()) { | 
 |     CollectSavedArrayIndices(program, m->structure(), cb); | 
 |     return; | 
 |   } | 
 |  | 
 |   if (auto* u = expr->As<ast::UnaryOpExpression>()) { | 
 |     CollectSavedArrayIndices(program, u->expr(), cb); | 
 |     return; | 
 |   } | 
 |  | 
 |   // Note: Other ast::Expression types can be safely ignored as they cannot be | 
 |   // used to generate a reference or pointer. | 
 |   // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers | 
 | } | 
 |  | 
 | // PtrLet represents a `let` declaration of a pointer type. | 
 | struct PtrLet { | 
 |   // A map of ptr-let initializer sub-expression to the name of generated | 
 |   // variable that holds the saved value of this sub-expression, when resolved | 
 |   // at the point of the ptr-let declaration. | 
 |   std::unordered_map<const ast::Expression*, Symbol> saved_vars; | 
 | }; | 
 |  | 
 | }  // namespace | 
 |  | 
 | InlinePointerLets::InlinePointerLets() = default; | 
 |  | 
 | InlinePointerLets::~InlinePointerLets() = default; | 
 |  | 
 | void InlinePointerLets::Run(CloneContext& ctx, const DataMap&, DataMap&) { | 
 |   // If not null, current_ptr_let is the current PtrLet being operated on. | 
 |   PtrLet* current_ptr_let = nullptr; | 
 |   // A map of the AST `let` variable to the PtrLet | 
 |   std::unordered_map<const ast::Variable*, std::unique_ptr<PtrLet>> ptr_lets; | 
 |  | 
 |   // Register the ast::Expression transform handler. | 
 |   // This performs two different transformations: | 
 |   // * Identifiers that resolve to the pointer-typed `let` declarations are | 
 |   // replaced with the inlined (and recursively transformed) initializer | 
 |   // expression for the `let` declaration. | 
 |   // * Sub-expressions inside the pointer-typed `let` initializer expression | 
 |   // that have been hoisted to a saved variable are replaced with the saved | 
 |   // variable identifier. | 
 |   ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* { | 
 |     if (current_ptr_let) { | 
 |       // We're currently processing the initializer expression of a | 
 |       // pointer-typed `let` declaration. Look to see if we need to swap this | 
 |       // Expression with a saved variable. | 
 |       auto it = current_ptr_let->saved_vars.find(expr); | 
 |       if (it != current_ptr_let->saved_vars.end()) { | 
 |         return ctx.dst->Expr(it->second); | 
 |       } | 
 |     } | 
 |     if (auto* ident = expr->As<ast::IdentifierExpression>()) { | 
 |       if (auto* vu = ctx.src->Sem().Get<sem::VariableUser>(ident)) { | 
 |         auto* var = vu->Variable()->Declaration(); | 
 |         auto it = ptr_lets.find(var); | 
 |         if (it != ptr_lets.end()) { | 
 |           // We've found an identifier that resolves to a `let` declaration. | 
 |           // We need to replace this identifier with the initializer expression | 
 |           // of the `let` declaration. Clone the initializer expression to make | 
 |           // a copy. Note that this will call back into this ReplaceAll() | 
 |           // handler for sub-expressions of the initializer. | 
 |           auto* ptr_let = it->second.get(); | 
 |           // TINT_SCOPED_ASSIGNMENT provides a stack of PtrLet*, this is | 
 |           // required to handle the 'chaining' of inlined `let`s. | 
 |           TINT_SCOPED_ASSIGNMENT(current_ptr_let, ptr_let); | 
 |           return ctx.Clone(var->constructor()); | 
 |         } | 
 |       } | 
 |     } | 
 |     return nullptr; | 
 |   }); | 
 |  | 
 |   // Find all the pointer-typed `let` declarations. | 
 |   // Note that these must be function-scoped, as module-scoped `let`s are not | 
 |   // permitted. | 
 |   for (auto* node : ctx.src->ASTNodes().Objects()) { | 
 |     if (auto* let = node->As<ast::VariableDeclStatement>()) { | 
 |       if (!let->variable()->is_const()) { | 
 |         continue;  // Not a `let` declaration. Ignore. | 
 |       } | 
 |  | 
 |       auto* var = ctx.src->Sem().Get(let->variable()); | 
 |       if (!var->Type()->Is<sem::Pointer>()) { | 
 |         continue;  // Not a pointer type. Ignore. | 
 |       } | 
 |  | 
 |       // We're dealing with a pointer-typed `let` declaration. | 
 |       auto ptr_let = std::make_unique<PtrLet>(); | 
 |       TINT_SCOPED_ASSIGNMENT(current_ptr_let, ptr_let.get()); | 
 |  | 
 |       auto* block = ctx.src->Sem().Get(let)->Block()->Declaration(); | 
 |  | 
 |       // Scan the initializer expression for array index expressions that need | 
 |       // to be hoist to temporary "saved" variables. | 
 |       CollectSavedArrayIndices( | 
 |           ctx.src, var->Declaration()->constructor(), | 
 |           [&](ast::Expression* idx_expr) { | 
 |             // We have a sub-expression that needs to be saved. | 
 |             // Create a new variable | 
 |             auto saved_name = ctx.dst->Symbols().New( | 
 |                 ctx.src->Symbols().NameFor(var->Declaration()->symbol()) + | 
 |                 "_save"); | 
 |             auto* saved = ctx.dst->Decl( | 
 |                 ctx.dst->Const(saved_name, nullptr, ctx.Clone(idx_expr))); | 
 |             // Place this variable after the pointer typed let. Order here is | 
 |             // important as order-of-operations needs to be preserved. | 
 |             // CollectSavedArrayIndices() visits the LHS of an array accessor | 
 |             // before the index expression. | 
 |             // Note that repeated calls to InsertAfter() with the same `after` | 
 |             // argument will result in nodes to inserted in the order the calls | 
 |             // are made (last call is inserted last). | 
 |             ctx.InsertAfter(block->statements(), let, saved); | 
 |             // Record the substitution of `idx_expr` to the saved variable with | 
 |             // the symbol `saved_name`. This will be used by the ReplaceAll() | 
 |             // handler above. | 
 |             ptr_let->saved_vars.emplace(idx_expr, saved_name); | 
 |           }); | 
 |  | 
 |       // Record the pointer-typed `let` declaration. | 
 |       // This will be used by the ReplaceAll() handler above. | 
 |       ptr_lets.emplace(let->variable(), std::move(ptr_let)); | 
 |       // As the original `let` declaration will be fully inlined, there's no | 
 |       // need for the original declaration to exist. Remove it. | 
 |       RemoveStatement(ctx, let); | 
 |     } | 
 |   } | 
 |  | 
 |   ctx.Clone(); | 
 | } | 
 |  | 
 | }  // namespace transform | 
 | }  // namespace tint |