// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/transform/msl.h"

#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

#include "src/ast/disable_validation_decoration.h"
#include "src/program_builder.h"
#include "src/sem/call.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
#include "src/transform/array_length_from_uniform.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/external_texture_transform.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
#include "src/transform/pad_array_elements.h"
#include "src/transform/promote_initializers_to_const_var.h"
#include "src/transform/simplify.h"
#include "src/transform/wrap_arrays_in_structs.h"
#include "src/transform/zero_init_workgroup_memory.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Config);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Result);

namespace tint {
namespace transform {

Msl::Msl() = default;
Msl::~Msl() = default;

Output Msl::Run(const Program* in, const DataMap& inputs) {
  Manager manager;
  DataMap internal_inputs;

  auto* cfg = inputs.Get<Config>();

  // Build the configs for the internal transforms.
  uint32_t buffer_size_ubo_index = kDefaultBufferSizeUniformIndex;
  uint32_t fixed_sample_mask = 0xFFFFFFFF;
  if (cfg) {
    buffer_size_ubo_index = cfg->buffer_size_ubo_index;
    fixed_sample_mask = cfg->fixed_sample_mask;
  }
  auto array_length_from_uniform_cfg = ArrayLengthFromUniform::Config(
      sem::BindingPoint{0, buffer_size_ubo_index});
  auto entry_point_io_cfg = CanonicalizeEntryPointIO::Config(
      CanonicalizeEntryPointIO::BuiltinStyle::kParameter, fixed_sample_mask);

  // Use the SSBO binding numbers as the indices for the buffer size lookups.
  for (auto* var : in->AST().GlobalVariables()) {
    auto* sem_var = in->Sem().Get(var);
    if (sem_var->StorageClass() == ast::StorageClass::kStorage) {
      array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
          sem_var->BindingPoint(), sem_var->BindingPoint().binding);
    }
  }

  // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
  // ZeroInitWorkgroupMemory may inject new builtin parameters.
  manager.Add<ZeroInitWorkgroupMemory>();
  manager.Add<CanonicalizeEntryPointIO>();
  manager.Add<ExternalTextureTransform>();
  manager.Add<PromoteInitializersToConstVar>();
  manager.Add<WrapArraysInStructs>();
  manager.Add<PadArrayElements>();
  manager.Add<InlinePointerLets>();
  manager.Add<Simplify>();
  // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
  // it assumes that the form of the array length argument is &var.array.
  manager.Add<ArrayLengthFromUniform>();
  internal_inputs.Add<ArrayLengthFromUniform::Config>(
      std::move(array_length_from_uniform_cfg));
  internal_inputs.Add<CanonicalizeEntryPointIO::Config>(
      std::move(entry_point_io_cfg));
  auto out = manager.Run(in, internal_inputs);
  if (!out.program.IsValid()) {
    return out;
  }

  ProgramBuilder builder;
  CloneContext ctx(&builder, &out.program);
  // TODO(jrprice): Consider making this a standalone transform, with target
  // storage class(es) as transform options.
  HandleModuleScopeVariables(ctx);
  ctx.Clone();

  auto result = std::make_unique<Result>(
      out.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);

  builder.SetTransformApplied(this);
  return Output{Program(std::move(builder)), std::move(result)};
}

void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
  // MSL does not allow private and workgroup variables at module-scope, so we
  // push these declarations into the entry point function and then pass them as
  // pointer parameters to any function that references them.
  // Similarly, texture and sampler types are converted to entry point
  // parameters and passed by value to functions that need them.
  //
  // Since WGSL does not allow function-scope variables to have these storage
  // classes, we annotate the new variable declarations with an attribute that
  // bypasses that validation rule.
  //
  // Before:
  // ```
  // var<private> v : f32 = 2.0;
  //
  // fn foo() {
  //   v = v + 1.0;
  // }
  //
  // [[stage(compute), workgroup_size(1)]]
  // fn main() {
  //   foo();
  // }
  // ```
  //
  // After:
  // ```
  // fn foo(v : ptr<private, f32>) {
  //   *v = *v + 1.0;
  // }
  //
  // [[stage(compute), workgroup_size(1)]]
  // fn main() {
  //   var<private> v : f32 = 2.0;
  //   foo(&v);
  // }
  // ```

  // Predetermine the list of function calls that need to be replaced.
  using CallList = std::vector<const ast::CallExpression*>;
  std::unordered_map<const ast::Function*, CallList> calls_to_replace;

  std::vector<ast::Function*> functions_to_process;

