// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/transform/pad_array_elements.h"

#include <unordered_map>
#include <utility>

#include "src/tint/program_builder.h"
#include "src/tint/sem/array.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/map.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);

namespace tint {
namespace transform {
namespace {

using ArrayBuilder = std::function<const ast::Array*()>;

/// PadArray returns a function that constructs a new array in `ctx.dst` with
/// the element type padded to account for the explicit stride. PadArray will
/// recursively pad arrays-of-arrays. The new array element type will be added
/// to module-scope type declarations of `ctx.dst`.
/// @param ctx the CloneContext
/// @param create_ast_type_for Transform::CreateASTTypeFor()
/// @param padded_arrays a map of src array type to the new array name
/// @param array the array type
/// @return the new AST array
template <typename CREATE_AST_TYPE_FOR>
ArrayBuilder PadArray(
    CloneContext& ctx,
    CREATE_AST_TYPE_FOR&& create_ast_type_for,
    std::unordered_map<const sem::Array*, ArrayBuilder>& padded_arrays,
    const sem::Array* array) {
  if (array->IsStrideImplicit()) {
    // We don't want to wrap arrays that have an implicit stride
    return nullptr;
  }

  return utils::GetOrCreate(padded_arrays, array, [&] {
    // Generate a unique name for the array element type
    auto name = ctx.dst->Symbols().New("tint_padded_array_element");

    // Examine the element type. Is it also an array?
    const ast::Type* el_ty = nullptr;
    if (auto* el_array = array->ElemType()->As<sem::Array>()) {
      // Array of array - call PadArray() on the element type
      if (auto p =
              PadArray(ctx, create_ast_type_for, padded_arrays, el_array)) {
        el_ty = p();
      }
    }

    // If the element wasn't a padded array, just create the typical AST type
    // for it
    if (el_ty == nullptr) {
      el_ty = create_ast_type_for(ctx, array->ElemType());
    }

    // Structure() will create and append the ast::Struct to the
    // global declarations of `ctx.dst`. As we haven't finished building the
    // current module-scope statement or function, this will be placed
    // immediately before the usage.
    ctx.dst->Structure(
        name,
        {ctx.dst->Member("el", el_ty, {ctx.dst->MemberSize(array->Stride())})});

    auto* dst = ctx.dst;
    return [=] {
      if (array->IsRuntimeSized()) {
        return dst->ty.array(dst->create<ast::TypeName>(name));
      } else {
        return dst->ty.array(dst->create<ast::TypeName>(name), array->Count());
      }
    };
  });
}

}  // namespace

PadArrayElements::PadArrayElements() = default;

PadArrayElements::~PadArrayElements() = default;

bool PadArrayElements::ShouldRun(const Program* program, const DataMap&) const {
  for (auto* node : program->ASTNodes().Objects()) {
    if (auto* var = node->As<ast::Type>()) {
      if (auto* arr = program->Sem().Get<sem::Array>(var)) {
        if (!arr->IsStrideImplicit()) {
          return true;
        }
      }
    }
  }
  return false;
}

void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
  auto& sem = ctx.src->Sem();

  std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays;
  auto pad = [&](const sem::Array* array) {
    return PadArray(ctx, CreateASTTypeFor, padded_arrays, array);
  };

  // Replace all array types with their corresponding padded array type
  ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
    auto* type = ctx.src->TypeOf(ast_type);
    if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
      if (auto p = pad(array)) {
        return p();
      }
    }
    return nullptr;
  });

  // Fix up index accessors so `a[1]` becomes `a[1].el`
  ctx.ReplaceAll([&](const ast::IndexAccessorExpression* accessor)
                     -> const ast::Expression* {
    if (auto* array = tint::As<sem::Array>(
            sem.Get(accessor->object)->Type()->UnwrapRef())) {
      if (pad(array)) {
        // Array element is wrapped in a structure. Emit a member accessor
        // to get to the actual array element.
        auto* idx = ctx.CloneWithoutTransform(accessor);
        return ctx.dst->MemberAccessor(idx, "el");
      }
    }
    return nullptr;
  });

  // Fix up array constructors so `A(1,2)` becomes
  // `A(padded(1), padded(2))`
  ctx.ReplaceAll(
      [&](const ast::CallExpression* expr) -> const ast::Expression* {
        auto* call = sem.Get(expr);
        if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
          if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
            if (auto p = pad(array)) {
              auto* arr_ty = p();
              auto el_typename = arr_ty->type->As<ast::TypeName>()->name;

              ast::ExpressionList args;
              args.reserve(call->Arguments().size());
              for (auto* arg : call->Arguments()) {
                auto* val = ctx.Clone(arg->Declaration());
                args.emplace_back(ctx.dst->Construct(
                    ctx.dst->create<ast::TypeName>(el_typename), val));
              }

              return ctx.dst->Construct(arr_ty, args);
            }
          }
        }
        return nullptr;
      });

  ctx.Clone();
}

}  // namespace transform
}  // namespace tint
