Add Metal vertex pulling behind a flag

Implements vertex pulling on the Metal backend, hidden behind a flag
until ready for use (we are missing support for more complicated vertex
input types).

Bug: dawn:480
Change-Id: I38028b80673693ebf21309ad5336561fb99f40dc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/26522
Commit-Queue: Idan Raiter <idanr@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index 40ed372..4fdd708 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -295,6 +295,82 @@
                     << " binding " << static_cast<uint32_t>(binding);
             return ostream.str();
         }
+
+#ifdef DAWN_ENABLE_WGSL
+        tint::ast::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) {
+            switch (format) {
+                case wgpu::VertexFormat::UChar2:
+                    return tint::ast::transform::VertexFormat::kVec2U8;
+                case wgpu::VertexFormat::UChar4:
+                    return tint::ast::transform::VertexFormat::kVec4U8;
+                case wgpu::VertexFormat::Char2:
+                    return tint::ast::transform::VertexFormat::kVec2I8;
+                case wgpu::VertexFormat::Char4:
+                    return tint::ast::transform::VertexFormat::kVec4I8;
+                case wgpu::VertexFormat::UChar2Norm:
+                    return tint::ast::transform::VertexFormat::kVec2U8Norm;
+                case wgpu::VertexFormat::UChar4Norm:
+                    return tint::ast::transform::VertexFormat::kVec4U8Norm;
+                case wgpu::VertexFormat::Char2Norm:
+                    return tint::ast::transform::VertexFormat::kVec2I8Norm;
+                case wgpu::VertexFormat::Char4Norm:
+                    return tint::ast::transform::VertexFormat::kVec4I8Norm;
+                case wgpu::VertexFormat::UShort2:
+                    return tint::ast::transform::VertexFormat::kVec2U16;
+                case wgpu::VertexFormat::UShort4:
+                    return tint::ast::transform::VertexFormat::kVec4U16;
+                case wgpu::VertexFormat::Short2:
+                    return tint::ast::transform::VertexFormat::kVec2I16;
+                case wgpu::VertexFormat::Short4:
+                    return tint::ast::transform::VertexFormat::kVec4I16;
+                case wgpu::VertexFormat::UShort2Norm:
+                    return tint::ast::transform::VertexFormat::kVec2U16Norm;
+                case wgpu::VertexFormat::UShort4Norm:
+                    return tint::ast::transform::VertexFormat::kVec4U16Norm;
+                case wgpu::VertexFormat::Short2Norm:
+                    return tint::ast::transform::VertexFormat::kVec2I16Norm;
+                case wgpu::VertexFormat::Short4Norm:
+                    return tint::ast::transform::VertexFormat::kVec4I16Norm;
+                case wgpu::VertexFormat::Half2:
+                    return tint::ast::transform::VertexFormat::kVec2F16;
+                case wgpu::VertexFormat::Half4:
+                    return tint::ast::transform::VertexFormat::kVec4F16;
+                case wgpu::VertexFormat::Float:
+                    return tint::ast::transform::VertexFormat::kF32;
+                case wgpu::VertexFormat::Float2:
+                    return tint::ast::transform::VertexFormat::kVec2F32;
+                case wgpu::VertexFormat::Float3:
+                    return tint::ast::transform::VertexFormat::kVec3F32;
+                case wgpu::VertexFormat::Float4:
+                    return tint::ast::transform::VertexFormat::kVec4F32;
+                case wgpu::VertexFormat::UInt:
+                    return tint::ast::transform::VertexFormat::kU32;
+                case wgpu::VertexFormat::UInt2:
+                    return tint::ast::transform::VertexFormat::kVec2U32;
+                case wgpu::VertexFormat::UInt3:
+                    return tint::ast::transform::VertexFormat::kVec3U32;
+                case wgpu::VertexFormat::UInt4:
+                    return tint::ast::transform::VertexFormat::kVec4U32;
+                case wgpu::VertexFormat::Int:
+                    return tint::ast::transform::VertexFormat::kI32;
+                case wgpu::VertexFormat::Int2:
+                    return tint::ast::transform::VertexFormat::kVec2I32;
+                case wgpu::VertexFormat::Int3:
+                    return tint::ast::transform::VertexFormat::kVec3I32;
+                case wgpu::VertexFormat::Int4:
+                    return tint::ast::transform::VertexFormat::kVec4I32;
+            }
+        }
+
+        tint::ast::transform::InputStepMode ToTintInputStepMode(wgpu::InputStepMode mode) {
+            switch (mode) {
+                case wgpu::InputStepMode::Vertex:
+                    return tint::ast::transform::InputStepMode::kVertex;
+                case wgpu::InputStepMode::Instance:
+                    return tint::ast::transform::InputStepMode::kInstance;
+            }
+        }
+#endif
     }  // anonymous namespace
 
     MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) {
@@ -400,6 +476,75 @@
         std::vector<uint32_t> spirv = generator.result();
         return std::move(spirv);
     }
