// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"

#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/block_statement.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/rtti/switch.h"

TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SimplifyPointers);

namespace tint::ast::transform {

namespace {

/// PointerOp describes either possible indirection or address-of action on an
/// expression.
struct PointerOp {
    /// Positive: Number of times the `expr` was dereferenced (*expr)
    /// Negative: Number of times the `expr` was 'addressed-of' (&expr)
    /// Zero: no pointer op on `expr`
    int indirections = 0;
    /// The expression being operated on
    const Expression* expr = nullptr;
};

}  // namespace

/// PIMPL state for the transform
struct SimplifyPointers::State {
    /// The source program
    const Program& src;
    /// The target program builder
    ProgramBuilder b;
    /// The clone context
    program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};

    /// Constructor
    /// @param program the source program
    explicit State(const Program& program) : src(program) {}

    /// Traverses the expression `expr` looking for non-literal array indexing
    /// expressions that would affect the computed address of a pointer
    /// expression. The function-like argument `cb` is called for each found.
    /// @param expr the expression to traverse
    /// @param cb a function-like object with the signature
    /// `void(const Expression*)`, which is called for each array index
    /// expression
    template <typename F>
    static void CollectSavedArrayIndices(const Expression* expr, F&& cb) {
        if (auto* a = expr->As<IndexAccessorExpression>()) {
            CollectSavedArrayIndices(a->object, cb);
            if (!a->index->Is<LiteralExpression>()) {
                cb(a->index);
            }
            return;
        }

        if (auto* m = expr->As<MemberAccessorExpression>()) {
            CollectSavedArrayIndices(m->object, cb);
            return;
        }

        if (auto* u = expr->As<UnaryOpExpression>()) {
            CollectSavedArrayIndices(u->expr, cb);
            return;
        }

        // Note: Other Expression types can be safely ignored as they cannot be
        // used to generate a reference or pointer.
        // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
    }

    /// Reduce walks the expression chain, collapsing all address-of and
    /// indirection ops into a PointerOp.
    /// @param in the expression to walk
    /// @returns the reduced PointerOp
    PointerOp Reduce(const Expression* in) const {
        PointerOp op{0, in};
        while (true) {
            if (auto* unary = op.expr->As<UnaryOpExpression>()) {
                switch (unary->op) {
                    case core::UnaryOp::kIndirection:
                        op.indirections++;
                        op.expr = unary->expr;
                        continue;
                    case core::UnaryOp::kAddressOf:
                        op.indirections--;
                        op.expr = unary->expr;
                        continue;
                    default:
                        break;
                }
            }
            if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
                auto* var = user->Variable();
                if (var->Is<sem::LocalVariable>() &&  //
                    var->Declaration()->Is<Let>() &&  //
                    var->Type()->Is<core::type::Pointer>()) {
                    op.expr = var->Declaration()->initializer;
                    continue;
                }
            }
            return op;
        }
    }

    /// Runs the transform
    /// @returns the new program or SkipTransform if the transform is not required
    ApplyResult Run() {
        // A map of saved expressions to their saved variable name
        Hashmap<const Expression*, Symbol, 8> saved_vars;

        bool needs_transform = false;
        for (auto* ty : ctx.src->Types()) {
            if (ty->Is<core::type::Pointer>()) {
                // Program contains pointers which need removing.
                needs_transform = true;
                break;
            }
        }

        // Find all the pointer-typed `let` declarations.
        // Note that these must be function-scoped, as module-scoped `let`s are not
        // permitted.
        for (auto* node : ctx.src->ASTNodes().Objects()) {
            Switch(
                node,  //
                [&](const VariableDeclStatement* let) {
                    if (!let->variable->Is<Let>()) {
                        return;  // Not a `let` declaration. Ignore.
                    }

                    auto* var = ctx.src->Sem().Get(let->variable);
                    if (!var->Type()->Is<core::type::Pointer>()) {
                        return;  // Not a pointer type. Ignore.
                    }

                    // We're dealing with a pointer-typed `let` declaration.

                    // Scan the initializer expression for array index expressions that need
                    // to be hoist to temporary "saved" variables.
                    tint::Vector<const VariableDeclStatement*, 8> saved;
                    CollectSavedArrayIndices(
                        var->Declaration()->initializer, [&](const Expression* idx_expr) {
                            // We have a sub-expression that needs to be saved.
                            // Create a new variable
                            auto saved_name = ctx.dst->Symbols().New(
                                var->Declaration()->name->symbol.Name() + "_save");
                            auto* decl =
                                ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
                            saved.Push(decl);
                            // Record the substitution of `idx_expr` to the saved variable
                            // with the symbol `saved_name`. This will be used by the
                            // ReplaceAll() handler above.
                            saved_vars.Add(idx_expr, saved_name);
                        });

                    // Find the place to insert the saved declarations.
                    // Special care needs to be made for lets declared as the initializer
                    // part of for-loops. In this case the block will hold the for-loop
                    // statement, not the let.
                    if (!saved.IsEmpty()) {
                        auto* stmt = ctx.src->Sem().Get(let);
                        auto* block = stmt->Block();
                        // Find the statement owned by the block (either the let decl or a
                        // for-loop)
                        while (block != stmt->Parent()) {
                            stmt = stmt->Parent();
                        }
                        // Declare the stored variables just before stmt. Order here is
                        // important as order-of-operations needs to be preserved.
                        // CollectSavedArrayIndices() visits the LHS of an index accessor
                        // before the index expression.
                        for (auto* decl : saved) {
                            // Note that repeated calls to InsertBefore() with the same `before`
                            // argument will result in nodes to inserted in the order the
                            // calls are made (last call is inserted last).
                            ctx.InsertBefore(block->Declaration()->statements, stmt->Declaration(),
                                             decl);
                        }
                    }

                    // As the original `let` declaration will be fully inlined, there's no
                    // need for the original declaration to exist. Remove it.
                    RemoveStatement(ctx, let);
                },
                [&](const UnaryOpExpression* op) {
                    if (op->op == core::UnaryOp::kAddressOf) {
                        // Transform can be skipped if no address-of operator is used, as there
                        // will be no pointers that can be inlined.
                        needs_transform = true;
                    }
                });
        }

        if (!needs_transform) {
            return SkipTransform;
        }

        // Register the Expression transform handler.
        // This performs two different transformations:
        // * Identifiers that resolve to the pointer-typed `let` declarations are
        // replaced with the recursively inlined initializer expression for the
        // `let` declaration.
        // * Sub-expressions inside the pointer-typed `let` initializer expression
        // that have been hoisted to a saved variable are replaced with the saved
        // variable identifier.
        ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
            // Look to see if we need to swap this Expression with a saved variable.
            if (auto saved_var = saved_vars.Find(expr)) {
                return ctx.dst->Expr(*saved_var);
            }

            // Reduce the expression, folding away chains of address-of / indirections
            auto op = Reduce(expr);

            // Clone the reduced root expression
            expr = ctx.CloneWithoutTransform(op.expr);

            // And reapply the minimum number of address-of / indirections
            for (int i = 0; i < op.indirections; i++) {
                expr = ctx.dst->Deref(expr);
            }
            for (int i = 0; i > op.indirections; i--) {
                expr = ctx.dst->AddressOf(expr);
            }
            return expr;
        });

        ctx.Clone();
        return resolver::Resolve(b);
    }
};

SimplifyPointers::SimplifyPointers() = default;

SimplifyPointers::~SimplifyPointers() = default;

Transform::ApplyResult SimplifyPointers::Apply(const Program& src, const DataMap&, DataMap&) const {
    return State(src).Run();
}

}  // namespace tint::ast::transform
