// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/ir/builder_impl.h"

#include "src/tint/ast/alias.h"
#include "src/tint/ast/binary_expression.h"
#include "src/tint/ast/block_statement.h"
#include "src/tint/ast/bool_literal_expression.h"
#include "src/tint/ast/break_if_statement.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/ast/continue_statement.h"
#include "src/tint/ast/float_literal_expression.h"
#include "src/tint/ast/for_loop_statement.h"
#include "src/tint/ast/function.h"
#include "src/tint/ast/if_statement.h"
#include "src/tint/ast/int_literal_expression.h"
#include "src/tint/ast/literal_expression.h"
#include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/return_statement.h"
#include "src/tint/ast/statement.h"
#include "src/tint/ast/static_assert.h"
#include "src/tint/ast/switch_statement.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ir/function.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/loop.h"
#include "src/tint/ir/module.h"
#include "src/tint/ir/switch.h"
#include "src/tint/ir/terminator.h"
#include "src/tint/program.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/module.h"

namespace tint::ir {
namespace {

using ResultType = utils::Result<Module>;

class FlowStackScope {
  public:
    FlowStackScope(BuilderImpl* impl, FlowNode* node) : impl_(impl) {
        impl_->flow_stack.Push(node);
    }

    ~FlowStackScope() { impl_->flow_stack.Pop(); }

