// Copyright 2021 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
//    contributors may be used to endorse or promote products derived from
//    this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "src/tint/lang/wgsl/ast/transform/array_length_from_uniform.h"

#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <utility>

#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/unary_op.h"
#include "src/tint/lang/wgsl/ast/expression.h"
#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
#include "src/tint/lang/wgsl/ast/unary_op_expression.h"
#include "src/tint/lang/wgsl/ast/variable.h"
#include "src/tint/lang/wgsl/builtin_fn.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/array.h"
#include "src/tint/lang/wgsl/sem/builtin_fn.h"
#include "src/tint/lang/wgsl/sem/call.h"
#include "src/tint/lang/wgsl/sem/expression.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/member_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/containers/unique_vector.h"
#include "src/tint/utils/diagnostic/diagnostic.h"
#include "src/tint/utils/ice/ice.h"
#include "src/tint/utils/rtti/switch.h"
#include "src/tint/utils/text/text_style.h"

TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ArrayLengthFromUniform);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ArrayLengthFromUniform::Config);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ArrayLengthFromUniform::Result);

using namespace tint::core::fluent_types;  // NOLINT

namespace tint::ast::transform {
namespace {

bool ShouldRun(const Program& program) {
    for (auto* fn : program.AST().Functions()) {
        if (auto* sem_fn = program.Sem().Get(fn)) {
            for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
                if (builtin->Fn() == wgsl::BuiltinFn::kArrayLength) {
                    return true;
                }
            }
        }
    }
    return false;
}

}  // namespace

ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;

/// PIMPL state for the transform
struct ArrayLengthFromUniform::State {
    /// Constructor
    /// @param program the source program
    /// @param in the input transform data
    /// @param out the output transform data
    State(const Program& program, const DataMap& in, DataMap& out)
        : src(program), outputs(out), cfg(in.Get<Config>()) {}

    /// Runs the transform
    /// @returns the new program or SkipTransform if the transform is not required
    ApplyResult Run() {
        if (cfg == nullptr) {
            b.Diagnostics().AddError(Source{}) << "missing transform data for "
                                               << tint::TypeInfo::Of<ArrayLengthFromUniform>().name;
            return resolver::Resolve(b);
        }

        if (cfg->bindpoint_to_size_index.empty() || !ShouldRun(src)) {
            return SkipTransform;
        }

        // Create the name of the array lengths uniform variable.
        array_lengths_var = b.Symbols().New("tint_array_lengths");

        // Replace all the arrayLength() calls.
        for (auto* fn : src.AST().Functions()) {
            if (auto* sem_fn = sem.Get(fn)) {
                for (auto* call : sem_fn->DirectCalls()) {
                    if (auto* target = call->Target()->As<sem::BuiltinFn>()) {
                        if (target->Fn() == wgsl::BuiltinFn::kArrayLength) {
                            ReplaceArrayLengthCall(call);
                        }
                    }
                }
            }
        }

        // Add the necessary array-length arguments to all the newly created array-length
        // parameters.
        while (!len_params_needing_args.IsEmpty()) {
            AddArrayLengthArguments(len_params_needing_args.Pop());
        }

        // Add the tint_array_lengths module-scope uniform variable.
        AddArrayLengthsUniformVar();

        outputs.Add<Result>(used_size_indices);

        ctx.Clone();
        return resolver::Resolve(b);
    }

  private:
    // Replaces the arrayLength() builtin call with an array-length expression passed via a uniform
    // buffer.
    void ReplaceArrayLengthCall(const sem::Call* call) {
        if (auto* replacement = ArrayLengthOf(call->Arguments()[0])) {
            ctx.Replace(call->Declaration(), replacement);
        }
    }

