// Copyright 2022 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
//    contributors may be used to endorse or promote products derived from
//    this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "src/tint/lang/wgsl/ast/transform/direct_variable_access.h"

#include <algorithm>
#include <string>
#include <utility>

#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/type/abstract_int.h"
#include "src/tint/lang/wgsl/ast/transform/hoist_to_decl_before.h"
#include "src/tint/lang/wgsl/ast/traverse_expressions.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/call.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/index_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/member_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/module.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/lang/wgsl/sem/struct.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/containers/reverse.h"
#include "src/tint/utils/macros/scoped_assignment.h"
#include "src/tint/utils/text/string_stream.h"

TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DirectVariableAccess);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::DirectVariableAccess::Config);

using namespace tint::core::number_suffixes;  // NOLINT
using namespace tint::core::fluent_types;     // NOLINT

namespace tint::ast::transform {

namespace {

/// AccessRoot describes the root of an AccessShape.
struct AccessRoot {
    /// The pointer-unwrapped type of the *transformed* variable.
    /// This may be different for pointers in 'private' and 'function' address space, as the pointer
    /// parameter type is to the *base object* instead of the input pointer type.
    tint::core::type::Type const* type = nullptr;
    /// The originating module-scope variable ('private', 'storage', 'uniform', 'workgroup'),
    /// function-scope variable ('function'), or pointer parameter in the source program.
    tint::sem::Variable const* variable = nullptr;
    /// The address space of the variable or pointer type.
    tint::core::AddressSpace address_space = tint::core::AddressSpace::kUndefined;

    /// @return a hash code for this object
    size_t HashCode() const { return Hash(type, variable); }
};

/// Inequality operator for AccessRoot
bool operator!=(const AccessRoot& a, const AccessRoot& b) {
    return a.type != b.type || a.variable != b.variable;
}

/// DynamicIndex is used by DirectVariableAccess::State::AccessOp to indicate an array, matrix or
/// vector index.
struct DynamicIndex {
    /// @return a hash code for this object
    size_t HashCode() const { return 42 /* empty struct: any number will do */; }
};

/// Inequality operator for DynamicIndex
bool operator!=(const DynamicIndex&, const DynamicIndex&) {
    return false;  // empty struct: two DynamicIndex objects are always equal
}

/// AccessOp describes a single access in an access chain.
/// The access is one of:
/// Symbol        - a struct member access.
/// DynamicIndex  - a runtime index on an array, matrix column, or vector element.
using AccessOp = std::variant<tint::Symbol, DynamicIndex>;

/// A vector of AccessOp. Describes the static "path" from a root variable to an element
/// within the variable. Array accessors index expressions are held externally to the
/// AccessShape, so AccessShape will be considered equal even if the array, matrix or vector
/// index values differ.
///
/// For example, consider the following:
///
/// ```
///           struct A {
///               x : array<i32, 8>,
///               y : u32,
///           };
///           struct B {
///               x : i32,
///               y : array<A, 4>
///           };
///           var<workgroup> C : B;
/// ```
///
/// The following AccessShape would describe the following:
///
/// +==============================+===============+=================================+
/// | AccessShape                  | Type          |  Expression                     |
/// +==============================+===============+=================================+
/// | [ Variable 'C', Symbol 'x' ] | i32           |  C.x                            |
/// +------------------------------+---------------+---------------------------------+
/// | [ Variable 'C', Symbol 'y' ] | array<A, 4>   |  C.y                            |
/// +------------------------------+---------------+---------------------------------+
/// | [ Variable 'C', Symbol 'y',  | A             |  C.y[dyn_idx[0]]                |
/// |   DynamicIndex ]             |               |                                 |
/// +------------------------------+---------------+---------------------------------+
/// | [ Variable 'C', Symbol 'y',  | array<i32, 8> |  C.y[dyn_idx[0]].x              |
/// |   DynamicIndex, Symbol 'x' ] |               |                                 |
/// +------------------------------+---------------+---------------------------------+
/// | [ Variable 'C', Symbol 'y',  | i32           |  C.y[dyn_idx[0]].x[dyn_idx[1]]  |
/// |   DynamicIndex, Symbol 'x',  |               |                                 |
/// |   DynamicIndex ]             |               |                                 |
/// +------------------------------+---------------+---------------------------------+
/// | [ Variable 'C', Symbol 'y',  | u32           |  C.y[dyn_idx[0]].y              |
/// |   DynamicIndex, Symbol 'y' ] |               |                                 |
/// +------------------------------+---------------+---------------------------------+
///
/// Where: `dyn_idx` is the AccessChain::dynamic_indices.
struct AccessShape {
    // The originating variable.
    AccessRoot root;
    /// The chain of access ops.
    tint::Vector<AccessOp, 8> ops;

    /// @returns the number of DynamicIndex operations in #ops.
    uint32_t NumDynamicIndices() const {
        uint32_t count = 0;
        for (auto& op : ops) {
            if (std::holds_alternative<DynamicIndex>(op)) {
                count++;
            }
        }
        return count;
    }

    /// @return a hash code for this object
    size_t HashCode() const { return Hash(root, ops); }
};

/// Equality operator for AccessShape
bool operator==(const AccessShape& a, const AccessShape& b) {
    return !(a.root != b.root) && a.ops == b.ops;
}

/// Inequality operator for AccessShape
bool operator!=(const AccessShape& a, const AccessShape& b) {
    return !(a == b);
}

/// AccessChain describes a chain of access expressions originating from a variable.
struct AccessChain : AccessShape {
    /// The array accessor index expressions. This vector is indexed by the `DynamicIndex`s in
    /// #indices.
    Vector<const sem::ValueExpression*, 8> dynamic_indices;
    /// If true, then this access chain is used as an argument to call a variant.
    bool used_in_call = false;
};

}  // namespace

/// The PIMPL state for the DirectVariableAccess transform
struct DirectVariableAccess::State {
    /// Constructor
    /// @param src the source Program
    /// @param options the transform options
    State(const Program& src, const Options& options)
        : ctx{&b, &src, /* auto_clone_symbols */ true}, opts(options) {}

