Add vertex pulling transform

Adds a first-pass version of vertex pulling. This is missing several important things such as buffer offsets, support for more types, and clamping.

Bug: dawn:480, tint:206
Change-Id: Ia8a3abc446bca4c5a40e064f85fb59de1c3f5af9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26260
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 9d589e7..28bf839 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -314,6 +314,8 @@
     "src/ast/struct_member_offset_decoration.h",
     "src/ast/switch_statement.cc",
     "src/ast/switch_statement.h",
+    "src/ast/transform/vertex_pulling_transform.cc",
+    "src/ast/transform/vertex_pulling_transform.h",
     "src/ast/type/alias_type.cc",
     "src/ast/type/alias_type.h",
     "src/ast/type/array_type.cc",
@@ -714,6 +716,7 @@
     "src/ast/struct_member_test.cc",
     "src/ast/struct_test.cc",
     "src/ast/switch_statement_test.cc",
+    "src/ast/transform/vertex_pulling_transform_test.cc",
     "src/ast/type/alias_type_test.cc",
     "src/ast/type/array_type_test.cc",
     "src/ast/type/bool_type_test.cc",
diff --git a/include/tint/tint.h b/include/tint/tint.h
index 8873612..d7e405a 100644
--- a/include/tint/tint.h
+++ b/include/tint/tint.h
@@ -25,6 +25,8 @@
 #include "src/validator.h"
 #include "src/writer/writer.h"
 
+#include "src/ast/transform/vertex_pulling_transform.h"
+
 #if TINT_BUILD_SPV_READER
 #include "src/reader/spirv/parser.h"
 #endif  // TINT_BUILD_SPV_READER
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c3b2e29..9c2c326 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -136,6 +136,8 @@
   ast/struct_member_offset_decoration.h
   ast/switch_statement.cc
   ast/switch_statement.h
+  ast/transform/vertex_pulling_transform.cc
+  ast/transform/vertex_pulling_transform.h
   ast/type_constructor_expression.h
   ast/type_constructor_expression.cc
   ast/type/alias_type.cc
@@ -325,6 +327,7 @@
   ast/struct_member_offset_decoration_test.cc
   ast/struct_test.cc
   ast/switch_statement_test.cc
+  ast/transform/vertex_pulling_transform_test.cc
   ast/type/alias_type_test.cc
   ast/type/array_type_test.cc
   ast/type/bool_type_test.cc