    /// @returns an AST expression that is equal to the arrayLength() of the runtime-sized array
    /// accessed by the pointer expression @p expr, or nullptr on error or if the array is not in
    /// the Config::bindpoint_to_size_index map.
    const ast::Expression* ArrayLengthOf(const sem::Expression* expr) {
        const ast::Expression* len = nullptr;
        while (expr) {
            expr = Switch(
                expr,  //
                [&](const sem::VariableUser* user) {
                    len = ArrayLengthOf(user->Variable());
                    return nullptr;
                },
                [&](const sem::MemberAccessorExpression* access) {
                    return access->Object();  // Follow the object
                },
                [&](const sem::Expression* e) {
                    return Switch(
                        e->Declaration(),  //
                        [&](const ast::UnaryOpExpression* unary) -> const sem::Expression* {
                            switch (unary->op) {
                                case core::UnaryOp::kAddressOf:
                                case core::UnaryOp::kIndirection:
                                    return sem.Get(unary->expr);  // Follow the object
                                default:
                                    TINT_ICE() << "unexpected unary op: " << unary->op;
                                    return nullptr;
                            }
                        },
                        TINT_ICE_ON_NO_MATCH);
                },
                TINT_ICE_ON_NO_MATCH);
        }
        return len;
    }

    /// @returns an AST expression that is equal to the arrayLength() of the runtime-sized array
    /// held by the module-scope variable or parameter @p var, or nullptr on error or if the array
    /// is not in the Config::bindpoint_to_size_index map.
    const ast::Expression* ArrayLengthOf(const sem::Variable* var) {
        return Switch(
            var,  //
            [&](const sem::GlobalVariable* global) { return ArrayLengthOf(global); },
            [&](const sem::Parameter* param) { return ArrayLengthOf(param); },
            TINT_ICE_ON_NO_MATCH);
    }

