// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/transform/vertex_pulling.h"

#include <utility>

#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/variable_decl_statement.h"
#include "src/program_builder.h"
#include "src/semantic/variable.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);

namespace tint {
namespace transform {
namespace {

static const char kVertexBufferNamePrefix[] = "_tint_pulling_vertex_buffer_";
static const char kStructBufferName[] = "_tint_vertex_data";
static const char kStructName[] = "TintVertexData";
static const char kPullingPosVarName[] = "_tint_pulling_pos";
static const char kDefaultVertexIndexName[] = "_tint_pulling_vertex_index";
static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index";

}  // namespace

VertexPulling::VertexPulling() = default;
VertexPulling::VertexPulling(const Config& config) : cfg_(config) {}

VertexPulling::~VertexPulling() = default;

Transform::Output VertexPulling::Run(const Program* in, const DataMap& data) {
  ProgramBuilder out;

  auto cfg = cfg_;
  if (auto* cfg_data = data.Get<Config>()) {
    cfg = *cfg_data;
  }

  // Find entry point
  auto* func = in->AST().Functions().Find(
      in->Symbols().Get(cfg.entry_point_name), ast::PipelineStage::kVertex);
  if (func == nullptr) {
    out.Diagnostics().add_error("Vertex stage entry point not found");
    return Output(Program(std::move(out)));
  }

  // TODO(idanr): Need to check shader locations in descriptor cover all
  // attributes

  // TODO(idanr): Make sure we covered all error cases, to guarantee the
  // following stages will pass

  CloneContext ctx(&out, in);
  State state{ctx, cfg};
  state.FindOrInsertVertexIndexIfUsed();
  state.FindOrInsertInstanceIndexIfUsed();
  state.ConvertVertexInputVariablesToPrivate();
  state.AddVertexStorageBuffers();

  for (auto& replacement : state.location_replacements) {
    ctx.Replace(replacement.from, replacement.to);
  }
  ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* {
    if (f == func) {
      return CloneWithStatementsAtStart(&ctx, f,
                                        {state.CreateVertexPullingPreamble()});
    }
    return nullptr;  // Just clone func
  });
  ctx.Clone();

  return Output(Program(std::move(out)));
}

VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
    default;

VertexPulling::State::State(CloneContext& context, const Config& c)
    : ctx(context), cfg(c) {}
VertexPulling::State::State(const State&) = default;
VertexPulling::State::~State() = default;

std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const {
  return kVertexBufferNamePrefix + std::to_string(index);
}

void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
  bool uses_vertex_step_mode = false;
  for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
    if (buffer_layout.step_mode == InputStepMode::kVertex) {
      uses_vertex_step_mode = true;
      break;
    }
  }
  if (!uses_vertex_step_mode) {
    return;
  }

  // Look for an existing vertex index builtin
  for (auto* v : ctx.src->AST().GlobalVariables()) {
    auto* sem = ctx.src->Sem().Get(v);
    if (sem->StorageClass() != ast::StorageClass::kInput) {
      continue;
    }

    for (auto* d : v->decorations()) {
      if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
        if (builtin->value() == ast::Builtin::kVertexIndex) {
          vertex_index_name = ctx.src->Symbols().NameFor(v->symbol());
          return;
        }
      }
    }
  }

  // We didn't find a vertex index builtin, so create one
  vertex_index_name = kDefaultVertexIndexName;

  auto* var = ctx.dst->create<ast::Variable>(
      Source{},                                        // source
      ctx.dst->Symbols().Register(vertex_index_name),  // symbol
      ast::StorageClass::kInput,                       // storage_class
      GetU32Type(),                                    // type
      false,                                           // is_const
      nullptr,                                         // constructor
      ast::DecorationList{
          ctx.dst->create<ast::BuiltinDecoration>(Source{},
                                                  ast::Builtin::kVertexIndex),
      });

  ctx.dst->AST().AddGlobalVariable(var);
}

