// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SRC_AST_TRAVERSE_EXPRESSIONS_H_
#define SRC_AST_TRAVERSE_EXPRESSIONS_H_

#include <vector>

#include "src/ast/array_accessor_expression.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/call_expression.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/phony_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h"
#include "src/utils/reverse.h"

namespace tint {
namespace ast {

/// The action to perform after calling the TraverseExpressions() callback
/// function.
enum class TraverseAction {
  /// Stop traversal immediately.
  Stop,
  /// Descend into this expression.
  Descend,
  /// Do not descend into this expression.
  Skip,
};

/// The order TraverseExpressions() will traverse expressions
enum class TraverseOrder {
  /// Expressions will be traversed from left to right
  LeftToRight,
  /// Expressions will be traversed from right to left
  RightToLeft,
};

/// TraverseExpressions performs a depth-first traversal of the expression nodes
/// from `root`, calling `callback` for each of the visited expressions that
/// match the predicate parameter type, in pre-ordering (root first).
/// @param root the root expression node
/// @param diags the diagnostics used for error messages
/// @param callback the callback function. Must be of the signature:
///        `TraverseAction(const T*)` where T is an ast::Expression type.
/// @return true on success, false on error
template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
bool TraverseExpressions(const ast::Expression* root,
                         diag::List& diags,
                         CALLBACK&& callback) {
  using EXPR_TYPE = std::remove_pointer_t<traits::ParamTypeT<CALLBACK, 0>>;
  std::vector<const ast::Expression*> to_visit{root};

  auto push_pair = [&](const ast::Expression* left,
                       const ast::Expression* right) {
    if (ORDER == TraverseOrder::LeftToRight) {
      to_visit.push_back(right);
      to_visit.push_back(left);
    } else {
      to_visit.push_back(left);
      to_visit.push_back(right);
    }
  };
  auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
    if (ORDER == TraverseOrder::LeftToRight) {
      for (auto* expr : utils::Reverse(exprs)) {
        to_visit.push_back(expr);
      }
    } else {
      for (auto* expr : exprs) {
        to_visit.push_back(expr);
      }
    }
  };

  while (!to_visit.empty()) {
    auto* expr = to_visit.back();
    to_visit.pop_back();

    if (auto* filtered = expr->As<EXPR_TYPE>()) {
      switch (callback(filtered)) {
        case TraverseAction::Stop:
          return true;
        case TraverseAction::Skip:
          continue;
        case TraverseAction::Descend:
          break;
      }
    }

    if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
      push_pair(array->array, array->index);
    } else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
      push_pair(bin_op->lhs, bin_op->rhs);
    } else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
      to_visit.push_back(bitcast->expr);
    } else if (auto* call = expr->As<ast::CallExpression>()) {
      // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
      // function name in the traversal.
      // to_visit.push_back(call->func);
      push_list(call->args);
    } else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
      push_list(type_ctor->values);
    } else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
      // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
      // member name in the traversal.
      // push_pair(member->structure, member->member);
      to_visit.push_back(member->structure);
    } else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
      to_visit.push_back(unary->expr);
    } else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
                             ast::IdentifierExpression,
                             ast::PhonyExpression>()) {
      // Leaf expression
    } else {
      TINT_ICE(AST, diags) << "unhandled expression type: "
                           << expr->TypeInfo().name;
      return false;
    }
  }
  return true;
}

}  // namespace ast
}  // namespace tint

#endif  // SRC_AST_TRAVERSE_EXPRESSIONS_H_
