// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/transform/unwind_discard_functions.h"

#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include "src/tint/ast/discard_statement.h"
#include "src/tint/ast/return_statement.h"
#include "src/tint/ast/traverse_expressions.h"
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/transform/utils/get_insertion_point.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);

namespace tint::transform {
namespace {

class State {
  private:
    CloneContext& ctx;
    ProgramBuilder& b;
    const sem::Info& sem;
    Symbol module_discard_var_name;   // Use ModuleDiscardVarName() to read
    Symbol module_discard_func_name;  // Use ModuleDiscardFuncName() to read

    // Returns true if `sem_expr` contains a call expression that may
    // (transitively) execute a discard statement.
    bool MayDiscard(const sem::Expression* sem_expr) {
        return sem_expr && sem_expr->Behaviors().Contains(sem::Behavior::kDiscard);
    }

    // Lazily creates and returns the name of the module bool variable for whether
    // to discard: "tint_discard".
    Symbol ModuleDiscardVarName() {
        if (!module_discard_var_name.IsValid()) {
            module_discard_var_name = b.Symbols().New("tint_discard");
            ctx.dst->GlobalVar(module_discard_var_name, b.ty.bool_(), b.Expr(false),
                               ast::StorageClass::kPrivate);
        }
        return module_discard_var_name;
    }

    // Lazily creates and returns the name of the function that contains a single
    // discard statement: "tint_discard_func".
    // We do this to avoid having multiple discard statements in a single program,
    // which causes problems in certain backends (see crbug.com/1118).
    Symbol ModuleDiscardFuncName() {
        if (!module_discard_func_name.IsValid()) {
            module_discard_func_name = b.Symbols().New("tint_discard_func");
            b.Func(module_discard_func_name, tint::utils::Empty, b.ty.void_(),
                   tint::utils::Vector{b.Discard()});
        }
        return module_discard_func_name;
    }

    // Creates "return <default return value>;" based on the return type of
    // `stmt`'s owning function.
    const ast::ReturnStatement* Return(const ast::Statement* stmt) {
        const ast::Expression* ret_val = nullptr;
        auto* ret_type = sem.Get(stmt)->Function()->Declaration()->return_type;
        if (!ret_type->Is<ast::Void>()) {
            ret_val = b.Construct(ctx.Clone(ret_type));
        }
        return b.Return(ret_val);
    }

    // Returns true if the function `stmt` is in is an entry point
    bool IsInEntryPointFunc(const ast::Statement* stmt) {
        return sem.Get(stmt)->Function()->Declaration()->IsEntryPoint();
    }

    // Creates "tint_discard_func();"
    const ast::CallStatement* CallDiscardFunc() {
        auto func_name = ModuleDiscardFuncName();
        return b.CallStmt(b.Call(func_name));
    }

    // Creates and returns a new if-statement of the form:
    //
    //    if (tint_discard) {
    //      return <default value>;
    //    }
    //
    // or if `stmt` is in a entry point function:
    //
    //    if (tint_discard) {
    //      tint_discard_func();
    //      return <default value>;
    //    }
    //
    const ast::IfStatement* IfDiscardReturn(const ast::Statement* stmt) {
        tint::utils::Vector<const ast::Statement*, 2> stmts;

        // For entry point functions, also emit the discard statement
        if (IsInEntryPointFunc(stmt)) {
            stmts.Push(CallDiscardFunc());
        }

        stmts.Push(Return(stmt));

        auto var_name = ModuleDiscardVarName();
        return b.If(var_name, b.Block(stmts));
    }

    // Hoists `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
    // For example, if `stmt` is:
    //
    //    return f();
    //
    // This function will transform this to:
    //
    //    let t1 = f();
    //    if (tint_discard) {
    //      return;
    //    }
    //    return t1;
    //
    const ast::Statement* HoistAndInsertBefore(const ast::Statement* stmt,
                                               const sem::Expression* sem_expr) {
        auto* expr = sem_expr->Declaration();

        auto ip = utils::GetInsertionPoint(ctx, stmt);
        auto var_name = b.Sym();
        auto* decl = b.Decl(b.Var(var_name, ctx.Clone(expr)));
        ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);

        ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt));

        auto* var_expr = b.Expr(var_name);

        // Special handling for CallStatement as we can only replace its expression
        // with a CallExpression.
        if (stmt->Is<ast::CallStatement>()) {
            // We could replace the call statement with no statement, but we can't do
            // that with transforms (yet), so just return a phony assignment.
            return b.Assign(b.Phony(), var_expr);
        }