    /// @returns an AST expression that is equal to the arrayLength() of the runtime-sized array
    /// held by the module scope variable @p global, or nullptr on error or if the array is not in
    /// the Config::bindpoint_to_size_index map.
    const ast::Expression* ArrayLengthOf(const sem::GlobalVariable* global) {
        auto binding = global->Attributes().binding_point;
        TINT_ASSERT_OR_RETURN_VALUE(binding, nullptr);

        auto idx_it = cfg->bindpoint_to_size_index.find(*binding);
        if (idx_it == cfg->bindpoint_to_size_index.end()) {
            // If the bindpoint_to_size_index map does not contain an entry for the storage buffer,
            // then we preserve the arrayLength() call.
            return nullptr;
        }

        uint32_t size_index = idx_it->second;
        used_size_indices.insert(size_index);

        // Load the total storage buffer size from the UBO.
        uint32_t array_index = size_index / 4;
        auto* vec_expr = b.IndexAccessor(
            b.MemberAccessor(array_lengths_var, kArrayLengthsMemberName), u32(array_index));
        uint32_t vec_index = size_index % 4;
        auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index));

        // Calculate actual array length
        //                total_storage_buffer_size - array_offset
        // array_length = ----------------------------------------
        //                             array_stride
        const Expression* total_size = total_storage_buffer_size;
        if (TINT_UNLIKELY(global->Type()->Is<core::type::Pointer>())) {
            TINT_ICE() << "storage buffer variable should not be a pointer. "
                          "These should have been removed by the SimplifyPointers transform";
            return nullptr;
        }
        auto* storage_buffer_type = global->Type()->UnwrapRef();
        const core::type::Array* array_type = nullptr;
        if (auto* str = storage_buffer_type->As<core::type::Struct>()) {
            // The variable is a struct, so subtract the byte offset of the
            // array member.
            auto* array_member_sem = str->Members().Back();
            array_type = array_member_sem->Type()->As<core::type::Array>();
            total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
        } else if (auto* arr = storage_buffer_type->As<core::type::Array>()) {
            array_type = arr;
        } else {
            TINT_ICE() << "expected form of arrayLength argument to be &array_var or "
                          "&struct_var.array_member";
            return nullptr;
        }
        return b.Div(total_size, u32(array_type->Stride()));
    }

    /// @returns an AST expression that is equal to the arrayLength() of the runtime-sized array
    /// held by the object pointed to by the pointer parameter @p param.
    const ast::Expression* ArrayLengthOf(const sem::Parameter* param) {
        // Pointer originates from a parameter.
        // Add a new array length parameter to the function, and use that.
        auto len_name = param_lengths.GetOrAdd(param, [&] {
            auto* fn = param->Owner()->As<sem::Function>();
            auto name = b.Symbols().New(param->Declaration()->name->symbol.Name() + "_length");
            auto* len_param = b.Param(name, b.ty.u32());
            ctx.InsertAfter(fn->Declaration()->params, param->Declaration(), len_param);
            len_params_needing_args.Add(param);
            return name;
        });
        return b.Expr(len_name);
    }

    /// Constructs the uniform buffer variable that will hold the array lengths.
    void AddArrayLengthsUniformVar() {
        // Calculate the highest index in the array lengths array
        uint32_t highest_index = 0;
        for (auto idx : used_size_indices) {
            if (idx > highest_index) {
                highest_index = idx;
            }
        }

        // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
        // We do this because UBOs require an element stride that is 16-byte aligned.
        auto* buffer_size_struct =
            b.Structure(b.Symbols().New("TintArrayLengths"),
                        tint::Vector{
                            b.Member(kArrayLengthsMemberName,
                                     b.ty.array(b.ty.vec4<u32>(), u32((highest_index / 4) + 1))),
                        });
        b.GlobalVar(array_lengths_var, b.ty.Of(buffer_size_struct), core::AddressSpace::kUniform,
                    b.Group(AInt(cfg->ubo_binding.group)),
                    b.Binding(AInt(cfg->ubo_binding.binding)));
    }

    /// Adds an additional array-length argument to all the calls to the function that owns the
    /// pointer parameter @p param. This may add new entries to #len_params_needing_args.
    void AddArrayLengthArguments(const sem::Parameter* param) {
        auto* fn = param->Owner()->As<sem::Function>();
        for (auto* call : fn->CallSites()) {
            auto* arg = call->Arguments()[param->Index()];
            if (auto* len = ArrayLengthOf(arg); len) {
                ctx.InsertAfter(call->Declaration()->args, arg->Declaration(), len);
            } else {
                // Callee expects an array length, but there's no binding for it.
                // Call arrayLength() at the call-site.
                len = b.Call(wgsl::BuiltinFn::kArrayLength, ctx.Clone(arg->Declaration()));
                ctx.InsertAfter(call->Declaration()->args, arg->Declaration(), len);
            }
        }
    }

    /// Name of the array-lengths struct member that holds all the array lengths.
    static constexpr std::string_view kArrayLengthsMemberName = "array_lengths";

    /// The source program
    const Program& src;
    /// The transform outputs
    DataMap& outputs;
    /// The transform config
    const Config* const cfg;
    /// The target program builder
    ProgramBuilder b;
    /// The clone context
    program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};
    /// Alias to src.Sem()
    const sem::Info& sem = src.Sem();
    /// Name of the uniform buffer variable that holds the array lengths
    Symbol array_lengths_var;
    /// A map of pointer-parameter to the name of the new array-length parameter.
    Hashmap<const sem::Parameter*, Symbol, 8> param_lengths;
    /// Indices into the uniform buffer array indices that are statically used.
    std::unordered_set<uint32_t> used_size_indices;
    /// A vector of array-length parameters which need corresponding array-length arguments for all
    /// callsites.
    UniqueVector<const sem::Parameter*, 8> len_params_needing_args;
};

Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program& src,
                                                     const DataMap& inputs,
                                                     DataMap& outputs) const {
    return State{src, inputs, outputs}.Run();
}

ArrayLengthFromUniform::Config::Config() = default;
ArrayLengthFromUniform::Config::Config(BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}
ArrayLengthFromUniform::Config::Config(const Config&) = default;
ArrayLengthFromUniform::Config& ArrayLengthFromUniform::Config::operator=(const Config&) = default;
ArrayLengthFromUniform::Config::~Config() = default;

ArrayLengthFromUniform::Result::Result(std::unordered_set<uint32_t> used_size_indices_in)
    : used_size_indices(std::move(used_size_indices_in)) {}
ArrayLengthFromUniform::Result::Result(const Result&) = default;
ArrayLengthFromUniform::Result::~Result() = default;

}  // namespace tint::ast::transform
