// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/transform/zero_init_workgroup_memory.h"

#include <algorithm>
#include <map>
#include <unordered_map>
#include <utility>
#include <vector>

#include "src/ast/workgroup_decoration.h"
#include "src/program_builder.h"
#include "src/sem/atomic_type.h"
#include "src/sem/function.h"
#include "src/sem/variable.h"
#include "src/utils/get_or_create.h"
#include "src/utils/unique_vector.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);

namespace tint {
namespace transform {

/// PIMPL state for the ZeroInitWorkgroupMemory transform
struct ZeroInitWorkgroupMemory::State {
  /// The clone context
  CloneContext& ctx;

  /// An alias to *ctx.dst
  ProgramBuilder& b = *ctx.dst;

  /// The constant size of the workgroup. If 0, then #workgroup_size_expr should
  /// be used instead.
  uint32_t workgroup_size_const = 0;
  /// The size of the workgroup as an expression generator. Use if
  /// #workgroup_size_const is 0.
  std::function<ast::Expression*()> workgroup_size_expr;

  /// ArrayIndex represents a function on the local invocation index, of
  /// the form: `array_index = (local_invocation_index % modulo) / division`
  struct ArrayIndex {
    /// The RHS of the modulus part of the expression
    uint32_t modulo = 1;
    /// The RHS of the division part of the expression
    uint32_t division = 1;

    /// Equality operator
    /// @param i the ArrayIndex to compare to this ArrayIndex
    /// @returns true if `i` and this ArrayIndex are equal
    bool operator==(const ArrayIndex& i) const {
      return modulo == i.modulo && division == i.division;
    }

    /// Hash function for the ArrayIndex type
    struct Hasher {
      /// @param i the ArrayIndex to calculate a hash for
      /// @returns the hash value for the ArrayIndex `i`
      size_t operator()(const ArrayIndex& i) const {
        return utils::Hash(i.modulo, i.division);
      }
    };
  };

  /// A list of unique ArrayIndex
  using ArrayIndices = UniqueVector<ArrayIndex, ArrayIndex::Hasher>;

  /// Expression holds information about an expression that is being built for a
  /// statement will zero workgroup values.
  struct Expression {
    /// The AST expression node
    ast::Expression* expr = nullptr;
    /// The number of iterations required to zero the value
    uint32_t num_iterations = 0;
    /// All array indices used by this expression
    ArrayIndices array_indices;
  };

  /// Statement holds information about a statement that will zero workgroup
  /// values.
  struct Statement {
    /// The AST statement node
    ast::Statement* stmt;
    /// The number of iterations required to zero the value
    uint32_t num_iterations;
    /// All array indices used by this statement
    ArrayIndices array_indices;
  };

  /// All statements that zero workgroup memory
  std::vector<Statement> statements;

  /// A map of ArrayIndex to the name reserved for the `let` declaration of that
  /// index.
  std::unordered_map<ArrayIndex, Symbol, ArrayIndex::Hasher> array_index_names;

  /// Constructor
  /// @param c the CloneContext used for the transform
  explicit State(CloneContext& c) : ctx(c) {}

