// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/lang/msl/writer/ast_raise/pixel_local.h"

#include <utility>

#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/module.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/lang/wgsl/sem/struct.h"
#include "src/tint/utils/containers/transform.h"

TINT_INSTANTIATE_TYPEINFO(tint::msl::writer::PixelLocal);
TINT_INSTANTIATE_TYPEINFO(tint::msl::writer::PixelLocal::Attachment);
TINT_INSTANTIATE_TYPEINFO(tint::msl::writer::PixelLocal::Config);

using namespace tint::core::number_suffixes;  // NOLINT
using namespace tint::core::fluent_types;     // NOLINT

namespace tint::msl::writer {

/// PIMPL state for the transform
struct PixelLocal::State {
    /// The source program
    const Program* const src;
    /// The target program builder
    ProgramBuilder b;
    /// The clone context
    program::CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
    /// The transform config
    const Config& cfg;

    /// Constructor
    /// @param program the source program
    /// @param config the transform config
    State(const Program* program, const Config& config) : src(program), cfg(config) {}

    /// Runs the transform
    /// @returns the new program or SkipTransform if the transform is not required
    ApplyResult Run() {
        auto& sem = src->Sem();

        // If the pixel local extension isn't enabled, then there must be no use of pixel_local
        // variables, and so there's nothing for this transform to do.
        if (!sem.Module()->Extensions().Contains(
                core::Extension::kChromiumExperimentalPixelLocal)) {
            return SkipTransform;
        }

        bool made_changes = false;

        // Change all module scope `var<pixel_local>` variables to `var<private>`.
        // We need to do this even if the variable is not referenced by the entry point as later
        // stages do not understand the pixel_local address space.
        for (auto* global : src->AST().GlobalVariables()) {
            if (auto* var = global->As<ast::Var>()) {
                if (sem.Get(var)->AddressSpace() == core::AddressSpace::kPixelLocal) {
                    // Change the 'var<pixel_local>' to 'var<private>'
                    ctx.Replace(var->declared_address_space, b.Expr(core::AddressSpace::kPrivate));
                    made_changes = true;
                }
            }
        }

        // Find the single entry point
        const sem::Function* entry_point = nullptr;
        for (auto* fn : src->AST().Functions()) {
            if (fn->IsEntryPoint()) {
                if (entry_point != nullptr) {
                    TINT_ICE() << "PixelLocal transform requires that the SingleEntryPoint "
                                  "transform has already been run";
                    return SkipTransform;
                }
                entry_point = sem.Get(fn);

                // Look for a `var<pixel_local>` used by the entry point...
                for (auto* global : entry_point->TransitivelyReferencedGlobals()) {
                    if (global->AddressSpace() != core::AddressSpace::kPixelLocal) {
                        continue;
                    }

                    // Obtain struct of the pixel local.
                    auto* pixel_local_str = global->Type()->UnwrapRef()->As<sem::Struct>();

                    // Add an attachment decoration to each member of the pixel_local structure.
                    for (auto* member : pixel_local_str->Members()) {
                        ctx.InsertBack(member->Declaration()->attributes,
                                       Attachment(AttachmentIndex(member->Index())));
                        ctx.InsertBack(member->Declaration()->attributes,
                                       b.Disable(ast::DisabledValidation::kEntryPointParameter));
                    }

                    TransformEntryPoint(entry_point, global, pixel_local_str);
                    made_changes = true;

                    break;  // Only a single `var<pixel_local>` can be used by an entry point.
                }
            }
        }

        if (!made_changes) {
            return SkipTransform;
        }

        ctx.Clone();
        return resolver::Resolve(b);
    }