void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
  bool uses_instance_step_mode = false;
  for (const VertexBufferLayoutDescriptor& buffer_layout : cfg.vertex_state) {
    if (buffer_layout.step_mode == InputStepMode::kInstance) {
      uses_instance_step_mode = true;
      break;
    }
  }
  if (!uses_instance_step_mode) {
    return;
  }

  // Look for an existing instance index builtin
  for (auto* v : ctx.src->AST().GlobalVariables()) {
    auto* sem = ctx.src->Sem().Get(v);
    if (sem->StorageClass() != ast::StorageClass::kInput) {
      continue;
    }

    for (auto* d : v->decorations()) {
      if (auto* builtin = d->As<ast::BuiltinDecoration>()) {
        if (builtin->value() == ast::Builtin::kInstanceIndex) {
          instance_index_name = ctx.src->Symbols().NameFor(v->symbol());
          return;
        }
      }
    }
  }

  // We didn't find an instance index builtin, so create one
  instance_index_name = kDefaultInstanceIndexName;

  auto* var = ctx.dst->create<ast::Variable>(
      Source{},                                          // source
      ctx.dst->Symbols().Register(instance_index_name),  // symbol
      ast::StorageClass::kInput,                         // storage_class
      GetU32Type(),                                      // type
      false,                                             // is_const
      nullptr,                                           // constructor
      ast::DecorationList{
          ctx.dst->create<ast::BuiltinDecoration>(Source{},
                                                  ast::Builtin::kInstanceIndex),
      });
  ctx.dst->AST().AddGlobalVariable(var);
}

void VertexPulling::State::ConvertVertexInputVariablesToPrivate() {
  for (auto* v : ctx.src->AST().GlobalVariables()) {
    auto* sem = ctx.src->Sem().Get(v);
    if (sem->StorageClass() != ast::StorageClass::kInput) {
      continue;
    }

    for (auto* d : v->decorations()) {
      if (auto* l = d->As<ast::LocationDecoration>()) {
        uint32_t location = l->value();
        // This is where the replacement is created. Expressions use identifier
        // strings instead of pointers, so we don't need to update any other
        // place in the AST.
        auto name = ctx.src->Symbols().NameFor(v->symbol());
        auto* replacement = ctx.dst->create<ast::Variable>(
            Source{},                           // source
            ctx.dst->Symbols().Register(name),  // symbol
            ast::StorageClass::kPrivate,        // storage_class
            ctx.Clone(v->declared_type()),      // type
            false,                              // is_const
            nullptr,                            // constructor
            ast::DecorationList{});             // decorations
        location_to_var[location] = replacement;
        location_replacements.emplace_back(LocationReplacement{v, replacement});
        break;
      }
    }
  }
}