    /// The main function for the transform.
    /// @returns the ApplyResult
    ApplyResult Run() {
        // If there are no functions with pointer parameters, then this transform can be skipped.
        if (!AnyPointerParameters()) {
            return SkipTransform;
        }

        // Stage 1:
        // Walk all the expressions of the program, starting with the expression leaves.
        // Whenever we find an identifier resolving to a var, pointer parameter or pointer let to
        // another chain, start constructing an access chain. When chains are accessed, these chains
        // are grown and moved up the expression tree. After this stage, we are left with all the
        // expression access chains to variables that we may need to transform.
        for (auto* node : ctx.src->ASTNodes().Objects()) {
            if (auto* expr = sem.GetVal(node)) {
                AppendAccessChain(expr);
            }
        }

        // Stage 2:
        // Walk the functions in dependency order, starting with the entry points.
        // Construct the set of function 'variants' by examining the calls made by each function to
        // their call target. Each variant holds a map of pointer parameter to access chains, and
        // will have the pointer parameters replaced with an array of u32s, used to perform the
        // pointer indexing in the variant.
        // Function call pointer arguments are replaced with an array of these dynamic indices.
        auto decls = sem.Module()->DependencyOrderedDeclarations();
        for (auto* decl : tint::Reverse(decls)) {
            if (auto* fn = sem.Get<sem::Function>(decl)) {
                auto* fn_info = FnInfoFor(fn);
                ProcessFunction(fn, fn_info);
                TransformFunction(fn, fn_info);
            }
        }

        // Stage 3:
        // Filter out access chains that do not need transforming.
        // Ensure that chain dynamic index expressions are evaluated once at the correct place
        ProcessAccessChains();

        // Stage 4:
        // Replace all the access chain expressions in all functions with reconstructed expression
        // using the originating global variable, and any dynamic indices passed in to the function
        // variant.
        TransformAccessChainExpressions();

        // Stage 5:
        // Actually kick the clone.
        CloneState state;
        clone_state = &state;
        ctx.Clone();
        return resolver::Resolve(b);
    }

  private:
    /// Holds symbols of the transformed pointer parameter.
    /// If both symbols are valid, then #base_ptr and #indices are both program-unique symbols
    /// derived from the original parameter name.
    /// If only one symbol is valid, then this is the original parameter symbol.
    struct PtrParamSymbols {
        /// The symbol of the base pointer parameter.
        Symbol base_ptr;
        /// The symbol of the dynamic indicies parameter.
        Symbol indices;
    };

    /// FnVariant describes a unique variant of a function, specialized by the AccessShape of the
    /// pointer arguments - also known as the variant's "signature".
    ///
    /// To help understand what a variant is, consider the following WGSL:
    ///
    /// ```
    /// fn F(a : ptr<storage, u32>, b : u32, c : ptr<storage, u32>) {
    ///    return *a + b + *c;
    /// }
    ///
    /// @group(0) @binding(0) var<storage> S0 : u32;
    /// @group(0) @binding(0) var<storage> S1 : array<u32, 64>;
    ///
    /// fn x() {
    ///    F(&S0, 0, &S0);       // (A)
    ///    F(&S0, 0, &S0);       // (B)
    ///    F(&S1[0], 1, &S0);    // (C)
    ///    F(&S1[5], 2, &S0);    // (D)
    ///    F(&S1[5], 3, &S1[3]); // (E)
    ///    F(&S1[7], 4, &S1[2]); // (F)
    /// }
    /// ```
    ///
    /// Given the calls in x(), function F() will have 3 variants:
    /// (1) F<S0,S0>                   - called by (A) and (B).
    ///                                  Note that only 'uniform', 'storage' and 'workgroup' pointer
    ///                                  parameters are considered for a variant signature, and so
    ///                                  the argument for parameter 'b' is not included in the
    ///                                  signature.
    /// (2) F<S1[dyn_idx],S0>          - called by (C) and (D).
    ///                                  Note that the array index value is external to the
    ///                                  AccessShape, and so is not part of the variant signature.
    /// (3) F<S1[dyn_idx],S1[dyn_idx]> - called by (E) and (F).
    ///
    /// Each variant of the function will be emitted as a separate function by the transform, and
    /// would look something like:
    ///
    /// ```
    /// // variant F<S0,S0> (1)
    /// fn F_S0_S0(b : u32) {
    ///    return S0 + b + S0;
    /// }
    ///
    /// type S1_X = array<u32, 1>;
    ///
    /// // variant F<S1[dyn_idx],S0> (2)
    /// fn F_S1_X_S0(a : S1_X, b : u32) {
    ///    return S1[a[0]] + b + S0;
    /// }
    ///
    /// // variant F<S1[dyn_idx],S1[dyn_idx]> (3)
    /// fn F_S1_X_S1_X(a : S1_X, b : u32, c : S1_X) {
    ///    return S1[a[0]] + b + S1[c[0]];
    /// }
    ///
    /// @group(0) @binding(0) var<storage> S0 : u32;
    /// @group(0) @binding(0) var<storage> S1 : array<u32, 64>;
    ///
    /// fn x() {
    ///    F_S0_S0(0);                        // (A)
    ///    F(&S0, 0, &S0);                    // (B)
    ///    F_S1_X_S0(S1_X(0), 1);             // (C)
    ///    F_S1_X_S0(S1_X(5), 2);             // (D)
    ///    F_S1_X_S1_X(S1_X(5), 3, S1_X(3));  // (E)
    ///    F_S1_X_S1_X(S1_X(7), 4, S1_X(2));  // (F)
    /// }
    /// ```
    struct FnVariant {
        /// The signature of the variant is a map of each of the function's 'uniform', 'storage' and
        /// 'workgroup' pointer parameters to the caller's AccessShape.
        using Signature = Hashmap<const sem::Parameter*, AccessShape, 4>;