    /// Transforms the entry point @p entry_point to handle the direct or transitive usage of the
    /// `var<pixel_local>` @p pixel_local_var.
    /// @param entry_point the entry point
    /// @param pixel_local_var the `var<pixel_local>`
    /// @param pixel_local_str the struct type of the var
    void TransformEntryPoint(const sem::Function* entry_point,
                             const sem::GlobalVariable* pixel_local_var,
                             const sem::Struct* pixel_local_str) {
        auto* fn = entry_point->Declaration();
        auto fn_name = fn->name->symbol.Name();
        auto pixel_local_str_name = ctx.Clone(pixel_local_str->Name());
        auto pixel_local_var_name = ctx.Clone(pixel_local_var->Declaration()->name->symbol);

        // Remove the @fragment attribute from the entry point
        ctx.Remove(fn->attributes, ast::GetAttribute<ast::StageAttribute>(fn->attributes));
        // Rename the entry point
        auto inner_name = b.Symbols().New(fn_name + "_inner");
        ctx.Replace(fn->name, b.Ident(inner_name));

        // Create a new function that wraps the entry point.
        // This function has all the existing entry point parameters and an additional
        // parameter for the input pixel local structure.
        auto params = ctx.Clone(fn->params);
        auto pl_param = b.Symbols().New("pixel_local");
        params.Push(b.Param(pl_param, b.ty(pixel_local_str_name)));

        // Remove any entry-point attributes from the inner function.
        // This must come after `ctx.Clone(fn->params)` as we want these attributes on the outer
        // function.
        for (auto* param : fn->params) {
            for (auto* attr : param->attributes) {
                if (attr->IsAnyOf<ast::BuiltinAttribute, ast::LocationAttribute,
                                  ast::InterpolateAttribute, ast::InvariantAttribute>()) {
                    ctx.Remove(param->attributes, attr);
                }
            }
        }

        // Build the outer function's statements, starting with an assignment of the pixel local
        // parameter to the module scope var.
        Vector<const ast::Statement*, 3> body{
            b.Assign(pixel_local_var_name, pl_param),
        };

        // Build the arguments to call the inner function
        auto call_args =
            tint::Transform(fn->params, [&](auto* p) { return b.Expr(ctx.Clone(p->name)); });

        // Create a structure to hold the combined flattened result of the entry point and the pixel
        // local structure.
        auto str_name = b.Symbols().New(fn_name + "_res");
        Vector<const ast::StructMember*, 8> members;
        Vector<const ast::Expression*, 8> return_args;  // arguments to the final `return` statement

        auto add_member = [&](const core::type::Type* ty, VectorRef<const ast::Attribute*> attrs) {
            members.Push(b.Member("output_" + std::to_string(members.Length()),
                                  CreateASTTypeFor(ctx, ty), std::move(attrs)));
        };
        for (auto* member : pixel_local_str->Members()) {
            add_member(member->Type(), Vector{
                                           b.Location(AInt(AttachmentIndex(member->Index()))),
                                       });
            return_args.Push(b.MemberAccessor(pixel_local_var_name, ctx.Clone(member->Name())));
        }
        if (fn->return_type) {
            Symbol call_result = b.Symbols().New("result");
            if (auto* str = entry_point->ReturnType()->As<sem::Struct>()) {
                // The entry point returned a structure.
                for (auto* member : str->Members()) {
                    auto& member_attrs = member->Declaration()->attributes;
                    add_member(member->Type(), ctx.Clone(member_attrs));
                    return_args.Push(b.MemberAccessor(call_result, ctx.Clone(member->Name())));
                    if (auto* location = ast::GetAttribute<ast::LocationAttribute>(member_attrs)) {
                        // Remove the @location attribute from the member of the inner function's
                        // output structure.
                        // Note: This will break other entry points that share the same output
                        // structure, however this transform assumes that the SingleEntryPoint
                        // transform will have already been run.
                        ctx.Remove(member_attrs, location);
                    }
                }
            } else {
                // The entry point returned a non-structure
                add_member(entry_point->ReturnType(), ctx.Clone(fn->return_type_attributes));
                return_args.Push(b.Expr(call_result));

                // Remove the @location from the inner function's return type attributes
                ctx.Remove(fn->return_type_attributes,
                           ast::GetAttribute<ast::LocationAttribute>(fn->return_type_attributes));
            }
            body.Push(b.Decl(b.Let(call_result, b.Call(inner_name, std::move(call_args)))));
        } else {
            body.Push(b.CallStmt(b.Call(inner_name, std::move(call_args))));
        }

        // Declare the output structure
        b.Structure(str_name, std::move(members));

        // Return the output structure
        body.Push(b.Return(b.Call(str_name, std::move(return_args))));

        // Declare the new entry point that calls the inner function
        b.Func(fn_name, std::move(params), b.ty(str_name), body,
               Vector{b.Stage(ast::PipelineStage::kFragment)});
    }

    /// @returns a new Attachment attribute
    /// @param index the index of the attachment
    PixelLocal::Attachment* Attachment(uint32_t index) {
        return b.ASTNodes().Create<PixelLocal::Attachment>(b.ID(), b.AllocateNodeID(), index);
    }

    /// @returns the attachment index for the pixel local field with the given index
    /// @param field_index the pixel local field index
    uint32_t AttachmentIndex(uint32_t field_index) {
        auto idx = cfg.attachments.Get(field_index);
        if (TINT_UNLIKELY(!idx)) {
            b.Diagnostics().add_error(diag::System::Transform,
                                      "PixelLocal::Config::attachments missing entry for field " +
                                          std::to_string(field_index));
            return 0;
        }
        return *idx;
    }
};

PixelLocal::PixelLocal() = default;

PixelLocal::~PixelLocal() = default;

ast::transform::Transform::ApplyResult PixelLocal::Apply(const Program* src,
                                                         const ast::transform::DataMap& inputs,
                                                         ast::transform::DataMap&) const {
    auto* cfg = inputs.Get<Config>();
    if (!cfg) {
        ProgramBuilder b;
        b.Diagnostics().add_error(diag::System::Transform,
                                  "missing transform data for " + std::string(TypeInfo().name));
        return resolver::Resolve(b);
    }

    return State(src, *cfg).Run();
}

PixelLocal::Config::Config() = default;

PixelLocal::Config::Config(const Config&) = default;

PixelLocal::Config::~Config() = default;

PixelLocal::Attachment::Attachment(GenerationID pid, ast::NodeID nid, uint32_t idx)
    : Base(pid, nid, Empty), index(idx) {}

PixelLocal::Attachment::~Attachment() = default;

std::string PixelLocal::Attachment::InternalName() const {
    return "attachment(" + std::to_string(index) + ")";
}

const PixelLocal::Attachment* PixelLocal::Attachment::Clone(ast::CloneContext& ctx) const {
    return ctx.dst->ASTNodes().Create<Attachment>(ctx.dst->ID(), ctx.dst->AllocateNodeID(), index);
}

}  // namespace tint::msl::writer