  /// Run inserts the workgroup memory zero-initialization logic at the top of
  /// the given function
  /// @param fn a compute shader entry point function
  void Run(ast::Function* fn) {
    auto& sem = ctx.src->Sem();

    CalculateWorkgroupSize(
        ast::GetDecoration<ast::WorkgroupDecoration>(fn->decorations));

    // Generate a list of statements to zero initialize each of the
    // workgroup storage variables used by `fn`. This will populate #statements.
    auto* func = sem.Get(fn);
    for (auto* var : func->ReferencedModuleVariables()) {
      if (var->StorageClass() == ast::StorageClass::kWorkgroup) {
        BuildZeroingStatements(
            var->Type()->UnwrapRef(), [&](uint32_t num_values) {
              auto var_name = ctx.Clone(var->Declaration()->symbol);
              return Expression{b.Expr(var_name), num_values, ArrayIndices{}};
            });
      }
    }

    if (statements.empty()) {
      return;  // No workgroup variables to initialize.
    }

    // Scan the entry point for an existing local_invocation_index builtin
    // parameter
    std::function<ast::Expression*()> local_index;
    for (auto* param : fn->params) {
      if (auto* builtin =
              ast::GetDecoration<ast::BuiltinDecoration>(param->decorations)) {
        if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
          local_index = [=] { return b.Expr(ctx.Clone(param->symbol)); };
          break;
        }
      }

      if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
        for (auto* member : str->Members()) {
          if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
                  member->Declaration()->decorations)) {
            if (builtin->builtin == ast::Builtin::kLocalInvocationIndex) {
              local_index = [=] {
                auto* param_expr = b.Expr(ctx.Clone(param->symbol));
                auto member_name = ctx.Clone(member->Declaration()->symbol);
                return b.MemberAccessor(param_expr, member_name);
              };
              break;
            }
          }
        }
      }
    }
    if (!local_index) {
      // No existing local index parameter. Append one to the entry point.
      auto* param =
          b.Param(b.Symbols().New("local_invocation_index"), b.ty.u32(),
                  {b.Builtin(ast::Builtin::kLocalInvocationIndex)});
      ctx.InsertBack(fn->params, param);
      local_index = [=] { return b.Expr(param->symbol); };
    }

    // Take the zeroing statements and bin them by the number of iterations
    // required to zero the workgroup data. We then emit these in blocks,
    // possibly wrapped in if-statements or for-loops.
    std::unordered_map<uint32_t, std::vector<Statement>>
        stmts_by_num_iterations;
    std::vector<uint32_t> num_sorted_iterations;
    for (auto& s : statements) {
      auto& stmts = stmts_by_num_iterations[s.num_iterations];
      if (stmts.empty()) {
        num_sorted_iterations.emplace_back(s.num_iterations);
      }
      stmts.emplace_back(s);
    }
    std::sort(num_sorted_iterations.begin(), num_sorted_iterations.end());

    // Loop over the statements, grouped by num_iterations.
    for (auto num_iterations : num_sorted_iterations) {
      auto& stmts = stmts_by_num_iterations[num_iterations];

      // Gather all the array indices used by all the statements in the block.
      ArrayIndices array_indices;
      for (auto& s : stmts) {
        for (auto& idx : s.array_indices) {
          array_indices.add(idx);
        }
      }

      // Determine the block type used to emit these statements.

      if (workgroup_size_const == 0 || num_iterations > workgroup_size_const) {
        // Either the workgroup size is dynamic, or smaller than num_iterations.
        // In either case, we need to generate a for loop to ensure we
        // initialize all the array elements.
        //
        //  for (var idx : u32 = local_index;
        //           idx < num_iterations;
        //           idx += workgroup_size) {
        //    ...
        //  }
        auto idx = b.Symbols().New("idx");
        auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
        auto* cond = b.create<ast::BinaryExpression>(
            ast::BinaryOp::kLessThan, b.Expr(idx), b.Expr(num_iterations));
        auto* cont = b.Assign(
            idx, b.Add(idx, workgroup_size_const ? b.Expr(workgroup_size_const)
                                                 : workgroup_size_expr()));

        auto block = DeclareArrayIndices(num_iterations, array_indices,
                                         [&] { return b.Expr(idx); });
        for (auto& s : stmts) {
          block.emplace_back(s.stmt);
        }
        auto* for_loop = b.For(init, cond, cont, b.Block(block));
        ctx.InsertFront(fn->body->statements, for_loop);
      } else if (num_iterations < workgroup_size_const) {
        // Workgroup size is a known constant, but is greater than
        // num_iterations. Emit an if statement:
        //
        //  if (local_index < num_iterations) {
        //    ...
        //  }
        auto* cond = b.create<ast::BinaryExpression>(
            ast::BinaryOp::kLessThan, local_index(), b.Expr(num_iterations));
        auto block = DeclareArrayIndices(num_iterations, array_indices,
                                         [&] { return b.Expr(local_index()); });
        for (auto& s : stmts) {
          block.emplace_back(s.stmt);
        }
        auto* if_stmt = b.If(cond, b.Block(block));
        ctx.InsertFront(fn->body->statements, if_stmt);
      } else {
        // Workgroup size exactly equals num_iterations.
        // No need for any conditionals. Just emit a basic block:
        //
        // {
        //    ...
        // }
        auto block = DeclareArrayIndices(num_iterations, array_indices,
                                         [&] { return b.Expr(local_index()); });
        for (auto& s : stmts) {
          block.emplace_back(s.stmt);
        }
        ctx.InsertFront(fn->body->statements, b.Block(block));
      }
    }

    // Append a single workgroup barrier after the zero initialization.
    ctx.InsertFront(fn->body->statements,
                    b.create<ast::CallStatement>(b.Call("workgroupBarrier")));
  }

  /// BuildZeroingExpr is a function that builds a sub-expression used to zero
  /// workgroup values. `num_values` is the number of elements that the
  /// expression will be used to zero. Returns the expression.
  using BuildZeroingExpr = std::function<Expression(uint32_t num_values)>;

  /// BuildZeroingStatements() generates the statements required to zero
  /// initialize the workgroup storage expression of type `ty`.
  /// @param ty the expression type
  /// @param get_expr a function that builds the AST nodes for the expression.
  void BuildZeroingStatements(const sem::Type* ty,
                              const BuildZeroingExpr& get_expr) {
    if (CanTriviallyZero(ty)) {
      auto var = get_expr(1u);
      auto* zero_init = b.Construct(CreateASTTypeFor(ctx, ty));
      statements.emplace_back(Statement{b.Assign(var.expr, zero_init),
                                        var.num_iterations, var.array_indices});
      return;
    }

    if (auto* atomic = ty->As<sem::Atomic>()) {
      auto* zero_init = b.Construct(CreateASTTypeFor(ctx, atomic->Type()));
      auto expr = get_expr(1u);
      auto* store = b.Call("atomicStore", b.AddressOf(expr.expr), zero_init);
      statements.emplace_back(Statement{b.create<ast::CallStatement>(store),
                                        expr.num_iterations,
                                        expr.array_indices});
      return;
    }

    if (auto* str = ty->As<sem::Struct>()) {
      for (auto* member : str->Members()) {
        auto name = ctx.Clone(member->Declaration()->symbol);
        BuildZeroingStatements(member->Type(), [&](uint32_t num_values) {
          auto s = get_expr(num_values);
          return Expression{b.MemberAccessor(s.expr, name), s.num_iterations,
                            s.array_indices};
        });
      }
      return;
    }

    if (auto* arr = ty->As<sem::Array>()) {
      BuildZeroingStatements(arr->ElemType(), [&](uint32_t num_values) {
        // num_values is the number of values to zero for the element type.
        // The number of iterations required to zero the array and its elements
        // is:
        //      `num_values * arr->Count()`
        // The index for this array is:
        //      `(idx % modulo) / division`
        auto modulo = num_values * arr->Count();
        auto division = num_values;
        auto a = get_expr(modulo);
        auto array_indices = a.array_indices;
        array_indices.add(ArrayIndex{modulo, division});
        auto index =
            utils::GetOrCreate(array_index_names, ArrayIndex{modulo, division},
                               [&] { return b.Symbols().New("i"); });
        return Expression{b.IndexAccessor(a.expr, index), a.num_iterations,
                          array_indices};
      });
      return;
    }

    TINT_UNREACHABLE(Transform, b.Diagnostics())
        << "could not zero workgroup type: " << ty->type_name();
  }

  /// DeclareArrayIndices returns a list of statements that contain the `let`
  /// declarations for all of the ArrayIndices.
  /// @param num_iterations the number of iterations for the block
  /// @param array_indices the list of array indices to generate `let`
  ///        declarations for
  /// @param iteration a function that returns the index of the current
  ///         iteration.
  /// @returns the list of `let` statements that declare the array indices
  ast::StatementList DeclareArrayIndices(
      uint32_t num_iterations,
      const ArrayIndices& array_indices,
      const std::function<ast::Expression*()>& iteration) {
    ast::StatementList stmts;
    std::map<Symbol, ArrayIndex> indices_by_name;
    for (auto index : array_indices) {
      auto name = array_index_names.at(index);
      auto* mod =
          (num_iterations > index.modulo)
              ? b.create<ast::BinaryExpression>(
                    ast::BinaryOp::kModulo, iteration(), b.Expr(index.modulo))
              : iteration();
      auto* div = (index.division != 1u) ? b.Div(mod, index.division) : mod;
      auto* decl = b.Decl(b.Const(name, b.ty.u32(), div));
      stmts.emplace_back(decl);
    }
    return stmts;
  }

  /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
  /// #workgroup_size_expr with the linear workgroup size.
  /// @param deco the workgroup decoration applied to the entry point function
  void CalculateWorkgroupSize(const ast::WorkgroupDecoration* deco) {
    bool is_signed = false;
    workgroup_size_const = 1u;
    workgroup_size_expr = nullptr;
    for (auto* expr : deco->Values()) {
      if (!expr) {
        continue;
      }
      auto* sem = ctx.src->Sem().Get(expr);
      if (auto c = sem->ConstantValue()) {
        if (c.ElementType()->Is<sem::I32>()) {
          workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
          continue;
        } else if (c.ElementType()->Is<sem::U32>()) {
          workgroup_size_const *= c.Elements()[0].u32;
          continue;
        }
      }
      // Constant value could not be found. Build expression instead.
      workgroup_size_expr = [this, expr, size = workgroup_size_expr] {
        auto* e = ctx.Clone(expr);
        if (ctx.src->TypeOf(expr)->UnwrapRef()->Is<sem::I32>()) {
          e = b.Construct<ProgramBuilder::u32>(e);
        }
        return size ? b.Mul(size(), e) : e;
      };
    }
    if (workgroup_size_expr) {
      if (workgroup_size_const != 1) {
        // Fold workgroup_size_const in to workgroup_size_expr
        workgroup_size_expr = [this, is_signed,
                               const_size = workgroup_size_const,
                               expr_size = workgroup_size_expr] {
          return is_signed
                     ? b.Mul(expr_size(), static_cast<int32_t>(const_size))
                     : b.Mul(expr_size(), const_size);
        };
      }
      // Indicate that workgroup_size_expr should be used instead of the
      // constant.
      workgroup_size_const = 0;
    }
  }

  /// @returns true if a variable with store type `ty` can be efficiently zeroed
  /// by assignment of a type constructor without operands. If
  /// CanTriviallyZero() returns false, then the type needs to be
  /// initialized by decomposing the initialization into multiple
  /// sub-initializations.
  /// @param ty the type to inspect
  bool CanTriviallyZero(const sem::Type* ty) {
    if (ty->Is<sem::Atomic>()) {
      return false;
    }
    if (auto* str = ty->As<sem::Struct>()) {
      for (auto* member : str->Members()) {
        if (!CanTriviallyZero(member->Type())) {
          return false;
        }
      }
    }
    if (ty->Is<sem::Array>()) {
      return false;
    }
    // True for all other storable types
    return true;
  }
};

ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;

ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;

void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
  for (auto* fn : ctx.src->AST().Functions()) {
    if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
      State{ctx}.Run(fn);
    }
  }
  ctx.Clone();
}

}  // namespace transform
}  // namespace tint