+
+    ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
+        const char* source,
+        const VertexStateDescriptor& vertexState,
+        const std::string& entryPoint,
+        uint32_t pullingBufferBindingSet) {
+        std::ostringstream errorStream;
+        errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
+
+        tint::Context context;
+        tint::reader::wgsl::Parser parser(&context, source);
+
+        // TODO: This is a duplicate parse with ValidateWGSL, need to store
+        // state between calls to avoid this.
+        if (!parser.Parse()) {
+            errorStream << "Parser: " << parser.error() << std::endl;
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        tint::ast::Module module = parser.module();
+        if (!module.IsValid()) {
+            errorStream << "Invalid module generated..." << std::endl;
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        tint::ast::transform::VertexPullingTransform transform(&context, &module);
+        auto state = std::make_unique<tint::ast::transform::VertexStateDescriptor>();
+        for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
+            auto& vertexBuffer = vertexState.vertexBuffers[i];
+            tint::ast::transform::VertexBufferLayoutDescriptor layout;
+            layout.array_stride = vertexBuffer.arrayStride;
+            layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
+
+            for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
+                auto& attribute = vertexBuffer.attributes[j];
+                tint::ast::transform::VertexAttributeDescriptor attr;
+                attr.format = ToTintVertexFormat(attribute.format);
+                attr.offset = attribute.offset;
+                attr.shader_location = attribute.shaderLocation;
+
+                layout.attributes.push_back(std::move(attr));
+            }
+
+            state->vertex_buffers.push_back(std::move(layout));
+        }
+        transform.SetVertexState(std::move(state));
+        transform.SetEntryPoint(entryPoint);
+        transform.SetPullingBufferBindingSet(pullingBufferBindingSet);
+
+        if (!transform.Run()) {
+            errorStream << "Vertex pulling transform: " << transform.GetError();
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        tint::TypeDeterminer type_determiner(&context, &module);
+        if (!type_determiner.Determine()) {
+            errorStream << "Type Determination: " << type_determiner.error();
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        tint::writer::spirv::Generator generator(std::move(module));
+        if (!generator.Generate()) {
+            errorStream << "Generator: " << generator.error() << std::endl;
+            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+        }
+
+        std::vector<uint32_t> spirv = generator.result();
+        return std::move(spirv);
+    }
 #endif  // DAWN_ENABLE_WGSL
 
     MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
@@ -1094,10 +1239,22 @@
         return mSpirv;
     }
 
+#ifdef DAWN_ENABLE_WGSL
+    ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
+        const VertexStateDescriptor& vertexState,
+        const std::string& entryPoint,
+        uint32_t pullingBufferBindingSet) const {
+        return ConvertWGSLToSPIRVWithPulling(mWgsl.c_str(), vertexState, entryPoint,
+                                             pullingBufferBindingSet);
+    }
+#endif
+
     shaderc_spvc::CompileOptions ShaderModuleBase::GetCompileOptions() const {
         shaderc_spvc::CompileOptions options;
         options.SetValidate(GetDevice()->IsValidationEnabled());
         options.SetRobustBufferAccessPass(GetDevice()->IsRobustnessEnabled());
+        options.SetSourceEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1);
+        options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_1);
         return options;
     }
 
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 4e771aa..336551e 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -91,6 +91,13 @@
         shaderc_spvc::Context* GetContext();
         const std::vector<uint32_t>& GetSpirv() const;
 