diff --git a/src/ast/transform/vertex_pulling_transform.cc b/src/ast/transform/vertex_pulling_transform.cc
new file mode 100644
index 0000000..fc571d5
--- /dev/null
+++ b/src/ast/transform/vertex_pulling_transform.cc
@@ -0,0 +1,403 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/transform/vertex_pulling_transform.h"
+
+#include "src/ast/array_accessor_expression.h"
+#include "src/ast/as_expression.h"
+#include "src/ast/assignment_statement.h"
+#include "src/ast/binary_expression.h"
+#include "src/ast/decorated_variable.h"
+#include "src/ast/expression.h"
+#include "src/ast/member_accessor_expression.h"
+#include "src/ast/module.h"
+#include "src/ast/scalar_constructor_expression.h"
+#include "src/ast/struct.h"
+#include "src/ast/struct_decoration.h"
+#include "src/ast/struct_member.h"
+#include "src/ast/struct_member_offset_decoration.h"
+#include "src/ast/type/array_type.h"
+#include "src/ast/type/f32_type.h"
+#include "src/ast/type/i32_type.h"
+#include "src/ast/type/struct_type.h"
+#include "src/ast/type/u32_type.h"
+#include "src/ast/type/vector_type.h"
+#include "src/ast/type_constructor_expression.h"
+#include "src/ast/uint_literal.h"
+#include "src/ast/variable_decl_statement.h"
+#include "src/context.h"
+
+namespace tint {
+namespace ast {
+namespace transform {
+
+namespace {
+// TODO(idanr): What to do if these names are already used?
+static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
+static const char kStructBufferName[] = "data";
+static const char kPullingPosVarName[] = "tint_pulling_pos";
+static const char kDefaultVertexIndexName[] = "tint_pulling_vertex_index";
+}  // namespace
+
+VertexPullingTransform::VertexPullingTransform(Context* ctx, Module* mod)
+    : ctx_(ctx), mod_(mod) {}
+
+VertexPullingTransform::~VertexPullingTransform() = default;
+
+void VertexPullingTransform::SetVertexState(
+    std::unique_ptr<VertexStateDescriptor> vertex_state) {
+  vertex_state_ = std::move(vertex_state);
+}
+
+void VertexPullingTransform::SetEntryPoint(std::string entry_point) {
+  entry_point_name_ = std::move(entry_point);
+}
+
+bool VertexPullingTransform::Run() {
+  // Check SetVertexState was called
+  if (vertex_state_ == nullptr) {
+    SetError("SetVertexState not called");
+    return false;
+  }
+
+  // Find entry point
+  EntryPoint* entry_point = nullptr;
+  for (const auto& entry : mod_->entry_points()) {
+    if (entry->name() == entry_point_name_) {
+      entry_point = entry.get();
+      break;
+    }
+  }
+
+  if (entry_point == nullptr) {
+    SetError("Vertex stage entry point not found");
+    return false;
+  }
+
+  // Check entry point is the right stage
+  if (entry_point->stage() != PipelineStage::kVertex) {
+    SetError("Entry point is not for vertex stage");
+    return false;
+  }
+
+  // Save the vertex function
+  auto* vertex_func = mod_->FindFunctionByName(entry_point->function_name());
+
+  // TODO(idanr): Need to check shader locations in descriptor cover all
+  // attributes
+
+  // TODO(idanr): Make sure we covered all error cases, to guarantee the
+  // following stages will pass
+
+  FindOrInsertVertexIndex();
+  ConvertVertexInputVariablesToPrivate();
+  AddVertexStorageBuffers();
+  AddVertexPullingPreamble(vertex_func);
+
+  return true;
+}
+
+void VertexPullingTransform::SetError(const std::string& error) {
+  error_ = error;
+}
+
+std::string VertexPullingTransform::GetVertexBufferName(uint32_t index) {
+  return kVertexBufferNamePrefix + std::to_string(index);
+}
+
+void VertexPullingTransform::FindOrInsertVertexIndex() {
+  // Look for an existing vertex index builtin
+  for (auto& v : mod_->global_variables()) {
+    if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) {
+      continue;
+    }
+
+    for (auto& d : v->AsDecorated()->decorations()) {
+      if (d->IsBuiltin() && d->AsBuiltin()->value() == Builtin::kVertexIdx) {
+        vertex_index_name_ = v->name();
+        return;
+      }
+    }
+  }
+
+  // We didn't find a vertex index builtin, so create one
+  vertex_index_name_ = kDefaultVertexIndexName;
+
+  auto var = std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
+      vertex_index_name_, StorageClass::kInput, GetI32Type()));
+
+  VariableDecorationList decorations;
+  decorations.push_back(
+      std::make_unique<BuiltinDecoration>(Builtin::kVertexIdx));
+
+  var->set_decorations(std::move(decorations));
+  mod_->AddGlobalVariable(std::move(var));
+}
+
+void VertexPullingTransform::ConvertVertexInputVariablesToPrivate() {
+  for (auto& v : mod_->global_variables()) {
+    if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) {
+      continue;
+    }
+
+    for (auto& d : v->AsDecorated()->decorations()) {
+      if (!d->IsLocation()) {
+        continue;
+      }
+
+      uint32_t location = d->AsLocation()->value();
+      // This is where the replacement happens. Expressions use identifier
+      // strings instead of pointers, so we don't need to update any other place
+      // in the AST.
+      v = std::make_unique<Variable>(v->name(), StorageClass::kPrivate,
+                                     v->type());
+      location_to_var_[location] = v.get();
+      break;
+    }
+  }
+}
+
+void VertexPullingTransform::AddVertexStorageBuffers() {
+  // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935
+  // The array inside the struct definition
+  auto internal_array = std::make_unique<type::ArrayType>(GetU32Type());
+  internal_array->set_array_stride(4u);
+  auto* internal_array_type = ctx_->type_mgr().Get(std::move(internal_array));
+
+  // Creating the struct type
+  StructMemberList members;
+  StructMemberDecorationList member_dec;
+  member_dec.push_back(std::make_unique<StructMemberOffsetDecoration>(0u));
+  members.push_back(std::make_unique<StructMember>(
+      kStructBufferName, internal_array_type, std::move(member_dec)));
+  auto* struct_type = ctx_->type_mgr().Get(std::make_unique<type::StructType>(
+      std::make_unique<Struct>(StructDecoration::kBlock, std::move(members))));
+
+  for (uint32_t i = 0; i < vertex_state_->vertex_buffers.size(); ++i) {
+    // The decorated variable with struct type
+    auto var = std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
+        GetVertexBufferName(i), StorageClass::kStorageBuffer, struct_type));
+
+    // Add decorations
+    VariableDecorationList decorations;
+    decorations.push_back(std::make_unique<BindingDecoration>(i));
+    decorations.push_back(std::make_unique<SetDecoration>(0));
+    var->set_decorations(std::move(decorations));
+
+    mod_->AddGlobalVariable(std::move(var));
+  }
+}
+
+void VertexPullingTransform::AddVertexPullingPreamble(Function* vertex_func) {
+  // Assign by looking at the vertex descriptor to find attributes with matching
+  // location.
+
+  // A block statement allowing us to use append instead of insert
+  auto block = std::make_unique<BlockStatement>();
+
+  // Declare the |kPullingPosVarName| variable in the shader
+  auto pos_declaration =
+      std::make_unique<VariableDeclStatement>(std::make_unique<Variable>(
+          kPullingPosVarName, StorageClass::kFunction, GetI32Type()));
+
+  // |kPullingPosVarName| refers to the byte location of the current read. We
+  // declare a variable in the shader to avoid having to reuse Expression
+  // objects.
+  block->append(std::move(pos_declaration));
+
+  for (uint32_t i = 0; i < vertex_state_->vertex_buffers.size(); ++i) {
+    const VertexBufferLayoutDescriptor& buffer_layout =
+        vertex_state_->vertex_buffers[i];
+
+    for (const VertexAttributeDescriptor& attribute_desc :
+         buffer_layout.attributes) {
+      auto it = location_to_var_.find(attribute_desc.shader_location);
+      if (it == location_to_var_.end()) {
+        continue;
+      }
+      auto* v = it->second;
+
+      // An expression for the start of the read in the buffer in bytes
+      auto pos_value = std::make_unique<BinaryExpression>(
+          BinaryOp::kAdd,
+          std::make_unique<BinaryExpression>(
+              BinaryOp::kMultiply,
+              std::make_unique<IdentifierExpression>(vertex_index_name_),
+              GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
+          GenUint(static_cast<uint32_t>(attribute_desc.offset)));
+
+      // Update position of the read
+      auto set_pos_expr = std::make_unique<AssignmentStatement>(
+          CreatePullingPositionIdent(), std::move(pos_value));
+      block->append(std::move(set_pos_expr));
+
+      block->append(std::make_unique<AssignmentStatement>(
+          std::make_unique<IdentifierExpression>(v->name()),
+          AccessByFormat(i, attribute_desc.format)));
+    }
+  }
+
+  vertex_func->body()->insert(0, std::move(block));
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::GenUint(uint32_t value) {
+  return std::make_unique<ScalarConstructorExpression>(
+      std::make_unique<UintLiteral>(GetU32Type(), value));
+}
+
+std::unique_ptr<Expression>
+VertexPullingTransform::CreatePullingPositionIdent() {
+  return std::make_unique<IdentifierExpression>(kPullingPosVarName);
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::AccessByFormat(
+    uint32_t buffer,
+    VertexFormat format) {
+  // TODO(idanr): this doesn't account for the format of the attribute in the
+  // shader. ex: vec<u32> in shader, and attribute claims VertexFormat::Float4
+  // right now, we would try to assign a vec4<f32> to this attribute, but we
+  // really need to assign a vec4<u32> by casting.
+  // We could split this function to first do memory accesses and unpacking into
+  // int/uint/float1-4/etc, then convert that variable to a var<in> with the
+  // conversion defined in the WebGPU spec.
+  switch (format) {
+    case VertexFormat::kU32:
+      return AccessU32(buffer, CreatePullingPositionIdent());
+    case VertexFormat::kI32:
+      return AccessI32(buffer, CreatePullingPositionIdent());
+    case VertexFormat::kF32:
+      return AccessF32(buffer, CreatePullingPositionIdent());
+    case VertexFormat::kVec2F32:
+      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 2);
+    case VertexFormat::kVec3F32:
+      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 3);
+    case VertexFormat::kVec4F32:
+      return AccessVec(buffer, 4, GetF32Type(), VertexFormat::kF32, 4);
+    default:
+      return nullptr;
+  }
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::AccessU32(
+    uint32_t buffer,
+    std::unique_ptr<Expression> pos) {
+  // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer
+  // has byte offsets for each attribute, and we will convert it to u32 indexes
+  // by dividing. Then, that element is going to be read, and if needed,
+  // unpacked into an appropriate variable. All reads should end up here as a
+  // base case.
+  return std::make_unique<ArrayAccessorExpression>(
+      std::make_unique<MemberAccessorExpression>(
+          std::make_unique<IdentifierExpression>(GetVertexBufferName(buffer)),
+          std::make_unique<IdentifierExpression>(kStructBufferName)),
+      std::make_unique<BinaryExpression>(BinaryOp::kDivide, std::move(pos),
+                                         GenUint(4)));
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::AccessI32(
+    uint32_t buffer,
+    std::unique_ptr<Expression> pos) {
+  // as<T> reinterprets bits
+  return std::make_unique<AsExpression>(GetI32Type(),
+                                        AccessU32(buffer, std::move(pos)));
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::AccessF32(
+    uint32_t buffer,
+    std::unique_ptr<Expression> pos) {
+  // as<T> reinterprets bits
+  return std::make_unique<AsExpression>(GetF32Type(),
+                                        AccessU32(buffer, std::move(pos)));
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::AccessPrimitive(
+    uint32_t buffer,
+    std::unique_ptr<Expression> pos,
+    VertexFormat format) {
+  // This function uses a position expression to read, rather than using the
+  // position variable. This allows us to read from offset positions relative to
+  // |kPullingPosVarName|. We can't call AccessByFormat because it reads only
+  // from the position variable.
+  switch (format) {
+    case VertexFormat::kU32:
+      return AccessU32(buffer, std::move(pos));
+    case VertexFormat::kI32:
+      return AccessI32(buffer, std::move(pos));
+    case VertexFormat::kF32:
+      return AccessF32(buffer, std::move(pos));
+    default:
+      return nullptr;
+  }
+}
+
+std::unique_ptr<Expression> VertexPullingTransform::AccessVec(
+    uint32_t buffer,
+    uint32_t element_stride,
+    type::Type* base_type,
+    VertexFormat base_format,
+    uint32_t count) {
+  ExpressionList expr_list;
+  for (uint32_t i = 0; i < count; ++i) {
+    // Offset read position by element_stride for each component
+    auto cur_pos = std::make_unique<BinaryExpression>(
+        BinaryOp::kAdd, CreatePullingPositionIdent(),
+        GenUint(element_stride * i));
+    expr_list.push_back(
+        AccessPrimitive(buffer, std::move(cur_pos), base_format));
+  }
+
+  return std::make_unique<TypeConstructorExpression>(
+      ctx_->type_mgr().Get(
+          std::make_unique<type::VectorType>(base_type, count)),
+      std::move(expr_list));
+}
+
+type::Type* VertexPullingTransform::GetU32Type() {
+  return ctx_->type_mgr().Get(std::make_unique<type::U32Type>());
+}
+
+type::Type* VertexPullingTransform::GetI32Type() {
+  return ctx_->type_mgr().Get(std::make_unique<type::I32Type>());
+}
+
+type::Type* VertexPullingTransform::GetF32Type() {
+  return ctx_->type_mgr().Get(std::make_unique<type::F32Type>());
+}
+
+VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
+VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
+    uint64_t in_array_stride,
+    InputStepMode in_step_mode,
+    std::vector<VertexAttributeDescriptor> in_attributes)
+    : array_stride(std::move(in_array_stride)),
+      step_mode(std::move(in_step_mode)),
+      attributes(std::move(in_attributes)) {}
+VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
+    const VertexBufferLayoutDescriptor& other)
+    : array_stride(other.array_stride),
+      step_mode(other.step_mode),
+      attributes(other.attributes) {}
+VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
+
+VertexStateDescriptor::VertexStateDescriptor() = default;
+VertexStateDescriptor::VertexStateDescriptor(
+    std::vector<VertexBufferLayoutDescriptor> in_vertex_buffers)
+    : vertex_buffers(std::move(in_vertex_buffers)) {}
+VertexStateDescriptor::VertexStateDescriptor(const VertexStateDescriptor& other)
+    : vertex_buffers(other.vertex_buffers) {}
+VertexStateDescriptor::~VertexStateDescriptor() = default;
+
+}  // namespace transform
+}  // namespace ast
+}  // namespace tint
diff --git a/src/ast/transform/vertex_pulling_transform.h b/src/ast/transform/vertex_pulling_transform.h
new file mode 100644
index 0000000..636ddc4
--- /dev/null
+++ b/src/ast/transform/vertex_pulling_transform.h
@@ -0,0 +1,249 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_AST_TRANSFORM_VERTEX_PULLING_TRANSFORM_H_
+#define SRC_AST_TRANSFORM_VERTEX_PULLING_TRANSFORM_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tint {
+
+class Context;
+
+namespace ast {
+
+class EntryPoint;
+class Expression;
+class Function;
+class Module;
+class Statement;
+class Variable;
+
+namespace type {
+class Type;
+}  // namespace type
+
+namespace transform {
+
+/// Describes the format of data in a vertex buffer
+enum class VertexFormat {
+  kVec2U8,
+  kVec4U8,
+  kVec2I8,
+  kVec4I8,
+  kVec2U8Norm,
+  kVec4U8Norm,
+  kVec2I8Norm,
+  kVec4I8Norm,
+  kVec2U16,
+  kVec4U16,
+  kVec2I16,
+  kVec4I16,
+  kVec2U16Norm,
+  kVec4U16Norm,
+  kVec2I16Norm,
+  kVec4I16Norm,
+  kVec2F16,
+  kVec4F16,
+  kF32,
+  kVec2F32,
+  kVec3F32,
+  kVec4F32,
+  kU32,
+  kVec2U32,
+  kVec3U32,
+  kVec4U32,
+  kI32,
+  kVec2I32,
+  kVec3I32,
+  kVec4I32
+};
+
+/// Describes if a vertex attribtes increments with vertex index or instance
+/// index
+enum class InputStepMode { kVertex, kInstance };
+
+/// Describes a vertex attribute within a buffer
+struct VertexAttributeDescriptor {
+  VertexFormat format;
+  uint64_t offset;
+  uint32_t shader_location;
+};
+
+/// Describes a buffer containing multiple vertex attributes
+struct VertexBufferLayoutDescriptor {
+  VertexBufferLayoutDescriptor();
+  ~VertexBufferLayoutDescriptor();
+  VertexBufferLayoutDescriptor(
+      uint64_t in_array_stride,
+      InputStepMode in_step_mode,
+      std::vector<VertexAttributeDescriptor> in_attributes);
+  VertexBufferLayoutDescriptor(const VertexBufferLayoutDescriptor& other);
+
+  uint64_t array_stride = 0u;
+  InputStepMode step_mode = InputStepMode::kVertex;
+  std::vector<VertexAttributeDescriptor> attributes;
+};
+
+/// Describes vertex state, which consists of many buffers containing vertex
+/// attributes
+struct VertexStateDescriptor {
+  VertexStateDescriptor();
+  ~VertexStateDescriptor();
+  VertexStateDescriptor(
+      std::vector<VertexBufferLayoutDescriptor> in_vertex_buffers);
+  VertexStateDescriptor(const VertexStateDescriptor& other);
+
+  std::vector<VertexBufferLayoutDescriptor> vertex_buffers;
+};
+
+/// Converts a module to use vertex pulling
+///
+/// Variables which accept vertex input are var<in> with a location decoration.
+/// This transform will convert those to be assigned from storage buffers
+/// instead. The intention is to allow vertex input to rely on a storage buffer
+/// clamping pass for out of bounds reads. We bind the storage buffers as arrays
+/// of u32, so any read to byte position |p| will actually need to read position
+/// |p / 4|, since sizeof(u32) == 4.
+///
+/// |VertexFormat| represents the input type of the attribute. This isn't
+/// related to the type of the variable in the shader. For example,
+/// |VertexFormat::vec2_f16| tells us that the buffer will contain f16 elements,
+/// to be read as vec2. In the shader, a user would make a vec2<f32> to be able
+/// to use them. The conversion between f16 and f32 will need to be handled by
+/// us (using unpack functions).
+///
+/// To be clear, there won't be types such as f16 or u8 anywhere in WGSL code,
+/// but these are types that the data may arrive as. We need to convert these
+/// smaller types into the base types such as f32 and u32 for the shader to use.
+class VertexPullingTransform {
+ public:
+  /// Constructor
+  /// @param ctx the tint context
+  /// @param mod the module to convert to vertex pulling
+  VertexPullingTransform(Context* ctx, Module* mod);
+  ~VertexPullingTransform();
+
+  /// Sets the vertex state descriptor, containing info about attributes
+  /// @param vertex_state the vertex state descriptor
+  void SetVertexState(std::unique_ptr<VertexStateDescriptor> vertex_state);
+
+  /// Sets the entry point to add assignments into
+  /// @param entry_point the vertex stage entry point
+  void SetEntryPoint(std::string entry_point);
+
+  /// @returns true if the transformation was successful
+  bool Run();
+
+  /// @returns error messages
+  const std::string& GetError() { return error_; }
+
+ private:
+  void SetError(const std::string& error);
+
+  /// Generate the vertex buffer binding name
+  /// @param index index to append to buffer name
+  std::string GetVertexBufferName(uint32_t index);
+
+  /// Inserts vertex_idx binding, or finds the existing one
+  void FindOrInsertVertexIndex();
+
+  /// Converts var<in> with a location decoration to var<private>
+  void ConvertVertexInputVariablesToPrivate();
+
+  /// Adds storage buffer decorated variables for the vertex buffers
+  void AddVertexStorageBuffers();
+
+  /// Adds assignment to the variables from the buffers
+  void AddVertexPullingPreamble(Function* vertex_func);
+
+  /// Generates an expression holding a constant uint
+  /// @param value uint value
+  std::unique_ptr<Expression> GenUint(uint32_t value);
+
+  /// Generates an expression to read the shader value |kPullingPosVarName|
+  std::unique_ptr<Expression> CreatePullingPositionIdent();
+
+  /// Generates an expression reading from a buffer a specific format.
+  /// This reads the value wherever |kPullingPosVarName| points to at the time
+  /// of the read.
+  /// @param buffer the index of the vertex buffer
+  /// @param format the format to read
+  std::unique_ptr<Expression> AccessByFormat(uint32_t buffer,
+                                             VertexFormat format);
+
+  /// Generates an expression reading a uint32 from a vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  std::unique_ptr<Expression> AccessU32(uint32_t buffer,
+                                        std::unique_ptr<Expression> pos);
+
+  /// Generates an expression reading an int32 from a vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  std::unique_ptr<Expression> AccessI32(uint32_t buffer,
+                                        std::unique_ptr<Expression> pos);
+
+  /// Generates an expression reading a float from a vertex buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  std::unique_ptr<Expression> AccessF32(uint32_t buffer,
+                                        std::unique_ptr<Expression> pos);
+
+  /// Generates an expression reading a basic type (u32, i32, f32) from a vertex
+  /// buffer
+  /// @param buffer the index of the vertex buffer
+  /// @param pos an expression for the position of the access, in bytes
+  /// @param format the underlying vertex format
+  std::unique_ptr<Expression> AccessPrimitive(uint32_t buffer,
+                                              std::unique_ptr<Expression> pos,
+                                              VertexFormat format);
+
+  /// Generates an expression reading a vec2/3/4 from a vertex buffer.
+  /// This reads the value wherever |kPullingPosVarName| points to at the time
+  /// of the read.
+  /// @param buffer the index of the vertex buffer
+  /// @param element_stride stride between elements, in bytes
+  /// @param base_type underlying AST type
+  /// @param base_format underlying vertex format
+  /// @param count how many elements the vector has
+  std::unique_ptr<Expression> AccessVec(uint32_t buffer,
+                                        uint32_t element_stride,
+                                        type::Type* base_type,
+                                        VertexFormat base_format,
+                                        uint32_t count);
+
+  // Used to grab corresponding types from the type manager
+  type::Type* GetU32Type();
+  type::Type* GetI32Type();
+  type::Type* GetF32Type();
+
+  Context* ctx_ = nullptr;
+  Module* mod_ = nullptr;
+  std::string entry_point_name_;
+  std::string error_;
+
+  std::string vertex_index_name_;
+
+  std::unordered_map<uint32_t, Variable*> location_to_var_;
+  std::unique_ptr<VertexStateDescriptor> vertex_state_;
+};
+
+}  // namespace transform
+}  // namespace ast
+}  // namespace tint
+
+#endif  // SRC_AST_TRANSFORM_VERTEX_PULLING_TRANSFORM_H_
diff --git a/src/ast/transform/vertex_pulling_transform_test.cc b/src/ast/transform/vertex_pulling_transform_test.cc
new file mode 100644
index 0000000..c0cc841
--- /dev/null
+++ b/src/ast/transform/vertex_pulling_transform_test.cc
@@ -0,0 +1,775 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/transform/vertex_pulling_transform.h"
+
+#include "gtest/gtest.h"
+#include "src/ast/decorated_variable.h"
+#include "src/ast/function.h"
+#include "src/ast/type/array_type.h"
+#include "src/ast/type/f32_type.h"
+#include "src/ast/type/i32_type.h"
+#include "src/ast/type/void_type.h"
+#include "src/type_determiner.h"
+#include "src/validator.h"
+
+namespace tint {
+namespace ast {
+namespace transform {
+namespace {
+
+class VertexPullingTransformHelper {
+ public:
+  VertexPullingTransformHelper() {
+    mod_ = std::make_unique<ast::Module>();
+    transform_ = std::make_unique<VertexPullingTransform>(&ctx_, mod_.get());
+  }
+
+  // Create basic module with an entry point and vertex function
+  void InitBasicModule() {
+    mod()->AddEntryPoint(std::make_unique<EntryPoint>(PipelineStage::kVertex,
+                                                      "main", "vtx_main"));
+    mod()->AddFunction(std::make_unique<Function>(
+        "vtx_main", VariableList{},
+        ctx_.type_mgr().Get(std::make_unique<type::VoidType>())));
+  }
+
+  // Set up the transformation, after building the module
+  void InitTransform(VertexStateDescriptor vertex_state) {
+    EXPECT_TRUE(mod_->IsValid());
+
+    tint::TypeDeterminer td(&ctx_, mod_.get());
+    EXPECT_TRUE(td.Determine());
+
+    tint::Validator v;
+    EXPECT_TRUE(v.Validate(mod_.get()));
+
+    transform_->SetVertexState(
+        std::make_unique<VertexStateDescriptor>(std::move(vertex_state)));
+    transform_->SetEntryPoint("main");
+  }
+
+  // Inserts a variable which will be converted to vertex pulling
+  void AddVertexInputVariable(uint32_t location,
+                              std::string name,
+                              type::Type* type) {
+    auto var = std::make_unique<DecoratedVariable>(
+        std::make_unique<Variable>(name, StorageClass::kInput, type));
+
+    VariableDecorationList decorations;
+    decorations.push_back(std::make_unique<LocationDecoration>(location));
+
+    var->set_decorations(std::move(decorations));
+    mod_->AddGlobalVariable(std::move(var));
+  }
+
+  ast::Module* mod() { return mod_.get(); }
+  VertexPullingTransform* transform() { return transform_.get(); }
+
+ private:
+  Context ctx_;
+  std::unique_ptr<ast::Module> mod_;
+  std::unique_ptr<VertexPullingTransform> transform_;
+};
+
+class VertexPullingTransformTest : public VertexPullingTransformHelper,
+                                   public testing::Test {};
+
+TEST_F(VertexPullingTransformTest, Error_NoVertexState) {
+  EXPECT_FALSE(transform()->Run());
+  EXPECT_EQ(transform()->GetError(), "SetVertexState not called");
+}
+
+TEST_F(VertexPullingTransformTest, Error_NoEntryPoint) {
+  transform()->SetVertexState(std::make_unique<VertexStateDescriptor>());
+  EXPECT_FALSE(transform()->Run());
+  EXPECT_EQ(transform()->GetError(), "Vertex stage entry point not found");
+}
+
+TEST_F(VertexPullingTransformTest, Error_InvalidEntryPoint) {
+  InitBasicModule();
+  InitTransform({});
+  transform()->SetEntryPoint("_");
+
+  EXPECT_FALSE(transform()->Run());
+  EXPECT_EQ(transform()->GetError(), "Vertex stage entry point not found");
+}
+
+TEST_F(VertexPullingTransformTest, Error_EntryPointWrongStage) {
+  InitBasicModule();
+  mod()->entry_points()[0]->set_pipeline_stage(PipelineStage::kFragment);
+
+  InitTransform({});
+  EXPECT_FALSE(transform()->Run());
+  EXPECT_EQ(transform()->GetError(), "Entry point is not for vertex stage");
+}
+
+TEST_F(VertexPullingTransformTest, BasicModule) {
+  InitBasicModule();
+  InitTransform({});
+  EXPECT_TRUE(transform()->Run());
+}
+
+TEST_F(VertexPullingTransformTest, OneAttribute) {
+  InitBasicModule();
+
+  type::F32Type f32;
+  AddVertexInputVariable(0, "var_a", &f32);
+
+  InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
+
+  EXPECT_TRUE(transform()->Run());
+
+  EXPECT_EQ(R"(Module{
+  Variable{
+    var_a
+    private
+    __f32
+  }
+  DecoratedVariable{
+    Decorations{
+      BuiltinDecoration{vertex_idx}
+    }
+    tint_pulling_vertex_index
+    in
+    __i32
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{0}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_0
+    storage_buffer
+    __struct_
+  }
+  EntryPoint{vertex as main = vtx_main}
+  Function vtx_main -> __void
+  ()
+  {
+    Block{
+      VariableDeclStatement{
+        Variable{
+          tint_pulling_pos
+          function
+          __i32
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_vertex_index}
+            multiply
+            ScalarConstructor{4}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_a}
+        As<__f32>{
+          ArrayAccessor{
+            MemberAccessor{
+              Identifier{tint_pulling_vertex_buffer_0}
+              Identifier{data}
+            }
+            Binary{
+              Identifier{tint_pulling_pos}
+              divide
+              ScalarConstructor{4}
+            }
+          }
+        }
+      }
+    }
+  }
+}
+)",
+            mod()->to_str());
+}
+
+// We expect the transform to use an existing vertex_idx builtin variable if it
+// finds one
+TEST_F(VertexPullingTransformTest, ExistingVertexIndex) {
+  InitBasicModule();
+
+  type::F32Type f32;
+  AddVertexInputVariable(0, "var_a", &f32);
+
+  type::I32Type i32;
+  auto vertex_index_var =
+      std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
+          "custom_vertex_index", StorageClass::kInput, &i32));
+
+  VariableDecorationList decorations;
+  decorations.push_back(
+      std::make_unique<BuiltinDecoration>(Builtin::kVertexIdx));
+
+  vertex_index_var->set_decorations(std::move(decorations));
+  mod()->AddGlobalVariable(std::move(vertex_index_var));
+
+  InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
+
+  EXPECT_TRUE(transform()->Run());
+
+  EXPECT_EQ(R"(Module{
+  Variable{
+    var_a
+    private
+    __f32
+  }
+  DecoratedVariable{
+    Decorations{
+      BuiltinDecoration{vertex_idx}
+    }
+    custom_vertex_index
+    in
+    __i32
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{0}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_0
+    storage_buffer
+    __struct_
+  }
+  EntryPoint{vertex as main = vtx_main}
+  Function vtx_main -> __void
+  ()
+  {
+    Block{
+      VariableDeclStatement{
+        Variable{
+          tint_pulling_pos
+          function
+          __i32
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{custom_vertex_index}
+            multiply
+            ScalarConstructor{4}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_a}
+        As<__f32>{
+          ArrayAccessor{
+            MemberAccessor{
+              Identifier{tint_pulling_vertex_buffer_0}
+              Identifier{data}
+            }
+            Binary{
+              Identifier{tint_pulling_pos}
+              divide
+              ScalarConstructor{4}
+            }
+          }
+        }
+      }
+    }
+  }
+}
+)",
+            mod()->to_str());
+}
+
+TEST_F(VertexPullingTransformTest, TwoAttributesSameBuffer) {
+  InitBasicModule();
+
+  type::F32Type f32;
+  AddVertexInputVariable(0, "var_a", &f32);
+
+  type::ArrayType vec4_f32{&f32, 4u};
+  AddVertexInputVariable(1, "var_b", &vec4_f32);
+
+  InitTransform(
+      {{{16,
+         InputStepMode::kVertex,
+         {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}});
+
+  EXPECT_TRUE(transform()->Run());
+
+  EXPECT_EQ(R"(Module{
+  Variable{
+    var_a
+    private
+    __f32
+  }
+  Variable{
+    var_b
+    private
+    __array__f32_4
+  }
+  DecoratedVariable{
+    Decorations{
+      BuiltinDecoration{vertex_idx}
+    }
+    tint_pulling_vertex_index
+    in
+    __i32
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{0}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_0
+    storage_buffer
+    __struct_
+  }
+  EntryPoint{vertex as main = vtx_main}
+  Function vtx_main -> __void
+  ()
+  {
+    Block{
+      VariableDeclStatement{
+        Variable{
+          tint_pulling_pos
+          function
+          __i32
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_vertex_index}
+            multiply
+            ScalarConstructor{16}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_a}
+        As<__f32>{
+          ArrayAccessor{
+            MemberAccessor{
+              Identifier{tint_pulling_vertex_buffer_0}
+              Identifier{data}
+            }
+            Binary{
+              Identifier{tint_pulling_pos}
+              divide
+              ScalarConstructor{4}
+            }
+          }
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_vertex_index}
+            multiply
+            ScalarConstructor{16}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_b}
+        TypeConstructor{
+          __vec_4__f32
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_0}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{0}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_0}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{4}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_0}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{8}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_0}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{12}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+)",
+            mod()->to_str());
+}
+
+TEST_F(VertexPullingTransformTest, FloatVectorAttributes) {
+  InitBasicModule();
+
+  type::F32Type f32;
+  type::ArrayType vec2_f32{&f32, 2u};
+  AddVertexInputVariable(0, "var_a", &vec2_f32);
+
+  type::ArrayType vec3_f32{&f32, 3u};
+  AddVertexInputVariable(1, "var_b", &vec3_f32);
+
+  type::ArrayType vec4_f32{&f32, 4u};
+  AddVertexInputVariable(2, "var_c", &vec4_f32);
+
+  InitTransform(
+      {{{8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
+        {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
+        {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}});
+
+  EXPECT_TRUE(transform()->Run());
+
+  EXPECT_EQ(R"(Module{
+  Variable{
+    var_a
+    private
+    __array__f32_2
+  }
+  Variable{
+    var_b
+    private
+    __array__f32_3
+  }
+  Variable{
+    var_c
+    private
+    __array__f32_4
+  }
+  DecoratedVariable{
+    Decorations{
+      BuiltinDecoration{vertex_idx}
+    }
+    tint_pulling_vertex_index
+    in
+    __i32
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{0}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_0
+    storage_buffer
+    __struct_
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{1}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_1
+    storage_buffer
+    __struct_
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{2}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_2
+    storage_buffer
+    __struct_
+  }
+  EntryPoint{vertex as main = vtx_main}
+  Function vtx_main -> __void
+  ()
+  {
+    Block{
+      VariableDeclStatement{
+        Variable{
+          tint_pulling_pos
+          function
+          __i32
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_vertex_index}
+            multiply
+            ScalarConstructor{8}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_a}
+        TypeConstructor{
+          __vec_2__f32
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_0}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{0}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_0}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{4}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_vertex_index}
+            multiply
+            ScalarConstructor{12}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_b}
+        TypeConstructor{
+          __vec_3__f32
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_1}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{0}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_1}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{4}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_1}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{8}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_vertex_index}
+            multiply
+            ScalarConstructor{16}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_c}
+        TypeConstructor{
+          __vec_4__f32
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_2}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{0}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_2}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{4}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_2}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{8}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+          As<__f32>{
+            ArrayAccessor{
+              MemberAccessor{
+                Identifier{tint_pulling_vertex_buffer_2}
+                Identifier{data}
+              }
+              Binary{
+                Binary{
+                  Identifier{tint_pulling_pos}
+                  add
+                  ScalarConstructor{12}
+                }
+                divide
+                ScalarConstructor{4}
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+)",
+            mod()->to_str());
+}
+
+}  // namespace
+}  // namespace transform
+}  // namespace ast
+}  // namespace tint