// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/transform/localize_struct_array_assignment.h"

#include <unordered_map>
#include <utility>

#include "src/ast/assignment_statement.h"
#include "src/ast/traverse_expressions.h"
#include "src/program_builder.h"
#include "src/sem/expression.h"
#include "src/sem/member_accessor_expression.h"
#include "src/sem/reference_type.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
#include "src/transform/simplify_pointers.h"
#include "src/utils/scoped_assignment.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment);

namespace tint {
namespace transform {

/// Private implementation of LocalizeStructArrayAssignment transform
class LocalizeStructArrayAssignment::State {
 private:
  CloneContext& ctx;
  ProgramBuilder& b;

  /// Returns true if `expr` contains an index accessor expression to a
  /// structure member of array type.
  bool ContainsStructArrayIndex(const ast::Expression* expr) {
    bool result = false;
    ast::TraverseExpressions(
        expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
          // Indexing using a runtime value?
          auto* idx_sem = ctx.src->Sem().Get(ia->index);
          if (!idx_sem->ConstantValue().IsValid()) {
            // Indexing a member access expr?
            if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
              // That accesses an array?
              if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
                result = true;
                return ast::TraverseAction::Stop;
              }
            }
          }
          return ast::TraverseAction::Descend;
        });

    return result;
  }

  // Returns the type and storage class of the originating variable of the lhs
  // of the assignment statement.
  // See https://www.w3.org/TR/WGSL/#originating-variable-section
  std::pair<const sem::Type*, ast::StorageClass>
  GetOriginatingTypeAndStorageClass(
      const ast::AssignmentStatement* assign_stmt) {
    // Get first IdentifierExpr from lhs of assignment, which should resolve to
    // the pointer or reference of the originating variable of the assignment.
    // TraverseExpressions traverses left to right, and this code depends on the
    // fact that for an assignment statement, the variable will be the left-most
    // expression.
    // TODO(crbug.com/tint/1341): do this in the Resolver, setting the
    // originating variable on sem::Expression.
    const ast::IdentifierExpression* ident = nullptr;
    ast::TraverseExpressions(assign_stmt->lhs, b.Diagnostics(),
                             [&](const ast::IdentifierExpression* id) {
                               ident = id;
                               return ast::TraverseAction::Stop;
                             });
    auto* sem_var_user = ctx.src->Sem().Get<sem::VariableUser>(ident);
    if (!sem_var_user) {
      TINT_ICE(Transform, b.Diagnostics())
          << "Expected to find variable of lhs of assignment statement";
      return {};
    }

    auto* var = sem_var_user->Variable();
    if (auto* ptr = var->Type()->As<sem::Pointer>()) {
      return {ptr->StoreType(), ptr->StorageClass()};
    }

    auto* ref = var->Type()->As<sem::Reference>();
    if (!ref) {
      TINT_ICE(Transform, b.Diagnostics())
          << "Expecting to find variable of type pointer or reference on lhs "
             "of assignment statement";
      return {};
    }

    return {ref->StoreType(), ref->StorageClass()};
  }

 public:
  /// Constructor
  /// @param ctx_in the CloneContext primed with the input program and
  /// ProgramBuilder
  explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}

  /// Runs the transform
  void Run() {
    struct Shared {
      bool process_nested_nodes = false;
      ast::StatementList insert_before_stmts;
      ast::StatementList insert_after_stmts;
    } s;

    ctx.ReplaceAll([&](const ast::AssignmentStatement* assign_stmt)
                       -> const ast::Statement* {
      // Process if it's an assignment statement to a dynamically indexed array
      // within a struct on a function or private storage variable. This
      // specific use-case is what FXC fails to compile with:
      // error X3500: array reference cannot be used as an l-value; not natively
      // addressable
      if (!ContainsStructArrayIndex(assign_stmt->lhs)) {
        return nullptr;
      }
      auto og = GetOriginatingTypeAndStorageClass(assign_stmt);
      if (!(og.first->Is<sem::Struct>() &&
            (og.second == ast::StorageClass::kFunction ||
             og.second == ast::StorageClass::kPrivate))) {
        return nullptr;
      }

      // Reset shared state for this assignment statement
      s = Shared{};

      const ast::Expression* new_lhs = nullptr;
      {
        TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
        new_lhs = ctx.Clone(assign_stmt->lhs);
      }

      auto* new_assign_stmt = b.Assign(new_lhs, ctx.Clone(assign_stmt->rhs));

      // Combine insert_before_stmts + new_assign_stmt + insert_after_stmts into
      // a block and return it
      ast::StatementList stmts = std::move(s.insert_before_stmts);
      stmts.reserve(1 + s.insert_after_stmts.size());
      stmts.emplace_back(new_assign_stmt);
      stmts.insert(stmts.end(), s.insert_after_stmts.begin(),
                   s.insert_after_stmts.end());

      return b.Block(std::move(stmts));
    });

    ctx.ReplaceAll([&](const ast::IndexAccessorExpression* index_access)
                       -> const ast::Expression* {
      if (!s.process_nested_nodes) {
        return nullptr;
      }

      // Indexing a member access expr?
      auto* mem_access =
          index_access->object->As<ast::MemberAccessorExpression>();
      if (!mem_access) {
        return nullptr;
      }

      // Process any nested IndexAccessorExpressions
      mem_access = ctx.Clone(mem_access);

      // Store the address of the member access into a let as we need to read
      // the value twice e.g. let tint_symbol = &(s.a1);
      auto mem_access_ptr = b.Sym();
      s.insert_before_stmts.push_back(
          b.Decl(b.Const(mem_access_ptr, nullptr, b.AddressOf(mem_access))));

      // Disable further transforms when cloning
      TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, false);

      // Copy entire array out of struct into local temp var
      // e.g. var tint_symbol_1 = *(tint_symbol);
      auto tmp_var = b.Sym();
      s.insert_before_stmts.push_back(
          b.Decl(b.Var(tmp_var, nullptr, b.Deref(mem_access_ptr))));

      // Replace input index_access with a clone of itself, but with its
      // .object replaced by the new temp var. This is returned from this
      // function to modify the original assignment statement. e.g.
      // tint_symbol_1[uniforms.i]
      auto* new_index_access =
          b.IndexAccessor(tmp_var, ctx.Clone(index_access->index));

      // Assign temp var back to array
      // e.g. *(tint_symbol) = tint_symbol_1;
      auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
      s.insert_after_stmts.insert(s.insert_after_stmts.begin(),
                                  assign_rhs_to_temp);  // push_front

      return new_index_access;
    });

    ctx.Clone();
  }
};

LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;

LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;

void LocalizeStructArrayAssignment::Run(CloneContext& ctx,
                                        const DataMap&,
                                        DataMap&) {
  if (!Requires<SimplifyPointers>(ctx)) {
    return;
  }

  State state(ctx);
  state.Run();

  // This transform may introduce pointers
  ctx.dst->UnsetTransformApplied<transform::SimplifyPointers>();
}
}  // namespace transform
}  // namespace tint