        ctx.Replace(expr, var_expr);
        return ctx.CloneWithoutTransform(stmt);
    }

    // Returns true if `stmt` is a for-loop initializer statement.
    bool IsForLoopInitStatement(const ast::Statement* stmt) {
        if (auto* sem_stmt = sem.Get(stmt)) {
            if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
                return sem_fl->Declaration()->initializer == stmt;
            }
        }
        return false;
    }

    // Inserts an `IfDiscardReturn` after `stmt` if possible (i.e. `stmt` is not
    // in a for-loop init), otherwise falls back to HoistAndInsertBefore, hoisting
    // `sem_expr` to a let followed by an `IfDiscardReturn` before `stmt`.
    //
    // For example, if `stmt` is:
    //
    //    let r = f();
    //
    // This function will transform this to:
    //
    //    let r = f();
    //    if (tint_discard) {
    //      return;
    //    }
    const ast::Statement* TryInsertAfter(const ast::Statement* stmt,
                                         const sem::Expression* sem_expr) {
        // If `stmt` is the init of a for-loop, hoist and insert before instead.
        if (IsForLoopInitStatement(stmt)) {
            return HoistAndInsertBefore(stmt, sem_expr);
        }

        auto ip = utils::GetInsertionPoint(ctx, stmt);
        ctx.InsertAfter(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt));
        return nullptr;  // Don't replace current statement
    }

    // Replaces the input discard statement with either setting the module level
    // discard bool ("tint_discard = true"), or calling the discard function
    // ("tint_discard_func()"), followed by a default return statement.
    //
    // Replaces "discard;" with:
    //
    //    tint_discard = true;
    //    return;
    //
    // Or if `stmt` is a entry point function, replaces with:
    //
    //    tint_discard_func();
    //    return;
    //
    const ast::Statement* ReplaceDiscardStatement(const ast::DiscardStatement* stmt) {
        const ast::Statement* to_insert = nullptr;
        if (IsInEntryPointFunc(stmt)) {
            to_insert = CallDiscardFunc();
        } else {
            auto var_name = ModuleDiscardVarName();
            to_insert = b.Assign(var_name, true);
        }

        auto ip = utils::GetInsertionPoint(ctx, stmt);
        ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
        return Return(stmt);
    }

    // Handle statement
    const ast::Statement* Statement(const ast::Statement* stmt) {
        return Switch(
            stmt,
            [&](const ast::DiscardStatement* s) -> const ast::Statement* {
                return ReplaceDiscardStatement(s);
            },
            [&](const ast::AssignmentStatement* s) -> const ast::Statement* {
                auto* sem_lhs = sem.Get(s->lhs);
                auto* sem_rhs = sem.Get(s->rhs);
                if (MayDiscard(sem_lhs)) {
                    if (MayDiscard(sem_rhs)) {
                        TINT_ICE(Transform, b.Diagnostics())
                            << "Unexpected: both sides of assignment statement may "
                               "discard. Make sure transform::PromoteSideEffectsToDecl "
                               "was run first.";
                    }
                    return TryInsertAfter(s, sem_lhs);
                } else if (MayDiscard(sem_rhs)) {
                    return TryInsertAfter(s, sem_rhs);
                }
                return nullptr;
            },
            [&](const ast::CallStatement* s) -> const ast::Statement* {
                auto* sem_expr = sem.Get(s->expr);
                if (!MayDiscard(sem_expr)) {
                    return nullptr;
                }
                return TryInsertAfter(s, sem_expr);
            },
            [&](const ast::ForLoopStatement* s) -> const ast::Statement* {
                if (MayDiscard(sem.Get(s->condition))) {
                    TINT_ICE(Transform, b.Diagnostics())
                        << "Unexpected ForLoopStatement condition that may discard. "
                           "Make sure transform::PromoteSideEffectsToDecl was run "
                           "first.";
                }
                return nullptr;
            },
            [&](const ast::WhileStatement* s) -> const ast::Statement* {
                if (MayDiscard(sem.Get(s->condition))) {
                    TINT_ICE(Transform, b.Diagnostics())
                        << "Unexpected WhileStatement condition that may discard. "
                           "Make sure transform::PromoteSideEffectsToDecl was run "
                           "first.";
                }
                return nullptr;
            },
            [&](const ast::IfStatement* s) -> const ast::Statement* {
                auto* sem_expr = sem.Get(s->condition);
                if (!MayDiscard(sem_expr)) {
                    return nullptr;
                }
                return HoistAndInsertBefore(s, sem_expr);
            },
            [&](const ast::ReturnStatement* s) -> const ast::Statement* {
                auto* sem_expr = sem.Get(s->value);
                if (!MayDiscard(sem_expr)) {
                    return nullptr;
                }
                return HoistAndInsertBefore(s, sem_expr);
            },
            [&](const ast::SwitchStatement* s) -> const ast::Statement* {
                auto* sem_expr = sem.Get(s->condition);
                if (!MayDiscard(sem_expr)) {
                    return nullptr;
                }
                return HoistAndInsertBefore(s, sem_expr);
            },
            [&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
                auto* var = s->variable;
                if (!var->constructor) {
                    return nullptr;
                }
                auto* sem_expr = sem.Get(var->constructor);
                if (!MayDiscard(sem_expr)) {
                    return nullptr;
                }
                return TryInsertAfter(s, sem_expr);
            });
    }

  public:
    /// Constructor
    /// @param ctx_in the context
    explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}

    /// Runs the transform
    void Run() {
        ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
            // Iterate block statements and replace them as needed.
            for (auto* stmt : block->statements) {
                if (auto* new_stmt = Statement(stmt)) {
                    ctx.Replace(stmt, new_stmt);
                }

                // Handle for loops, as they are the only other AST node that
                // contains statements outside of BlockStatements.
                if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
                    if (auto* new_stmt = Statement(fl->initializer)) {
                        ctx.Replace(fl->initializer, new_stmt);
                    }
                    if (auto* new_stmt = Statement(fl->continuing)) {
                        // NOTE: Should never reach here as we cannot discard in a
                        // continuing block.
                        ctx.Replace(fl->continuing, new_stmt);
                    }
                }
            }

            return nullptr;
        });

        ctx.Clone();
    }
};

}  // namespace

UnwindDiscardFunctions::UnwindDiscardFunctions() = default;
UnwindDiscardFunctions::~UnwindDiscardFunctions() = default;

void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
    State state(ctx);
    state.Run();
}

bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const {
    auto& sem = program->Sem();
    for (auto* f : program->AST().Functions()) {
        if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
            return true;
        }
    }
    return false;
}

}  // namespace tint::transform