        /// The unique name of the variant.
        /// The symbol is in the `ctx.dst` program namespace.
        Symbol name;

        /// A map of direct calls made by this variant to the name of other function variants.
        Hashmap<const sem::Call*, Symbol, 4> calls;

        /// A map of input program parameter to output parameter symbols.
        Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;

        /// The declaration order of the variant, in relation to other variants of the same
        /// function. Used to ensure deterministic ordering of the transform, as map iteration is
        /// not deterministic between compilers.
        size_t order = 0;
    };

    /// FnInfo holds information about a function in the input program.
    struct FnInfo {
        /// A map of variant signature to the variant data.
        Hashmap<FnVariant::Signature, FnVariant, 8> variants;
        /// A map of expressions that have been hoisted to a 'let' declaration in the function.
        Hashmap<const sem::ValueExpression*, Symbol, 8> hoisted_exprs;

        /// @returns the variants of the function in a deterministically ordered vector.
        tint::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> SortedVariants() {
            tint::Vector<std::pair<const FnVariant::Signature*, FnVariant*>, 8> out;
            out.Reserve(variants.Count());
            for (auto it : variants) {
                out.Push({&it.key, &it.value});
            }
            out.Sort([&](auto& va, auto& vb) { return va.second->order < vb.second->order; });
            return out;
        }
    };

    /// The program builder
    ProgramBuilder b;
    /// The clone context
    program::CloneContext ctx;
    /// The transform options
    const Options& opts;
    /// Alias to the semantic info in ctx.src
    const sem::Info& sem = ctx.src->Sem();
    /// Alias to the symbols in ctx.src
    const SymbolTable& sym = ctx.src->Symbols();
    /// Map of semantic function to the function info
    Hashmap<const sem::Function*, FnInfo*, 8> fns;
    /// Map of AccessShape to the name of a type alias for the an array<u32, N> used for the
    /// dynamic indices of an access chain, passed down as the transformed type of a variant's
    /// pointer parameter.
    Hashmap<AccessShape, Symbol, 8> dynamic_index_array_aliases;
    /// Map of semantic expression to AccessChain
    Hashmap<const sem::ValueExpression*, AccessChain*, 32> access_chains;
    /// Allocator for FnInfo
    BlockAllocator<FnInfo> fn_info_allocator;
    /// Allocator for AccessChain
    BlockAllocator<AccessChain> access_chain_allocator;
    /// Helper used for hoisting expressions to lets
    HoistToDeclBefore hoist{ctx};
    /// Map of string to unique symbol (no collisions in output program).
    Hashmap<std::string, Symbol, 8> unique_symbols;

    /// CloneState holds pointers to the current function, variant and variant's parameters.
    struct CloneState {
        /// The current function being cloned
        FnInfo* current_function = nullptr;
        /// The current function variant being built
        FnVariant* current_variant = nullptr;
        /// The signature of the current function variant being built
        const FnVariant::Signature* current_variant_sig = nullptr;
    };

    /// The clone state.
    /// Only valid during the lifetime of the program::CloneContext::Clone().
    CloneState* clone_state = nullptr;

    /// @returns true if any user functions have parameters of a pointer type.
    bool AnyPointerParameters() const {
        for (auto* fn : ctx.src->AST().Functions()) {
            for (auto* param : fn->params) {
                if (sem.Get(param)->Type()->Is<core::type::Pointer>()) {
                    return true;
                }
            }
        }
        return false;
    }

    /// AppendAccessChain creates or extends an existing AccessChain for the given expression,
    /// modifying the #access_chains map.
    void AppendAccessChain(const sem::ValueExpression* expr) {
        // take_chain moves the AccessChain from the expression `from` to the expression `expr`.
        // Returns nullptr if `from` did not hold an access chain.
        auto take_chain = [&](const sem::ValueExpression* from) -> AccessChain* {
            if (auto* chain = AccessChainFor(from)) {
                access_chains.Remove(from);
                access_chains.Add(expr, chain);
                return chain;
            }
            return nullptr;
        };

        Switch(
            expr,
            [&](const sem::VariableUser* user) {
                // Expression resolves to a variable.
                auto* variable = user->Variable();

                auto create_new_chain = [&] {
                    auto* chain = access_chain_allocator.Create();
                    chain->root.variable = variable;
                    chain->root.type = variable->Type();
                    chain->root.address_space = variable->AddressSpace();
                    if (auto* ptr = chain->root.type->As<core::type::Pointer>()) {
                        chain->root.address_space = ptr->AddressSpace();
                    }
                    access_chains.Add(expr, chain);
                };

                Switch(
                    variable->Declaration(),
                    [&](const Var*) {
                        if (variable->AddressSpace() != core::AddressSpace::kHandle) {
                            // Start a new access chain for the non-handle 'var' access
                            create_new_chain();
                        }
                    },
                    [&](const Parameter*) {
                        if (variable->Type()->Is<core::type::Pointer>()) {
                            // Start a new access chain for the pointer parameter access
                            create_new_chain();
                        }
                    },
                    [&](const Let*) {
                        if (variable->Type()->Is<core::type::Pointer>()) {
                            // variable is a pointer-let.
                            auto* init = sem.GetVal(variable->Declaration()->initializer);
                            // Note: We do not use take_chain() here, as we need to preserve the
                            // AccessChain on the let's initializer, as the let needs its
                            // initializer updated, and the let may be used multiple times. Instead
                            // we copy the let's AccessChain into a a new AccessChain.
                            if (auto* init_chain = AccessChainFor(init)) {
                                access_chains.Add(expr, access_chain_allocator.Create(*init_chain));
                            }
                        }
                    });
            },
            [&](const sem::StructMemberAccess* a) {
                // Structure member access.
                // Append the Symbol of the member name to the chain, and move the chain to the
                // member access expression.
                if (auto* chain = take_chain(a->Object())) {
                    chain->ops.Push(a->Member()->Name());
                }
            },
            [&](const sem::IndexAccessorExpression* a) {
                // Array, matrix or vector index.
                // Store the index expression into AccessChain::dynamic_indices, append a
                // DynamicIndex to the chain, and move the chain to the index accessor expression.
                if (auto* chain = take_chain(a->Object())) {
                    chain->ops.Push(DynamicIndex{});
                    chain->dynamic_indices.Push(a->Index());
                }
            },
            [&](const sem::ValueExpression* e) {
                if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
                    // Unary op.
                    // If this is a '&' or '*', simply move the chain to the unary op expression.
                    if (unary->op == core::UnaryOp::kAddressOf ||
                        unary->op == core::UnaryOp::kIndirection) {
                        take_chain(sem.GetVal(unary->expr));
                    }
                }
            });
    }

