// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/writer/append_vector.h"

#include <utility>

#include "src/sem/expression.h"

namespace tint {
namespace writer {

namespace {

ast::TypeConstructorExpression* AsVectorConstructor(ProgramBuilder* b,
                                                    ast::Expression* expr) {
  if (auto* constructor = expr->As<ast::TypeConstructorExpression>()) {
    if (b->TypeOf(constructor)->Is<sem::Vector>()) {
      return constructor;
    }
  }
  return nullptr;
}

}  // namespace

ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
                                             ast::Expression* vector,
                                             ast::Expression* scalar) {
  uint32_t packed_size;
  const sem::Type* packed_el_sem_ty;
  auto* vector_sem = b->Sem().Get(vector);
  auto* vector_ty = vector_sem->Type()->UnwrapRef();
  if (auto* vec = vector_ty->As<sem::Vector>()) {
    packed_size = vec->Width() + 1;
    packed_el_sem_ty = vec->type();
  } else {
    packed_size = 2;
    packed_el_sem_ty = vector_ty;
  }

  ast::Type* packed_el_ty = nullptr;
  if (packed_el_sem_ty->Is<sem::I32>()) {
    packed_el_ty = b->create<ast::I32>();
  } else if (packed_el_sem_ty->Is<sem::U32>()) {
    packed_el_ty = b->create<ast::U32>();
  } else if (packed_el_sem_ty->Is<sem::F32>()) {
    packed_el_ty = b->create<ast::F32>();
  } else if (packed_el_sem_ty->Is<sem::Bool>()) {
    packed_el_ty = b->create<ast::Bool>();
  } else {
    TINT_UNREACHABLE(Writer, b->Diagnostics())
        << "unsupported vector element type: "
        << packed_el_sem_ty->TypeInfo().name;
  }

  auto* statement = vector_sem->Stmt();

  auto* packed_ty = b->create<ast::Vector>(packed_el_ty, packed_size);
  auto* packed_sem_ty = b->create<sem::Vector>(packed_el_sem_ty, packed_size);

  // If the coordinates are already passed in a vector constructor, with only
  // scalar components supplied, extract the elements into the new vector
  // instead of nesting a vector-in-vector.
  // If the coordinates are a zero-constructor of the vector, then expand that
  // to scalar zeros.
  // The other cases for a nested vector constructor are when it is used
  // to convert a vector of a different type, e.g. vec2<i32>(vec2<u32>()).
  // In that case, preserve the original argument, or you'll get a type error.
  ast::ExpressionList packed;
  if (auto* vc = AsVectorConstructor(b, vector)) {
    const auto num_supplied = vc->values().size();
    if (num_supplied == 0) {
      // Zero-value vector constructor. Populate with zeros
      auto buildZero = [&]() -> ast::ScalarConstructorExpression* {
        if (packed_el_sem_ty->Is<sem::I32>()) {
          return b->Expr(0);
        } else if (packed_el_sem_ty->Is<sem::U32>()) {
          return b->Expr(0u);
        } else if (packed_el_sem_ty->Is<sem::F32>()) {
          return b->Expr(0.0f);
        } else if (packed_el_sem_ty->Is<sem::Bool>()) {
          return b->Expr(false);
        } else {
          TINT_UNREACHABLE(Writer, b->Diagnostics())
              << "unsupported vector element type: "
              << packed_el_sem_ty->TypeInfo().name;
        }
        return nullptr;
      };

      for (uint32_t i = 0; i < packed_size - 1; i++) {
        auto* zero = buildZero();
        b->Sem().Add(
            zero, b->create<sem::Expression>(zero, packed_el_sem_ty, statement,
                                             sem::Constant{}));
        packed.emplace_back(zero);
      }
    } else if (num_supplied + 1 == packed_size) {
      // All vector components were supplied as scalars.  Pass them through.
      packed = vc->values();
    }
  }
  if (packed.empty()) {
    // The special cases didn't occur. Use the vector argument as-is.
    packed.emplace_back(vector);
  }
  if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapRef()) {
    // Cast scalar to the vector element type
    auto* scalar_cast = b->Construct(packed_el_ty, scalar);
    b->Sem().Add(scalar_cast,
                 b->create<sem::Expression>(scalar_cast, packed_el_sem_ty,
                                            statement, sem::Constant{}));
    packed.emplace_back(scalar_cast);
  } else {
    packed.emplace_back(scalar);
  }

  auto* constructor = b->Construct(packed_ty, std::move(packed));
  b->Sem().Add(constructor,
               b->create<sem::Expression>(constructor, packed_sem_ty, statement,
                                          sem::Constant{}));

  return constructor;
}

}  // namespace writer
}  // namespace tint