+#ifdef DAWN_ENABLE_WGSL
+        ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
+            const VertexStateDescriptor& vertexState,
+            const std::string& entryPoint,
+            uint32_t pullingBufferBindingSet) const;
+#endif
+
       protected:
         static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
         shaderc_spvc::CompileOptions GetCompileOptions() const;
diff --git a/src/dawn_native/Toggles.cpp b/src/dawn_native/Toggles.cpp
index 7f40ebd..d54dc09 100644
--- a/src/dawn_native/Toggles.cpp
+++ b/src/dawn_native/Toggles.cpp
@@ -138,7 +138,11 @@
                "Clear buffers on their first use. This is a temporary toggle only for the "
                "development of buffer lazy initialization and will be removed after buffer lazy "
                "initialization is completely implemented.",
-               "https://crbug.com/dawn/414"}}}};
+               "https://crbug.com/dawn/414"}},
+             {Toggle::MetalEnableVertexPulling,
+              {"metal_enable_vertex_pulling",
+               "Uses vertex pulling to protect out-of-bounds reads on Metal",
+               "https://crbug.com/dawn/480"}}}};
 
     }  // anonymous namespace
 
diff --git a/src/dawn_native/Toggles.h b/src/dawn_native/Toggles.h
index f9d66a2..5186cb5 100644
--- a/src/dawn_native/Toggles.h
+++ b/src/dawn_native/Toggles.h
@@ -44,6 +44,7 @@
         UseDXC,
         DisableRobustness,
         LazyClearBufferOnFirstUse,
+        MetalEnableVertexPulling,
 
         EnumCount,
         InvalidEnum = EnumCount,
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 7390b92..351090aa 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -263,7 +263,9 @@
             // MSL code generated by SPIRV-Cross expects.
             PerStage<std::array<uint32_t, kGenericMetalBufferSlots>> data;
 