  // Build a list of functions that transitively reference any private or
  // workgroup variables, or texture/sampler variables.
  for (auto* func_ast : ctx.src->AST().Functions()) {
    auto* func_sem = ctx.src->Sem().Get(func_ast);

    bool needs_processing = false;
    for (auto* var : func_sem->ReferencedModuleVariables()) {
      if (var->StorageClass() == ast::StorageClass::kPrivate ||
          var->StorageClass() == ast::StorageClass::kWorkgroup ||
          var->StorageClass() == ast::StorageClass::kUniformConstant) {
        needs_processing = true;
        break;
      }
    }

    if (needs_processing) {
      functions_to_process.push_back(func_ast);

      // Find all of the calls to this function that will need to be replaced.
      for (auto* call : func_sem->CallSites()) {
        auto* call_sem = ctx.src->Sem().Get(call);
        calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
      }
    }
  }

  for (auto* func_ast : functions_to_process) {
    auto* func_sem = ctx.src->Sem().Get(func_ast);
    bool is_entry_point = func_ast->IsEntryPoint();

    // Map module-scope variables onto their function-scope replacement.
    std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;

    for (auto* var : func_sem->ReferencedModuleVariables()) {
      if (var->StorageClass() != ast::StorageClass::kPrivate &&
          var->StorageClass() != ast::StorageClass::kWorkgroup &&
          var->StorageClass() != ast::StorageClass::kUniformConstant) {
        continue;
      }

      // This is the symbol for the variable that replaces the module-scope var.
      auto new_var_symbol = ctx.dst->Sym();

      auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());

      if (is_entry_point) {
        if (store_type->is_handle()) {
          // For a texture or sampler variable, redeclare it as an entry point
          // parameter. Disable entry point parameter validation.
          auto* disable_validation =
              ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
                  ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
          auto decos = ctx.Clone(var->Declaration()->decorations());
          decos.push_back(disable_validation);
          auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
          ctx.InsertFront(func_ast->params(), param);
        } else {
          // For a private or workgroup variable, redeclare it at function
          // scope. Disable storage class validation on this variable.
          auto* disable_validation =
              ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
                  ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
          auto* constructor = ctx.Clone(var->Declaration()->constructor());
          auto* local_var = ctx.dst->Var(
              new_var_symbol, store_type, var->StorageClass(), constructor,
              ast::DecorationList{disable_validation});
          ctx.InsertFront(func_ast->body()->statements(),
                          ctx.dst->Decl(local_var));
        }
      } else {
        // For a regular function, redeclare the variable as a parameter.
        // Use a pointer for non-handle types.
        auto* param_type = store_type;
        if (!store_type->is_handle()) {
          param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
        }
        ctx.InsertBack(func_ast->params(),
                       ctx.dst->Param(new_var_symbol, param_type));
      }

      // Replace all uses of the module-scope variable.
      // For non-entry points, dereference non-handle pointer parameters.
      for (auto* user : var->Users()) {
        if (user->Stmt()->Function() == func_ast) {
          ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
          if (!is_entry_point && !store_type->is_handle()) {
            expr = ctx.dst->Deref(expr);
          }
          ctx.Replace(user->Declaration(), expr);
        }
      }

      var_to_symbol[var] = new_var_symbol;
    }

    // Pass the variables as pointers to any functions that need them.
    for (auto* call : calls_to_replace[func_ast]) {
      auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
      auto* target_sem = ctx.src->Sem().Get(target);

      // Add new arguments for any variables that are needed by the callee.
      // For entry points, pass non-handle types as pointers.
      for (auto* target_var : target_sem->ReferencedModuleVariables()) {
        if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
            target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
            target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
          ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
          if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
            arg = ctx.dst->AddressOf(arg);
          }
          ctx.InsertBack(call->params(), arg);
        }
      }
    }
  }

  // Now remove all module-scope variables with these storage classes.
  for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
    auto* var_sem = ctx.src->Sem().Get(var_ast);
    if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
        var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
        var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
      ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
    }
  }
}

Msl::Config::Config(uint32_t buffer_size_ubo_idx, uint32_t sample_mask)
    : buffer_size_ubo_index(buffer_size_ubo_idx),
      fixed_sample_mask(sample_mask) {}
Msl::Config::Config(const Config&) = default;
Msl::Config::~Config() = default;

Msl::Result::Result(bool needs_buffer_sizes)
    : needs_storage_buffer_sizes(needs_buffer_sizes) {}
Msl::Result::Result(const Result&) = default;
Msl::Result::~Result() = default;

}  // namespace transform
}  // namespace tint