  private:
    BuilderImpl* impl_;
};

bool IsBranched(const Block* b) {
    return b->branch_target != nullptr;
}

bool IsConnected(const FlowNode* b) {
    // Function is always connected as it's the start.
    if (b->Is<ir::Function>()) {
        return true;
    }

    for (auto* parent : b->inbound_branches) {
        if (IsConnected(parent)) {
            return true;
        }
    }
    // Getting here means all the incoming branches are disconnected.
    return false;
}

}  // namespace

BuilderImpl::BuilderImpl(const Program* program) : builder(program) {}

BuilderImpl::~BuilderImpl() = default;

void BuilderImpl::BranchTo(FlowNode* node) {
    TINT_ASSERT(IR, current_flow_block);
    TINT_ASSERT(IR, !IsBranched(current_flow_block));

    builder.Branch(current_flow_block, node);
    current_flow_block = nullptr;
}

void BuilderImpl::BranchToIfNeeded(FlowNode* node) {
    if (!current_flow_block || IsBranched(current_flow_block)) {
        return;
    }
    BranchTo(node);
}

FlowNode* BuilderImpl::FindEnclosingControl(ControlFlags flags) {
    for (auto it = flow_stack.rbegin(); it != flow_stack.rend(); ++it) {
        if ((*it)->Is<Loop>()) {
            return *it;
        }
        if (flags == ControlFlags::kExcludeSwitch) {
            continue;
        }
        if ((*it)->Is<Switch>()) {
            return *it;
        }
    }
    return nullptr;
}

ResultType BuilderImpl::Build() {
    auto* sem = builder.ir.program->Sem().Module();

    for (auto* decl : sem->DependencyOrderedDeclarations()) {
        bool ok = tint::Switch(
            decl,  //
            // [&](const ast::Struct* str) { },
            [&](const ast::Alias*) {
                // Folded away and doesn't appear in the IR.
                return true;
            },
            // [&](const ast::Variable* var) { },
            [&](const ast::Function* func) { return EmitFunction(func); },
            // [&](const ast::Enable*) { },
            [&](const ast::StaticAssert*) {
                // Evaluated by the resolver, drop from the IR.
                return true;
            },
            [&](Default) {
                diagnostics_.add_warning(tint::diag::System::IR,
                                         "unknown type: " + std::string(decl->TypeInfo().name),
                                         decl->source);
                return true;
            });
        if (!ok) {
            return utils::Failure;
        }
    }

    return ResultType{std::move(builder.ir)};
}

bool BuilderImpl::EmitFunction(const ast::Function* ast_func) {
    // The flow stack should have been emptied when the previous function finshed building.
    TINT_ASSERT(IR, flow_stack.IsEmpty());

    auto* ir_func = builder.CreateFunction(ast_func);
    current_function_ = ir_func;
    builder.ir.functions.Push(ir_func);

    ast_to_flow_[ast_func] = ir_func;

    if (ast_func->IsEntryPoint()) {
        builder.ir.entry_points.Push(ir_func);
    }

    {
        FlowStackScope scope(this, ir_func);

        current_flow_block = ir_func->start_target;
        if (!EmitStatements(ast_func->body->statements)) {
            return false;
        }

        // If the branch target has already been set then a `return` was called. Only set in the
        // case where `return` wasn't called.
        BranchToIfNeeded(current_function_->end_target);
    }

    TINT_ASSERT(IR, flow_stack.IsEmpty());
    current_flow_block = nullptr;
    current_function_ = nullptr;

    return true;
}

bool BuilderImpl::EmitStatements(utils::VectorRef<const ast::Statement*> stmts) {
    for (auto* s : stmts) {
        if (!EmitStatement(s)) {
            return false;
        }

        // If the current flow block has a branch target then the rest of the statements in this
        // block are dead code. Skip them.
        if (!current_flow_block || IsBranched(current_flow_block)) {
            break;
        }
    }
    return true;
}

bool BuilderImpl::EmitStatement(const ast::Statement* stmt) {
    return tint::Switch(
        stmt,
        // [&](const ast::AssignmentStatement* a) { },
        [&](const ast::BlockStatement* b) { return EmitBlock(b); },
        [&](const ast::BreakStatement* b) { return EmitBreak(b); },
        [&](const ast::BreakIfStatement* b) { return EmitBreakIf(b); },
        // [&](const ast::CallStatement* c) { },
        // [&](const ast::CompoundAssignmentStatement* c) { },
        [&](const ast::ContinueStatement* c) { return EmitContinue(c); },
        // [&](const ast::DiscardStatement* d) { },
        [&](const ast::IfStatement* i) { return EmitIf(i); },
        [&](const ast::LoopStatement* l) { return EmitLoop(l); },
        [&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
        [&](const ast::WhileStatement* l) { return EmitWhile(l); },
        [&](const ast::ReturnStatement* r) { return EmitReturn(r); },
        [&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
        [&](const ast::VariableDeclStatement* v) { return EmitVariable(v->variable); },
        [&](const ast::StaticAssert*) {
            return true;  // Not emitted
        },
        [&](Default) {
            diagnostics_.add_warning(
                tint::diag::System::IR,
                "unknown statement type: " + std::string(stmt->TypeInfo().name), stmt->source);
            return true;
        });
}

bool BuilderImpl::EmitBlock(const ast::BlockStatement* block) {
    // Note, this doesn't need to emit a Block as the current block flow node should be
    // sufficient as the blocks all get flattened. Each flow control node will inject the basic
    // blocks it requires.
    return EmitStatements(block->statements);
}

bool BuilderImpl::EmitIf(const ast::IfStatement* stmt) {
    auto* if_node = builder.CreateIf(stmt);

    // Emit the if condition into the end of the preceeding block
    auto reg = EmitExpression(stmt->condition);
    if (!reg) {
        return false;
    }
    if_node->condition = reg.Get();

    BranchTo(if_node);

    ast_to_flow_[stmt] = if_node;

    {
        FlowStackScope scope(this, if_node);

        current_flow_block = if_node->true_target;
        if (!EmitStatement(stmt->body)) {
            return false;
        }
        // If the true branch did not execute control flow, then go to the merge target
        BranchToIfNeeded(if_node->merge_target);

        current_flow_block = if_node->false_target;
        if (stmt->else_statement && !EmitStatement(stmt->else_statement)) {
            return false;
        }
        // If the false branch did not execute control flow, then go to the merge target
        BranchToIfNeeded(if_node->merge_target);
    }
    current_flow_block = nullptr;

    // If both branches went somewhere, then they both returned, continued or broke. So,
    // there is no need for the if merge-block and there is nothing to branch to the merge
    // block anyway.
    if (IsConnected(if_node->merge_target)) {
        current_flow_block = if_node->merge_target;
    }

    return true;
}

bool BuilderImpl::EmitLoop(const ast::LoopStatement* stmt) {
    auto* loop_node = builder.CreateLoop(stmt);

    BranchTo(loop_node);

    ast_to_flow_[stmt] = loop_node;

    {
        FlowStackScope scope(this, loop_node);

        current_flow_block = loop_node->start_target;
        if (!EmitStatement(stmt->body)) {
            return false;
        }

        // The current block didn't `break`, `return` or `continue`, go to the continuing block.
        BranchToIfNeeded(loop_node->continuing_target);

        current_flow_block = loop_node->continuing_target;
        if (stmt->continuing) {
            if (!EmitStatement(stmt->continuing)) {
                return false;
            }
        }

        // Branch back to the start node if the continue target didn't branch out already
        BranchToIfNeeded(loop_node->start_target);
    }

    // The loop merge can get disconnected if the loop returns directly, or the continuing target
    // branches, eventually, to the merge, but nothing branched to the continuing target.
    current_flow_block = loop_node->merge_target;
    if (!IsConnected(loop_node->merge_target)) {
        current_flow_block = nullptr;
    }
    return true;
}

bool BuilderImpl::EmitWhile(const ast::WhileStatement* stmt) {
    auto* loop_node = builder.CreateLoop(stmt);
    // Continue is always empty, just go back to the start
    builder.Branch(loop_node->continuing_target, loop_node->start_target);

    BranchTo(loop_node);

    ast_to_flow_[stmt] = loop_node;

    {
        FlowStackScope scope(this, loop_node);

        current_flow_block = loop_node->start_target;

        // Emit the while condition into the start target of the loop
        auto reg = EmitExpression(stmt->condition);
        if (!reg) {
            return false;
        }

        // Create an if (cond) {} else {break;} control flow
        auto* if_node = builder.CreateIf(nullptr);
        builder.Branch(if_node->true_target, if_node->merge_target);
        builder.Branch(if_node->false_target, loop_node->merge_target);
        if_node->condition = reg.Get();

        BranchTo(if_node);

        current_flow_block = if_node->merge_target;
        if (!EmitStatement(stmt->body)) {
            return false;
        }

        BranchToIfNeeded(loop_node->continuing_target);
    }
    // The while loop always has a path to the merge target as the break statement comes before
    // anything inside the loop.
    current_flow_block = loop_node->merge_target;
    return true;
}

bool BuilderImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
    auto* loop_node = builder.CreateLoop(stmt);
    builder.Branch(loop_node->continuing_target, loop_node->start_target);

    if (stmt->initializer) {
        // Emit the for initializer before branching to the loop
        if (!EmitStatement(stmt->initializer)) {
            return false;
        }
    }

    BranchTo(loop_node);

    ast_to_flow_[stmt] = loop_node;

    {
        FlowStackScope scope(this, loop_node);

        current_flow_block = loop_node->start_target;

        if (stmt->condition) {
            // Emit the condition into the target target of the loop
            auto reg = EmitExpression(stmt->condition);
            if (!reg) {
                return false;
            }

            // Create an if (cond) {} else {break;} control flow
            auto* if_node = builder.CreateIf(nullptr);
            builder.Branch(if_node->true_target, if_node->merge_target);
            builder.Branch(if_node->false_target, loop_node->merge_target);
            if_node->condition = reg.Get();

            BranchTo(if_node);
            current_flow_block = if_node->merge_target;
        }

        if (!EmitStatement(stmt->body)) {
            return false;
        }

        BranchToIfNeeded(loop_node->continuing_target);

        if (stmt->continuing) {
            current_flow_block = loop_node->continuing_target;
            if (!EmitStatement(stmt->continuing)) {
                return false;
            }
        }
    }
    // The while loop always has a path to the merge target as the break statement comes before
    // anything inside the loop.
    current_flow_block = loop_node->merge_target;
    return true;
}

bool BuilderImpl::EmitSwitch(const ast::SwitchStatement* stmt) {
    auto* switch_node = builder.CreateSwitch(stmt);

    // Emit the condition into the preceeding block
    auto reg = EmitExpression(stmt->condition);
    if (!reg) {
        return false;
    }
    switch_node->condition = reg.Get();

    BranchTo(switch_node);

    ast_to_flow_[stmt] = switch_node;

    {
        FlowStackScope scope(this, switch_node);

        for (const auto* c : stmt->body) {
            current_flow_block = builder.CreateCase(switch_node, c->selectors);
            if (!EmitStatement(c->body)) {
                return false;
            }
            BranchToIfNeeded(switch_node->merge_target);
        }
    }
    current_flow_block = nullptr;

    if (IsConnected(switch_node->merge_target)) {
        current_flow_block = switch_node->merge_target;
    }

    return true;
}

bool BuilderImpl::EmitReturn(const ast::ReturnStatement*) {
    // TODO(dsinclair): Emit the return value ....

    BranchTo(current_function_->end_target);
    return true;
}

bool BuilderImpl::EmitBreak(const ast::BreakStatement*) {
    auto* current_control = FindEnclosingControl(ControlFlags::kNone);
    TINT_ASSERT(IR, current_control);

    if (auto* c = current_control->As<Loop>()) {
        BranchTo(c->merge_target);
    } else if (auto* s = current_control->As<Switch>()) {
        BranchTo(s->merge_target);
    } else {
        TINT_UNREACHABLE(IR, diagnostics_);
        return false;
    }

    return true;
}

bool BuilderImpl::EmitContinue(const ast::ContinueStatement*) {
    auto* current_control = FindEnclosingControl(ControlFlags::kExcludeSwitch);
    TINT_ASSERT(IR, current_control);

    if (auto* c = current_control->As<Loop>()) {
        BranchTo(c->continuing_target);
    } else {
        TINT_UNREACHABLE(IR, diagnostics_);
    }

    return true;
}

bool BuilderImpl::EmitBreakIf(const ast::BreakIfStatement* stmt) {
    auto* if_node = builder.CreateIf(stmt);

    // Emit the break-if condition into the end of the preceeding block
    auto reg = EmitExpression(stmt->condition);
    if (!reg) {
        return false;
    }
    if_node->condition = reg.Get();

    BranchTo(if_node);

    ast_to_flow_[stmt] = if_node;

    auto* current_control = FindEnclosingControl(ControlFlags::kExcludeSwitch);
    TINT_ASSERT(IR, current_control);
    TINT_ASSERT(IR, current_control->Is<Loop>());

    auto* loop = current_control->As<Loop>();

    current_flow_block = if_node->true_target;
    BranchTo(loop->merge_target);

    current_flow_block = if_node->false_target;
    BranchTo(if_node->merge_target);

    current_flow_block = if_node->merge_target;

    // The `break-if` has to be the last item in the continuing block. The false branch of the
    // `break-if` will always take us back to the start of the loop.
    // break then we go back to the start of the loop.
    BranchTo(loop->start_target);

    return true;
}

utils::Result<const Value*> BuilderImpl::EmitExpression(const ast::Expression* expr) {
    return tint::Switch(
        expr,
        // [&](const ast::IndexAccessorExpression* a) { return EmitIndexAccessor(a); },
        [&](const ast::BinaryExpression* b) { return EmitBinary(b); },
        // [&](const ast::BitcastExpression* b) { return EmitBitcast(b); },
        // [&](const ast::CallExpression* c) { return EmitCall(c); },
        // [&](const ast::IdentifierExpression* i) { return EmitIdentifier(i); },
        [&](const ast::LiteralExpression* l) { return EmitLiteral(l); },
        // [&](const ast::MemberAccessorExpression* m) { return EmitMemberAccessor(m); },
        // [&](const ast::PhonyExpression*) { return true; },
        // [&](const ast::UnaryOpExpression* u) { return EmitUnaryOp(u); },
        [&](Default) {
            diagnostics_.add_warning(
                tint::diag::System::IR,
                "unknown expression type: " + std::string(expr->TypeInfo().name), expr->source);
            return utils::Failure;
        });
}

bool BuilderImpl::EmitVariable(const ast::Variable* var) {
    return tint::Switch(  //
        var,
        // [&](const ast::Var* var) {},
        // [&](const ast::Let*) {},
        // [&](const ast::Override*) { },
        // [&](const ast::Const* c) { },
        [&](Default) {
            diagnostics_.add_warning(tint::diag::System::IR,
                                     "unknown variable: " + std::string(var->TypeInfo().name),
                                     var->source);
            return false;
        });
}

utils::Result<const Value*> BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) {
    auto lhs = EmitExpression(expr->lhs);
    if (!lhs) {
        return utils::Failure;
    }

    auto rhs = EmitExpression(expr->rhs);
    if (!rhs) {
        return utils::Failure;
    }

    auto* sem = builder.ir.program->Sem().Get(expr);
    const Binary* instr = nullptr;
    switch (expr->op) {
        case ast::BinaryOp::kAnd:
            instr = builder.And(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kOr:
            instr = builder.Or(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kXor:
            instr = builder.Xor(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kLogicalAnd:
            instr = builder.LogicalAnd(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kLogicalOr:
            instr = builder.LogicalOr(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kEqual:
            instr = builder.Equal(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kNotEqual:
            instr = builder.NotEqual(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kLessThan:
            instr = builder.LessThan(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kGreaterThan:
            instr = builder.GreaterThan(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kLessThanEqual:
            instr = builder.LessThanEqual(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kGreaterThanEqual:
            instr = builder.GreaterThanEqual(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kShiftLeft:
            instr = builder.ShiftLeft(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kShiftRight:
            instr = builder.ShiftRight(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kAdd:
            instr = builder.Add(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kSubtract:
            instr = builder.Subtract(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kMultiply:
            instr = builder.Multiply(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kDivide:
            instr = builder.Divide(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kModulo:
            instr = builder.Modulo(sem->Type(), lhs.Get(), rhs.Get());
            break;
        case ast::BinaryOp::kNone:
            TINT_ICE(IR, diagnostics_) << "missing binary operand type";
            return utils::Failure;
    }

    current_flow_block->instructions.Push(instr);
    return utils::Result<const Value*>(instr->Result());
}

utils::Result<const Value*> BuilderImpl::EmitLiteral(const ast::LiteralExpression* lit) {
    auto* sem = builder.ir.program->Sem().Get(lit);
    if (!sem) {
        diagnostics_.add_error(
            tint::diag::System::IR,
            "Failed to get semantic information for node " + std::string(lit->TypeInfo().name),
            lit->source);
        return utils::Failure;
    }

    auto* cv = sem->ConstantValue();
    if (!cv) {
        diagnostics_.add_error(
            tint::diag::System::IR,
            "Failed to get constant value for node " + std::string(lit->TypeInfo().name),
            lit->source);
        return utils::Failure;
    }
    return utils::Result<const Value*>(builder.Constant(cv));
}

bool BuilderImpl::EmitType(const ast::Type* ty) {
    return tint::Switch(
        ty,
        // [&](const ast::Array* ary) { },
        // [&](const ast::Bool* b) { },
        // [&](const ast::F32* f) { },
        // [&](const ast::F16* f) { },
        // [&](const ast::I32* i) { },
        // [&](const ast::U32* u) { },
        // [&](const ast::Vector* v) { },
        // [&](const ast::Matrix* mat) { },
        // [&](const ast::Pointer* ptr) { },'
        // [&](const ast::Atomic* a) { },
        // [&](const ast::Sampler* s) { },
        // [&](const ast::ExternalTexture* t) { },
        // [&](const ast::Texture* t) {
        //      return tint::Switch(
        //          t,
        //          [&](const ast::DepthTexture*) { },
        //          [&](const ast::DepthMultisampledTexture*) { },
        //          [&](const ast::SampledTexture*) { },
        //          [&](const ast::MultisampledTexture*) { },
        //          [&](const ast::StorageTexture*) {  },
        //          [&](Default) {
        //              diagnostics_.add_warning(tint::diag::System::IR,
        //                  "unknown texture: " + std::string(t->TypeInfo().name), t->source);
        //              return false;
        //          });
        // },
        // [&](const ast::Void* v) { },
        // [&](const ast::TypeName* tn) { },
        [&](Default) {
            diagnostics_.add_warning(tint::diag::System::IR,
                                     "unknown type: " + std::string(ty->TypeInfo().name),
                                     ty->source);
            return false;
        });
}

bool BuilderImpl::EmitAttributes(utils::VectorRef<const ast::Attribute*> attrs) {
    for (auto* attr : attrs) {
        if (!EmitAttribute(attr)) {
            return false;
        }
    }
    return true;
}

bool BuilderImpl::EmitAttribute(const ast::Attribute* attr) {
    return tint::Switch(  //
        attr,
        // [&](const ast::WorkgroupAttribute* wg) {},
        // [&](const ast::StageAttribute* s) {},
        // [&](const ast::BindingAttribute* b) {},
        // [&](const ast::GroupAttribute* g) {},
        // [&](const ast::LocationAttribute* l) {},
        // [&](const ast::BuiltinAttribute* b) {},
        // [&](const ast::InterpolateAttribute* i) {},
        // [&](const ast::InvariantAttribute* i) {},
        // [&](const ast::IdAttribute* i) {},
        // [&](const ast::StructMemberSizeAttribute* s) {},
        // [&](const ast::StructMemberAlignAttribute* a) {},
        // [&](const ast::StrideAttribute* s) {}
        // [&](const ast::InternalAttribute *i) {},
        [&](Default) {
            diagnostics_.add_warning(tint::diag::System::IR,
                                     "unknown attribute: " + std::string(attr->TypeInfo().name),
                                     attr->source);
            return false;
        });
}

}  // namespace tint::ir