-            void Apply(id<MTLRenderCommandEncoder> render, RenderPipeline* pipeline) {
+            void Apply(id<MTLRenderCommandEncoder> render,
+                       RenderPipeline* pipeline,
+                       bool enableVertexPulling) {
                 wgpu::ShaderStage stagesToApply =
                     dirtyStages & pipeline->GetStagesRequiringStorageBufferLength();
 
@@ -274,6 +276,11 @@
                 if (stagesToApply & wgpu::ShaderStage::Vertex) {
                     uint32_t bufferCount = ToBackend(pipeline->GetLayout())
                                                ->GetBufferBindingCount(SingleShaderStage::Vertex);
+
+                    if (enableVertexPulling) {
+                        bufferCount += pipeline->GetVertexStateDescriptor()->vertexBufferCount;
+                    }
+
                     [render setVertexBytes:data[SingleShaderStage::Vertex].data()
                                     length:sizeof(uint32_t) * bufferCount
                                    atIndex:kBufferLengthBufferSlot];
@@ -483,10 +490,17 @@
         // all the relevant state.
         class VertexBufferTracker {
           public:
+            explicit VertexBufferTracker(StorageBufferLengthTracker* lengthTracker)
+                : mLengthTracker(lengthTracker) {
+            }
+
             void OnSetVertexBuffer(uint32_t slot, Buffer* buffer, uint64_t offset) {
                 mVertexBuffers[slot] = buffer->GetMTLBuffer();
                 mVertexBufferOffsets[slot] = offset;
 
+                ASSERT(buffer->GetSize() < std::numeric_limits<uint32_t>::max());
+                mVertexBufferBindingSizes[slot] = static_cast<uint32_t>(buffer->GetSize() - offset);
+
                 // Use 64 bit masks and make sure there are no shift UB
                 static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, "");
                 mDirtyVertexBuffers |= 1ull << slot;
@@ -499,13 +513,22 @@
                 mDirtyVertexBuffers |= pipeline->GetVertexBufferSlotsUsed();
             }
 
-            void Apply(id<MTLRenderCommandEncoder> encoder, RenderPipeline* pipeline) {
+            void Apply(id<MTLRenderCommandEncoder> encoder,
+                       RenderPipeline* pipeline,
+                       bool enableVertexPulling) {
                 std::bitset<kMaxVertexBuffers> vertexBuffersToApply =
                     mDirtyVertexBuffers & pipeline->GetVertexBufferSlotsUsed();
 
                 for (uint32_t dawnIndex : IterateBitSet(vertexBuffersToApply)) {
                     uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(dawnIndex);
 
+                    if (enableVertexPulling) {
+                        // Insert lengths for vertex buffers bound as storage buffers
+                        mLengthTracker->data[SingleShaderStage::Vertex][metalIndex] =
+                            mVertexBufferBindingSizes[dawnIndex];
+                        mLengthTracker->dirtyStages |= wgpu::ShaderStage::Vertex;
+                    }
+
                     [encoder setVertexBuffers:&mVertexBuffers[dawnIndex]
                                       offsets:&mVertexBufferOffsets[dawnIndex]
                                     withRange:NSMakeRange(metalIndex, 1)];
@@ -519,6 +542,9 @@
             std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers;
             std::array<id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers;
             std::array<NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets;
+            std::array<uint32_t, kMaxVertexBuffers> mVertexBufferBindingSizes;
+
+            StorageBufferLengthTracker* mLengthTracker;
         };
 
     }  // anonymous namespace
@@ -949,11 +975,12 @@
                                                        MTLRenderPassDescriptor* mtlRenderPass,
                                                        uint32_t width,
                                                        uint32_t height) {
+        bool enableVertexPulling = GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling);
         RenderPipeline* lastPipeline = nullptr;
         id<MTLBuffer> indexBuffer = nil;
         uint32_t indexBufferBaseOffset = 0;
-        VertexBufferTracker vertexBuffers;
         StorageBufferLengthTracker storageBufferLengths = {};
+        VertexBufferTracker vertexBuffers(&storageBufferLengths);
         BindGroupTracker bindGroups(&storageBufferLengths);
 
         id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass);
@@ -963,9 +990,9 @@
                 case Command::Draw: {
                     DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
-                    vertexBuffers.Apply(encoder, lastPipeline);
+                    vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
                     bindGroups.Apply(encoder);
-                    storageBufferLengths.Apply(encoder, lastPipeline);
+                    storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
 
                     // The instance count must be non-zero, otherwise no-op
                     if (draw->instanceCount != 0) {
@@ -991,9 +1018,9 @@
                     size_t formatSize =
                         IndexFormatSize(lastPipeline->GetVertexStateDescriptor()->indexFormat);
 
-                    vertexBuffers.Apply(encoder, lastPipeline);
+                    vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
                     bindGroups.Apply(encoder);
-                    storageBufferLengths.Apply(encoder, lastPipeline);
+                    storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
 
                     // The index and instance count must be non-zero, otherwise no-op
                     if (draw->indexCount != 0 && draw->instanceCount != 0) {
@@ -1025,9 +1052,9 @@
                 case Command::DrawIndirect: {
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
 
-                    vertexBuffers.Apply(encoder, lastPipeline);
+                    vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
                     bindGroups.Apply(encoder);
-                    storageBufferLengths.Apply(encoder, lastPipeline);
+                    storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
 
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
@@ -1040,9 +1067,9 @@
                 case Command::DrawIndexedIndirect: {
                     DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
 
-                    vertexBuffers.Apply(encoder, lastPipeline);
+                    vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
                     bindGroups.Apply(encoder);
-                    storageBufferLengths.Apply(encoder, lastPipeline);
+                    storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
 
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index a09f9270..0867639 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -61,6 +61,11 @@
 
     MaybeError Device::Initialize() {
         InitTogglesFromDriver();
+
+        if (!IsRobustnessEnabled() || !IsToggleEnabled(Toggle::UseSpvc)) {
+            ForceSetToggle(Toggle::MetalEnableVertexPulling, false);
+        }
+
         mCommandQueue = [mMtlDevice newCommandQueue];
 
         return DeviceBase::Initialize(new Queue(this));
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 25c7869..1e9efe1 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -329,11 +329,24 @@
 
         MTLRenderPipelineDescriptor* descriptorMTL = [MTLRenderPipelineDescriptor new];
 
+        // TODO: MakeVertexDesc should be const in the future, so we don't need to call it here when
+        // vertex pulling is enabled
+        MTLVertexDescriptor* vertexDesc = MakeVertexDesc();
+        descriptorMTL.vertexDescriptor = vertexDesc;
+        [vertexDesc release];
+
+        if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
+            // Calling MakeVertexDesc first is important since it sets indices for packed bindings
+            MTLVertexDescriptor* emptyVertexDesc = [MTLVertexDescriptor new];
+            descriptorMTL.vertexDescriptor = emptyVertexDesc;
+            [emptyVertexDesc release];
+        }
+
         ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
         const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
         ShaderModule::MetalFunctionData vertexData;
         DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex,
-                                           ToBackend(GetLayout()), &vertexData));
+                                           ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this));
 
         descriptorMTL.vertexFunction = vertexData.function;
         if (vertexData.needsStorageBufferLength) {
@@ -377,11 +390,6 @@
         }
 
         descriptorMTL.inputPrimitiveTopology = MTLInputPrimitiveTopology(GetPrimitiveTopology());
-
-        MTLVertexDescriptor* vertexDesc = MakeVertexDesc();
-        descriptorMTL.vertexDescriptor = vertexDesc;
-        [vertexDesc release];
-
         descriptorMTL.sampleCount = GetSampleCount();
         descriptorMTL.alphaToCoverageEnabled = descriptor->alphaToCoverageEnabled;
 
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 7460727..9193b25 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -29,6 +29,7 @@
 
     class Device;
     class PipelineLayout;
+    class RenderPipeline;
 
     class ShaderModule final : public ShaderModuleBase {
       public:
@@ -47,7 +48,8 @@
                                SingleShaderStage functionStage,
                                const PipelineLayout* layout,
                                MetalFunctionData* out,
-                               uint32_t sampleMask = 0xFFFFFFFF);
+                               uint32_t sampleMask = 0xFFFFFFFF,
+                               const RenderPipeline* renderPipeline = nullptr);
 
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index e49cc95..208612e 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -17,6 +17,7 @@
 #include "dawn_native/BindGroupLayout.h"
 #include "dawn_native/metal/DeviceMTL.h"
 #include "dawn_native/metal/PipelineLayoutMTL.h"
+#include "dawn_native/metal/RenderPipelineMTL.h"
 
 #include <spirv_msl.hpp>
 
@@ -92,10 +93,24 @@
                                          SingleShaderStage functionStage,
                                          const PipelineLayout* layout,
                                          ShaderModule::MetalFunctionData* out,
-                                         uint32_t sampleMask) {
+                                         uint32_t sampleMask,
+                                         const RenderPipeline* renderPipeline) {
         ASSERT(!IsError());
         ASSERT(out);
-        const std::vector<uint32_t>& spirv = GetSpirv();
+        const std::vector<uint32_t>* spirv = &GetSpirv();
+
+#ifdef DAWN_ENABLE_WGSL
+        // Use set 4 since it is bigger than what users can access currently
+        static const uint32_t kPullingBufferBindingSet = 4;
+        std::vector<uint32_t> pullingSpirv;
+        if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
+            functionStage == SingleShaderStage::Vertex) {
+            DAWN_TRY_ASSIGN(pullingSpirv,
+                            GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
+                                                 functionName, kPullingBufferBindingSet));
+            spirv = &pullingSpirv;
+        }
+#endif
 
         std::unique_ptr<spirv_cross::CompilerMSL> compilerImpl;
         spirv_cross::CompilerMSL* compiler;
