| // Copyright 2022 The Dawn & Tint Authors |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are met: |
| // |
| // 1. Redistributions of source code must retain the above copyright notice, this |
| // list of conditions and the following disclaimer. |
| // |
| // 2. Redistributions in binary form must reproduce the above copyright notice, |
| // this list of conditions and the following disclaimer in the documentation |
| // and/or other materials provided with the distribution. |
| // |
| // 3. Neither the name of the copyright holder nor the names of its |
| // contributors may be used to endorse or promote products derived from |
| // this software without specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| #include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h" |
| |
| #include <algorithm> |
| #include <string> |
| #include <utility> |
| |
| #include "src/tint/lang/core/fluent_types.h" |
| #include "src/tint/lang/core/type/abstract_int.h" |
| #include "src/tint/lang/wgsl/ast/transform/hoist_to_decl_before.h" |
| #include "src/tint/lang/wgsl/ast/traverse_expressions.h" |
| #include "src/tint/lang/wgsl/program/clone_context.h" |
| #include "src/tint/lang/wgsl/program/program_builder.h" |
| #include "src/tint/lang/wgsl/resolver/resolve.h" |
| #include "src/tint/lang/wgsl/sem/call.h" |
| #include "src/tint/lang/wgsl/sem/function.h" |
| #include "src/tint/lang/wgsl/sem/index_accessor_expression.h" |
| #include "src/tint/lang/wgsl/sem/member_accessor_expression.h" |
| #include "src/tint/lang/wgsl/sem/module.h" |
| #include "src/tint/lang/wgsl/sem/statement.h" |
| #include "src/tint/lang/wgsl/sem/struct.h" |
| #include "src/tint/lang/wgsl/sem/variable.h" |
| #include "src/tint/utils/containers/reverse.h" |
| #include "src/tint/utils/macros/scoped_assignment.h" |
| #include "src/tint/utils/text/string_stream.h" |
| |
| TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DirectVariableAccess); |
| TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DirectVariableAccess::Config); |
| |
| using namespace tint::core::number_suffixes; // NOLINT |
| using namespace tint::core::fluent_types; // NOLINT |
| |
| namespace tint::ast::transform { |
| |
| namespace { |
| |
| /// AccessRoot describes the root of an AccessShape. |
| struct AccessRoot { |
| /// The pointer-unwrapped type of the *transformed* variable. |
| /// This may be different for pointers in 'private' and 'function' address space, as the pointer |
| /// parameter type is to the *base object* instead of the input pointer type. |
| tint::core::type::Type const* type = nullptr; |
| /// The originating module-scope variable ('private', 'storage', 'uniform', 'workgroup'), |
| /// function-scope variable ('function'), or pointer parameter in the source program. |
| tint::sem::Variable const* variable = nullptr; |
| /// The address space of the variable or pointer type. |
| tint::core::AddressSpace address_space = tint::core::AddressSpace::kUndefined; |
| |
| /// @return a hash code for this object |
| tint::HashCode HashCode() const { return Hash(type, variable); } |
| }; |
| |
| /// Inequality operator for AccessRoot |
| bool operator!=(const AccessRoot& a, const AccessRoot& b) { |
| return a.type != b.type || a.variable != b.variable; |
| } |
| |
| /// DynamicIndex is used by DirectVariableAccess::State::AccessOp to indicate an array, matrix or |
| /// vector index. |
| struct DynamicIndex { |
| /// @return a hash code for this object |
| tint::HashCode HashCode() const { return 42 /* empty struct: any number will do */; } |
| }; |
| |
| /// Inequality operator for DynamicIndex |
| bool operator!=(const DynamicIndex&, const DynamicIndex&) { |
| return false; // empty struct: two DynamicIndex objects are always equal |
| } |
| |
| /// AccessOp describes a single access in an access chain. |
| /// The access is one of: |
| /// Symbol - a struct member access. |
| /// DynamicIndex - a runtime index on an array, matrix column, or vector element. |
| using AccessOp = std::variant<tint::Symbol, DynamicIndex>; |
| |
| /// A vector of AccessOp. Describes the static "path" from a root variable to an element |
| /// within the variable. Array accessors index expressions are held externally to the |
| /// AccessShape, so AccessShape will be considered equal even if the array, matrix or vector |
| /// index values differ. |
| /// |
| /// For example, consider the following: |
| /// |
| /// ``` |
| /// struct A { |
| /// x : array<i32, 8>, |
| /// y : u32, |
| /// }; |
| /// struct B { |
| /// x : i32, |
| /// y : array<A, 4> |
| /// }; |
| /// var<workgroup> C : B; |
| /// ``` |
| /// |
| /// The following AccessShape would describe the following: |
| /// |
| /// +==============================+===============+=================================+ |
| /// | AccessShape | Type | Expression | |
| /// +==============================+===============+=================================+ |
| /// | [ Variable 'C', Symbol 'x' ] | i32 | C.x | |
| /// +------------------------------+---------------+---------------------------------+ |
| /// | [ Variable 'C', Symbol 'y' ] | array<A, 4> | C.y | |
| /// +------------------------------+---------------+---------------------------------+ |
| /// | [ Variable 'C', Symbol 'y', | A | C.y[dyn_idx[0]] | |
| /// | DynamicIndex ] | | | |
| /// +------------------------------+---------------+---------------------------------+ |
| /// | [ Variable 'C', Symbol 'y', | array<i32, 8> | C.y[dyn_idx[0]].x | |
| /// | DynamicIndex, Symbol 'x' ] | | | |
| /// +------------------------------+---------------+---------------------------------+ |
| /// | [ Variable 'C', Symbol 'y', | i32 | C.y[dyn_idx[0]].x[dyn_idx[1]] | |
| /// | DynamicIndex, Symbol 'x', | | | |
| /// | DynamicIndex ] | | | |
| /// +------------------------------+---------------+---------------------------------+ |
| /// | [ Variable 'C', Symbol 'y', | u32 | C.y[dyn_idx[0]].y | |
| /// | DynamicIndex, Symbol 'y' ] | | | |
| /// +------------------------------+---------------+---------------------------------+ |
| /// |
| /// Where: `dyn_idx` is the AccessChain::dynamic_indices. |
| struct AccessShape { |
| // The originating variable. |
| AccessRoot root; |
| /// The chain of access ops. |
| tint::Vector<AccessOp, 8> ops; |
| |
| /// @returns the number of DynamicIndex operations in #ops. |
| uint32_t NumDynamicIndices() const { |
| uint32_t count = 0; |
| for (auto& op : ops) { |
| if (std::holds_alternative<DynamicIndex>(op)) { |
| count++; |
| } |
| } |
| return count; |
| } |
| |
| /// @return a hash code for this object |
| tint::HashCode HashCode() const { return Hash(root, ops); } |
| }; |
| |
| /// Equality operator for AccessShape |
| bool operator==(const AccessShape& a, const AccessShape& b) { |
| return !(a.root != b.root) && a.ops == b.ops; |
| } |
| |
| /// Inequality operator for AccessShape |
| bool operator!=(const AccessShape& a, const AccessShape& b) { |
| return !(a == b); |
| } |
| |
| /// AccessChain describes a chain of access expressions originating from a variable. |
| struct AccessChain : AccessShape { |
| /// The array accessor index expressions. This vector is indexed by the `DynamicIndex`s in |
| /// #indices. |
| Vector<const sem::ValueExpression*, 8> dynamic_indices; |
| /// If true, then this access chain is used as an argument to call a variant. |
| bool used_in_call = false; |
| }; |
| |
| } // namespace |
| |
| /// The PIMPL state for the DirectVariableAccess transform |
| struct DirectVariableAccess::State { |
| /// Constructor |
| /// @param src the source Program |
| /// @param options the transform options |
| State(const Program& src, const Options& options) |
| : ctx{&b, &src, /* auto_clone_symbols */ true}, opts(options) {} |
| |
| /// The main function for the transform. |
| /// @returns the ApplyResult |
| ApplyResult Run() { |
| // If there are no functions with pointer parameters, then this transform can be skipped. |
| if (!AnyPointerParameters()) { |
| return SkipTransform; |
| } |
| |
| // Stage 1: |
| // Walk all the expressions of the program, starting with the expression leaves. |
| // Whenever we find an identifier resolving to a var, pointer parameter or pointer let to |
| // another chain, start constructing an access chain. When chains are accessed, these chains |
| // are grown and moved up the expression tree. After this stage, we are left with all the |
| // expression access chains to variables that we may need to transform. |
| for (auto* node : ctx.src->ASTNodes().Objects()) { |
| if (auto* expr = sem.GetVal(node)) { |
| AppendAccessChain(expr); |
| } |
| } |
| |
| // Stage 2: |
| // Walk the functions in dependency order, starting with the entry points. |
| // Construct the set of function 'variants' by examining the calls made by each function to |
| // their call target. Each variant holds a map of pointer parameter to access chains, and |
| // will have the pointer parameters replaced with an array of u32s, used to perform the |
| // pointer indexing in the variant. |
| // Function call pointer arguments are replaced with an array of these dynamic indices. |
| auto decls = sem.Module()->DependencyOrderedDeclarations(); |
| for (auto* decl : tint::Reverse(decls)) { |
| if (auto* fn = sem.Get<sem::Function>(decl)) { |
| auto* fn_info = FnInfoFor(fn); |
| ProcessFunction(fn, fn_info); |
| TransformFunction(fn, fn_info); |
| } |
| } |
| |
| // Stage 3: |
| // Filter out access chains that do not need transforming. |
| // Ensure that chain dynamic index expressions are evaluated once at the correct place |
| ProcessAccessChains(); |
| |
| // Stage 4: |
| // Replace all the access chain expressions in all functions with reconstructed expression |
| // using the originating global variable, and any dynamic indices passed in to the function |
| // variant. |
| TransformAccessChainExpressions(); |
| |
| // Stage 5: |
| // Actually kick the clone. |
| CloneState state; |
| clone_state = &state; |
| ctx.Clone(); |
| return resolver::Resolve(b); |
| } |
| |
| private: |
| /// Holds symbols of the transformed pointer parameter. |
| /// If both symbols are valid, then #base_ptr and #indices are both program-unique symbols |
| /// derived from the original parameter name. |
| /// If only one symbol is valid, then this is the original parameter symbol. |
| struct PtrParamSymbols { |
| /// The symbol of the base pointer parameter. |
| Symbol base_ptr; |
| /// The symbol of the dynamic indicies parameter. |
| Symbol indices; |
| }; |
| |
| /// FnVariant describes a unique variant of a function, specialized by the AccessShape of the |
| /// pointer arguments - also known as the variant's "signature". |
| /// |
| /// To help understand what a variant is, consider the following WGSL: |
| /// |
| /// ``` |
| /// fn F(a : ptr<storage, u32>, b : u32, c : ptr<storage, u32>) { |
| /// return *a + b + *c; |
| /// } |
| /// |
| /// @group(0) @binding(0) var<storage> S0 : u32; |
| /// @group(0) @binding(0) var<storage> S1 : array<u32, 64>; |
| /// |
| /// fn x() { |
| /// F(&S0, 0, &S0); // (A) |
| /// F(&S0, 0, &S0); // (B) |
| /// F(&S1[0], 1, &S0); // (C) |
| /// F(&S1[5], 2, &S0); // (D) |
| /// F(&S1[5], 3, &S1[3]); // (E) |
| /// F(&S1[7], 4, &S1[2]); // (F) |
| /// } |
| /// ``` |
| /// |
| /// Given the calls in x(), function F() will have 3 variants: |
| /// (1) F<S0,S0> - called by (A) and (B). |
| /// Note that only 'uniform', 'storage' and 'workgroup' pointer |
| /// parameters are considered for a variant signature, and so |
| /// the argument for parameter 'b' is not included in the |
| /// signature. |
| /// (2) F<S1[dyn_idx],S0> - called by (C) and (D). |
| /// Note that the array index value is external to the |
| /// AccessShape, and so is not part of the variant signature. |
| /// (3) F<S1[dyn_idx],S1[dyn_idx]> - called by (E) and (F). |
| /// |
| /// Each variant of the function will be emitted as a separate function by the transform, and |
| /// would look something like: |
| /// |
| /// ``` |
| /// // variant F<S0,S0> (1) |
| /// fn F_S0_S0(b : u32) { |
| /// return S0 + b + S0; |
| /// } |
| /// |
| /// type S1_X = array<u32, 1>; |
| /// |
| /// // variant F<S1[dyn_idx],S0> (2) |
| /// fn F_S1_X_S0(a : S1_X, b : u32) { |
| /// return S1[a[0]] + b + S0; |
| /// } |
| /// |
| /// // variant F<S1[dyn_idx],S1[dyn_idx]> (3) |
| /// fn F_S1_X_S1_X(a : S1_X, b : u32, c : S1_X) { |
| /// return S1[a[0]] + b + S1[c[0]]; |
| /// } |
| /// |
| /// @group(0) @binding(0) var<storage> S0 : u32; |
| /// @group(0) @binding(0) var<storage> S1 : array<u32, 64>; |
| /// |
| /// fn x() { |
| /// F_S0_S0(0); // (A) |
| /// F(&S0, 0, &S0); // (B) |
| /// F_S1_X_S0(S1_X(0), 1); // (C) |
| /// F_S1_X_S0(S1_X(5), 2); // (D) |
| /// F_S1_X_S1_X(S1_X(5), 3, S1_X(3)); // (E) |
| /// F_S1_X_S1_X(S1_X(7), 4, S1_X(2)); // (F) |
| /// } |
| /// ``` |
| struct FnVariant { |
| /// The signature of the variant is a map of each of the function's 'uniform', 'storage' and |
| /// 'workgroup' pointer parameters to the caller's AccessShape. |
| using Signature = Hashmap<const sem::Parameter*, AccessShape, 4>; |
| |
| /// The unique name of the variant. |
| /// The symbol is in the `ctx.dst` program namespace. |
| Symbol name; |
| |
| /// A map of direct calls made by this variant to the name of other function variants. |
| Hashmap<const sem::Call*, Symbol, 4> calls; |
| |
| /// A map of input program parameter to output parameter symbols. |
| Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols; |
| |
| /// The declaration order of the variant, in relation to other variants of the same |
| /// function. Used to ensure deterministic ordering of the transform, as map iteration is |
| /// not deterministic between compilers. |
| size_t order = 0; |
| }; |
| |
| /// FnInfo holds information about a function in the input program. |
| struct FnInfo { |
| /// A map of variant signature to the variant data. |
| Hashmap<FnVariant::Signature, FnVariant, 8> variants; |
| /// A map of expressions that have been hoisted to a 'let' declaration in the function. |
| Hashmap<const sem::ValueExpression*, Symbol, 8> hoisted_exprs; |
| |
| /// @returns the variants of the function in a deterministically ordered vector. |
| tint::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> SortedVariants() { |
| tint::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> out; |
| out.Reserve(variants.Count()); |
| for (auto& it : variants) { |
| out.Push({&it.key.Value(), &it.value}); |
| } |
| out.Sort([&](auto& va, auto& vb) { return va.second->order < vb.second->order; }); |
| return out; |
| } |
| }; |
| |
| /// The program builder |
| ProgramBuilder b; |
| /// The clone context |
| program::CloneContext ctx; |
| /// The transform options |
| const Options& opts; |
| /// Alias to the semantic info in ctx.src |
| const sem::Info& sem = ctx.src->Sem(); |
| /// Alias to the symbols in ctx.src |
| const SymbolTable& sym = ctx.src->Symbols(); |
| /// Map of semantic function to the function info |
| Hashmap<const sem::Function*, FnInfo*, 8> fns; |
| /// Map of AccessShape to the name of a type alias for the an array<u32, N> used for the |
| /// dynamic indices of an access chain, passed down as the transformed type of a variant's |
| /// pointer parameter. |
| Hashmap<AccessShape, Symbol, 8> dynamic_index_array_aliases; |
| /// Map of semantic expression to AccessChain |
| Hashmap<const sem::ValueExpression*, AccessChain*, 32> access_chains; |
| /// Allocator for FnInfo |
| BlockAllocator<FnInfo> fn_info_allocator; |
| /// Allocator for AccessChain |
| BlockAllocator<AccessChain> access_chain_allocator; |
| /// Helper used for hoisting expressions to lets |
| HoistToDeclBefore hoist{ctx}; |
| /// Map of string to unique symbol (no collisions in output program). |
| Hashmap<std::string, Symbol, 8> unique_symbols; |
| |
| /// CloneState holds pointers to the current function, variant and variant's parameters. |
| struct CloneState { |
| /// The current function being cloned |
| FnInfo* current_function = nullptr; |
| /// The current function variant being built |
| FnVariant* current_variant = nullptr; |
| /// The signature of the current function variant being built |
| const FnVariant::Signature* current_variant_sig = nullptr; |
| }; |
| |
| /// The clone state. |
| /// Only valid during the lifetime of the program::CloneContext::Clone(). |
| CloneState* clone_state = nullptr; |
| |
| /// @returns true if any user functions have parameters of a pointer type. |
| bool AnyPointerParameters() const { |
| for (auto* fn : ctx.src->AST().Functions()) { |
| for (auto* param : fn->params) { |
| if (sem.Get(param)->Type()->Is<core::type::Pointer>()) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| /// AppendAccessChain creates or extends an existing AccessChain for the given expression, |
| /// modifying the #access_chains map. |
| void AppendAccessChain(const sem::ValueExpression* expr) { |
| // take_chain moves the AccessChain from the expression `from` to the expression `expr`. |
| // Returns nullptr if `from` did not hold an access chain. |
| auto take_chain = [&](const sem::ValueExpression* from) -> AccessChain* { |
| if (auto* chain = AccessChainFor(from)) { |
| access_chains.Remove(from); |
| access_chains.Add(expr, chain); |
| return chain; |
| } |
| return nullptr; |
| }; |
| |
| Switch( |
| expr, |
| [&](const sem::VariableUser* user) { |
| // Expression resolves to a variable. |
| auto* variable = user->Variable(); |
| |
| auto create_new_chain = [&] { |
| auto* chain = access_chain_allocator.Create(); |
| chain->root.variable = variable; |
| chain->root.type = variable->Type(); |
| chain->root.address_space = variable->AddressSpace(); |
| if (auto* ptr = chain->root.type->As<core::type::Pointer>()) { |
| chain->root.address_space = ptr->AddressSpace(); |
| } |
| access_chains.Add(expr, chain); |
| }; |
| |
| Switch( |
| variable->Declaration(), |
| [&](const Var*) { |
| if (variable->AddressSpace() != core::AddressSpace::kHandle) { |
| // Start a new access chain for the non-handle 'var' access |
| create_new_chain(); |
| } |
| }, |
| [&](const Parameter*) { |
| if (variable->Type()->Is<core::type::Pointer>()) { |
| // Start a new access chain for the pointer parameter access |
| create_new_chain(); |
| } |
| }, |
| [&](const Let*) { |
| if (variable->Type()->Is<core::type::Pointer>()) { |
| // variable is a pointer-let. |
| auto* init = sem.GetVal(variable->Declaration()->initializer); |
| // Note: We do not use take_chain() here, as we need to preserve the |
| // AccessChain on the let's initializer, as the let needs its |
| // initializer updated, and the let may be used multiple times. Instead |
| // we copy the let's AccessChain into a a new AccessChain. |
| if (auto* init_chain = AccessChainFor(init)) { |
| access_chains.Add(expr, access_chain_allocator.Create(*init_chain)); |
| } |
| } |
| }); |
| }, |
| [&](const sem::StructMemberAccess* a) { |
| // Structure member access. |
| // Append the Symbol of the member name to the chain, and move the chain to the |
| // member access expression. |
| if (auto* chain = take_chain(a->Object())) { |
| chain->ops.Push(a->Member()->Name()); |
| } |
| }, |
| [&](const sem::IndexAccessorExpression* a) { |
| // Array, matrix or vector index. |
| // Store the index expression into AccessChain::dynamic_indices, append a |
| // DynamicIndex to the chain, and move the chain to the index accessor expression. |
| if (auto* chain = take_chain(a->Object())) { |
| chain->ops.Push(DynamicIndex{}); |
| chain->dynamic_indices.Push(a->Index()); |
| } |
| }, |
| [&](const sem::ValueExpression* e) { |
| if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) { |
| // Unary op. |
| // If this is a '&' or '*', simply move the chain to the unary op expression. |
| if (unary->op == core::UnaryOp::kAddressOf || |
| unary->op == core::UnaryOp::kIndirection) { |
| take_chain(sem.GetVal(unary->expr)); |
| } |
| } |
| }); |
| } |
| |
| /// MaybeHoistDynamicIndices examines the AccessChain::dynamic_indices member of @p chain, |
| /// hoisting all expressions to their own uniquely named 'let' if none of the following are |
| /// true: |
| /// 1. The index expression is a constant value. |
| /// 2. The index expression's statement is the same as @p usage. |
| /// 3. The index expression is an identifier resolving to a 'let', 'const' or parameter, AND |
| /// that identifier resolves to the same variable at @p usage. |
| /// |
| /// A dynamic index will only be hoisted once. The hoisting applies to all variants of the |
| /// function that holds the dynamic index expression. |
| void MaybeHoistDynamicIndices(AccessChain* chain, const sem::Statement* usage) { |
| for (auto& idx : chain->dynamic_indices) { |
| if (idx->ConstantValue()) { |
| // Dynamic index is constant. |
| continue; // Hoisting not required. |
| } |
| |
| if (idx->Stmt() == usage) { |
| // The index expression is owned by the statement of usage. |
| continue; // Hoisting not required |
| } |
| |
| if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) { |
| auto* idx_variable = idx_variable_user->Variable(); |
| if (idx_variable->Declaration()->IsAnyOf<Let, Parameter>()) { |
| // Dynamic index is an immutable variable |
| continue; // Hoisting not required. |
| } |
| } |
| |
| // The dynamic index needs to be hoisted (if it hasn't been already). |
| auto fn = FnInfoFor(idx->Stmt()->Function()); |
| fn->hoisted_exprs.GetOrAdd(idx, [this, idx] { |
| // Create a name for the new 'let' |
| auto name = b.Symbols().New("ptr_index_save"); |
| // Insert a new 'let' just above the dynamic index statement. |
| hoist.InsertBefore(idx->Stmt(), [this, idx, name] { |
| return b.Decl(b.Let(name, ctx.CloneWithoutTransform(idx->Declaration()))); |
| }); |
| return name; |
| }); |
| } |
| } |
| |
| /// BuildDynamicIndex builds the AST expression node for the dynamic index expression used in an |
| /// AccessChain. This is similar to just cloning the expression, but BuildDynamicIndex() |
| /// also: |
| /// * Collapses constant value index expressions down to the computed value. This acts as an |
| /// constant folding optimization and reduces noise from the transform. |
| /// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type |
| /// isn't implicitly usable as a u32. This is to help feed the expression into a |
| /// `array<u32, N>` argument passed to a callee variant function. |
| const Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) { |
| if (auto* val = idx->ConstantValue()) { |
| // Expression evaluated to a constant value. Just emit that constant. |
| return b.Expr(val->ValueAs<AInt>()); |
| } |
| |
| // Expression is not a constant, clone the expression. |
| // Note: If the dynamic index expression was hoisted to a let, then cloning will return an |
| // identifier expression to the hoisted let. |
| auto* expr = ctx.Clone(idx->Declaration()); |
| |
| if (cast_to_u32) { |
| // The index may be fed to a dynamic index array<u32, N> argument, so the index |
| // expression may need casting to u32. |
| if (!idx->UnwrapMaterialize() |
| ->Type() |
| ->UnwrapRef() |
| ->IsAnyOf<core::type::U32, core::type::AbstractInt>()) { |
| expr = b.Call<u32>(expr); |
| } |
| } |
| |
| return expr; |
| } |
| |
| /// ProcessFunction scans the direct calls made by the function @p fn, adding new variants to |
| /// the callee functions and transforming the call expression to pass dynamic indices instead of |
| /// true pointers. |
| /// If the function @p fn has pointer parameters that must be transformed to a caller variant, |
| /// and the function is not called, then the function is dropped from the output of the |
| /// transform, as it cannot be generated. |
| /// @note ProcessFunction must be called in dependency order for the program, starting with the |
| /// entry points. |
| void ProcessFunction(const sem::Function* fn, FnInfo* fn_info) { |
| if (fn_info->variants.IsEmpty()) { |
| // Function has no variants pre-generated by callers. |
| if (MustBeCalled(fn)) { |
| // Drop the function, as it wasn't called and cannot be generated. |
| ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn->Declaration()); |
| return; |
| } |
| |
| // Function was not called. Create a single variant with an empty signature. |
| FnVariant variant; |
| variant.name = ctx.Clone(fn->Declaration()->name->symbol); |
| variant.order = 0; // Unaltered comes first. |
| fn_info->variants.Add(FnVariant::Signature{}, std::move(variant)); |
| } |
| |
| // Process each of the direct calls made by this function. |
| for (auto* call : fn->DirectCalls()) { |
| ProcessCall(fn_info, call); |
| } |
| } |
| |
| /// ProcessCall creates new variants of the callee function by permuting the call for each of |
| /// the variants of @p caller. ProcessCall also registers the clone callback to transform the |
| /// call expression to pass dynamic indices instead of true pointers. |
| void ProcessCall(FnInfo* caller, const sem::Call* call) { |
| auto* target = call->Target()->As<sem::Function>(); |
| if (!target) { |
| // Call target is not a user-declared function. |
| return; // Not interested in this call. |
| } |
| |
| if (!HasPointerParameter(target)) { |
| return; // Not interested in this call. |
| } |
| |
| bool call_needs_transforming = false; |
| |
| // Build the call target function variant for each variant of the caller. |
| for (auto caller_variant_it : caller->SortedVariants()) { |
| auto& caller_signature = *caller_variant_it.first; |
| auto& caller_variant = *caller_variant_it.second; |
| |
| // Build the target variant's signature. |
| FnVariant::Signature target_signature; |
| for (size_t i = 0; i < call->Arguments().Length(); i++) { |
| const auto* arg = call->Arguments()[i]; |
| const auto* param = target->Parameters()[i]; |
| const auto* param_ty = param->Type()->As<core::type::Pointer>(); |
| if (!param_ty) { |
| continue; // Parameter type is not a pointer. |
| } |
| |
| // Fetch the access chain for the argument. |
| auto* arg_chain = AccessChainFor(arg); |
| if (!arg_chain) { |
| continue; // Argument does not have an access chain |
| } |
| |
| // Construct the absolute AccessShape by considering the AccessShape of the caller |
| // variant's argument. This will propagate back through pointer parameters, to the |
| // outermost caller. |
| auto absolute = AbsoluteAccessShape(caller_signature, *arg_chain); |
| |
| // If the address space of the root variable of the access chain does not require |
| // transformation, then there's nothing to do. |
| if (!AddressSpaceRequiresTransform(absolute.root.address_space)) { |
| continue; |
| } |
| |
| // Record that this chain was used in a function call. |
| // This preserves the chain during the access chain filtering stage. |
| arg_chain->used_in_call = true; |
| |
| if (IsPrivateOrFunction(absolute.root.address_space)) { |
| // Pointers in 'private' and 'function' address spaces need to be passed by |
| // pointer argument. |
| absolute.root.variable = param; |
| } |
| |
| // Add the parameter's absolute AccessShape to the target's signature. |
| target_signature.Add(param, std::move(absolute)); |
| } |
| |
| // Construct a new FnVariant if this is the first caller of the target signature |
| auto* target_info = FnInfoFor(target); |
| auto& target_variant = target_info->variants.GetOrAdd(target_signature, [&] { |
| if (target_signature.IsEmpty()) { |
| // Call target does not require any argument changes. |
| FnVariant variant; |
| variant.name = ctx.Clone(target->Declaration()->name->symbol); |
| variant.order = 0; // Unaltered comes first. |
| return variant; |
| } |
| |
| // Build an appropriate variant function name. |
| // This is derived from the original function name and the pointer parameter |
| // chains. |
| StringStream ss; |
| ss << target->Declaration()->name->symbol.Name(); |
| for (auto* param : target->Parameters()) { |
| if (auto indices = target_signature.Get(param)) { |
| ss << "_" << AccessShapeName(*indices); |
| } |
| } |
| |
| // Build the pointer parameter symbols. |
| Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols; |
| for (auto& param_it : target_signature) { |
| auto* param = param_it.key.Value(); |
| auto& shape = param_it.value; |
| |
| // Parameter needs replacing with either zero, one or two parameters: |
| // If the parameter is in the 'private' or 'function' address space, then the |
| // originating pointer is always passed down. This always comes first. |
| // If the access chain has dynamic indices, then we create an array<u32, N> |
| // parameter to hold the dynamic indices. |
| bool requires_base_ptr_param = IsPrivateOrFunction(shape.root.address_space); |
| bool requires_indices_param = shape.NumDynamicIndices() > 0; |
| |
| PtrParamSymbols symbols; |
| if (requires_base_ptr_param && requires_indices_param) { |
| auto original_name = param->Declaration()->name->symbol; |
| symbols.base_ptr = UniqueSymbolWithSuffix(original_name, "_base"); |
| symbols.indices = UniqueSymbolWithSuffix(original_name, "_indices"); |
| } else if (requires_base_ptr_param) { |
| symbols.base_ptr = ctx.Clone(param->Declaration()->name->symbol); |
| } else if (requires_indices_param) { |
| symbols.indices = ctx.Clone(param->Declaration()->name->symbol); |
| } |
| |
| // Remember this base pointer name. |
| ptr_param_symbols.Add(param, symbols); |
| } |
| |
| // Build the variant. |
| FnVariant variant; |
| variant.name = b.Symbols().New(ss.str()); |
| variant.order = target_info->variants.Count() + 1; |
| variant.ptr_param_symbols = std::move(ptr_param_symbols); |
| return variant; |
| }); |
| |
| // Record the call made by caller variant to the target variant. |
| caller_variant.calls.Add(call, target_variant.name); |
| if (!target_signature.IsEmpty()) { |
| // The call expression will need transforming for at least one caller variant. |
| call_needs_transforming = true; |
| } |
| } |
| |
| if (call_needs_transforming) { |
| // Register the clone callback to correctly transform the call expression into the |
| // appropriate variant calls. |
| TransformCall(call); |
| } |
| } |
| |
| /// @returns true if the address space @p address_space requires transforming given the |
| /// transform's options. |
| bool AddressSpaceRequiresTransform(core::AddressSpace address_space) const { |
| switch (address_space) { |
| case core::AddressSpace::kUniform: |
| case core::AddressSpace::kStorage: |
| case core::AddressSpace::kWorkgroup: |
| return true; |
| case core::AddressSpace::kPrivate: |
| return opts.transform_private; |
| case core::AddressSpace::kFunction: |
| return opts.transform_function; |
| default: |
| return false; |
| } |
| } |
| |
| /// @returns the AccessChain for the expression @p expr, or nullptr if the expression does |
| /// not hold an access chain. |
| AccessChain* AccessChainFor(const sem::ValueExpression* expr) const { |
| if (auto chain = access_chains.Get(expr)) { |
| return *chain; |
| } |
| return nullptr; |
| } |
| |
| /// @returns the absolute AccessShape for @p indices, by replacing the originating pointer |
| /// parameter with the AccessChain of variant's signature. |
| AccessShape AbsoluteAccessShape(const FnVariant::Signature& signature, |
| const AccessShape& shape) const { |
| if (auto* root_param = shape.root.variable->As<sem::Parameter>()) { |
| if (auto incoming_chain = signature.Get(root_param)) { |
| // Access chain originates from a parameter, which will be transformed into an array |
| // of dynamic indices. Concatenate the signature's AccessShape for the parameter |
| // to the chain's indices, skipping over the chain's initial parameter index. |
| auto absolute = *incoming_chain; |
| for (auto& op : shape.ops) { |
| absolute.ops.Push(op); |
| } |
| return absolute; |
| } |
| } |
| |
| // Chain does not originate from a parameter, so is already absolute. |
| return shape; |
| } |
| |
| /// TransformFunction registers the clone callback to transform the function @p fn into the |
| /// (potentially multiple) function's variants. TransformFunction will assign the current |
| /// function and variant to #clone_state, which can be used by the other clone callbacks. |
| void TransformFunction(const sem::Function* fn, FnInfo* fn_info) { |
| // Register a custom handler for the specific function |
| ctx.Replace(fn->Declaration(), [this, fn, fn_info] { |
| // For the scope of this lambda, assign current_function to fn_info. |
| TINT_SCOPED_ASSIGNMENT(clone_state->current_function, fn_info); |
| |
| // This callback expects a single function returned. As we're generating potentially |
| // many variant functions, keep a record of the last created variant, and explicitly add |
| // this to the module if it isn't the last. We'll return the last created variant, |
| // taking the place of the original function. |
| const Function* pending_variant = nullptr; |
| |
| // For each variant of fn... |
| for (auto variant_it : fn_info->SortedVariants()) { |
| if (pending_variant) { |
| b.AST().AddFunction(pending_variant); |
| } |
| |
| auto& variant_sig = *variant_it.first; |
| auto& variant = *variant_it.second; |
| |
| // For the rest of this scope, assign the current variant and variant signature. |
| TINT_SCOPED_ASSIGNMENT(clone_state->current_variant_sig, &variant_sig); |
| TINT_SCOPED_ASSIGNMENT(clone_state->current_variant, &variant); |
| |
| // Build the variant's parameters. |
| // Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are |
| // either replaced with an array of dynamic indices, or are dropped (if there are no |
| // dynamic indices). |
| tint::Vector<const Parameter*, 8> params; |
| for (auto* param : fn->Parameters()) { |
| if (auto incoming_shape = variant_sig.Get(param)) { |
| auto& symbols = *variant.ptr_param_symbols.Get(param); |
| if (symbols.base_ptr.IsValid()) { |
| auto base_ptr_ty = b.ty.ptr( |
| incoming_shape->root.address_space, |
| CreateASTTypeFor(ctx, incoming_shape->root.type->UnwrapPtrOrRef())); |
| params.Push(b.Param(symbols.base_ptr, base_ptr_ty)); |
| } |
| if (symbols.indices.IsValid()) { |
| // Variant has dynamic indices for this variant, replace it. |
| auto dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape); |
| params.Push(b.Param(symbols.indices, dyn_idx_arr_type)); |
| } |
| } else { |
| // Just a regular parameter. Just clone the original parameter. |
| params.Push(ctx.Clone(param->Declaration())); |
| } |
| } |
| |
| // Build the variant by cloning the source function. The other clone callbacks will |
| // use clone_state->current_variant and clone_state->current_variant_sig to produce |
| // the variant. |
| auto ret_ty = ctx.Clone(fn->Declaration()->return_type); |
| auto body = ctx.Clone(fn->Declaration()->body); |
| auto attrs = ctx.Clone(fn->Declaration()->attributes); |
| auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes); |
| pending_variant = |
| b.create<Function>(b.Ident(variant.name), std::move(params), ret_ty, body, |
| std::move(attrs), std::move(ret_attrs)); |
| } |
| |
| return pending_variant; |
| }); |
| } |
| |
| /// TransformCall registers the clone callback to transform the call expression @p call to call |
| /// the correct target variant, and to replace pointers arguments with an array of dynamic |
| /// indices. |
| void TransformCall(const sem::Call* call) { |
| // Register a custom handler for the specific call expression |
| ctx.Replace(call->Declaration(), [this, call] { |
| auto target_variant = clone_state->current_variant->calls.Get(call); |
| if (!target_variant) { |
| // The current variant does not need to transform this call. |
| return ctx.CloneWithoutTransform(call->Declaration()); |
| } |
| |
| // Build the new call expressions's arguments. |
| tint::Vector<const Expression*, 8> new_args; |
| for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) { |
| auto* arg = call->Arguments()[arg_idx]; |
| auto* param = call->Target()->Parameters()[arg_idx]; |
| auto* param_ty = param->Type()->As<core::type::Pointer>(); |
| if (!param_ty) { |
| // Parameter is not a pointer. |
| // Just clone the unaltered argument. |
| new_args.Push(ctx.Clone(arg->Declaration())); |
| continue; // Parameter is not a pointer |
| } |
| |
| auto* chain = AccessChainFor(arg); |
| if (!chain) { |
| // No access chain means the argument is not a pointer that needs transforming. |
| // Just clone the unaltered argument. |
| new_args.Push(ctx.Clone(arg->Declaration())); |
| continue; |
| } |
| |
| // Construct the absolute AccessShape by considering the AccessShape of the caller |
| // variant's argument. This will propagate back through pointer parameters, to the |
| // outermost caller. |
| auto full_indices = AbsoluteAccessShape(*clone_state->current_variant_sig, *chain); |
| |
| // If the parameter is a pointer in the 'private' or 'function' address space, then |
| // we need to pass an additional pointer argument to the base object. |
| if (IsPrivateOrFunction(param_ty->AddressSpace())) { |
| auto* root_expr = BuildAccessRootExpr(chain->root, /* deref */ false); |
| if (!chain->root.variable->Is<sem::Parameter>()) { |
| root_expr = b.AddressOf(root_expr); |
| } |
| new_args.Push(root_expr); |
| } |
| |
| // Get or create the dynamic indices array. |
| if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) { |
| // Build an array of dynamic indices to pass as the replacement for the pointer. |
| tint::Vector<const Expression*, 8> dyn_idx_args; |
| if (auto* root_param = chain->root.variable->As<sem::Parameter>()) { |
| // Access chain originates from a pointer parameter. |
| if (auto incoming_chain = |
| clone_state->current_variant_sig->Get(root_param)) { |
| auto indices = |
| clone_state->current_variant->ptr_param_symbols.Get(root_param) |
| ->indices; |
| |
| // This pointer parameter will have been replaced with a array<u32, N> |
| // holding the variant's dynamic indices for the pointer. Unpack these |
| // directly into the array constructor's arguments. |
| auto N = incoming_chain->NumDynamicIndices(); |
| for (uint32_t i = 0; i < N; i++) { |
| dyn_idx_args.Push(b.IndexAccessor(indices, u32(i))); |
| } |
| } |
| } |
| // Pass the dynamic indices of the access chain into the array constructor. |
| for (auto& dyn_idx : chain->dynamic_indices) { |
| dyn_idx_args.Push(BuildDynamicIndex(dyn_idx, /* cast_to_u32 */ true)); |
| } |
| // Construct the dynamic index array, and push as an argument. |
| new_args.Push(b.Call(dyn_idx_arr_ty, std::move(dyn_idx_args))); |
| } |
| } |
| |
| // Make the call to the target's variant. |
| return b.Call(*target_variant, std::move(new_args)); |
| }); |
| } |
| |
| /// ProcessAccessChains performs the following: |
| /// * Removes all AccessChains from expressions that are not either used as a pointer argument |
| /// in a call, or originates from a pointer parameter. |
| /// * Hoists the dynamic index expressions of AccessChains to 'let' statements, to prevent |
| /// multiple evaluation of the expressions, and avoid expressions resolving to different |
| /// variables based on lexical scope. |
| void ProcessAccessChains() { |
| auto chain_exprs = access_chains.Keys(); |
| chain_exprs.Sort([](const auto& expr_a, const auto& expr_b) { |
| return expr_a->Declaration()->node_id.value < expr_b->Declaration()->node_id.value; |
| }); |
| |
| for (auto* expr : chain_exprs) { |
| auto* chain = *access_chains.Get(expr); |
| if (!chain->used_in_call && !chain->root.variable->Is<sem::Parameter>()) { |
| // Chain was not used in a function call, and does not originate from a |
| // parameter. This chain does not need transforming. Drop it. |
| access_chains.Remove(expr); |
| continue; |
| } |
| |
| // Chain requires transforming. |
| |
| // We need to be careful that the chain does not use expressions with side-effects which |
| // cannot be repeatedly evaluated. In this situation we can hoist the dynamic index |
| // expressions to their own uniquely named lets (if required). |
| MaybeHoistDynamicIndices(chain, expr->Stmt()); |
| } |
| } |
| |
| /// TransformAccessChainExpressions registers the clone callback to: |
| /// * Transform all expressions that have an AccessChain (which aren't arguments to function |
| /// calls, these are handled by TransformCall()), into the equivalent expression using a |
| /// module-scope variable. |
| /// * Replace expressions that have been hoisted to a let, with an identifier expression to that |
| /// let. |
| void TransformAccessChainExpressions() { |
| // Register a custom handler for all non-function call expressions |
| ctx.ReplaceAll([this](const Expression* ast_expr) -> const Expression* { |
| if (!clone_state->current_variant) { |
| // Expression does not belong to a function variant. |
| return nullptr; // Just clone the expression. |
| } |
| |
| auto* expr = sem.GetVal(ast_expr); |
| if (!expr) { |
| // No semantic node for the expression. |
| return nullptr; // Just clone the expression. |
| } |
| |
| // If the expression has been hoisted to a 'let', then replace the expression with an |
| // identifier to the hoisted let. |
| if (auto hoisted = clone_state->current_function->hoisted_exprs.Get(expr)) { |
| return b.Expr(*hoisted); |
| } |
| |
| auto* chain = AccessChainFor(expr); |
| if (!chain) { |
| // The expression does not have an AccessChain. |
| return nullptr; // Just clone the expression. |
| } |
| |
| auto* root_param = chain->root.variable->As<sem::Parameter>(); |
| if (!root_param) { |
| // The expression has an access chain, but does not originate with a pointer |
| // parameter. We don't need to change anything here. |
| return nullptr; // Just clone the expression. |
| } |
| |
| auto incoming_shape = clone_state->current_variant_sig->Get(root_param); |
| if (!incoming_shape) { |
| // The root parameter of the access chain is not part of the variant's signature. |
| return nullptr; // Just clone the expression. |
| } |
| |
| // Expression holds an access chain to a pointer parameter that needs transforming. |
| // Reconstruct the expression using the variant's incoming shape. |
| |
| auto* chain_expr = BuildAccessRootExpr(incoming_shape->root, /* deref */ true); |
| |
| // Chain starts with a pointer parameter. |
| // Replace this with the variant's incoming shape. This will bring the expression up to |
| // the incoming pointer. |
| size_t next_dyn_idx_from_indices = 0; |
| auto& indices = |
| clone_state->current_variant->ptr_param_symbols.Get(root_param)->indices; |
| for (auto param_access : incoming_shape->ops) { |
| chain_expr = BuildAccessExpr(chain_expr, param_access, [&] { |
| return b.IndexAccessor(indices, AInt(next_dyn_idx_from_indices++)); |
| }); |
| } |
| |
| // Now build the expression chain within the function. |
| |
| // For each access in the chain (excluding the pointer parameter)... |
| size_t next_dyn_idx_from_chain = 0; |
| for (auto& op : chain->ops) { |
| chain_expr = BuildAccessExpr(chain_expr, op, [&] { |
| return BuildDynamicIndex(chain->dynamic_indices[next_dyn_idx_from_chain++], |
| false); |
| }); |
| } |
| |
| // BuildAccessExpr() always returns a non-pointer. |
| // If the expression we're replacing is a pointer, take the address. |
| if (expr->Type()->Is<core::type::Pointer>()) { |
| chain_expr = b.AddressOf(chain_expr); |
| } |
| |
| return chain_expr; |
| }); |
| } |
| |
| /// @returns the FnInfo for the given function, constructing a new FnInfo if @p fn doesn't |
| /// already have one. |
| FnInfo* FnInfoFor(const sem::Function* fn) { |
| return fns.GetOrAdd(fn, [this] { return fn_info_allocator.Create(); }); |
| } |
| |
| /// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias |
| /// if this is the first call for the given shape. |
| Type DynamicIndexArrayType(const AccessShape& shape) { |
| auto name = dynamic_index_array_aliases.GetOrAdd(shape, [&] { |
| // Count the number of dynamic indices |
| uint32_t num_dyn_indices = shape.NumDynamicIndices(); |
| if (num_dyn_indices == 0) { |
| return Symbol{}; |
| } |
| auto symbol = b.Symbols().New(AccessShapeName(shape)); |
| b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices))); |
| return symbol; |
| }); |
| return name.IsValid() ? b.ty(name) : Type{}; |
| } |
| |
| /// @returns a name describing the given shape |
| std::string AccessShapeName(const AccessShape& shape) { |
| StringStream ss; |
| |
| if (IsPrivateOrFunction(shape.root.address_space)) { |
| ss << "F"; |
| } else { |
| ss << shape.root.variable->Declaration()->name->symbol.Name(); |
| } |
| |
| for (auto& op : shape.ops) { |
| ss << "_"; |
| |
| if (std::holds_alternative<DynamicIndex>(op)) { |
| /// The op uses a dynamic (runtime-expression) index. |
| ss << "X"; |
| continue; |
| } |
| |
| auto* member = std::get_if<Symbol>(&op); |
| if (TINT_LIKELY(member)) { |
| ss << member->Name(); |
| continue; |
| } |
| |
| TINT_ICE() << "unhandled variant for access chain"; |
| } |
| return ss.str(); |
| } |
| |
| /// Builds an expresion to the root of an access, returning the new expression. |
| /// @param root the AccessRoot |
| /// @param deref if true, the returned expression will always be a reference type. |
| const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) { |
| if (auto* param = root.variable->As<sem::Parameter>()) { |
| if (auto symbols = clone_state->current_variant->ptr_param_symbols.Get(param)) { |
| if (deref) { |
| return b.Deref(b.Expr(symbols->base_ptr)); |
| } |
| return b.Expr(symbols->base_ptr); |
| } |
| } |
| |
| const Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol)); |
| if (deref) { |
| if (root.variable->Type()->Is<core::type::Pointer>()) { |
| expr = b.Deref(expr); |
| } |
| } |
| return expr; |
| } |
| |
| /// Builds a single access in an access chain, returning the new expression. |
| /// The returned expression will always be of a reference type. |
| /// @param expr the input expression |
| /// @param access the access to perform on the current expression |
| /// @param dynamic_index a function that obtains the next dynamic index |
| const Expression* BuildAccessExpr(const Expression* expr, |
| const AccessOp& access, |
| std::function<const Expression*()> dynamic_index) { |
| if (std::holds_alternative<DynamicIndex>(access)) { |
| /// The access uses a dynamic (runtime-expression) index. |
| auto* idx = dynamic_index(); |
| return b.IndexAccessor(expr, idx); |
| } |
| |
| auto* member = std::get_if<Symbol>(&access); |
| if (TINT_LIKELY(member)) { |
| /// The access is a member access. |
| return b.MemberAccessor(expr, ctx.Clone(*member)); |
| } |
| |
| TINT_ICE() << "unhandled variant type for access chain"; |
| } |
| |
| /// @returns a new Symbol starting with @p symbol concatenated with @p suffix, and possibly an |
| /// underscore and number, if the symbol is already taken. |
| Symbol UniqueSymbolWithSuffix(Symbol symbol, const std::string& suffix) { |
| auto str = symbol.Name() + suffix; |
| return unique_symbols.GetOrAdd(str, [&] { return b.Symbols().New(str); }); |
| } |
| |
| /// @returns true if the function @p fn has at least one pointer parameter. |
| static bool HasPointerParameter(const sem::Function* fn) { |
| for (auto* param : fn->Parameters()) { |
| if (param->Type()->Is<core::type::Pointer>()) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| /// @returns true if the function @p fn has at least one pointer parameter in an address space |
| /// that must be replaced. If this function is not called, then the function cannot be sensibly |
| /// generated, and must be stripped. |
| static bool MustBeCalled(const sem::Function* fn) { |
| for (auto* param : fn->Parameters()) { |
| if (auto* ptr = param->Type()->As<core::type::Pointer>()) { |
| switch (ptr->AddressSpace()) { |
| case core::AddressSpace::kUniform: |
| case core::AddressSpace::kStorage: |
| case core::AddressSpace::kWorkgroup: |
| return true; |
| default: |
| return false; |
| } |
| } |
| } |
| return false; |
| } |
| |
| /// @returns true if the given address space is 'private' or 'function'. |
| static bool IsPrivateOrFunction(const core::AddressSpace sc) { |
| return sc == core::AddressSpace::kPrivate || sc == core::AddressSpace::kFunction; |
| } |
| }; |
| |
| DirectVariableAccess::Config::Config() = default; |
| DirectVariableAccess::Config::Config(const Options& opt) : options(opt) {} |
| |
| DirectVariableAccess::Config::~Config() = default; |
| |
| DirectVariableAccess::DirectVariableAccess() = default; |
| |
| DirectVariableAccess::~DirectVariableAccess() = default; |
| |
| Transform::ApplyResult DirectVariableAccess::Apply(const Program& program, |
| const DataMap& inputs, |
| DataMap&) const { |
| Options options; |
| if (auto* cfg = inputs.Get<Config>()) { |
| options = cfg->options; |
| } |
| return State(program, options).Run(); |
| } |
| |
| } // namespace tint::ast::transform |