void VertexPulling::State::AddVertexStorageBuffers() {
  // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
  // The array inside the struct definition
  auto* internal_array_type = ctx.dst->create<type::Array>(
      GetU32Type(), 0,
      ast::DecorationList{
          ctx.dst->create<ast::StrideDecoration>(Source{}, 4u),
      });

  // Creating the struct type
  ast::StructMemberList members;
  members.push_back(ctx.dst->create<ast::StructMember>(
      Source{}, ctx.dst->Symbols().Register(kStructBufferName),
      internal_array_type, ast::DecorationList{}));

  ast::DecorationList decos;
  decos.push_back(ctx.dst->create<ast::StructBlockDecoration>(Source{}));

  auto* struct_type = ctx.dst->create<type::Struct>(
      ctx.dst->Symbols().Register(kStructName),
      ctx.dst->create<ast::Struct>(Source{}, std::move(members),
                                   std::move(decos)));

  for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
    // The decorated variable with struct type
    std::string name = GetVertexBufferName(i);
    auto* var = ctx.dst->create<ast::Variable>(
        Source{},                           // source
        ctx.dst->Symbols().Register(name),  // symbol
        ast::StorageClass::kStorage,        // storage_class
        struct_type,                        // type
        false,                              // is_const
        nullptr,                            // constructor
        ast::DecorationList{
            ctx.dst->create<ast::BindingDecoration>(Source{}, i),
            ctx.dst->create<ast::GroupDecoration>(Source{}, cfg.pulling_group),
        });
    ctx.dst->AST().AddGlobalVariable(var);
  }
  ctx.dst->AST().AddConstructedType(struct_type);
}

ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
  // Assign by looking at the vertex descriptor to find attributes with matching
  // location.

  ast::StatementList stmts;

  // Declare the |kPullingPosVarName| variable in the shader
  auto* pos_declaration = ctx.dst->create<ast::VariableDeclStatement>(
      Source{}, ctx.dst->create<ast::Variable>(
                    Source{},                                         // source
                    ctx.dst->Symbols().Register(kPullingPosVarName),  // symbol
                    ast::StorageClass::kFunction,  // storage_class
                    GetU32Type(),                  // type
                    false,                         // is_const
                    nullptr,                       // constructor
                    ast::DecorationList{}));       // decorations

  // |kPullingPosVarName| refers to the byte location of the current read. We
  // declare a variable in the shader to avoid having to reuse Expression
  // objects.
  stmts.emplace_back(pos_declaration);

  for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
    const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[i];

    for (const VertexAttributeDescriptor& attribute_desc :
         buffer_layout.attributes) {
      auto it = location_to_var.find(attribute_desc.shader_location);
      if (it == location_to_var.end()) {
        continue;
      }
      auto* v = it->second;

      auto name = buffer_layout.step_mode == InputStepMode::kVertex
                      ? vertex_index_name
                      : instance_index_name;
      // Identifier to index by
      auto* index_identifier = ctx.dst->create<ast::IdentifierExpression>(
          Source{}, ctx.dst->Symbols().Register(name));

      // An expression for the start of the read in the buffer in bytes
      auto* pos_value = ctx.dst->create<ast::BinaryExpression>(
          Source{}, ast::BinaryOp::kAdd,
          ctx.dst->create<ast::BinaryExpression>(
              Source{}, ast::BinaryOp::kMultiply, index_identifier,
              GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
          GenUint(static_cast<uint32_t>(attribute_desc.offset)));

      // Update position of the read
      auto* set_pos_expr = ctx.dst->create<ast::AssignmentStatement>(
          Source{}, CreatePullingPositionIdent(), pos_value);
      stmts.emplace_back(set_pos_expr);

      stmts.emplace_back(ctx.dst->create<ast::AssignmentStatement>(
          Source{},
          ctx.dst->create<ast::IdentifierExpression>(Source{}, v->symbol()),
          AccessByFormat(i, attribute_desc.format)));
    }
  }

  return ctx.dst->create<ast::BlockStatement>(Source{}, stmts);
}

ast::Expression* VertexPulling::State::GenUint(uint32_t value) const {
  return ctx.dst->create<ast::ScalarConstructorExpression>(
      Source{},
      ctx.dst->create<ast::UintLiteral>(Source{}, GetU32Type(), value));
}

ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const {
  return ctx.dst->create<ast::IdentifierExpression>(
      Source{}, ctx.dst->Symbols().Register(kPullingPosVarName));
}