@@ -103,7 +118,7 @@
             // Initializing the compiler is needed every call, because this method uses reflection
             // to mutate the compiler's IR.
             DAWN_TRY(
-                CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv.data(), spirv.size(),
+                CheckSpvcSuccess(mSpvcContext.InitializeForMsl(spirv->data(), spirv->size(),
                                                                GetMSLCompileOptions(sampleMask)),
                                  "Unable to initialize instance of spvc"));
             DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetCompiler(reinterpret_cast<void**>(&compiler)),
@@ -126,7 +141,7 @@
 
             options_msl.additional_fixed_sample_mask = sampleMask;
 
-            compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(spirv);
+            compilerImpl = std::make_unique<spirv_cross::CompilerMSL>(*spirv);
             compiler = compilerImpl.get();
             compiler->set_msl_options(options_msl);
         }
@@ -172,6 +187,22 @@
             }
         }
 
+        // Add vertex buffers bound as storage buffers
+        if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
+            functionStage == SingleShaderStage::Vertex) {
+            for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
+                uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
+
+                shaderc_spvc_msl_resource_binding mslBinding;
+                mslBinding.stage = ToSpvcExecutionModel(SingleShaderStage::Vertex);
+                mslBinding.desc_set = kPullingBufferBindingSet;
+                mslBinding.binding = dawnIndex;
+                mslBinding.msl_buffer = metalIndex;
+                DAWN_TRY(CheckSpvcSuccess(mSpvcContext.AddMSLResourceBinding(mslBinding),
+                                          "Unable to add MSL Resource Binding"));
+            }
+        }
+
         {
             if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
                 shaderc_spvc_execution_model executionModel = ToSpvcExecutionModel(functionStage);
@@ -245,6 +276,11 @@
             out->needsStorageBufferLength = compiler->needs_buffer_size_buffer();
         }
 
