// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/transform/binding_remapper.h"

#include <string>
#include <unordered_set>
#include <utility>

#include "src/tint/ast/disable_validation_attribute.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/variable.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper);
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper::Remappings);

namespace tint {
namespace transform {

BindingRemapper::Remappings::Remappings(BindingPoints bp,
                                        AccessControls ac,
                                        bool may_collide)
    : binding_points(std::move(bp)),
      access_controls(std::move(ac)),
      allow_collisions(may_collide) {}

BindingRemapper::Remappings::Remappings(const Remappings&) = default;
BindingRemapper::Remappings::~Remappings() = default;

BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default;

bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
  if (auto* remappings = inputs.Get<Remappings>()) {
    return !remappings->binding_points.empty() ||
           !remappings->access_controls.empty();
  }
  return false;
}

void BindingRemapper::Run(CloneContext& ctx,
                          const DataMap& inputs,
                          DataMap&) const {
  auto* remappings = inputs.Get<Remappings>();
  if (!remappings) {
    ctx.dst->Diagnostics().add_error(
        diag::System::Transform,
        "missing transform data for " + std::string(TypeInfo().name));
    return;
  }

  // A set of post-remapped binding points that need to be decorated with a
  // DisableValidationAttribute to disable binding-point-collision validation
  std::unordered_set<sem::BindingPoint> add_collision_attr;

  if (remappings->allow_collisions) {
    // Scan for binding point collisions generated by this transform.
    // Populate all collisions in the `add_collision_attr` set.
    for (auto* func_ast : ctx.src->AST().Functions()) {
      if (!func_ast->IsEntryPoint()) {
        continue;
      }
      auto* func = ctx.src->Sem().Get(func_ast);
      std::unordered_map<sem::BindingPoint, int> binding_point_counts;
      for (auto* var : func->TransitivelyReferencedGlobals()) {
        if (auto binding_point = var->Declaration()->BindingPoint()) {
          BindingPoint from{binding_point.group->value,
                            binding_point.binding->value};
          auto bp_it = remappings->binding_points.find(from);
          if (bp_it != remappings->binding_points.end()) {
            // Remapped
            BindingPoint to = bp_it->second;
            if (binding_point_counts[to]++) {
              add_collision_attr.emplace(to);
            }
          } else {
            // No remapping
            if (binding_point_counts[from]++) {
              add_collision_attr.emplace(from);
            }
          }
        }
      }
    }
  }

  for (auto* var : ctx.src->AST().GlobalVariables()) {
    if (auto binding_point = var->BindingPoint()) {
      // The original binding point
      BindingPoint from{binding_point.group->value,
                        binding_point.binding->value};

      // The binding point after remapping
      BindingPoint bp = from;

      // Replace any group or binding attributes.
      // Note: This has to be performed *before* remapping access controls, as
      // `ctx.Clone(var->attributes)` depend on these replacements.
      auto bp_it = remappings->binding_points.find(from);
      if (bp_it != remappings->binding_points.end()) {
        BindingPoint to = bp_it->second;
        auto* new_group = ctx.dst->create<ast::GroupAttribute>(to.group);
        auto* new_binding = ctx.dst->create<ast::BindingAttribute>(to.binding);

        ctx.Replace(binding_point.group, new_group);
        ctx.Replace(binding_point.binding, new_binding);
        bp = to;
      }

      // Replace any access controls.
      auto ac_it = remappings->access_controls.find(from);
      if (ac_it != remappings->access_controls.end()) {
        ast::Access ac = ac_it->second;
        if (ac > ast::Access::kLastValid) {
          ctx.dst->Diagnostics().add_error(
              diag::System::Transform,
              "invalid access mode (" +
                  std::to_string(static_cast<uint32_t>(ac)) + ")");
          return;
        }
        auto* sem = ctx.src->Sem().Get(var);
        if (sem->StorageClass() != ast::StorageClass::kStorage) {
          ctx.dst->Diagnostics().add_error(
              diag::System::Transform,
              "cannot apply access control to variable with storage class " +
                  std::string(ast::ToString(sem->StorageClass())));
          return;
        }
        auto* ty = sem->Type()->UnwrapRef();
        const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
        auto* new_var = ctx.dst->create<ast::Variable>(
            ctx.Clone(var->source), ctx.Clone(var->symbol),
            var->declared_storage_class, ac, inner_ty, false, false,
            ctx.Clone(var->constructor), ctx.Clone(var->attributes));
        ctx.Replace(var, new_var);
      }

      // Add `DisableValidationAttribute`s if required
      if (add_collision_attr.count(bp)) {
        auto* attribute =
            ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
        ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
      }
    }
  }

  ctx.Clone();
}

}  // namespace transform
}  // namespace tint