ast::Expression* VertexPulling::State::AccessByFormat(
    uint32_t buffer,
    VertexFormat format) const {
  // TODO(idanr): this doesn't account for the format of the attribute in the
  // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
  // right now, we would try to assign a vec4<f32> to this attribute, but we
  // really need to assign a vec4<u32> by casting.
  // We could split this function to first do memory accesses and unpacking into
  // int/uint/float1-4/etc, then convert that variable to a var<in> with the
  // conversion defined in the WebGPU spec.
  switch (format) {
    case VertexFormat::kU32:
      return AccessU32(buffer, CreatePullingPositionIdent());
    case VertexFormat::kI32:
      return AccessI32(buffer, CreatePullingPositionIdent());
    case VertexFormat::kF32:
      return AccessF32(buffer, CreatePullingPositionIdent());
    case VertexFormat::kVec2F32:
      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2);
    case VertexFormat::kVec3F32:
      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3);
    case VertexFormat::kVec4F32:
      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4);
    default:
      return nullptr;
  }
}

ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer,
                                                 ast::Expression* pos) const {
  // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
  // has byte offsets for each attribute, and we will convert it to u32 indexes
  // by dividing. Then, that element is going to be read, and if needed,
  // unpacked into an appropriate variable. All reads should end up here as a
  // base case.
  auto vbuf_name = GetVertexBufferName(buffer);
  return ctx.dst->create<ast::ArrayAccessorExpression>(
      Source{},
      ctx.dst->create<ast::MemberAccessorExpression>(
          Source{},
          ctx.dst->create<ast::IdentifierExpression>(
              Source{}, ctx.dst->Symbols().Register(vbuf_name)),
          ctx.dst->create<ast::IdentifierExpression>(
              Source{}, ctx.dst->Symbols().Register(kStructBufferName))),
      ctx.dst->create<ast::BinaryExpression>(Source{}, ast::BinaryOp::kDivide,
                                             pos, GenUint(4)));
}

ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer,
                                                 ast::Expression* pos) const {
  // as<T> reinterprets bits
  return ctx.dst->create<ast::BitcastExpression>(Source{}, GetI32Type(),
                                                 AccessU32(buffer, pos));
}

ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer,
                                                 ast::Expression* pos) const {
  // as<T> reinterprets bits
  return ctx.dst->create<ast::BitcastExpression>(Source{}, GetF32Type(),
                                                 AccessU32(buffer, pos));
}

ast::Expression* VertexPulling::State::AccessPrimitive(
    uint32_t buffer,
    ast::Expression* pos,
    VertexFormat format) const {
  // This function uses a position expression to read, rather than using the
  // position variable. This allows us to read from offset positions relative to
  // |kPullingPosVarName|. We can't call AccessByFormat because it reads only
  // from the position variable.
  switch (format) {
    case VertexFormat::kU32:
      return AccessU32(buffer, pos);
    case VertexFormat::kI32:
      return AccessI32(buffer, pos);
    case VertexFormat::kF32:
      return AccessF32(buffer, pos);
    default:
      return nullptr;
  }
}

ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer,
                                                 uint32_t element_stride,
                                                 type::Type* base_type,
                                                 VertexFormat base_format,
                                                 uint32_t count) const {
  ast::ExpressionList expr_list;
  for (uint32_t i = 0; i < count; ++i) {
    // Offset read position by element_stride for each component
    auto* cur_pos = ctx.dst->create<ast::BinaryExpression>(
        Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(),
        GenUint(element_stride * i));
    expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format));
  }

  return ctx.dst->create<ast::TypeConstructorExpression>(
      Source{}, ctx.dst->create<type::Vector>(base_type, count),
      std::move(expr_list));
}

type::Type* VertexPulling::State::GetU32Type() const {
  return ctx.dst->create<type::U32>();
}

type::Type* VertexPulling::State::GetI32Type() const {
  return ctx.dst->create<type::I32>();
}

type::Type* VertexPulling::State::GetF32Type() const {
  return ctx.dst->create<type::F32>();
}

VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;

VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
    uint64_t in_array_stride,
    InputStepMode in_step_mode,
    std::vector<VertexAttributeDescriptor> in_attributes)
    : array_stride(in_array_stride),
      step_mode(in_step_mode),
      attributes(std::move(in_attributes)) {}

VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
    const VertexBufferLayoutDescriptor& other) = default;

VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
    const VertexBufferLayoutDescriptor& other) = default;

VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;

}  // namespace transform
}  // namespace tint