+        if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
+            functionStage == SingleShaderStage::Vertex && GetUsedVertexAttributes().any()) {
+            out->needsStorageBufferLength = true;
+        }
+
         return {};
     }
 
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 956f3d4..7817d8a 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -337,6 +337,10 @@
     frameworks = [ "IOSurface.framework" ]
   }
 
+  if (dawn_enable_wgsl) {
+    sources += [ "end2end/VertexBufferRobustnessTests.cpp" ]
+  }
+
   if (dawn_enable_opengl) {
     assert(dawn_supports_glfw_for_windowing)
   }
diff --git a/src/tests/end2end/VertexBufferRobustnessTests.cpp b/src/tests/end2end/VertexBufferRobustnessTests.cpp
new file mode 100644
index 0000000..59d090b
--- /dev/null
+++ b/src/tests/end2end/VertexBufferRobustnessTests.cpp
@@ -0,0 +1,231 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "common/Assert.h"
+#include "common/Constants.h"
+#include "common/Math.h"
+#include "tests/DawnTest.h"
+#include "utils/ComboRenderPipelineDescriptor.h"
+#include "utils/WGPUHelpers.h"
+
+// Vertex buffer robustness tests that clamping is applied on vertex attributes. This would happen
+// on backends where vertex pulling is enabled, such as Metal.
+
+class VertexBufferRobustnessTest : public DawnTest {
+  protected:
+    void SetUp() override {
+        DawnTest::SetUp();
+        // SPVC must be used currently, since we rely on the robustness pass in it
+        DAWN_SKIP_TEST_IF(!IsSpvcBeingUsed());
+    }
+
+    // Creates a vertex module that tests an expression with given attributes. If successful, the
+    // point drawn would be moved out of the viewport. On failure, the point is kept inside the
+    // viewport.
+    wgpu::ShaderModule CreateVertexModule(const std::string& attributes,
+                                          const std::string& successExpression) {
+        return utils::CreateShaderModuleFromWGSL(device, (R"(
+                entry_point vertex as "main" = vtx_main;
+
+                )" + attributes + R"(
+                [[builtin position]] var<out> Position : vec4<f32>;
+
+                fn vtx_main() -> void {
+                    if ()" + successExpression + R"() {
+                        # Success case, move the vertex out of the viewport
+                        Position = vec4<f32>(-10.0, 0.0, 0.0, 1.0);
+                    } else {
+                        # Failure case, move the vertex inside the viewport
+                        Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+                    }
+                    return;
+                }
+            )")
+                                                             .c_str());
+    }
+
+    // Runs the test, a true |expectation| meaning success
+    void DoTest(const std::string& attributes,
+                const std::string& successExpression,
+                utils::ComboVertexStateDescriptor vertexState,
+                wgpu::Buffer vertexBuffer,
+                uint64_t bufferOffset,
+                bool expectation) {
+        wgpu::ShaderModule vsModule = CreateVertexModule(attributes, successExpression);
+        wgpu::ShaderModule fsModule = utils::CreateShaderModuleFromWGSL(device, R"(
+                entry_point fragment as "main" = frag_main;
+                [[location 0]] var<out> outColor : vec4<f32>;
+                fn frag_main() -> void {
+                    outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+                    return;
+                }
+            )");
+
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 1, 1);
+
+        utils::ComboRenderPipelineDescriptor descriptor(device);
+        descriptor.vertexStage.module = vsModule;
+        descriptor.cFragmentStage.module = fsModule;
+        descriptor.primitiveTopology = wgpu::PrimitiveTopology::PointList;
+        descriptor.cVertexState = std::move(vertexState);
+        descriptor.cColorStates[0].format = renderPass.colorFormat;
+        renderPass.renderPassInfo.cColorAttachments[0].clearColor = {0, 0, 0, 1};
+
+        wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&descriptor);
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetPipeline(pipeline);
+        pass.SetVertexBuffer(0, vertexBuffer, bufferOffset);
+        pass.Draw(1000);
+        pass.EndPass();
+
+        wgpu::CommandBuffer commands = encoder.Finish();
+        queue.Submit(1, &commands);
+
+        RGBA8 noOutput(0, 0, 0, 255);
+        RGBA8 someOutput(255, 255, 255, 255);
+        EXPECT_PIXEL_RGBA8_EQ(expectation ? noOutput : someOutput, renderPass.color, 0, 0);
+    }
+};
+
+TEST_P(VertexBufferRobustnessTest, DetectInvalidValues) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(float);
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::Float;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 0, so we see 111.0, leading to failure
+    float kVertices[] = {111.0, 473.0, 473.0};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : f32;", "a == 473.0", std::move(vertexState), vertexBuffer, 0,
+           false);
+}
+
+TEST_P(VertexBufferRobustnessTest, FloatClamp) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(float);
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::Float;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 4, so we clamp to only values containing 473.0
+    float kVertices[] = {111.0, 473.0, 473.0};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : f32;", "a == 473.0", std::move(vertexState), vertexBuffer, 4,
+           true);
+}
+
+TEST_P(VertexBufferRobustnessTest, IntClamp) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(int32_t);
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::Int;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 4, so we clamp to only values containing 473
+    int32_t kVertices[] = {111, 473, 473};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : i32;", "a == 473", std::move(vertexState), vertexBuffer, 4,
+           true);
+}
+
+TEST_P(VertexBufferRobustnessTest, UIntClamp) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(uint32_t);
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::UInt;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 4, so we clamp to only values containing 473
+    uint32_t kVertices[] = {111, 473, 473};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : u32;", "a == 473", std::move(vertexState), vertexBuffer, 4,
+           true);
+}
+
+TEST_P(VertexBufferRobustnessTest, Float2Clamp) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 2;
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::Float2;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 8, so we clamp to only values containing 473.0
+    float kVertices[] = {111.0, 111.0, 473.0, 473.0};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : vec2<f32>;", "a[0] == 473.0 && a[1] == 473.0",
+           std::move(vertexState), vertexBuffer, 8, true);
+}
+
+TEST_P(VertexBufferRobustnessTest, Float3Clamp) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 3;
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::Float3;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 12, so we clamp to only values containing 473.0
+    float kVertices[] = {111.0, 111.0, 111.0, 473.0, 473.0, 473.0};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : vec3<f32>;",
+           "a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0", std::move(vertexState), vertexBuffer,
+           12, true);
+}
+
+TEST_P(VertexBufferRobustnessTest, Float4Clamp) {
+    utils::ComboVertexStateDescriptor vertexState;
+    vertexState.vertexBufferCount = 1;
+    vertexState.cVertexBuffers[0].arrayStride = sizeof(float) * 4;
+    vertexState.cVertexBuffers[0].attributeCount = 1;
+    vertexState.cAttributes[0].format = wgpu::VertexFormat::Float4;
+    vertexState.cAttributes[0].offset = 0;
+    vertexState.cAttributes[0].shaderLocation = 0;
+
+    // Bind at an offset of 16, so we clamp to only values containing 473.0
+    float kVertices[] = {111.0, 111.0, 111.0, 111.0, 473.0, 473.0, 473.0, 473.0};
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices),
+                                                            wgpu::BufferUsage::Vertex);
+
+    DoTest("[[location 0]] var<in> a : vec4<f32>;",
+           "a[0] == 473.0 && a[1] == 473.0 && a[2] == 473.0 && a[3] == 473.0",
+           std::move(vertexState), vertexBuffer, 16, true);
+}
+
+DAWN_INSTANTIATE_TEST(VertexBufferRobustnessTest, MetalBackend({"metal_enable_vertex_pulling"}));
diff --git a/src/utils/WGPUHelpers.cpp b/src/utils/WGPUHelpers.cpp
index b9a51bb..94f34c8 100644
--- a/src/utils/WGPUHelpers.cpp
+++ b/src/utils/WGPUHelpers.cpp
@@ -144,6 +144,14 @@
         return CreateShaderModuleFromResult(device, result);
     }
 
+    wgpu::ShaderModule CreateShaderModuleFromWGSL(const wgpu::Device& device, const char* source) {
+        wgpu::ShaderModuleWGSLDescriptor wgslDesc;
+        wgslDesc.source = source;
+        wgpu::ShaderModuleDescriptor descriptor;
+        descriptor.nextInChain = &wgslDesc;
+        return device.CreateShaderModule(&descriptor);
+    }
+
     std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source) {
         shaderc_shader_kind kind = ShadercShaderKind(stage);
 
diff --git a/src/utils/WGPUHelpers.h b/src/utils/WGPUHelpers.h
index 03e618c..d0a66fd 100644
--- a/src/utils/WGPUHelpers.h
+++ b/src/utils/WGPUHelpers.h
@@ -34,6 +34,8 @@
                                           SingleShaderStage stage,
                                           const char* source);
     wgpu::ShaderModule CreateShaderModuleFromASM(const wgpu::Device& device, const char* source);
+    wgpu::ShaderModule CreateShaderModuleFromWGSL(const wgpu::Device& device, const char* source);
+
     std::vector<uint32_t> CompileGLSLToSpirv(SingleShaderStage stage, const char* source);
 
     wgpu::Buffer CreateBufferFromData(const wgpu::Device& device,