    /// MaybeHoistDynamicIndices examines the AccessChain::dynamic_indices member of @p chain,
    /// hoisting all expressions to their own uniquely named 'let' if none of the following are
    /// true:
    /// 1. The index expression is a constant value.
    /// 2. The index expression's statement is the same as @p usage.
    /// 3. The index expression is an identifier resolving to a 'let', 'const' or parameter, AND
    ///    that identifier resolves to the same variable at @p usage.
    ///
    /// A dynamic index will only be hoisted once. The hoisting applies to all variants of the
    /// function that holds the dynamic index expression.
    void MaybeHoistDynamicIndices(AccessChain* chain, const sem::Statement* usage) {
        for (auto& idx : chain->dynamic_indices) {
            if (idx->ConstantValue()) {
                // Dynamic index is constant.
                continue;  // Hoisting not required.
            }

            if (idx->Stmt() == usage) {
                // The index expression is owned by the statement of usage.
                continue;  // Hoisting not required
            }

            if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
                auto* idx_variable = idx_variable_user->Variable();
                if (idx_variable->Declaration()->IsAnyOf<Let, Parameter>()) {
                    // Dynamic index is an immutable variable
                    continue;  // Hoisting not required.
                }
            }

            // The dynamic index needs to be hoisted (if it hasn't been already).
            auto fn = FnInfoFor(idx->Stmt()->Function());
            fn->hoisted_exprs.GetOrCreate(idx, [=] {
                // Create a name for the new 'let'
                auto name = b.Symbols().New("ptr_index_save");
                // Insert a new 'let' just above the dynamic index statement.
                hoist.InsertBefore(idx->Stmt(), [this, idx, name] {
                    return b.Decl(b.Let(name, ctx.CloneWithoutTransform(idx->Declaration())));
                });
                return name;
            });
        }
    }

    /// BuildDynamicIndex builds the AST expression node for the dynamic index expression used in an
    /// AccessChain. This is similar to just cloning the expression, but BuildDynamicIndex()
    /// also:
    /// * Collapses constant value index expressions down to the computed value. This acts as an
    ///   constant folding optimization and reduces noise from the transform.
    /// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
    ///   isn't implicitly usable as a u32. This is to help feed the expression into a
    ///   `array<u32, N>` argument passed to a callee variant function.
    const Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
        if (auto* val = idx->ConstantValue()) {
            // Expression evaluated to a constant value. Just emit that constant.
            return b.Expr(val->ValueAs<AInt>());
        }

        // Expression is not a constant, clone the expression.
        // Note: If the dynamic index expression was hoisted to a let, then cloning will return an
        // identifier expression to the hoisted let.
        auto* expr = ctx.Clone(idx->Declaration());

        if (cast_to_u32) {
            // The index may be fed to a dynamic index array<u32, N> argument, so the index
            // expression may need casting to u32.
            if (!idx->UnwrapMaterialize()
                     ->Type()
                     ->UnwrapRef()
                     ->IsAnyOf<core::type::U32, core::type::AbstractInt>()) {
                expr = b.Call<u32>(expr);
            }
        }

        return expr;
    }

    /// ProcessFunction scans the direct calls made by the function @p fn, adding new variants to
    /// the callee functions and transforming the call expression to pass dynamic indices instead of
    /// true pointers.
    /// If the function @p fn has pointer parameters that must be transformed to a caller variant,
    /// and the function is not called, then the function is dropped from the output of the
    /// transform, as it cannot be generated.
    /// @note ProcessFunction must be called in dependency order for the program, starting with the
    /// entry points.
    void ProcessFunction(const sem::Function* fn, FnInfo* fn_info) {
        if (fn_info->variants.IsEmpty()) {
            // Function has no variants pre-generated by callers.
            if (MustBeCalled(fn)) {
                // Drop the function, as it wasn't called and cannot be generated.
                ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn->Declaration());
                return;
            }

            // Function was not called. Create a single variant with an empty signature.
            FnVariant variant;
            variant.name = ctx.Clone(fn->Declaration()->name->symbol);
            variant.order = 0;  // Unaltered comes first.
            fn_info->variants.Add(FnVariant::Signature{}, std::move(variant));
        }

        // Process each of the direct calls made by this function.
        for (auto* call : fn->DirectCalls()) {
            ProcessCall(fn_info, call);
        }
    }

    /// ProcessCall creates new variants of the callee function by permuting the call for each of
    /// the variants of @p caller. ProcessCall also registers the clone callback to transform the
    /// call expression to pass dynamic indices instead of true pointers.
    void ProcessCall(FnInfo* caller, const sem::Call* call) {
        auto* target = call->Target()->As<sem::Function>();
        if (!target) {
            // Call target is not a user-declared function.
            return;  // Not interested in this call.
        }

        if (!HasPointerParameter(target)) {
            return;  // Not interested in this call.
        }

        bool call_needs_transforming = false;

        // Build the call target function variant for each variant of the caller.
        for (auto caller_variant_it : caller->SortedVariants()) {
            auto& caller_signature = *caller_variant_it.first;
            auto& caller_variant = *caller_variant_it.second;

            // Build the target variant's signature.
            FnVariant::Signature target_signature;
            for (size_t i = 0; i < call->Arguments().Length(); i++) {
                const auto* arg = call->Arguments()[i];
                const auto* param = target->Parameters()[i];
                const auto* param_ty = param->Type()->As<core::type::Pointer>();
                if (!param_ty) {
                    continue;  // Parameter type is not a pointer.
                }

                // Fetch the access chain for the argument.
                auto* arg_chain = AccessChainFor(arg);
                if (!arg_chain) {
                    continue;  // Argument does not have an access chain
                }

                // Construct the absolute AccessShape by considering the AccessShape of the caller
                // variant's argument. This will propagate back through pointer parameters, to the
                // outermost caller.
                auto absolute = AbsoluteAccessShape(caller_signature, *arg_chain);

                // If the address space of the root variable of the access chain does not require
                // transformation, then there's nothing to do.
                if (!AddressSpaceRequiresTransform(absolute.root.address_space)) {
                    continue;
                }

                // Record that this chain was used in a function call.
                // This preserves the chain during the access chain filtering stage.
                arg_chain->used_in_call = true;

                if (IsPrivateOrFunction(absolute.root.address_space)) {
                    // Pointers in 'private' and 'function' address spaces need to be passed by
                    // pointer argument.
                    absolute.root.variable = param;
                }

                // Add the parameter's absolute AccessShape to the target's signature.
                target_signature.Add(param, std::move(absolute));
            }

            // Construct a new FnVariant if this is the first caller of the target signature
            auto* target_info = FnInfoFor(target);
            auto& target_variant = target_info->variants.GetOrCreate(target_signature, [&] {
                if (target_signature.IsEmpty()) {
                    // Call target does not require any argument changes.
                    FnVariant variant;
                    variant.name = ctx.Clone(target->Declaration()->name->symbol);
                    variant.order = 0;  // Unaltered comes first.
                    return variant;
                }

                // Build an appropriate variant function name.
                // This is derived from the original function name and the pointer parameter
                // chains.
                StringStream ss;
                ss << target->Declaration()->name->symbol.Name();
                for (auto* param : target->Parameters()) {
                    if (auto indices = target_signature.Find(param)) {
                        ss << "_" << AccessShapeName(*indices);
                    }
                }

                // Build the pointer parameter symbols.
                Hashmap<const sem::Parameter*, PtrParamSymbols, 4> ptr_param_symbols;
                for (auto param_it : target_signature) {
                    auto* param = param_it.key;
                    auto& shape = param_it.value;

                    // Parameter needs replacing with either zero, one or two parameters:
                    // If the parameter is in the 'private' or 'function' address space, then the
                    // originating pointer is always passed down. This always comes first.
                    // If the access chain has dynamic indices, then we create an array<u32, N>
                    // parameter to hold the dynamic indices.
                    bool requires_base_ptr_param = IsPrivateOrFunction(shape.root.address_space);
                    bool requires_indices_param = shape.NumDynamicIndices() > 0;

                    PtrParamSymbols symbols;
                    if (requires_base_ptr_param && requires_indices_param) {
                        auto original_name = param->Declaration()->name->symbol;
                        symbols.base_ptr = UniqueSymbolWithSuffix(original_name, "_base");
                        symbols.indices = UniqueSymbolWithSuffix(original_name, "_indices");
                    } else if (requires_base_ptr_param) {
                        symbols.base_ptr = ctx.Clone(param->Declaration()->name->symbol);
                    } else if (requires_indices_param) {
                        symbols.indices = ctx.Clone(param->Declaration()->name->symbol);
                    }

                    // Remember this base pointer name.
                    ptr_param_symbols.Add(param, symbols);
                }

                // Build the variant.
                FnVariant variant;
                variant.name = b.Symbols().New(ss.str());
                variant.order = target_info->variants.Count() + 1;
                variant.ptr_param_symbols = std::move(ptr_param_symbols);
                return variant;
            });

            // Record the call made by caller variant to the target variant.
            caller_variant.calls.Add(call, target_variant.name);
            if (!target_signature.IsEmpty()) {
                // The call expression will need transforming for at least one caller variant.
                call_needs_transforming = true;
            }
        }

        if (call_needs_transforming) {
            // Register the clone callback to correctly transform the call expression into the
            // appropriate variant calls.
            TransformCall(call);
        }
    }

    /// @returns true if the address space @p address_space requires transforming given the
    /// transform's options.
    bool AddressSpaceRequiresTransform(core::AddressSpace address_space) const {
        switch (address_space) {
            case core::AddressSpace::kUniform:
            case core::AddressSpace::kStorage:
            case core::AddressSpace::kWorkgroup:
                return true;
            case core::AddressSpace::kPrivate:
                return opts.transform_private;
            case core::AddressSpace::kFunction:
                return opts.transform_function;
            default:
                return false;
        }
    }

    /// @returns the AccessChain for the expression @p expr, or nullptr if the expression does
    /// not hold an access chain.
    AccessChain* AccessChainFor(const sem::ValueExpression* expr) const {
        if (auto chain = access_chains.Find(expr)) {
            return *chain;
        }
        return nullptr;
    }

    /// @returns the absolute AccessShape for @p indices, by replacing the originating pointer
    /// parameter with the AccessChain of variant's signature.
    AccessShape AbsoluteAccessShape(const FnVariant::Signature& signature,
                                    const AccessShape& shape) const {
        if (auto* root_param = shape.root.variable->As<sem::Parameter>()) {
            if (auto incoming_chain = signature.Find(root_param)) {
                // Access chain originates from a parameter, which will be transformed into an array
                // of dynamic indices. Concatenate the signature's AccessShape for the parameter
                // to the chain's indices, skipping over the chain's initial parameter index.
                auto absolute = *incoming_chain;
                for (auto& op : shape.ops) {
                    absolute.ops.Push(op);
                }
                return absolute;
            }
        }

        // Chain does not originate from a parameter, so is already absolute.
        return shape;
    }

    /// TransformFunction registers the clone callback to transform the function @p fn into the
    /// (potentially multiple) function's variants. TransformFunction will assign the current
    /// function and variant to #clone_state, which can be used by the other clone callbacks.
    void TransformFunction(const sem::Function* fn, FnInfo* fn_info) {
        // Register a custom handler for the specific function
        ctx.Replace(fn->Declaration(), [this, fn, fn_info] {
            // For the scope of this lambda, assign current_function to fn_info.
            TINT_SCOPED_ASSIGNMENT(clone_state->current_function, fn_info);

            // This callback expects a single function returned. As we're generating potentially
            // many variant functions, keep a record of the last created variant, and explicitly add
            // this to the module if it isn't the last. We'll return the last created variant,
            // taking the place of the original function.
            const Function* pending_variant = nullptr;

            // For each variant of fn...
            for (auto variant_it : fn_info->SortedVariants()) {
                if (pending_variant) {
                    b.AST().AddFunction(pending_variant);
                }

                auto& variant_sig = *variant_it.first;
                auto& variant = *variant_it.second;

                // For the rest of this scope, assign the current variant and variant signature.
                TINT_SCOPED_ASSIGNMENT(clone_state->current_variant_sig, &variant_sig);
                TINT_SCOPED_ASSIGNMENT(clone_state->current_variant, &variant);

                // Build the variant's parameters.
                // Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
                // either replaced with an array of dynamic indices, or are dropped (if there are no
                // dynamic indices).
                tint::Vector<const Parameter*, 8> params;
                for (auto* param : fn->Parameters()) {
                    if (auto incoming_shape = variant_sig.Find(param)) {
                        auto& symbols = *variant.ptr_param_symbols.Find(param);
                        if (symbols.base_ptr.IsValid()) {
                            auto base_ptr_ty =
                                b.ty.ptr(incoming_shape->root.address_space,
                                         CreateASTTypeFor(ctx, incoming_shape->root.type));
                            params.Push(b.Param(symbols.base_ptr, base_ptr_ty));
                        }
                        if (symbols.indices.IsValid()) {
                            // Variant has dynamic indices for this variant, replace it.
                            auto dyn_idx_arr_type = DynamicIndexArrayType(*incoming_shape);
                            params.Push(b.Param(symbols.indices, dyn_idx_arr_type));
                        }
                    } else {
                        // Just a regular parameter. Just clone the original parameter.
                        params.Push(ctx.Clone(param->Declaration()));
                    }
                }

                // Build the variant by cloning the source function. The other clone callbacks will
                // use clone_state->current_variant and clone_state->current_variant_sig to produce
                // the variant.
                auto ret_ty = ctx.Clone(fn->Declaration()->return_type);
                auto body = ctx.Clone(fn->Declaration()->body);
                auto attrs = ctx.Clone(fn->Declaration()->attributes);
                auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
                pending_variant =
                    b.create<Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
                                       std::move(attrs), std::move(ret_attrs));
            }

            return pending_variant;
        });
    }

    /// TransformCall registers the clone callback to transform the call expression @p call to call
    /// the correct target variant, and to replace pointers arguments with an array of dynamic
    /// indices.
    void TransformCall(const sem::Call* call) {
        // Register a custom handler for the specific call expression
        ctx.Replace(call->Declaration(), [this, call] {
            auto target_variant = clone_state->current_variant->calls.Find(call);
            if (!target_variant) {
                // The current variant does not need to transform this call.
                return ctx.CloneWithoutTransform(call->Declaration());
            }

            // Build the new call expressions's arguments.
            tint::Vector<const Expression*, 8> new_args;
            for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
                auto* arg = call->Arguments()[arg_idx];
                auto* param = call->Target()->Parameters()[arg_idx];
                auto* param_ty = param->Type()->As<core::type::Pointer>();
                if (!param_ty) {
                    // Parameter is not a pointer.
                    // Just clone the unaltered argument.
                    new_args.Push(ctx.Clone(arg->Declaration()));
                    continue;  // Parameter is not a pointer
                }

                auto* chain = AccessChainFor(arg);
                if (!chain) {
                    // No access chain means the argument is not a pointer that needs transforming.
                    // Just clone the unaltered argument.
                    new_args.Push(ctx.Clone(arg->Declaration()));
                    continue;
                }

                // Construct the absolute AccessShape by considering the AccessShape of the caller
                // variant's argument. This will propagate back through pointer parameters, to the
                // outermost caller.
                auto full_indices = AbsoluteAccessShape(*clone_state->current_variant_sig, *chain);

                // If the parameter is a pointer in the 'private' or 'function' address space, then
                // we need to pass an additional pointer argument to the base object.
                if (IsPrivateOrFunction(param_ty->AddressSpace())) {
                    auto* root_expr = BuildAccessRootExpr(chain->root, /* deref */ false);
                    if (!chain->root.variable->Is<sem::Parameter>()) {
                        root_expr = b.AddressOf(root_expr);
                    }
                    new_args.Push(root_expr);
                }

                // Get or create the dynamic indices array.
                if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
                    // Build an array of dynamic indices to pass as the replacement for the pointer.
                    tint::Vector<const Expression*, 8> dyn_idx_args;
                    if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
                        // Access chain originates from a pointer parameter.
                        if (auto incoming_chain =
                                clone_state->current_variant_sig->Find(root_param)) {
                            auto indices =
                                clone_state->current_variant->ptr_param_symbols.Find(root_param)
                                    ->indices;

                            // This pointer parameter will have been replaced with a array<u32, N>
                            // holding the variant's dynamic indices for the pointer. Unpack these
                            // directly into the array constructor's arguments.
                            auto N = incoming_chain->NumDynamicIndices();
                            for (uint32_t i = 0; i < N; i++) {
                                dyn_idx_args.Push(b.IndexAccessor(indices, u32(i)));
                            }
                        }
                    }
                    // Pass the dynamic indices of the access chain into the array constructor.
                    for (auto& dyn_idx : chain->dynamic_indices) {
                        dyn_idx_args.Push(BuildDynamicIndex(dyn_idx, /* cast_to_u32 */ true));
                    }
                    // Construct the dynamic index array, and push as an argument.
                    new_args.Push(b.Call(dyn_idx_arr_ty, std::move(dyn_idx_args)));
                }
            }

            // Make the call to the target's variant.
            return b.Call(*target_variant, std::move(new_args));
        });
    }

    /// ProcessAccessChains performs the following:
    /// * Removes all AccessChains from expressions that are not either used as a pointer argument
    ///   in a call, or originates from a pointer parameter.
    /// * Hoists the dynamic index expressions of AccessChains to 'let' statements, to prevent
    ///   multiple evaluation of the expressions, and avoid expressions resolving to different
    ///   variables based on lexical scope.
    void ProcessAccessChains() {
        auto chain_exprs = access_chains.Keys();
        chain_exprs.Sort([](const auto& expr_a, const auto& expr_b) {
            return expr_a->Declaration()->node_id.value < expr_b->Declaration()->node_id.value;
        });

        for (auto* expr : chain_exprs) {
            auto* chain = *access_chains.Get(expr);
            if (!chain->used_in_call && !chain->root.variable->Is<sem::Parameter>()) {
                // Chain was not used in a function call, and does not originate from a
                // parameter. This chain does not need transforming. Drop it.
                access_chains.Remove(expr);
                continue;
            }

            // Chain requires transforming.

            // We need to be careful that the chain does not use expressions with side-effects which
            // cannot be repeatedly evaluated. In this situation we can hoist the dynamic index
            // expressions to their own uniquely named lets (if required).
            MaybeHoistDynamicIndices(chain, expr->Stmt());
        }
    }

    /// TransformAccessChainExpressions registers the clone callback to:
    /// * Transform all expressions that have an AccessChain (which aren't arguments to function
    ///   calls, these are handled by TransformCall()), into the equivalent expression using a
    ///   module-scope variable.
    /// * Replace expressions that have been hoisted to a let, with an identifier expression to that
    ///   let.
    void TransformAccessChainExpressions() {
        // Register a custom handler for all non-function call expressions
        ctx.ReplaceAll([this](const Expression* ast_expr) -> const Expression* {
            if (!clone_state->current_variant) {
                // Expression does not belong to a function variant.
                return nullptr;  // Just clone the expression.
            }

            auto* expr = sem.GetVal(ast_expr);
            if (!expr) {
                // No semantic node for the expression.
                return nullptr;  // Just clone the expression.
            }

            // If the expression has been hoisted to a 'let', then replace the expression with an
            // identifier to the hoisted let.
            if (auto hoisted = clone_state->current_function->hoisted_exprs.Find(expr)) {
                return b.Expr(*hoisted);
            }

            auto* chain = AccessChainFor(expr);
            if (!chain) {
                // The expression does not have an AccessChain.
                return nullptr;  // Just clone the expression.
            }

            auto* root_param = chain->root.variable->As<sem::Parameter>();
            if (!root_param) {
                // The expression has an access chain, but does not originate with a pointer
                // parameter. We don't need to change anything here.
                return nullptr;  // Just clone the expression.
            }

            auto incoming_shape = clone_state->current_variant_sig->Find(root_param);
            if (!incoming_shape) {
                // The root parameter of the access chain is not part of the variant's signature.
                return nullptr;  // Just clone the expression.
            }

            // Expression holds an access chain to a pointer parameter that needs transforming.
            // Reconstruct the expression using the variant's incoming shape.

            auto* chain_expr = BuildAccessRootExpr(incoming_shape->root, /* deref */ true);

            // Chain starts with a pointer parameter.
            // Replace this with the variant's incoming shape. This will bring the expression up to
            // the incoming pointer.
            size_t next_dyn_idx_from_indices = 0;
            auto indices =
                clone_state->current_variant->ptr_param_symbols.Find(root_param)->indices;
            for (auto param_access : incoming_shape->ops) {
                chain_expr = BuildAccessExpr(chain_expr, param_access, [&] {
                    return b.IndexAccessor(indices, AInt(next_dyn_idx_from_indices++));
                });
            }

            // Now build the expression chain within the function.

            // For each access in the chain (excluding the pointer parameter)...
            size_t next_dyn_idx_from_chain = 0;
            for (auto& op : chain->ops) {
                chain_expr = BuildAccessExpr(chain_expr, op, [&] {
                    return BuildDynamicIndex(chain->dynamic_indices[next_dyn_idx_from_chain++],
                                             false);
                });
            }

            // BuildAccessExpr() always returns a non-pointer.
            // If the expression we're replacing is a pointer, take the address.
            if (expr->Type()->Is<core::type::Pointer>()) {
                chain_expr = b.AddressOf(chain_expr);
            }

            return chain_expr;
        });
    }

    /// @returns the FnInfo for the given function, constructing a new FnInfo if @p fn doesn't
    /// already have one.
    FnInfo* FnInfoFor(const sem::Function* fn) {
        return fns.GetOrCreate(fn, [this] { return fn_info_allocator.Create(); });
    }

    /// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
    /// if this is the first call for the given shape.
    Type DynamicIndexArrayType(const AccessShape& shape) {
        auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
            // Count the number of dynamic indices
            uint32_t num_dyn_indices = shape.NumDynamicIndices();
            if (num_dyn_indices == 0) {
                return Symbol{};
            }
            auto symbol = b.Symbols().New(AccessShapeName(shape));
            b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
            return symbol;
        });
        return name.IsValid() ? b.ty(name) : Type{};
    }

    /// @returns a name describing the given shape
    std::string AccessShapeName(const AccessShape& shape) {
        StringStream ss;

        if (IsPrivateOrFunction(shape.root.address_space)) {
            ss << "F";
        } else {
            ss << shape.root.variable->Declaration()->name->symbol.Name();
        }

        for (auto& op : shape.ops) {
            ss << "_";

            if (std::holds_alternative<DynamicIndex>(op)) {
                /// The op uses a dynamic (runtime-expression) index.
                ss << "X";
                continue;
            }

            auto* member = std::get_if<Symbol>(&op);
            if (TINT_LIKELY(member)) {
                ss << member->Name();
                continue;
            }

            TINT_ICE() << "unhandled variant for access chain";
            break;
        }
        return ss.str();
    }

    /// Builds an expresion to the root of an access, returning the new expression.
    /// @param root the AccessRoot
    /// @param deref if true, the returned expression will always be a reference type.
    const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
        if (auto* param = root.variable->As<sem::Parameter>()) {
            if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
                if (deref) {
                    return b.Deref(b.Expr(symbols->base_ptr));
                }
                return b.Expr(symbols->base_ptr);
            }
        }

        const Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
        if (deref) {
            if (root.variable->Type()->Is<core::type::Pointer>()) {
                expr = b.Deref(expr);
            }
        }
        return expr;
    }

    /// Builds a single access in an access chain, returning the new expression.
    /// The returned expression will always be of a reference type.
    /// @param expr the input expression
    /// @param access the access to perform on the current expression
    /// @param dynamic_index a function that obtains the next dynamic index
    const Expression* BuildAccessExpr(const Expression* expr,
                                      const AccessOp& access,
                                      std::function<const Expression*()> dynamic_index) {
        if (std::holds_alternative<DynamicIndex>(access)) {
            /// The access uses a dynamic (runtime-expression) index.
            auto* idx = dynamic_index();
            return b.IndexAccessor(expr, idx);
        }

        auto* member = std::get_if<Symbol>(&access);
        if (TINT_LIKELY(member)) {
            /// The access is a member access.
            return b.MemberAccessor(expr, ctx.Clone(*member));
        }

        TINT_ICE() << "unhandled variant type for access chain";
        return nullptr;
    }

    /// @returns a new Symbol starting with @p symbol concatenated with @p suffix, and possibly an
    /// underscore and number, if the symbol is already taken.
    Symbol UniqueSymbolWithSuffix(Symbol symbol, const std::string& suffix) {
        auto str = symbol.Name() + suffix;
        return unique_symbols.GetOrCreate(str, [&] { return b.Symbols().New(str); });
    }

    /// @returns true if the function @p fn has at least one pointer parameter.
    static bool HasPointerParameter(const sem::Function* fn) {
        for (auto* param : fn->Parameters()) {
            if (param->Type()->Is<core::type::Pointer>()) {
                return true;
            }
        }
        return false;
    }

    /// @returns true if the function @p fn has at least one pointer parameter in an address space
    /// that must be replaced. If this function is not called, then the function cannot be sensibly
    /// generated, and must be stripped.
    static bool MustBeCalled(const sem::Function* fn) {
        for (auto* param : fn->Parameters()) {
            if (auto* ptr = param->Type()->As<core::type::Pointer>()) {
                switch (ptr->AddressSpace()) {
                    case core::AddressSpace::kUniform:
                    case core::AddressSpace::kStorage:
                    case core::AddressSpace::kWorkgroup:
                        return true;
                    default:
                        return false;
                }
            }
        }
        return false;
    }

    /// @returns true if the given address space is 'private' or 'function'.
    static bool IsPrivateOrFunction(const core::AddressSpace sc) {
        return sc == core::AddressSpace::kPrivate || sc == core::AddressSpace::kFunction;
    }
};

DirectVariableAccess::Config::Config(const Options& opt) : options(opt) {}

DirectVariableAccess::Config::~Config() = default;

DirectVariableAccess::DirectVariableAccess() = default;

DirectVariableAccess::~DirectVariableAccess() = default;

Transform::ApplyResult DirectVariableAccess::Apply(const Program& program,
                                                   const DataMap& inputs,
                                                   DataMap&) const {
    Options options;
    if (auto* cfg = inputs.Get<Config>()) {
        options = cfg->options;
    }
    return State(program, options).Run();
}

}  // namespace tint::ast::transform
