Input State Descriptorization

This change also removes InputState object.

BUG=dawn:107

Change-Id: Ia3fd2d348658f5719de0279bfe7bb10a4f183523
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/5660
Commit-Queue: Yunchao He <yunchao.he@intel.com>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index d975e76..7d80221 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -121,8 +121,6 @@
     "src/dawn_native/FenceSignalTracker.cpp",
     "src/dawn_native/FenceSignalTracker.h",
     "src/dawn_native/Forward.h",
-    "src/dawn_native/InputState.cpp",
-    "src/dawn_native/InputState.h",
     "src/dawn_native/Instance.cpp",
     "src/dawn_native/Instance.h",
     "src/dawn_native/ObjectBase.cpp",
@@ -184,8 +182,6 @@
       "src/dawn_native/d3d12/DeviceD3D12.cpp",
       "src/dawn_native/d3d12/DeviceD3D12.h",
       "src/dawn_native/d3d12/Forward.h",
-      "src/dawn_native/d3d12/InputStateD3D12.cpp",
-      "src/dawn_native/d3d12/InputStateD3D12.h",
       "src/dawn_native/d3d12/NativeSwapChainImplD3D12.cpp",
       "src/dawn_native/d3d12/NativeSwapChainImplD3D12.h",
       "src/dawn_native/d3d12/PipelineLayoutD3D12.cpp",
@@ -235,8 +231,6 @@
       "src/dawn_native/metal/DeviceMTL.h",
       "src/dawn_native/metal/DeviceMTL.mm",
       "src/dawn_native/metal/Forward.h",
-      "src/dawn_native/metal/InputStateMTL.h",
-      "src/dawn_native/metal/InputStateMTL.mm",
       "src/dawn_native/metal/PipelineLayoutMTL.h",
       "src/dawn_native/metal/PipelineLayoutMTL.mm",
       "src/dawn_native/metal/QueueMTL.h",
@@ -279,8 +273,6 @@
       "src/dawn_native/opengl/DeviceGL.cpp",
       "src/dawn_native/opengl/DeviceGL.h",
       "src/dawn_native/opengl/Forward.h",
-      "src/dawn_native/opengl/InputStateGL.cpp",
-      "src/dawn_native/opengl/InputStateGL.h",
       "src/dawn_native/opengl/PersistentPipelineStateGL.cpp",
       "src/dawn_native/opengl/PersistentPipelineStateGL.h",
       "src/dawn_native/opengl/PipelineGL.cpp",
@@ -326,8 +318,6 @@
       "src/dawn_native/vulkan/FencedDeleter.cpp",
       "src/dawn_native/vulkan/FencedDeleter.h",
       "src/dawn_native/vulkan/Forward.h",
-      "src/dawn_native/vulkan/InputStateVk.cpp",
-      "src/dawn_native/vulkan/InputStateVk.h",
       "src/dawn_native/vulkan/MemoryAllocator.cpp",
       "src/dawn_native/vulkan/MemoryAllocator.h",
       "src/dawn_native/vulkan/NativeSwapChainImplVk.cpp",
diff --git a/dawn.json b/dawn.json
index 2e88953..0b132a1 100644
--- a/dawn.json
+++ b/dawn.json
@@ -449,10 +449,6 @@
                 "returns": "command encoder"
             },
             {
-                "name": "create input state builder",
-                "returns": "input state builder"
-            },
-            {
                 "name": "create compute pipeline",
                 "returns": "compute pipeline",
                 "args": [
@@ -622,28 +618,14 @@
             {"name": "step mode", "type": "input step mode"}
         ]
     },
-    "input state": {
-        "category": "object"
-    },
-    "input state builder": {
-        "category": "object",
-        "methods": [
-            {
-                "name": "get result",
-                "returns": "input state"
-            },
-            {
-                "name": "set attribute",
-                "args": [
-                    {"name": "attribute", "type": "vertex attribute descriptor", "annotation": "const*"}
-                ]
-            },
-            {
-                "name": "set input",
-                "args": [
-                    {"name": "input", "type": "vertex input descriptor", "annotation": "const*"}
-                ]
-            }
+    "input state descriptor": {
+        "category": "structure",
+        "extensible": true,
+        "members": [
+            {"name": "num attributes", "type": "uint32_t"},
+            {"name": "attributes", "type": "vertex attribute descriptor", "annotation": "const*", "length": "num attributes"},
+            {"name": "num inputs", "type": "uint32_t"},
+            {"name": "inputs", "type": "vertex input descriptor", "annotation": "const*", "length": "num inputs"}
         ]
     },
     "input step mode": {
@@ -881,7 +863,7 @@
             {"name": "layout", "type": "pipeline layout"},
             {"name": "vertex stage", "type": "pipeline stage descriptor", "annotation": "const*"},
             {"name": "fragment stage", "type": "pipeline stage descriptor", "annotation": "const*"},
-            {"name": "input state", "type": "input state"},
+            {"name": "input state", "type": "input state descriptor", "annotation": "const*"},
             {"name": "index format", "type": "index format"},
             {"name": "primitive topology", "type": "primitive topology"},
             {"name": "sample count", "type": "uint32_t"},
diff --git a/examples/CHelloTriangle.cpp b/examples/CHelloTriangle.cpp
index 51e84ca..2a981ed 100644
--- a/examples/CHelloTriangle.cpp
+++ b/examples/CHelloTriangle.cpp
@@ -93,9 +93,13 @@
         pl.bindGroupLayouts = nullptr;
         descriptor.layout = dawnDeviceCreatePipelineLayout(device, &pl);
 
-        DawnInputStateBuilder inputStateBuilder = dawnDeviceCreateInputStateBuilder(device);
-        descriptor.inputState = dawnInputStateBuilderGetResult(inputStateBuilder);
-        dawnInputStateBuilderRelease(inputStateBuilder);
+        DawnInputStateDescriptor inputState;
+        inputState.nextInChain = nullptr;
+        inputState.numInputs = 0;
+        inputState.inputs = nullptr;
+        inputState.numAttributes = 0;
+        inputState.attributes = nullptr;
+        descriptor.inputState = &inputState;
 
         descriptor.indexFormat = DAWN_INDEX_FORMAT_UINT32;
         descriptor.primitiveTopology = DAWN_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
@@ -103,8 +107,6 @@
         descriptor.depthStencilState = nullptr;
 
         pipeline = dawnDeviceCreateRenderPipeline(device, &descriptor);
-
-        dawnInputStateRelease(descriptor.inputState);
     }
 
     dawnShaderModuleRelease(vsModule);
diff --git a/examples/ComputeBoids.cpp b/examples/ComputeBoids.cpp
index 196fc98..06b5ca5 100644
--- a/examples/ComputeBoids.cpp
+++ b/examples/ComputeBoids.cpp
@@ -114,48 +114,41 @@
         }
     )");
 
-    dawn::VertexAttributeDescriptor attribute1;
-    attribute1.shaderLocation = 0;
-    attribute1.inputSlot = 0;
-    attribute1.offset = offsetof(Particle, pos);
-    attribute1.format = dawn::VertexFormat::Float2;
+    dawn::VertexAttributeDescriptor attribute[3];
+    attribute[0].shaderLocation = 0;
+    attribute[0].inputSlot = 0;
+    attribute[0].offset = offsetof(Particle, pos);
+    attribute[0].format = dawn::VertexFormat::Float2;
 
-    dawn::VertexAttributeDescriptor attribute2;
-    attribute2.shaderLocation = 1;
-    attribute2.inputSlot = 0;
-    attribute2.offset = offsetof(Particle, vel);
-    attribute2.format = dawn::VertexFormat::Float2;
+    attribute[1].shaderLocation = 1;
+    attribute[1].inputSlot = 0;
+    attribute[1].offset = offsetof(Particle, vel);
+    attribute[1].format = dawn::VertexFormat::Float2;
 
-    dawn::VertexInputDescriptor input1;
-    input1.inputSlot = 0;
-    input1.stride = sizeof(Particle);
-    input1.stepMode = dawn::InputStepMode::Instance;
+    attribute[2].shaderLocation = 2;
+    attribute[2].inputSlot = 1;
+    attribute[2].offset = 0;
+    attribute[2].format = dawn::VertexFormat::Float2;
 
-    dawn::VertexAttributeDescriptor attribute3;
-    attribute3.shaderLocation = 2;
-    attribute3.inputSlot = 1;
-    attribute3.offset = 0;
-    attribute3.format = dawn::VertexFormat::Float2;
+    dawn::VertexInputDescriptor input[2];
+    input[0].inputSlot = 0;
+    input[0].stride = sizeof(Particle);
+    input[0].stepMode = dawn::InputStepMode::Instance;
 
-    dawn::VertexInputDescriptor input2;
-    input2.inputSlot = 1;
-    input2.stride = sizeof(glm::vec2);
-    input2.stepMode = dawn::InputStepMode::Vertex;
-
-    dawn::InputState inputState = device.CreateInputStateBuilder()
-                                      .SetAttribute(&attribute1)
-                                      .SetAttribute(&attribute2)
-                                      .SetInput(&input1)
-                                      .SetAttribute(&attribute3)
-                                      .SetInput(&input2)
-                                      .GetResult();
+    input[1].inputSlot = 1;
+    input[1].stride = sizeof(glm::vec2);
+    input[1].stepMode = dawn::InputStepMode::Vertex;
 
     depthStencilView = CreateDefaultDepthStencilView(device);
 
     utils::ComboRenderPipelineDescriptor descriptor(device);
     descriptor.cVertexStage.module = vsModule;
     descriptor.cFragmentStage.module = fsModule;
-    descriptor.inputState = inputState;
+
+    descriptor.cInputState.numAttributes = 3;
+    descriptor.cInputState.attributes = attribute;
+    descriptor.cInputState.numInputs = 2;
+    descriptor.cInputState.inputs = input;
     descriptor.depthStencilState = &descriptor.cDepthStencilState;
     descriptor.cDepthStencilState.format = dawn::TextureFormat::D32FloatS8Uint;
     descriptor.cColorStates[0]->format = GetPreferredSwapChainTextureFormat();
diff --git a/examples/CppHelloTriangle.cpp b/examples/CppHelloTriangle.cpp
index d431db2..ca756cf 100644
--- a/examples/CppHelloTriangle.cpp
+++ b/examples/CppHelloTriangle.cpp
@@ -122,9 +122,6 @@
     input.stride = 4 * sizeof(float);
     input.stepMode = dawn::InputStepMode::Vertex;
 
-    auto inputState =
-        device.CreateInputStateBuilder().SetAttribute(&attribute).SetInput(&input).GetResult();
-
     auto bgl = utils::MakeBindGroupLayout(
         device, {
                     {0, dawn::ShaderStageBit::Fragment, dawn::BindingType::Sampler},
@@ -139,7 +136,10 @@
     descriptor.layout = utils::MakeBasicPipelineLayout(device, &bgl);
     descriptor.cVertexStage.module = vsModule;
     descriptor.cFragmentStage.module = fsModule;
-    descriptor.inputState = inputState;
+    descriptor.cInputState.numAttributes = 1;
+    descriptor.cInputState.attributes = &attribute;
+    descriptor.cInputState.numInputs = 1;
+    descriptor.cInputState.inputs = &input;
     descriptor.depthStencilState = &descriptor.cDepthStencilState;
     descriptor.cDepthStencilState.format = dawn::TextureFormat::D32FloatS8Uint;
     descriptor.cColorStates[0]->format = GetPreferredSwapChainTextureFormat();
diff --git a/examples/CubeReflection.cpp b/examples/CubeReflection.cpp
index bba0e52..7a7423c 100644
--- a/examples/CubeReflection.cpp
+++ b/examples/CubeReflection.cpp
@@ -156,28 +156,27 @@
             fragColor = vec4(mix(f_col, vec3(0.5, 0.5, 0.5), 0.5), 1.0);
         })");
 
-    dawn::VertexAttributeDescriptor attribute1;
-    attribute1.shaderLocation = 0;
-    attribute1.inputSlot = 0;
-    attribute1.offset = 0;
-    attribute1.format = dawn::VertexFormat::Float3;
+    dawn::VertexAttributeDescriptor attribute[2];
+    attribute[0].shaderLocation = 0;
+    attribute[0].inputSlot = 0;
+    attribute[0].offset = 0;
+    attribute[0].format = dawn::VertexFormat::Float3;
 
-    dawn::VertexAttributeDescriptor attribute2;
-    attribute2.shaderLocation = 1;
-    attribute2.inputSlot = 0;
-    attribute2.offset = 3 * sizeof(float);
-    attribute2.format = dawn::VertexFormat::Float3;
+    attribute[1].shaderLocation = 1;
+    attribute[1].inputSlot = 0;
+    attribute[1].offset = 3 * sizeof(float);
+    attribute[1].format = dawn::VertexFormat::Float3;
 
     dawn::VertexInputDescriptor input;
     input.inputSlot = 0;
     input.stride = 6 * sizeof(float);
     input.stepMode = dawn::InputStepMode::Vertex;
 
-    auto inputState = device.CreateInputStateBuilder()
-                          .SetAttribute(&attribute1)
-                          .SetAttribute(&attribute2)
-                          .SetInput(&input)
-                          .GetResult();
+    dawn::InputStateDescriptor inputState;
+    inputState.numAttributes = 2;
+    inputState.attributes = attribute;
+    inputState.numInputs = 1;
+    inputState.inputs = &input;
 
     auto bgl = utils::MakeBindGroupLayout(
         device, {
@@ -214,7 +213,7 @@
     descriptor.layout = pl;
     descriptor.cVertexStage.module = vsModule;
     descriptor.cFragmentStage.module = fsModule;
-    descriptor.inputState = inputState;
+    descriptor.inputState = &inputState;
     descriptor.depthStencilState = &descriptor.cDepthStencilState;
     descriptor.cDepthStencilState.format = dawn::TextureFormat::D32FloatS8Uint;
     descriptor.cColorStates[0]->format = GetPreferredSwapChainTextureFormat();
@@ -227,7 +226,7 @@
     pDescriptor.layout = pl;
     pDescriptor.cVertexStage.module = vsModule;
     pDescriptor.cFragmentStage.module = fsModule;
-    pDescriptor.inputState = inputState;
+    pDescriptor.inputState = &inputState;
     pDescriptor.depthStencilState = &pDescriptor.cDepthStencilState;
     pDescriptor.cDepthStencilState.format = dawn::TextureFormat::D32FloatS8Uint;
     pDescriptor.cColorStates[0]->format = GetPreferredSwapChainTextureFormat();
@@ -241,7 +240,7 @@
     rfDescriptor.layout = pl;
     rfDescriptor.cVertexStage.module = vsModule;
     rfDescriptor.cFragmentStage.module = fsReflectionModule;
-    rfDescriptor.inputState = inputState;
+    rfDescriptor.inputState = &inputState;
     rfDescriptor.depthStencilState = &rfDescriptor.cDepthStencilState;
     rfDescriptor.cDepthStencilState.format = dawn::TextureFormat::D32FloatS8Uint;
     rfDescriptor.cColorStates[0]->format = GetPreferredSwapChainTextureFormat();
diff --git a/examples/glTFViewer/glTFViewer.cpp b/examples/glTFViewer/glTFViewer.cpp
index 41e836f..f3a8268 100644
--- a/examples/glTFViewer/glTFViewer.cpp
+++ b/examples/glTFViewer/glTFViewer.cpp
@@ -237,7 +237,11 @@
 
         auto oFSModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, hasTexture ? oFSSourceTextured : oFSSourceUntextured);
 
-        dawn::InputStateBuilder builder = device.CreateInputStateBuilder();
+        utils::ComboRenderPipelineDescriptor descriptor(device);
+        dawn::VertexAttributeDescriptor attributes[kMaxVertexAttributes];
+        dawn::VertexInputDescriptor inputs[kMaxVertexInputs];
+        uint32_t numAttributes = 0;
+        uint32_t numInputs = 0;
         std::bitset<3> slotsSet;
         for (const auto& a : iTechnique.attributes) {
             const auto iAttributeName = a.first;
@@ -247,60 +251,58 @@
                 fprintf(stderr, "unsupported technique parameter type %d\n", iParameter.type);
                 continue;
             }
-            dawn::VertexAttributeDescriptor attribute;
-            attribute.offset = 0;
-            attribute.format = format;
-            dawn::VertexInputDescriptor input;
-            input.stepMode = dawn::InputStepMode::Vertex;
+            attributes[numAttributes].offset = 0;
+            attributes[numAttributes].format = format;
+            inputs[numInputs].stepMode = dawn::InputStepMode::Vertex;
 
             if (iParameter.semantic == "POSITION") {
-                attribute.shaderLocation = 0;
-                attribute.inputSlot = 0;
-                input.inputSlot = 0;
-                input.stride = static_cast<uint32_t>(stridePos);
-                builder.SetAttribute(&attribute);
-                builder.SetInput(&input);
+                attributes[numAttributes].shaderLocation = 0;
+                attributes[numAttributes].inputSlot = 0;
+                inputs[numInputs].inputSlot = 0;
+                inputs[numInputs].stride = static_cast<uint32_t>(stridePos);
+                numAttributes++;
+                numInputs++;
                 slotsSet.set(0);
             } else if (iParameter.semantic == "NORMAL") {
-                attribute.shaderLocation = 1;
-                attribute.inputSlot = 1;
-                input.inputSlot = 1;
-                input.stride = static_cast<uint32_t>(strideNor);
-                builder.SetAttribute(&attribute);
-                builder.SetInput(&input);
+                attributes[numAttributes].shaderLocation = 1;
+                attributes[numAttributes].inputSlot = 1;
+                inputs[numInputs].inputSlot = 1;
+                inputs[numInputs].stride = static_cast<uint32_t>(strideNor);
+                numAttributes++;
+                numInputs++;
                 slotsSet.set(1);
             } else if (iParameter.semantic == "TEXCOORD_0") {
-                attribute.shaderLocation = 2;
-                attribute.inputSlot = 2;
-                input.inputSlot = 2;
-                input.stride = static_cast<uint32_t>(strideTxc);
-                builder.SetAttribute(&attribute);
-                builder.SetInput(&input);
+                attributes[numAttributes].shaderLocation = 2;
+                attributes[numAttributes].inputSlot = 2;
+                inputs[numInputs].inputSlot = 2;
+                inputs[numInputs].stride = static_cast<uint32_t>(strideTxc);
+                numAttributes++;
+                numInputs++;
                 slotsSet.set(2);
             } else {
                 fprintf(stderr, "unsupported technique attribute semantic %s\n", iParameter.semantic.c_str());
             }
-            // TODO: use iAttributeParameter.node?
         }
         for (uint32_t i = 0; i < slotsSet.size(); i++) {
             if (slotsSet[i]) {
                 continue;
             }
-            dawn::VertexAttributeDescriptor attribute;
-            attribute.offset = 0;
-            attribute.shaderLocation = i;
-            attribute.inputSlot = i;
-            attribute.format = dawn::VertexFormat::Float4;
+            attributes[numAttributes].offset = 0;
+            attributes[numAttributes].shaderLocation = i;
+            attributes[numAttributes].inputSlot = i;
+            attributes[numAttributes].format = dawn::VertexFormat::Float4;
 
-            dawn::VertexInputDescriptor input;
-            input.inputSlot = i;
-            input.stride = 0;
-            input.stepMode = dawn::InputStepMode::Vertex;
+            inputs[numInputs].inputSlot = i;
+            inputs[numInputs].stride = 0;
+            inputs[numInputs].stepMode = dawn::InputStepMode::Vertex;
 
-            builder.SetAttribute(&attribute);
-            builder.SetInput(&input);
+            numAttributes++;
+            numInputs++;
         }
-        auto inputState = builder.GetResult();
+        descriptor.cInputState.numAttributes = numAttributes;
+        descriptor.cInputState.attributes = attributes;
+        descriptor.cInputState.numInputs = numInputs;
+        descriptor.cInputState.inputs = inputs;
 
         constexpr dawn::ShaderStageBit kNoStages{};
         dawn::BindGroupLayout bindGroupLayout = utils::MakeBindGroupLayout(
@@ -313,11 +315,9 @@
 
         auto pipelineLayout = utils::MakeBasicPipelineLayout(device, &bindGroupLayout);
 
-        utils::ComboRenderPipelineDescriptor descriptor(device);
         descriptor.layout = pipelineLayout;
         descriptor.cVertexStage.module = oVSModule;
         descriptor.cFragmentStage.module = oFSModule;
-        descriptor.inputState = inputState;
         descriptor.indexFormat = dawn::IndexFormat::Uint16;
         descriptor.depthStencilState = &descriptor.cDepthStencilState;
         descriptor.cDepthStencilState.format = dawn::TextureFormat::D32FloatS8Uint;
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 9a7fdd9..77c14d2 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -19,7 +19,6 @@
 #include "dawn_native/BindGroup.h"
 #include "dawn_native/ComputePipeline.h"
 #include "dawn_native/Forward.h"
-#include "dawn_native/InputState.h"
 #include "dawn_native/PipelineLayout.h"
 #include "dawn_native/RenderPipeline.h"
 
@@ -106,7 +105,7 @@
         if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
             ASSERT(mLastRenderPipeline != nullptr);
 
-            auto requiredInputs = mLastRenderPipeline->GetInputState()->GetInputsSetMask();
+            auto requiredInputs = mLastRenderPipeline->GetInputsSetMask();
             if ((mInputsSet & requiredInputs) == requiredInputs) {
                 mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
             }
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 0448464..44be88a 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -25,7 +25,6 @@
 #include "dawn_native/ErrorData.h"
 #include "dawn_native/Fence.h"
 #include "dawn_native/FenceSignalTracker.h"
-#include "dawn_native/InputState.h"
 #include "dawn_native/PipelineLayout.h"
 #include "dawn_native/Queue.h"
 #include "dawn_native/RenderPipeline.h"
@@ -159,9 +158,6 @@
 
         return result;
     }
-    InputStateBuilder* DeviceBase::CreateInputStateBuilder() {
-        return new InputStateBuilder(this);
-    }
     PipelineLayoutBase* DeviceBase::CreatePipelineLayout(
         const PipelineLayoutDescriptor* descriptor) {
         PipelineLayoutBase* result = nullptr;
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index d1c7361..188dd6c 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -60,7 +60,6 @@
         FenceSignalTracker* GetFenceSignalTracker() const;
 
         virtual CommandBufferBase* CreateCommandBuffer(CommandEncoderBase* encoder) = 0;
-        virtual InputStateBase* CreateInputState(InputStateBuilder* builder) = 0;
 
         virtual Serial GetCompletedCommandSerial() const = 0;
         virtual Serial GetLastSubmittedCommandSerial() const = 0;
@@ -91,7 +90,6 @@
         BufferBase* CreateBuffer(const BufferDescriptor* descriptor);
         CommandEncoderBase* CreateCommandEncoder();
         ComputePipelineBase* CreateComputePipeline(const ComputePipelineDescriptor* descriptor);
-        InputStateBuilder* CreateInputStateBuilder();
         PipelineLayoutBase* CreatePipelineLayout(const PipelineLayoutDescriptor* descriptor);
         QueueBase* CreateQueue();
         RenderPipelineBase* CreateRenderPipeline(const RenderPipelineDescriptor* descriptor);
diff --git a/src/dawn_native/Forward.h b/src/dawn_native/Forward.h
index c11dc24..3f32c4a 100644
--- a/src/dawn_native/Forward.h
+++ b/src/dawn_native/Forward.h
@@ -28,8 +28,6 @@
     class CommandEncoderBase;
     class ComputePassEncoderBase;
     class FenceBase;
-    class InputStateBase;
-    class InputStateBuilder;
     class InstanceBase;
     class PipelineBase;
     class PipelineLayoutBase;
diff --git a/src/dawn_native/InputState.cpp b/src/dawn_native/InputState.cpp
deleted file mode 100644
index dce0a43..0000000
--- a/src/dawn_native/InputState.cpp
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "dawn_native/InputState.h"
-
-#include "common/Assert.h"
-#include "dawn_native/Device.h"
-#include "dawn_native/ValidationUtils_autogen.h"
-
-namespace dawn_native {
-
-    // InputState helpers
-
-    size_t IndexFormatSize(dawn::IndexFormat format) {
-        switch (format) {
-            case dawn::IndexFormat::Uint16:
-                return sizeof(uint16_t);
-            case dawn::IndexFormat::Uint32:
-                return sizeof(uint32_t);
-            default:
-                UNREACHABLE();
-        }
-    }
-
-    // TODO(shaobo.yan@intel.com): Add end2end test to cover all the formats.
-    uint32_t VertexFormatNumComponents(dawn::VertexFormat format) {
-        switch (format) {
-            case dawn::VertexFormat::UChar4:
-            case dawn::VertexFormat::Char4:
-            case dawn::VertexFormat::UChar4Norm:
-            case dawn::VertexFormat::Char4Norm:
-            case dawn::VertexFormat::UShort4:
-            case dawn::VertexFormat::Short4:
-            case dawn::VertexFormat::UShort4Norm:
-            case dawn::VertexFormat::Short4Norm:
-            case dawn::VertexFormat::Half4:
-            case dawn::VertexFormat::Float4:
-            case dawn::VertexFormat::UInt4:
-            case dawn::VertexFormat::Int4:
-                return 4;
-            case dawn::VertexFormat::Float3:
-            case dawn::VertexFormat::UInt3:
-            case dawn::VertexFormat::Int3:
-                return 3;
-            case dawn::VertexFormat::UChar2:
-            case dawn::VertexFormat::Char2:
-            case dawn::VertexFormat::UChar2Norm:
-            case dawn::VertexFormat::Char2Norm:
-            case dawn::VertexFormat::UShort2:
-            case dawn::VertexFormat::Short2:
-            case dawn::VertexFormat::UShort2Norm:
-            case dawn::VertexFormat::Short2Norm:
-            case dawn::VertexFormat::Half2:
-            case dawn::VertexFormat::Float2:
-            case dawn::VertexFormat::UInt2:
-            case dawn::VertexFormat::Int2:
-                return 2;
-            case dawn::VertexFormat::Float:
-            case dawn::VertexFormat::UInt:
-            case dawn::VertexFormat::Int:
-                return 1;
-            default:
-                UNREACHABLE();
-        }
-    }
-
-    size_t VertexFormatComponentSize(dawn::VertexFormat format) {
-        switch (format) {
-            case dawn::VertexFormat::UChar2:
-            case dawn::VertexFormat::UChar4:
-            case dawn::VertexFormat::Char2:
-            case dawn::VertexFormat::Char4:
-            case dawn::VertexFormat::UChar2Norm:
-            case dawn::VertexFormat::UChar4Norm:
-            case dawn::VertexFormat::Char2Norm:
-            case dawn::VertexFormat::Char4Norm:
-                return sizeof(char);
-            case dawn::VertexFormat::UShort2:
-            case dawn::VertexFormat::UShort4:
-            case dawn::VertexFormat::UShort2Norm:
-            case dawn::VertexFormat::UShort4Norm:
-            case dawn::VertexFormat::Short2:
-            case dawn::VertexFormat::Short4:
-            case dawn::VertexFormat::Short2Norm:
-            case dawn::VertexFormat::Short4Norm:
-            case dawn::VertexFormat::Half2:
-            case dawn::VertexFormat::Half4:
-                return sizeof(uint16_t);
-            case dawn::VertexFormat::Float:
-            case dawn::VertexFormat::Float2:
-            case dawn::VertexFormat::Float3:
-            case dawn::VertexFormat::Float4:
-                return sizeof(float);
-            case dawn::VertexFormat::UInt:
-            case dawn::VertexFormat::UInt2:
-            case dawn::VertexFormat::UInt3:
-            case dawn::VertexFormat::UInt4:
-            case dawn::VertexFormat::Int:
-            case dawn::VertexFormat::Int2:
-            case dawn::VertexFormat::Int3:
-            case dawn::VertexFormat::Int4:
-                return sizeof(int32_t);
-            default:
-                UNREACHABLE();
-        }
-    }
-
-    size_t VertexFormatSize(dawn::VertexFormat format) {
-        return VertexFormatNumComponents(format) * VertexFormatComponentSize(format);
-    }
-
-    // InputStateBase
-
-    InputStateBase::InputStateBase(InputStateBuilder* builder) : ObjectBase(builder->GetDevice()) {
-        mAttributesSetMask = builder->mAttributesSetMask;
-        mAttributeInfos = builder->mAttributeInfos;
-        mInputsSetMask = builder->mInputsSetMask;
-        mInputInfos = builder->mInputInfos;
-    }
-
-    const std::bitset<kMaxVertexAttributes>& InputStateBase::GetAttributesSetMask() const {
-        return mAttributesSetMask;
-    }
-
-    const VertexAttributeDescriptor& InputStateBase::GetAttribute(uint32_t location) const {
-        ASSERT(mAttributesSetMask[location]);
-        return mAttributeInfos[location];
-    }
-
-    const std::bitset<kMaxVertexInputs>& InputStateBase::GetInputsSetMask() const {
-        return mInputsSetMask;
-    }
-
-    const VertexInputDescriptor& InputStateBase::GetInput(uint32_t slot) const {
-        ASSERT(mInputsSetMask[slot]);
-        return mInputInfos[slot];
-    }
-
-    // InputStateBuilder
-
-    InputStateBuilder::InputStateBuilder(DeviceBase* device) : Builder(device) {
-    }
-
-    InputStateBase* InputStateBuilder::GetResultImpl() {
-        for (uint32_t location = 0; location < kMaxVertexAttributes; ++location) {
-            if (mAttributesSetMask[location] &&
-                !mInputsSetMask[mAttributeInfos[location].inputSlot]) {
-                HandleError("Attribute uses unset input");
-                return nullptr;
-            }
-        }
-
-        return GetDevice()->CreateInputState(this);
-    }
-
-    void InputStateBuilder::SetAttribute(const VertexAttributeDescriptor* attribute) {
-        if (attribute->shaderLocation >= kMaxVertexAttributes) {
-            HandleError("Setting attribute out of bounds");
-            return;
-        }
-        if (attribute->inputSlot >= kMaxVertexInputs) {
-            HandleError("Binding slot out of bounds");
-            return;
-        }
-        if (GetDevice()->ConsumedError(ValidateVertexFormat(attribute->format))) {
-            return;
-        }
-        // If attribute->offset is close to 0xFFFFFFFF, the validation below to add
-        // attribute->offset and VertexFormatSize(attribute->format) might overflow on a
-        // 32bit machine, then it can pass the validation incorrectly. We need to catch it.
-        if (attribute->offset >= kMaxVertexAttributeEnd) {
-            HandleError("Setting attribute offset out of bounds");
-            return;
-        }
-        if (attribute->offset + VertexFormatSize(attribute->format) > kMaxVertexAttributeEnd) {
-            HandleError("Setting attribute offset out of bounds");
-            return;
-        }
-        if (mAttributesSetMask[attribute->shaderLocation]) {
-            HandleError("Setting already set attribute");
-            return;
-        }
-
-        mAttributesSetMask.set(attribute->shaderLocation);
-        mAttributeInfos[attribute->shaderLocation] = *attribute;
-    }
-
-    void InputStateBuilder::SetInput(const VertexInputDescriptor* input) {
-        if (input->inputSlot >= kMaxVertexInputs) {
-            HandleError("Setting input out of bounds");
-            return;
-        }
-        if (input->stride > kMaxVertexInputStride) {
-            HandleError("Setting input stride out of bounds");
-            return;
-        }
-        if (mInputsSetMask[input->inputSlot]) {
-            HandleError("Setting already set input");
-            return;
-        }
-
-        mInputsSetMask.set(input->inputSlot);
-        mInputInfos[input->inputSlot] = *input;
-    }
-
-}  // namespace dawn_native
diff --git a/src/dawn_native/InputState.h b/src/dawn_native/InputState.h
deleted file mode 100644
index adf5d98..0000000
--- a/src/dawn_native/InputState.h
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef DAWNNATIVE_INPUTSTATE_H_
-#define DAWNNATIVE_INPUTSTATE_H_
-
-#include "common/Constants.h"
-#include "dawn_native/Builder.h"
-#include "dawn_native/Forward.h"
-#include "dawn_native/ObjectBase.h"
-
-#include "dawn_native/dawn_platform.h"
-
-#include <array>
-#include <bitset>
-
-namespace dawn_native {
-
-    size_t IndexFormatSize(dawn::IndexFormat format);
-    uint32_t VertexFormatNumComponents(dawn::VertexFormat format);
-    size_t VertexFormatComponentSize(dawn::VertexFormat format);
-    size_t VertexFormatSize(dawn::VertexFormat format);
-
-    class InputStateBase : public ObjectBase {
-      public:
-        InputStateBase(InputStateBuilder* builder);
-
-        const std::bitset<kMaxVertexAttributes>& GetAttributesSetMask() const;
-        const VertexAttributeDescriptor& GetAttribute(uint32_t location) const;
-        const std::bitset<kMaxVertexInputs>& GetInputsSetMask() const;
-        const VertexInputDescriptor& GetInput(uint32_t slot) const;
-
-      private:
-        std::bitset<kMaxVertexAttributes> mAttributesSetMask;
-        std::array<VertexAttributeDescriptor, kMaxVertexAttributes> mAttributeInfos;
-        std::bitset<kMaxVertexInputs> mInputsSetMask;
-        std::array<VertexInputDescriptor, kMaxVertexInputs> mInputInfos;
-    };
-
-    class InputStateBuilder : public Builder<InputStateBase> {
-      public:
-        InputStateBuilder(DeviceBase* device);
-
-        // Dawn API
-        void SetAttribute(const VertexAttributeDescriptor* attribute);
-        void SetInput(const VertexInputDescriptor* input);
-
-      private:
-        friend class InputStateBase;
-
-        InputStateBase* GetResultImpl() override;
-
-        std::bitset<kMaxVertexAttributes> mAttributesSetMask;
-        std::array<VertexAttributeDescriptor, kMaxVertexAttributes> mAttributeInfos;
-        std::bitset<kMaxVertexInputs> mInputsSetMask;
-        std::array<VertexInputDescriptor, kMaxVertexInputs> mInputInfos;
-    };
-
-}  // namespace dawn_native
-
-#endif  // DAWNNATIVE_INPUTSTATE_H_
diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp
index 6ddb742..257bf76 100644
--- a/src/dawn_native/Pipeline.cpp
+++ b/src/dawn_native/Pipeline.cpp
@@ -15,7 +15,6 @@
 #include "dawn_native/Pipeline.h"
 
 #include "dawn_native/Device.h"
-#include "dawn_native/InputState.h"
 #include "dawn_native/PipelineLayout.h"
 #include "dawn_native/ShaderModule.h"
 
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index ca04eab..58186f2 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -17,7 +17,6 @@
 #include "common/BitSetIterator.h"
 #include "dawn_native/Commands.h"
 #include "dawn_native/Device.h"
-#include "dawn_native/InputState.h"
 #include "dawn_native/Texture.h"
 #include "dawn_native/ValidationUtils_autogen.h"
 
@@ -25,6 +24,77 @@
     // Helper functions
     namespace {
 
+        MaybeError ValidateVertexInputDescriptor(const VertexInputDescriptor* input,
+                                                 std::bitset<kMaxVertexInputs>* inputsSetMask) {
+            DAWN_TRY(ValidateInputStepMode(input->stepMode));
+            if (input->inputSlot >= kMaxVertexInputs) {
+                return DAWN_VALIDATION_ERROR("Setting input out of bounds");
+            }
+            if (input->stride > kMaxVertexInputStride) {
+                return DAWN_VALIDATION_ERROR("Setting input stride out of bounds");
+            }
+            if ((*inputsSetMask)[input->inputSlot]) {
+                return DAWN_VALIDATION_ERROR("Setting already set input");
+            }
+
+            inputsSetMask->set(input->inputSlot);
+            return {};
+        }
+
+        MaybeError ValidateVertexAttributeDescriptor(
+            const VertexAttributeDescriptor* attribute,
+            const std::bitset<kMaxVertexInputs>* inputsSetMask,
+            std::bitset<kMaxVertexAttributes>* attributesSetMask) {
+            DAWN_TRY(ValidateVertexFormat(attribute->format));
+
+            if (attribute->shaderLocation >= kMaxVertexAttributes) {
+                return DAWN_VALIDATION_ERROR("Setting attribute out of bounds");
+            }
+            if (attribute->inputSlot >= kMaxVertexInputs) {
+                return DAWN_VALIDATION_ERROR("Binding slot out of bounds");
+            }
+            ASSERT(kMaxVertexAttributeEnd >= VertexFormatSize(attribute->format));
+            if (attribute->offset > kMaxVertexAttributeEnd - VertexFormatSize(attribute->format)) {
+                return DAWN_VALIDATION_ERROR("Setting attribute offset out of bounds");
+            }
+            if ((*attributesSetMask)[attribute->shaderLocation]) {
+                return DAWN_VALIDATION_ERROR("Setting already set attribute");
+            }
+            if (!(*inputsSetMask)[attribute->inputSlot]) {
+                return DAWN_VALIDATION_ERROR(
+                    "Vertex attribute slot doesn't match any vertex input slot");
+            }
+
+            attributesSetMask->set(attribute->shaderLocation);
+            return {};
+        }
+
+        MaybeError ValidateInputStateDescriptor(
+            const InputStateDescriptor* descriptor,
+            std::bitset<kMaxVertexInputs>* inputsSetMask,
+            std::bitset<kMaxVertexAttributes>* attributesSetMask) {
+            if (descriptor->nextInChain != nullptr) {
+                return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
+            }
+            if (descriptor->numInputs > kMaxVertexInputs) {
+                return DAWN_VALIDATION_ERROR("Vertex Inputs number exceeds maximum");
+            }
+            if (descriptor->numAttributes > kMaxVertexAttributes) {
+                return DAWN_VALIDATION_ERROR("Vertex Attributes number exceeds maximum");
+            }
+
+            for (uint32_t i = 0; i < descriptor->numInputs; ++i) {
+                DAWN_TRY(ValidateVertexInputDescriptor(&descriptor->inputs[i], inputsSetMask));
+            }
+
+            for (uint32_t i = 0; i < descriptor->numAttributes; ++i) {
+                DAWN_TRY(ValidateVertexAttributeDescriptor(&descriptor->attributes[i],
+                                                           inputsSetMask, attributesSetMask));
+            }
+
+            return {};
+        }
+
         MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
                                                    const PipelineStageDescriptor* descriptor,
                                                    const PipelineLayoutBase* layout,
@@ -91,6 +161,104 @@
 
     }  // anonymous namespace
 
+    // Helper functions
+    size_t IndexFormatSize(dawn::IndexFormat format) {
+        switch (format) {
+            case dawn::IndexFormat::Uint16:
+                return sizeof(uint16_t);
+            case dawn::IndexFormat::Uint32:
+                return sizeof(uint32_t);
+            default:
+                UNREACHABLE();
+        }
+    }
+
+    uint32_t VertexFormatNumComponents(dawn::VertexFormat format) {
+        switch (format) {
+            case dawn::VertexFormat::UChar4:
+            case dawn::VertexFormat::Char4:
+            case dawn::VertexFormat::UChar4Norm:
+            case dawn::VertexFormat::Char4Norm:
+            case dawn::VertexFormat::UShort4:
+            case dawn::VertexFormat::Short4:
+            case dawn::VertexFormat::UShort4Norm:
+            case dawn::VertexFormat::Short4Norm:
+            case dawn::VertexFormat::Half4:
+            case dawn::VertexFormat::Float4:
+            case dawn::VertexFormat::UInt4:
+            case dawn::VertexFormat::Int4:
+                return 4;
+            case dawn::VertexFormat::Float3:
+            case dawn::VertexFormat::UInt3:
+            case dawn::VertexFormat::Int3:
+                return 3;
+            case dawn::VertexFormat::UChar2:
+            case dawn::VertexFormat::Char2:
+            case dawn::VertexFormat::UChar2Norm:
+            case dawn::VertexFormat::Char2Norm:
+            case dawn::VertexFormat::UShort2:
+            case dawn::VertexFormat::Short2:
+            case dawn::VertexFormat::UShort2Norm:
+            case dawn::VertexFormat::Short2Norm:
+            case dawn::VertexFormat::Half2:
+            case dawn::VertexFormat::Float2:
+            case dawn::VertexFormat::UInt2:
+            case dawn::VertexFormat::Int2:
+                return 2;
+            case dawn::VertexFormat::Float:
+            case dawn::VertexFormat::UInt:
+            case dawn::VertexFormat::Int:
+                return 1;
+            default:
+                UNREACHABLE();
+        }
+    }
+
+    size_t VertexFormatComponentSize(dawn::VertexFormat format) {
+        switch (format) {
+            case dawn::VertexFormat::UChar2:
+            case dawn::VertexFormat::UChar4:
+            case dawn::VertexFormat::Char2:
+            case dawn::VertexFormat::Char4:
+            case dawn::VertexFormat::UChar2Norm:
+            case dawn::VertexFormat::UChar4Norm:
+            case dawn::VertexFormat::Char2Norm:
+            case dawn::VertexFormat::Char4Norm:
+                return sizeof(char);
+            case dawn::VertexFormat::UShort2:
+            case dawn::VertexFormat::UShort4:
+            case dawn::VertexFormat::UShort2Norm:
+            case dawn::VertexFormat::UShort4Norm:
+            case dawn::VertexFormat::Short2:
+            case dawn::VertexFormat::Short4:
+            case dawn::VertexFormat::Short2Norm:
+            case dawn::VertexFormat::Short4Norm:
+            case dawn::VertexFormat::Half2:
+            case dawn::VertexFormat::Half4:
+                return sizeof(uint16_t);
+            case dawn::VertexFormat::Float:
+            case dawn::VertexFormat::Float2:
+            case dawn::VertexFormat::Float3:
+            case dawn::VertexFormat::Float4:
+                return sizeof(float);
+            case dawn::VertexFormat::UInt:
+            case dawn::VertexFormat::UInt2:
+            case dawn::VertexFormat::UInt3:
+            case dawn::VertexFormat::UInt4:
+            case dawn::VertexFormat::Int:
+            case dawn::VertexFormat::Int2:
+            case dawn::VertexFormat::Int3:
+            case dawn::VertexFormat::Int4:
+                return sizeof(int32_t);
+            default:
+                UNREACHABLE();
+        }
+    }
+
+    size_t VertexFormatSize(dawn::VertexFormat format) {
+        return VertexFormatNumComponents(format) * VertexFormatComponentSize(format);
+    }
+
     MaybeError ValidateRenderPipelineDescriptor(DeviceBase* device,
                                                 const RenderPipelineDescriptor* descriptor) {
         if (descriptor->nextInChain != nullptr) {
@@ -104,14 +272,17 @@
         }
 
         DAWN_TRY(ValidateIndexFormat(descriptor->indexFormat));
+        std::bitset<kMaxVertexInputs> inputsSetMask;
+        std::bitset<kMaxVertexAttributes> attributesSetMask;
+        DAWN_TRY(ValidateInputStateDescriptor(descriptor->inputState, &inputsSetMask,
+                                              &attributesSetMask));
         DAWN_TRY(ValidatePrimitiveTopology(descriptor->primitiveTopology));
         DAWN_TRY(ValidatePipelineStageDescriptor(device, descriptor->vertexStage,
                                                  descriptor->layout, dawn::ShaderStage::Vertex));
         DAWN_TRY(ValidatePipelineStageDescriptor(device, descriptor->fragmentStage,
                                                  descriptor->layout, dawn::ShaderStage::Fragment));
 
-        if ((descriptor->vertexStage->module->GetUsedVertexAttributes() &
-             ~descriptor->inputState->GetAttributesSetMask())
+        if ((descriptor->vertexStage->module->GetUsedVertexAttributes() & ~attributesSetMask)
                 .any()) {
             return DAWN_VALIDATION_ERROR(
                 "Pipeline vertex stage uses inputs not in the input state");
@@ -168,10 +339,23 @@
                        descriptor->layout,
                        dawn::ShaderStageBit::Vertex | dawn::ShaderStageBit::Fragment),
           mIndexFormat(descriptor->indexFormat),
-          mInputState(descriptor->inputState),
+          mInputState(*descriptor->inputState),
           mPrimitiveTopology(descriptor->primitiveTopology),
           mHasDepthStencilAttachment(descriptor->depthStencilState != nullptr),
           mSampleCount(descriptor->sampleCount) {
+        uint32_t location = 0;
+        for (uint32_t i = 0; i < mInputState.numAttributes; ++i) {
+            location = mInputState.attributes[i].shaderLocation;
+            mAttributesSetMask.set(location);
+            mAttributeInfos[location] = mInputState.attributes[i];
+        }
+        uint32_t slot = 0;
+        for (uint32_t i = 0; i < mInputState.numInputs; ++i) {
+            slot = mInputState.inputs[i].inputSlot;
+            mInputsSetMask.set(slot);
+            mInputInfos[slot] = mInputState.inputs[i];
+        }
+
         if (mHasDepthStencilAttachment) {
             mDepthStencilState = *descriptor->depthStencilState;
         } else {
@@ -213,6 +397,33 @@
         return new RenderPipelineBase(device, ObjectBase::kError);
     }
 
+    const InputStateDescriptor* RenderPipelineBase::GetInputStateDescriptor() const {
+        ASSERT(!IsError());
+        return &mInputState;
+    }
+
+    const std::bitset<kMaxVertexAttributes>& RenderPipelineBase::GetAttributesSetMask() const {
+        ASSERT(!IsError());
+        return mAttributesSetMask;
+    }
+
+    const VertexAttributeDescriptor& RenderPipelineBase::GetAttribute(uint32_t location) const {
+        ASSERT(!IsError());
+        ASSERT(mAttributesSetMask[location]);
+        return mAttributeInfos[location];
+    }
+
+    const std::bitset<kMaxVertexInputs>& RenderPipelineBase::GetInputsSetMask() const {
+        ASSERT(!IsError());
+        return mInputsSetMask;
+    }
+
+    const VertexInputDescriptor& RenderPipelineBase::GetInput(uint32_t slot) const {
+        ASSERT(!IsError());
+        ASSERT(mInputsSetMask[slot]);
+        return mInputInfos[slot];
+    }
+
     const ColorStateDescriptor* RenderPipelineBase::GetColorStateDescriptor(
         uint32_t attachmentSlot) {
         ASSERT(!IsError());
@@ -230,11 +441,6 @@
         return mIndexFormat;
     }
 
-    InputStateBase* RenderPipelineBase::GetInputState() {
-        ASSERT(!IsError());
-        return mInputState.Get();
-    }
-
     dawn::PrimitiveTopology RenderPipelineBase::GetPrimitiveTopology() const {
         ASSERT(!IsError());
         return mPrimitiveTopology;
@@ -293,4 +499,10 @@
         return true;
     }
 
+    std::bitset<kMaxVertexAttributes> RenderPipelineBase::GetAttributesUsingInput(
+        uint32_t slot) const {
+        ASSERT(!IsError());
+        return attributesUsingInput[slot];
+    }
+
 }  // namespace dawn_native
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index c2f781e..e46fc04 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -15,7 +15,6 @@
 #ifndef DAWNNATIVE_RENDERPIPELINE_H_
 #define DAWNNATIVE_RENDERPIPELINE_H_
 
-#include "dawn_native/InputState.h"
 #include "dawn_native/Pipeline.h"
 
 #include "dawn_native/dawn_platform.h"
@@ -31,6 +30,11 @@
 
     MaybeError ValidateRenderPipelineDescriptor(DeviceBase* device,
                                                 const RenderPipelineDescriptor* descriptor);
+    size_t IndexFormatSize(dawn::IndexFormat format);
+    uint32_t VertexFormatNumComponents(dawn::VertexFormat format);
+    size_t VertexFormatComponentSize(dawn::VertexFormat format);
+    size_t VertexFormatSize(dawn::VertexFormat format);
+
     bool StencilTestEnabled(const DepthStencilStateDescriptor* mDepthStencilState);
     bool BlendEnabled(const ColorStateDescriptor* mColorState);
 
@@ -40,10 +44,15 @@
 
         static RenderPipelineBase* MakeError(DeviceBase* device);
 
+        const InputStateDescriptor* GetInputStateDescriptor() const;
+        const std::bitset<kMaxVertexAttributes>& GetAttributesSetMask() const;
+        const VertexAttributeDescriptor& GetAttribute(uint32_t location) const;
+        const std::bitset<kMaxVertexInputs>& GetInputsSetMask() const;
+        const VertexInputDescriptor& GetInput(uint32_t slot) const;
+
         const ColorStateDescriptor* GetColorStateDescriptor(uint32_t attachmentSlot);
         const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor();
         dawn::IndexFormat GetIndexFormat() const;
-        InputStateBase* GetInputState();
         dawn::PrimitiveTopology GetPrimitiveTopology() const;
 
         std::bitset<kMaxColorAttachments> GetColorAttachmentsMask() const;
@@ -54,14 +63,20 @@
         // A pipeline can be used in a render pass if its attachment info matches the actual
         // attachments in the render pass. This returns whether it is the case.
         bool IsCompatibleWith(const BeginRenderPassCmd* renderPassCmd) const;
+        std::bitset<kMaxVertexAttributes> GetAttributesUsingInput(uint32_t slot) const;
+        std::array<std::bitset<kMaxVertexAttributes>, kMaxVertexInputs> attributesUsingInput;
 
       private:
         RenderPipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
-        DepthStencilStateDescriptor mDepthStencilState;
         dawn::IndexFormat mIndexFormat;
-        Ref<InputStateBase> mInputState;
+        InputStateDescriptor mInputState;
+        std::bitset<kMaxVertexAttributes> mAttributesSetMask;
+        std::array<VertexAttributeDescriptor, kMaxVertexAttributes> mAttributeInfos;
+        std::bitset<kMaxVertexInputs> mInputsSetMask;
+        std::array<VertexInputDescriptor, kMaxVertexInputs> mInputInfos;
         dawn::PrimitiveTopology mPrimitiveTopology;
+        DepthStencilStateDescriptor mDepthStencilState;
         std::array<ColorStateDescriptor, kMaxColorAttachments> mColorStates;
 
         std::bitset<kMaxColorAttachments> mColorAttachmentsSet;
diff --git a/src/dawn_native/ToBackend.h b/src/dawn_native/ToBackend.h
index 1e09b1c..4f11fd4 100644
--- a/src/dawn_native/ToBackend.h
+++ b/src/dawn_native/ToBackend.h
@@ -59,11 +59,6 @@
     };
 
     template <typename BackendTraits>
-    struct ToBackendTraits<InputStateBase, BackendTraits> {
-        using BackendType = typename BackendTraits::InputStateType;
-    };
-
-    template <typename BackendTraits>
     struct ToBackendTraits<PipelineLayoutBase, BackendTraits> {
         using BackendType = typename BackendTraits::PipelineLayoutType;
     };
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 56807ae..a9c45c8 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -23,7 +23,6 @@
 #include "dawn_native/d3d12/ComputePipelineD3D12.h"
 #include "dawn_native/d3d12/DescriptorHeapAllocator.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
-#include "dawn_native/d3d12/InputStateD3D12.h"
 #include "dawn_native/d3d12/PipelineLayoutD3D12.h"
 #include "dawn_native/d3d12/RenderPipelineD3D12.h"
 #include "dawn_native/d3d12/ResourceAllocator.h"
@@ -576,11 +575,11 @@
 
     void CommandBuffer::FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
                                               VertexBuffersInfo* vertexBuffersInfo,
-                                              const InputState* inputState) {
+                                              const RenderPipeline* renderPipeline) {
         DAWN_ASSERT(vertexBuffersInfo != nullptr);
-        DAWN_ASSERT(inputState != nullptr);
+        DAWN_ASSERT(renderPipeline != nullptr);
 
-        auto inputsMask = inputState->GetInputsSetMask();
+        auto inputsMask = renderPipeline->GetInputsSetMask();
 
         uint32_t startSlot = vertexBuffersInfo->startSlot;
         uint32_t endSlot = vertexBuffersInfo->endSlot;
@@ -588,14 +587,14 @@
         // If the input state has changed, we need to update the StrideInBytes
         // for the D3D12 buffer views. We also need to extend the dirty range to
         // touch all these slots because the stride may have changed.
-        if (vertexBuffersInfo->lastInputState != inputState) {
-            vertexBuffersInfo->lastInputState = inputState;
+        if (vertexBuffersInfo->lastRenderPipeline != renderPipeline) {
+            vertexBuffersInfo->lastRenderPipeline = renderPipeline;
 
             for (uint32_t slot : IterateBitSet(inputsMask)) {
                 startSlot = std::min(startSlot, slot);
                 endSlot = std::max(endSlot, slot + 1);
                 vertexBuffersInfo->d3d12BufferViews[slot].StrideInBytes =
-                    inputState->GetInput(slot).stride;
+                    renderPipeline->GetInput(slot).stride;
             }
         }
 
@@ -728,7 +727,6 @@
 
         RenderPipeline* lastPipeline = nullptr;
         PipelineLayout* lastLayout = nullptr;
-        InputState* lastInputState = nullptr;
         VertexBuffersInfo vertexBuffersInfo = {};
 
         Command type;
@@ -742,7 +740,7 @@
                 case Command::Draw: {
                     DrawCmd* draw = mCommands.NextCommand<DrawCmd>();
 
-                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastInputState);
+                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
                     commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
                                                draw->firstVertex, draw->firstInstance);
                 } break;
@@ -750,7 +748,7 @@
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = mCommands.NextCommand<DrawIndexedCmd>();
 
-                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastInputState);
+                    FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
                     commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
                                                       draw->firstIndex, draw->baseVertex,
                                                       draw->firstInstance);
@@ -768,7 +766,6 @@
                     SetRenderPipelineCmd* cmd = mCommands.NextCommand<SetRenderPipelineCmd>();
                     RenderPipeline* pipeline = ToBackend(cmd->pipeline).Get();
                     PipelineLayout* layout = ToBackend(pipeline->GetLayout());
-                    InputState* inputState = ToBackend(pipeline->GetInputState());
 
                     commandList->SetGraphicsRootSignature(layout->GetRootSignature().Get());
                     commandList->SetPipelineState(pipeline->GetPipelineState().Get());
@@ -778,7 +775,6 @@
 
                     lastPipeline = pipeline;
                     lastLayout = layout;
-                    lastInputState = inputState;
                 } break;
 
                 case Command::SetStencilReference: {
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.h b/src/dawn_native/d3d12/CommandBufferD3D12.h
index e7aee32..5a95801 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.h
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.h
@@ -15,13 +15,15 @@
 #ifndef DAWNNATIVE_D3D12_COMMANDBUFFERD3D12_H_
 #define DAWNNATIVE_D3D12_COMMANDBUFFERD3D12_H_
 
+#include "common/Constants.h"
 #include "dawn_native/CommandAllocator.h"
 #include "dawn_native/CommandBuffer.h"
 
 #include "dawn_native/d3d12/Forward.h"
-#include "dawn_native/d3d12/InputStateD3D12.h"
 #include "dawn_native/d3d12/d3d12_platform.h"
 
+#include <array>
+
 namespace dawn_native {
     struct BeginRenderPassCmd;
 }  // namespace dawn_native
@@ -30,6 +32,7 @@
 
     class Device;
     class RenderPassDescriptorHeapTracker;
+    class RenderPipeline;
 
     struct BindGroupStateTracker;
 
@@ -38,7 +41,7 @@
         // If there are multiple calls to SetVertexBuffers, the start and end
         // represent the union of the dirty ranges (the union may have non-dirty
         // data in the middle of the range).
-        const InputState* lastInputState = nullptr;
+        const RenderPipeline* lastRenderPipeline = nullptr;
         uint32_t startSlot = kMaxVertexInputs;
         uint32_t endSlot = 0;
         std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexInputs> d3d12BufferViews = {};
@@ -54,7 +57,7 @@
       private:
         void FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
                                    VertexBuffersInfo* vertexBuffersInfo,
-                                   const InputState* inputState);
+                                   const RenderPipeline* lastRenderPipeline);
         void RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
                                BindGroupStateTracker* bindingTracker);
         void RecordRenderPass(ComPtr<ID3D12GraphicsCommandList> commandList,
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index 6366e81..33c4efd 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -26,7 +26,6 @@
 #include "dawn_native/d3d12/CommandBufferD3D12.h"
 #include "dawn_native/d3d12/ComputePipelineD3D12.h"
 #include "dawn_native/d3d12/DescriptorHeapAllocator.h"
-#include "dawn_native/d3d12/InputStateD3D12.h"
 #include "dawn_native/d3d12/PipelineLayoutD3D12.h"
 #include "dawn_native/d3d12/PlatformFunctions.h"
 #include "dawn_native/d3d12/QueueD3D12.h"
@@ -216,9 +215,6 @@
         const ComputePipelineDescriptor* descriptor) {
         return new ComputePipeline(this, descriptor);
     }
-    InputStateBase* Device::CreateInputState(InputStateBuilder* builder) {
-        return new InputState(builder);
-    }
     ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
         const PipelineLayoutDescriptor* descriptor) {
         return new PipelineLayout(this, descriptor);
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 98118ea..c76ec92 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -41,7 +41,6 @@
         ~Device();
 
         CommandBufferBase* CreateCommandBuffer(CommandEncoderBase* encoder) override;
-        InputStateBase* CreateInputState(InputStateBuilder* builder) override;
 
         Serial GetCompletedCommandSerial() const final override;
         Serial GetLastSubmittedCommandSerial() const final override;
diff --git a/src/dawn_native/d3d12/Forward.h b/src/dawn_native/d3d12/Forward.h
index e93b611..ade12e3 100644
--- a/src/dawn_native/d3d12/Forward.h
+++ b/src/dawn_native/d3d12/Forward.h
@@ -26,7 +26,6 @@
     class CommandBuffer;
     class ComputePipeline;
     class Device;
-    class InputState;
     class PipelineLayout;
     class Queue;
     class RenderPipeline;
@@ -45,7 +44,6 @@
         using CommandBufferType = CommandBuffer;
         using ComputePipelineType = ComputePipeline;
         using DeviceType = Device;
-        using InputStateType = InputState;
         using PipelineLayoutType = PipelineLayout;
         using QueueType = Queue;
         using RenderPipelineType = RenderPipeline;
diff --git a/src/dawn_native/d3d12/InputStateD3D12.cpp b/src/dawn_native/d3d12/InputStateD3D12.cpp
deleted file mode 100644
index fcfa6b1..0000000
--- a/src/dawn_native/d3d12/InputStateD3D12.cpp
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "dawn_native/d3d12/InputStateD3D12.h"
-
-#include "common/BitSetIterator.h"
-
-namespace dawn_native { namespace d3d12 {
-
-    static DXGI_FORMAT VertexFormatType(dawn::VertexFormat format) {
-        switch (format) {
-            case dawn::VertexFormat::UChar2:
-                return DXGI_FORMAT_R8G8_UINT;
-            case dawn::VertexFormat::UChar4:
-                return DXGI_FORMAT_R8G8B8A8_UINT;
-            case dawn::VertexFormat::Char2:
-                return DXGI_FORMAT_R8G8_SINT;
-            case dawn::VertexFormat::Char4:
-                return DXGI_FORMAT_R8G8B8A8_SINT;
-            case dawn::VertexFormat::UChar2Norm:
-                return DXGI_FORMAT_R8G8_UNORM;
-            case dawn::VertexFormat::UChar4Norm:
-                return DXGI_FORMAT_R8G8B8A8_UNORM;
-            case dawn::VertexFormat::Char2Norm:
-                return DXGI_FORMAT_R8G8_SNORM;
-            case dawn::VertexFormat::Char4Norm:
-                return DXGI_FORMAT_R8G8B8A8_SNORM;
-            case dawn::VertexFormat::UShort2:
-                return DXGI_FORMAT_R16G16_UINT;
-            case dawn::VertexFormat::UShort4:
-                return DXGI_FORMAT_R16G16B16A16_UINT;
-            case dawn::VertexFormat::Short2:
-                return DXGI_FORMAT_R16G16_SINT;
-            case dawn::VertexFormat::Short4:
-                return DXGI_FORMAT_R16G16B16A16_SINT;
-            case dawn::VertexFormat::UShort2Norm:
-                return DXGI_FORMAT_R16G16_UNORM;
-            case dawn::VertexFormat::UShort4Norm:
-                return DXGI_FORMAT_R16G16B16A16_UNORM;
-            case dawn::VertexFormat::Short2Norm:
-                return DXGI_FORMAT_R16G16_SNORM;
-            case dawn::VertexFormat::Short4Norm:
-                return DXGI_FORMAT_R16G16B16A16_SNORM;
-            case dawn::VertexFormat::Half2:
-                return DXGI_FORMAT_R16G16_FLOAT;
-            case dawn::VertexFormat::Half4:
-                return DXGI_FORMAT_R16G16B16A16_FLOAT;
-            case dawn::VertexFormat::Float:
-                return DXGI_FORMAT_R32_FLOAT;
-            case dawn::VertexFormat::Float2:
-                return DXGI_FORMAT_R32G32_FLOAT;
-            case dawn::VertexFormat::Float3:
-                return DXGI_FORMAT_R32G32B32_FLOAT;
-            case dawn::VertexFormat::Float4:
-                return DXGI_FORMAT_R32G32B32A32_FLOAT;
-            case dawn::VertexFormat::UInt:
-                return DXGI_FORMAT_R32_UINT;
-            case dawn::VertexFormat::UInt2:
-                return DXGI_FORMAT_R32G32_UINT;
-            case dawn::VertexFormat::UInt3:
-                return DXGI_FORMAT_R32G32B32_UINT;
-            case dawn::VertexFormat::UInt4:
-                return DXGI_FORMAT_R32G32B32A32_UINT;
-            case dawn::VertexFormat::Int:
-                return DXGI_FORMAT_R32_SINT;
-            case dawn::VertexFormat::Int2:
-                return DXGI_FORMAT_R32G32_SINT;
-            case dawn::VertexFormat::Int3:
-                return DXGI_FORMAT_R32G32B32_SINT;
-            case dawn::VertexFormat::Int4:
-                return DXGI_FORMAT_R32G32B32A32_SINT;
-            default:
-                UNREACHABLE();
-        }
-    }
-
-    static D3D12_INPUT_CLASSIFICATION InputStepModeFunction(dawn::InputStepMode mode) {
-        switch (mode) {
-            case dawn::InputStepMode::Vertex:
-                return D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
-            case dawn::InputStepMode::Instance:
-                return D3D12_INPUT_CLASSIFICATION_PER_INSTANCE_DATA;
-            default:
-                UNREACHABLE();
-        }
-    }
-
-    InputState::InputState(InputStateBuilder* builder) : InputStateBase(builder) {
-        const auto& attributesSetMask = GetAttributesSetMask();
-
-        unsigned int count = 0;
-        for (auto i : IterateBitSet(attributesSetMask)) {
-            if (!attributesSetMask[i]) {
-                continue;
-            }
-
-            D3D12_INPUT_ELEMENT_DESC& inputElementDescriptor = mInputElementDescriptors[count++];
-
-            const VertexAttributeDescriptor& attribute = GetAttribute(i);
-
-            // If the HLSL semantic is TEXCOORDN the SemanticName should be "TEXCOORD" and the
-            // SemanticIndex N
-            inputElementDescriptor.SemanticName = "TEXCOORD";
-            inputElementDescriptor.SemanticIndex = static_cast<uint32_t>(i);
-            inputElementDescriptor.Format = VertexFormatType(attribute.format);
-            inputElementDescriptor.InputSlot = attribute.inputSlot;
-
-            const VertexInputDescriptor& input = GetInput(attribute.inputSlot);
-
-            inputElementDescriptor.AlignedByteOffset = attribute.offset;
-            inputElementDescriptor.InputSlotClass = InputStepModeFunction(input.stepMode);
-            if (inputElementDescriptor.InputSlotClass ==
-                D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA) {
-                inputElementDescriptor.InstanceDataStepRate = 0;
-            } else {
-                inputElementDescriptor.InstanceDataStepRate = 1;
-            }
-        }
-
-        mInputLayoutDescriptor.pInputElementDescs = mInputElementDescriptors;
-        mInputLayoutDescriptor.NumElements = count;
-    }
-
-    const D3D12_INPUT_LAYOUT_DESC& InputState::GetD3D12InputLayoutDescriptor() const {
-        return mInputLayoutDescriptor;
-    }
-
-}}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/InputStateD3D12.h b/src/dawn_native/d3d12/InputStateD3D12.h
deleted file mode 100644
index f42b747..0000000
--- a/src/dawn_native/d3d12/InputStateD3D12.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef DAWNNATIVE_D3D12_INPUTSTATED3D12_H_
-#define DAWNNATIVE_D3D12_INPUTSTATED3D12_H_
-
-#include "dawn_native/InputState.h"
-
-#include "dawn_native/d3d12/d3d12_platform.h"
-
-namespace dawn_native { namespace d3d12 {
-
-    class Device;
-
-    class InputState : public InputStateBase {
-      public:
-        InputState(InputStateBuilder* builder);
-
-        const D3D12_INPUT_LAYOUT_DESC& GetD3D12InputLayoutDescriptor() const;
-
-      private:
-        D3D12_INPUT_LAYOUT_DESC mInputLayoutDescriptor;
-        D3D12_INPUT_ELEMENT_DESC mInputElementDescriptors[kMaxVertexAttributes];
-    };
-
-}}  // namespace dawn_native::d3d12
-
-#endif  // DAWNNATIVE_D3D12_INPUTSTATED3D12_H_
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index 4e4127d..046b7ae 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -16,7 +16,6 @@
 
 #include "common/Assert.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
-#include "dawn_native/d3d12/InputStateD3D12.h"
 #include "dawn_native/d3d12/PipelineLayoutD3D12.h"
 #include "dawn_native/d3d12/PlatformFunctions.h"
 #include "dawn_native/d3d12/ShaderModuleD3D12.h"
@@ -28,6 +27,84 @@
 namespace dawn_native { namespace d3d12 {
 
     namespace {
+        DXGI_FORMAT VertexFormatType(dawn::VertexFormat format) {
+            switch (format) {
+                case dawn::VertexFormat::UChar2:
+                    return DXGI_FORMAT_R8G8_UINT;
+                case dawn::VertexFormat::UChar4:
+                    return DXGI_FORMAT_R8G8B8A8_UINT;
+                case dawn::VertexFormat::Char2:
+                    return DXGI_FORMAT_R8G8_SINT;
+                case dawn::VertexFormat::Char4:
+                    return DXGI_FORMAT_R8G8B8A8_SINT;
+                case dawn::VertexFormat::UChar2Norm:
+                    return DXGI_FORMAT_R8G8_UNORM;
+                case dawn::VertexFormat::UChar4Norm:
+                    return DXGI_FORMAT_R8G8B8A8_UNORM;
+                case dawn::VertexFormat::Char2Norm:
+                    return DXGI_FORMAT_R8G8_SNORM;
+                case dawn::VertexFormat::Char4Norm:
+                    return DXGI_FORMAT_R8G8B8A8_SNORM;
+                case dawn::VertexFormat::UShort2:
+                    return DXGI_FORMAT_R16G16_UINT;
+                case dawn::VertexFormat::UShort4:
+                    return DXGI_FORMAT_R16G16B16A16_UINT;
+                case dawn::VertexFormat::Short2:
+                    return DXGI_FORMAT_R16G16_SINT;
+                case dawn::VertexFormat::Short4:
+                    return DXGI_FORMAT_R16G16B16A16_SINT;
+                case dawn::VertexFormat::UShort2Norm:
+                    return DXGI_FORMAT_R16G16_UNORM;
+                case dawn::VertexFormat::UShort4Norm:
+                    return DXGI_FORMAT_R16G16B16A16_UNORM;
+                case dawn::VertexFormat::Short2Norm:
+                    return DXGI_FORMAT_R16G16_SNORM;
+                case dawn::VertexFormat::Short4Norm:
+                    return DXGI_FORMAT_R16G16B16A16_SNORM;
+                case dawn::VertexFormat::Half2:
+                    return DXGI_FORMAT_R16G16_FLOAT;
+                case dawn::VertexFormat::Half4:
+                    return DXGI_FORMAT_R16G16B16A16_FLOAT;
+                case dawn::VertexFormat::Float:
+                    return DXGI_FORMAT_R32_FLOAT;
+                case dawn::VertexFormat::Float2:
+                    return DXGI_FORMAT_R32G32_FLOAT;
+                case dawn::VertexFormat::Float3:
+                    return DXGI_FORMAT_R32G32B32_FLOAT;
+                case dawn::VertexFormat::Float4:
+                    return DXGI_FORMAT_R32G32B32A32_FLOAT;
+                case dawn::VertexFormat::UInt:
+                    return DXGI_FORMAT_R32_UINT;
+                case dawn::VertexFormat::UInt2:
+                    return DXGI_FORMAT_R32G32_UINT;
+                case dawn::VertexFormat::UInt3:
+                    return DXGI_FORMAT_R32G32B32_UINT;
+                case dawn::VertexFormat::UInt4:
+                    return DXGI_FORMAT_R32G32B32A32_UINT;
+                case dawn::VertexFormat::Int:
+                    return DXGI_FORMAT_R32_SINT;
+                case dawn::VertexFormat::Int2:
+                    return DXGI_FORMAT_R32G32_SINT;
+                case dawn::VertexFormat::Int3:
+                    return DXGI_FORMAT_R32G32B32_SINT;
+                case dawn::VertexFormat::Int4:
+                    return DXGI_FORMAT_R32G32B32A32_SINT;
+                default:
+                    UNREACHABLE();
+            }
+        }
+
+        D3D12_INPUT_CLASSIFICATION InputStepModeFunction(dawn::InputStepMode mode) {
+            switch (mode) {
+                case dawn::InputStepMode::Vertex:
+                    return D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
+                case dawn::InputStepMode::Instance:
+                    return D3D12_INPUT_CLASSIFICATION_PER_INSTANCE_DATA;
+                default:
+                    UNREACHABLE();
+            }
+        }
+
         D3D12_PRIMITIVE_TOPOLOGY D3D12PrimitiveTopology(dawn::PrimitiveTopology primitiveTopology) {
             switch (primitiveTopology) {
                 case dawn::PrimitiveTopology::PointList:
@@ -261,9 +338,9 @@
         descriptorD3D12.pRootSignature = layout->GetRootSignature().Get();
 
         // D3D12 logs warnings if any empty input state is used
-        InputState* inputState = ToBackend(GetInputState());
-        if (inputState->GetAttributesSetMask().any()) {
-            descriptorD3D12.InputLayout = inputState->GetD3D12InputLayoutDescriptor();
+        std::array<D3D12_INPUT_ELEMENT_DESC, kMaxVertexAttributes> inputElementDescriptors;
+        if (GetAttributesSetMask().any()) {
+            descriptorD3D12.InputLayout = ComputeInputLayout(&inputElementDescriptors);
         }
 
         descriptorD3D12.RasterizerState.FillMode = D3D12_FILL_MODE_SOLID;
@@ -317,4 +394,42 @@
         return mPipelineState;
     }
 
+    D3D12_INPUT_LAYOUT_DESC RenderPipeline::ComputeInputLayout(
+        std::array<D3D12_INPUT_ELEMENT_DESC, kMaxVertexAttributes>* inputElementDescriptors) {
+        const auto& attributesSetMask = GetAttributesSetMask();
+        unsigned int count = 0;
+        for (auto i : IterateBitSet(attributesSetMask)) {
+            if (!attributesSetMask[i]) {
+                continue;
+            }
+
+            D3D12_INPUT_ELEMENT_DESC& inputElementDescriptor = (*inputElementDescriptors)[count++];
+
+            const VertexAttributeDescriptor& attribute = GetAttribute(i);
+
+            // If the HLSL semantic is TEXCOORDN the SemanticName should be "TEXCOORD" and the
+            // SemanticIndex N
+            inputElementDescriptor.SemanticName = "TEXCOORD";
+            inputElementDescriptor.SemanticIndex = static_cast<uint32_t>(i);
+            inputElementDescriptor.Format = VertexFormatType(attribute.format);
+            inputElementDescriptor.InputSlot = attribute.inputSlot;
+
+            const VertexInputDescriptor& input = GetInput(attribute.inputSlot);
+
+            inputElementDescriptor.AlignedByteOffset = attribute.offset;
+            inputElementDescriptor.InputSlotClass = InputStepModeFunction(input.stepMode);
+            if (inputElementDescriptor.InputSlotClass ==
+                D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA) {
+                inputElementDescriptor.InstanceDataStepRate = 0;
+            } else {
+                inputElementDescriptor.InstanceDataStepRate = 1;
+            }
+        }
+
+        D3D12_INPUT_LAYOUT_DESC inputLayoutDescriptor;
+        inputLayoutDescriptor.pInputElementDescs = &(*inputElementDescriptors)[0];
+        inputLayoutDescriptor.NumElements = count;
+        return inputLayoutDescriptor;
+    }
+
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.h b/src/dawn_native/d3d12/RenderPipelineD3D12.h
index 20502bb..b9c9029 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.h
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.h
@@ -32,6 +32,9 @@
         ComPtr<ID3D12PipelineState> GetPipelineState();
 
       private:
+        D3D12_INPUT_LAYOUT_DESC ComputeInputLayout(
+            std::array<D3D12_INPUT_ELEMENT_DESC, kMaxVertexAttributes>* inputElementDescriptors);
+
         D3D12_PRIMITIVE_TOPOLOGY mD3d12PrimitiveTopology;
         ComPtr<ID3D12PipelineState> mPipelineState;
     };
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 3506fa4..5c9058f 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -20,7 +20,6 @@
 #include "dawn_native/metal/BufferMTL.h"
 #include "dawn_native/metal/ComputePipelineMTL.h"
 #include "dawn_native/metal/DeviceMTL.h"
-#include "dawn_native/metal/InputStateMTL.h"
 #include "dawn_native/metal/PipelineLayoutMTL.h"
 #include "dawn_native/metal/RenderPipelineMTL.h"
 #include "dawn_native/metal/SamplerMTL.h"
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index 072e605..59fc9aa 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -38,7 +38,6 @@
         ~Device();
 
         CommandBufferBase* CreateCommandBuffer(CommandEncoderBase* encoder) override;
-        InputStateBase* CreateInputState(InputStateBuilder* builder) override;
 
         Serial GetCompletedCommandSerial() const final override;
         Serial GetLastSubmittedCommandSerial() const final override;
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 3b86392..322a9ba 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -21,7 +21,6 @@
 #include "dawn_native/metal/BufferMTL.h"
 #include "dawn_native/metal/CommandBufferMTL.h"
 #include "dawn_native/metal/ComputePipelineMTL.h"
-#include "dawn_native/metal/InputStateMTL.h"
 #include "dawn_native/metal/PipelineLayoutMTL.h"
 #include "dawn_native/metal/QueueMTL.h"
 #include "dawn_native/metal/RenderPipelineMTL.h"
@@ -86,9 +85,6 @@
         const ComputePipelineDescriptor* descriptor) {
         return new ComputePipeline(this, descriptor);
     }
-    InputStateBase* Device::CreateInputState(InputStateBuilder* builder) {
-        return new InputState(builder);
-    }
     ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
         const PipelineLayoutDescriptor* descriptor) {
         return new PipelineLayout(this, descriptor);
diff --git a/src/dawn_native/metal/Forward.h b/src/dawn_native/metal/Forward.h
index f2a2e3c..4e889cd 100644
--- a/src/dawn_native/metal/Forward.h
+++ b/src/dawn_native/metal/Forward.h
@@ -32,7 +32,6 @@
     class ComputePipeline;
     class Device;
     class Framebuffer;
-    class InputState;
     class PipelineLayout;
     class Queue;
     class RenderPipeline;
@@ -51,7 +50,6 @@
         using CommandBufferType = CommandBuffer;
         using ComputePipelineType = ComputePipeline;
         using DeviceType = Device;
-        using InputStateType = InputState;
         using PipelineLayoutType = PipelineLayout;
         using QueueType = Queue;
         using RenderPipelineType = RenderPipeline;
diff --git a/src/dawn_native/metal/InputStateMTL.h b/src/dawn_native/metal/InputStateMTL.h
deleted file mode 100644
index 496e6ea..0000000
--- a/src/dawn_native/metal/InputStateMTL.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef DAWNNATIVE_METAL_INPUTSTATEMTL_H_
-#define DAWNNATIVE_METAL_INPUTSTATEMTL_H_
-
-#include "dawn_native/InputState.h"
-
-#import <Metal/Metal.h>
-
-namespace dawn_native { namespace metal {
-
-    class InputState : public InputStateBase {
-      public:
-        InputState(InputStateBuilder* builder);
-        ~InputState();
-
-        MTLVertexDescriptor* GetMTLVertexDescriptor();
-
-      private:
-        MTLVertexDescriptor* mMtlVertexDescriptor = nil;
-    };
-
-}}  // namespace dawn_native::metal
-
-#endif  // DAWNNATIVE_METAL_COMMANDINPUTSTATEMTL_H_
diff --git a/src/dawn_native/metal/InputStateMTL.mm b/src/dawn_native/metal/InputStateMTL.mm
deleted file mode 100644
index ead1dac..0000000
--- a/src/dawn_native/metal/InputStateMTL.mm
+++ /dev/null
@@ -1,160 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "dawn_native/metal/InputStateMTL.h"
-
-#include "common/BitSetIterator.h"
-
-namespace dawn_native { namespace metal {
-
-    namespace {
-        MTLVertexFormat VertexFormatType(dawn::VertexFormat format) {
-            switch (format) {
-                case dawn::VertexFormat::UChar2:
-                    return MTLVertexFormatUChar2;
-                case dawn::VertexFormat::UChar4:
-                    return MTLVertexFormatUChar4;
-                case dawn::VertexFormat::Char2:
-                    return MTLVertexFormatChar2;
-                case dawn::VertexFormat::Char4:
-                    return MTLVertexFormatChar4;
-                case dawn::VertexFormat::UChar2Norm:
-                    return MTLVertexFormatUChar2Normalized;
-                case dawn::VertexFormat::UChar4Norm:
-                    return MTLVertexFormatUChar4Normalized;
-                case dawn::VertexFormat::Char2Norm:
-                    return MTLVertexFormatChar2Normalized;
-                case dawn::VertexFormat::Char4Norm:
-                    return MTLVertexFormatChar4Normalized;
-                case dawn::VertexFormat::UShort2:
-                    return MTLVertexFormatUShort2;
-                case dawn::VertexFormat::UShort4:
-                    return MTLVertexFormatUShort4;
-                case dawn::VertexFormat::Short2:
-                    return MTLVertexFormatShort2;
-                case dawn::VertexFormat::Short4:
-                    return MTLVertexFormatShort4;
-                case dawn::VertexFormat::UShort2Norm:
-                    return MTLVertexFormatUShort2Normalized;
-                case dawn::VertexFormat::UShort4Norm:
-                    return MTLVertexFormatUShort4Normalized;
-                case dawn::VertexFormat::Short2Norm:
-                    return MTLVertexFormatShort2Normalized;
-                case dawn::VertexFormat::Short4Norm:
-                    return MTLVertexFormatShort4Normalized;
-                case dawn::VertexFormat::Half2:
-                    return MTLVertexFormatHalf2;
-                case dawn::VertexFormat::Half4:
-                    return MTLVertexFormatHalf4;
-                case dawn::VertexFormat::Float:
-                    return MTLVertexFormatFloat;
-                case dawn::VertexFormat::Float2:
-                    return MTLVertexFormatFloat2;
-                case dawn::VertexFormat::Float3:
-                    return MTLVertexFormatFloat3;
-                case dawn::VertexFormat::Float4:
-                    return MTLVertexFormatFloat4;
-                case dawn::VertexFormat::UInt:
-                    return MTLVertexFormatUInt;
-                case dawn::VertexFormat::UInt2:
-                    return MTLVertexFormatUInt2;
-                case dawn::VertexFormat::UInt3:
-                    return MTLVertexFormatUInt3;
-                case dawn::VertexFormat::UInt4:
-                    return MTLVertexFormatUInt4;
-                case dawn::VertexFormat::Int:
-                    return MTLVertexFormatInt;
-                case dawn::VertexFormat::Int2:
-                    return MTLVertexFormatInt2;
-                case dawn::VertexFormat::Int3:
-                    return MTLVertexFormatInt3;
-                case dawn::VertexFormat::Int4:
-                    return MTLVertexFormatInt4;
-            }
-        }
-
-        MTLVertexStepFunction InputStepModeFunction(dawn::InputStepMode mode) {
-            switch (mode) {
-                case dawn::InputStepMode::Vertex:
-                    return MTLVertexStepFunctionPerVertex;
-                case dawn::InputStepMode::Instance:
-                    return MTLVertexStepFunctionPerInstance;
-            }
-        }
-    }
-
-    InputState::InputState(InputStateBuilder* builder) : InputStateBase(builder) {
-        mMtlVertexDescriptor = [MTLVertexDescriptor new];
-
-        const auto& attributesSetMask = GetAttributesSetMask();
-        for (uint32_t i = 0; i < attributesSetMask.size(); ++i) {
-            if (!attributesSetMask[i]) {
-                continue;
-            }
-            const VertexAttributeDescriptor& info = GetAttribute(i);
-
-            auto attribDesc = [MTLVertexAttributeDescriptor new];
-            attribDesc.format = VertexFormatType(info.format);
-            attribDesc.offset = info.offset;
-            attribDesc.bufferIndex = kMaxBindingsPerGroup + info.inputSlot;
-            mMtlVertexDescriptor.attributes[i] = attribDesc;
-            [attribDesc release];
-        }
-
-        for (uint32_t i : IterateBitSet(GetInputsSetMask())) {
-            const VertexInputDescriptor& info = GetInput(i);
-
-            auto layoutDesc = [MTLVertexBufferLayoutDescriptor new];
-            if (info.stride == 0) {
-                // For MTLVertexStepFunctionConstant, the stepRate must be 0,
-                // but the stride must NOT be 0, so we made up it with
-                // max(attrib.offset + sizeof(attrib) for each attrib)
-                uint32_t max_stride = 0;
-                for (uint32_t attribIndex : IterateBitSet(attributesSetMask)) {
-                    const VertexAttributeDescriptor& attrib = GetAttribute(attribIndex);
-                    // Only use the attributes that use the current input
-                    if (attrib.inputSlot != info.inputSlot) {
-                        continue;
-                    }
-                    max_stride = std::max(
-                        max_stride,
-                        static_cast<uint32_t>(VertexFormatSize(attrib.format)) + attrib.offset);
-                }
-
-                layoutDesc.stepFunction = MTLVertexStepFunctionConstant;
-                layoutDesc.stepRate = 0;
-                // Metal requires the stride must be a multiple of 4 bytes, align it with next
-                // multiple of 4 if it's not.
-                layoutDesc.stride = Align(max_stride, 4);
-            } else {
-                layoutDesc.stepFunction = InputStepModeFunction(info.stepMode);
-                layoutDesc.stepRate = 1;
-                layoutDesc.stride = info.stride;
-            }
-            // TODO(cwallez@chromium.org): make the offset depend on the pipeline layout
-            mMtlVertexDescriptor.layouts[kMaxBindingsPerGroup + i] = layoutDesc;
-            [layoutDesc release];
-        }
-    }
-
-    InputState::~InputState() {
-        [mMtlVertexDescriptor release];
-        mMtlVertexDescriptor = nil;
-    }
-
-    MTLVertexDescriptor* InputState::GetMTLVertexDescriptor() {
-        return mMtlVertexDescriptor;
-    }
-
-}}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/RenderPipelineMTL.h b/src/dawn_native/metal/RenderPipelineMTL.h
index 6da06e2..1edf434 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.h
+++ b/src/dawn_native/metal/RenderPipelineMTL.h
@@ -36,6 +36,8 @@
         id<MTLDepthStencilState> GetMTLDepthStencilState();
 
       private:
+        MTLVertexDescriptor* MakeVertexDesc();
+
         MTLIndexType mMtlIndexType;
         MTLPrimitiveType mMtlPrimitiveTopology;
         id<MTLRenderPipelineState> mMtlRenderPipelineState = nil;
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 45437a8..5020a43 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -15,7 +15,6 @@
 #include "dawn_native/metal/RenderPipelineMTL.h"
 
 #include "dawn_native/metal/DeviceMTL.h"
-#include "dawn_native/metal/InputStateMTL.h"
 #include "dawn_native/metal/PipelineLayoutMTL.h"
 #include "dawn_native/metal/ShaderModuleMTL.h"
 #include "dawn_native/metal/TextureMTL.h"
@@ -24,6 +23,80 @@
 namespace dawn_native { namespace metal {
 
     namespace {
+        MTLVertexFormat VertexFormatType(dawn::VertexFormat format) {
+            switch (format) {
+                case dawn::VertexFormat::UChar2:
+                    return MTLVertexFormatUChar2;
+                case dawn::VertexFormat::UChar4:
+                    return MTLVertexFormatUChar4;
+                case dawn::VertexFormat::Char2:
+                    return MTLVertexFormatChar2;
+                case dawn::VertexFormat::Char4:
+                    return MTLVertexFormatChar4;
+                case dawn::VertexFormat::UChar2Norm:
+                    return MTLVertexFormatUChar2Normalized;
+                case dawn::VertexFormat::UChar4Norm:
+                    return MTLVertexFormatUChar4Normalized;
+                case dawn::VertexFormat::Char2Norm:
+                    return MTLVertexFormatChar2Normalized;
+                case dawn::VertexFormat::Char4Norm:
+                    return MTLVertexFormatChar4Normalized;
+                case dawn::VertexFormat::UShort2:
+                    return MTLVertexFormatUShort2;
+                case dawn::VertexFormat::UShort4:
+                    return MTLVertexFormatUShort4;
+                case dawn::VertexFormat::Short2:
+                    return MTLVertexFormatShort2;
+                case dawn::VertexFormat::Short4:
+                    return MTLVertexFormatShort4;
+                case dawn::VertexFormat::UShort2Norm:
+                    return MTLVertexFormatUShort2Normalized;
+                case dawn::VertexFormat::UShort4Norm:
+                    return MTLVertexFormatUShort4Normalized;
+                case dawn::VertexFormat::Short2Norm:
+                    return MTLVertexFormatShort2Normalized;
+                case dawn::VertexFormat::Short4Norm:
+                    return MTLVertexFormatShort4Normalized;
+                case dawn::VertexFormat::Half2:
+                    return MTLVertexFormatHalf2;
+                case dawn::VertexFormat::Half4:
+                    return MTLVertexFormatHalf4;
+                case dawn::VertexFormat::Float:
+                    return MTLVertexFormatFloat;
+                case dawn::VertexFormat::Float2:
+                    return MTLVertexFormatFloat2;
+                case dawn::VertexFormat::Float3:
+                    return MTLVertexFormatFloat3;
+                case dawn::VertexFormat::Float4:
+                    return MTLVertexFormatFloat4;
+                case dawn::VertexFormat::UInt:
+                    return MTLVertexFormatUInt;
+                case dawn::VertexFormat::UInt2:
+                    return MTLVertexFormatUInt2;
+                case dawn::VertexFormat::UInt3:
+                    return MTLVertexFormatUInt3;
+                case dawn::VertexFormat::UInt4:
+                    return MTLVertexFormatUInt4;
+                case dawn::VertexFormat::Int:
+                    return MTLVertexFormatInt;
+                case dawn::VertexFormat::Int2:
+                    return MTLVertexFormatInt2;
+                case dawn::VertexFormat::Int3:
+                    return MTLVertexFormatInt3;
+                case dawn::VertexFormat::Int4:
+                    return MTLVertexFormatInt4;
+            }
+        }
+
+        MTLVertexStepFunction InputStepModeFunction(dawn::InputStepMode mode) {
+            switch (mode) {
+                case dawn::InputStepMode::Vertex:
+                    return MTLVertexStepFunctionPerVertex;
+                case dawn::InputStepMode::Instance:
+                    return MTLVertexStepFunctionPerInstance;
+            }
+        }
+
         MTLPrimitiveType MTLPrimitiveTopology(dawn::PrimitiveTopology primitiveTopology) {
             switch (primitiveTopology) {
                 case dawn::PrimitiveTopology::PointList:
@@ -243,8 +316,7 @@
 
         descriptorMTL.inputPrimitiveTopology = MTLInputPrimitiveTopology(GetPrimitiveTopology());
 
-        InputState* inputState = ToBackend(GetInputState());
-        descriptorMTL.vertexDescriptor = inputState->GetMTLVertexDescriptor();
+        descriptorMTL.vertexDescriptor = MakeVertexDesc();
 
         // TODO(kainino@chromium.org): push constants, textures, samplers
 
@@ -252,6 +324,7 @@
             NSError* error = nil;
             mMtlRenderPipelineState = [mtlDevice newRenderPipelineStateWithDescriptor:descriptorMTL
                                                                                 error:&error];
+            [descriptorMTL.vertexDescriptor release];
             [descriptorMTL release];
             if (error != nil) {
                 NSLog(@" error => %@", error);
@@ -289,4 +362,46 @@
         return mMtlDepthStencilState;
     }
 
+    MTLVertexDescriptor* RenderPipeline::MakeVertexDesc() {
+        MTLVertexDescriptor* mtlVertexDescriptor = [MTLVertexDescriptor new];
+
+        const auto& attributesSetMask = GetAttributesSetMask();
+        for (uint32_t i = 0; i < attributesSetMask.size(); ++i) {
+            if (!attributesSetMask[i]) {
+                continue;
+            }
+            const VertexAttributeDescriptor& info = GetAttribute(i);
+
+            auto attribDesc = [MTLVertexAttributeDescriptor new];
+            attribDesc.format = VertexFormatType(info.format);
+            attribDesc.offset = info.offset;
+            attribDesc.bufferIndex = kMaxBindingsPerGroup + info.inputSlot;
+            mtlVertexDescriptor.attributes[i] = attribDesc;
+            [attribDesc release];
+        }
+
+        for (uint32_t i : IterateBitSet(GetInputsSetMask())) {
+            const VertexInputDescriptor& info = GetInput(i);
+
+            auto layoutDesc = [MTLVertexBufferLayoutDescriptor new];
+            if (info.stride == 0) {
+                // For MTLVertexStepFunctionConstant, the stepRate must be 0,
+                // but the stride must NOT be 0, so I made up a value (256).
+                // TODO(cwallez@chromium.org): the made up value will need to be at least
+                //    max(attrib.offset + sizeof(attrib) for each attrib)
+                layoutDesc.stepFunction = MTLVertexStepFunctionConstant;
+                layoutDesc.stepRate = 0;
+                layoutDesc.stride = 256;
+            } else {
+                layoutDesc.stepFunction = InputStepModeFunction(info.stepMode);
+                layoutDesc.stepRate = 1;
+                layoutDesc.stride = info.stride;
+            }
+            // TODO(cwallez@chromium.org): make the offset depend on the pipeline layout
+            mtlVertexDescriptor.layouts[kMaxBindingsPerGroup + i] = layoutDesc;
+            [layoutDesc release];
+        }
+        return mtlVertexDescriptor;
+    }
+
 }}  // namespace dawn_native::metal
diff --git a/src/dawn_native/null/DeviceNull.cpp b/src/dawn_native/null/DeviceNull.cpp
index 693aded..4a88b0b 100644
--- a/src/dawn_native/null/DeviceNull.cpp
+++ b/src/dawn_native/null/DeviceNull.cpp
@@ -82,9 +82,6 @@
         const ComputePipelineDescriptor* descriptor) {
         return new ComputePipeline(this, descriptor);
     }
-    InputStateBase* Device::CreateInputState(InputStateBuilder* builder) {
-        return new InputState(builder);
-    }
     ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
         const PipelineLayoutDescriptor* descriptor) {
         return new PipelineLayout(this, descriptor);
diff --git a/src/dawn_native/null/DeviceNull.h b/src/dawn_native/null/DeviceNull.h
index 973fe13..8880658 100644
--- a/src/dawn_native/null/DeviceNull.h
+++ b/src/dawn_native/null/DeviceNull.h
@@ -22,7 +22,6 @@
 #include "dawn_native/CommandEncoder.h"
 #include "dawn_native/ComputePipeline.h"
 #include "dawn_native/Device.h"
-#include "dawn_native/InputState.h"
 #include "dawn_native/PipelineLayout.h"
 #include "dawn_native/Queue.h"
 #include "dawn_native/RenderPipeline.h"
@@ -44,7 +43,6 @@
     class CommandBuffer;
     using ComputePipeline = ComputePipelineBase;
     class Device;
-    using InputState = InputStateBase;
     using PipelineLayout = PipelineLayoutBase;
     class Queue;
     using RenderPipeline = RenderPipelineBase;
@@ -62,7 +60,6 @@
         using CommandBufferType = CommandBuffer;
         using ComputePipelineType = ComputePipeline;
         using DeviceType = Device;
-        using InputStateType = InputState;
         using PipelineLayoutType = PipelineLayout;
         using QueueType = Queue;
         using RenderPipelineType = RenderPipeline;
@@ -89,7 +86,6 @@
         ~Device();
 
         CommandBufferBase* CreateCommandBuffer(CommandEncoderBase* encoder) override;
-        InputStateBase* CreateInputState(InputStateBuilder* builder) override;
 
         Serial GetCompletedCommandSerial() const final override;
         Serial GetLastSubmittedCommandSerial() const final override;
diff --git a/src/dawn_native/opengl/CommandBufferGL.cpp b/src/dawn_native/opengl/CommandBufferGL.cpp
index 9aaf17f..cbce20f 100644
--- a/src/dawn_native/opengl/CommandBufferGL.cpp
+++ b/src/dawn_native/opengl/CommandBufferGL.cpp
@@ -21,7 +21,6 @@
 #include "dawn_native/opengl/ComputePipelineGL.h"
 #include "dawn_native/opengl/DeviceGL.h"
 #include "dawn_native/opengl/Forward.h"
-#include "dawn_native/opengl/InputStateGL.h"
 #include "dawn_native/opengl/PersistentPipelineStateGL.h"
 #include "dawn_native/opengl/PipelineLayoutGL.h"
 #include "dawn_native/opengl/RenderPipelineGL.h"
@@ -212,15 +211,14 @@
             }
 
             void OnSetPipeline(RenderPipelineBase* pipeline) {
-                InputStateBase* inputState = pipeline->GetInputState();
-                if (mLastInputState == inputState) {
+                if (mLastPipeline == pipeline) {
                     return;
                 }
 
                 mIndexBufferDirty = true;
-                mDirtyVertexBuffers |= inputState->GetInputsSetMask();
+                mDirtyVertexBuffers |= pipeline->GetInputsSetMask();
 
-                mLastInputState = ToBackend(inputState);
+                mLastPipeline = pipeline;
             }
 
             void Apply() {
@@ -230,15 +228,15 @@
                 }
 
                 for (uint32_t slot :
-                     IterateBitSet(mDirtyVertexBuffers & mLastInputState->GetInputsSetMask())) {
+                     IterateBitSet(mDirtyVertexBuffers & mLastPipeline->GetInputsSetMask())) {
                     for (uint32_t location :
-                         IterateBitSet(mLastInputState->GetAttributesUsingInput(slot))) {
-                        auto attribute = mLastInputState->GetAttribute(location);
+                         IterateBitSet(mLastPipeline->GetAttributesUsingInput(slot))) {
+                        auto attribute = mLastPipeline->GetAttribute(location);
 
                         GLuint buffer = mVertexBuffers[slot]->GetHandle();
                         uint32_t offset = mVertexBufferOffsets[slot];
 
-                        auto input = mLastInputState->GetInput(slot);
+                        auto input = mLastPipeline->GetInput(slot);
                         auto components = VertexFormatNumComponents(attribute.format);
                         auto formatType = VertexFormatType(attribute.format);
 
@@ -262,7 +260,7 @@
             std::array<Buffer*, kMaxVertexInputs> mVertexBuffers;
             std::array<uint32_t, kMaxVertexInputs> mVertexBufferOffsets;
 
-            InputState* mLastInputState = nullptr;
+            RenderPipelineBase* mLastPipeline = nullptr;
         };
 
         // Handles SetBindGroup commands with the specifics of translating to OpenGL texture and
diff --git a/src/dawn_native/opengl/DeviceGL.cpp b/src/dawn_native/opengl/DeviceGL.cpp
index f2a3edf..2d2b6e1 100644
--- a/src/dawn_native/opengl/DeviceGL.cpp
+++ b/src/dawn_native/opengl/DeviceGL.cpp
@@ -21,7 +21,6 @@
 #include "dawn_native/opengl/BufferGL.h"
 #include "dawn_native/opengl/CommandBufferGL.h"
 #include "dawn_native/opengl/ComputePipelineGL.h"
-#include "dawn_native/opengl/InputStateGL.h"
 #include "dawn_native/opengl/PipelineLayoutGL.h"
 #include "dawn_native/opengl/QueueGL.h"
 #include "dawn_native/opengl/RenderPipelineGL.h"
@@ -67,9 +66,6 @@
         const ComputePipelineDescriptor* descriptor) {
         return new ComputePipeline(this, descriptor);
     }
-    InputStateBase* Device::CreateInputState(InputStateBuilder* builder) {
-        return new InputState(builder);
-    }
     ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
         const PipelineLayoutDescriptor* descriptor) {
         return new PipelineLayout(this, descriptor);
diff --git a/src/dawn_native/opengl/DeviceGL.h b/src/dawn_native/opengl/DeviceGL.h
index 43fdf65..5cb1262 100644
--- a/src/dawn_native/opengl/DeviceGL.h
+++ b/src/dawn_native/opengl/DeviceGL.h
@@ -41,7 +41,6 @@
 
         // Dawn API
         CommandBufferBase* CreateCommandBuffer(CommandEncoderBase* encoder) override;
-        InputStateBase* CreateInputState(InputStateBuilder* builder) override;
 
         Serial GetCompletedCommandSerial() const final override;
         Serial GetLastSubmittedCommandSerial() const final override;
diff --git a/src/dawn_native/opengl/Forward.h b/src/dawn_native/opengl/Forward.h
index d4bfac6..6542ff9 100644
--- a/src/dawn_native/opengl/Forward.h
+++ b/src/dawn_native/opengl/Forward.h
@@ -32,7 +32,6 @@
     class CommandBuffer;
     class ComputePipeline;
     class Device;
-    class InputState;
     class PersistentPipelineState;
     class PipelineLayout;
     class Queue;
@@ -51,7 +50,6 @@
         using CommandBufferType = CommandBuffer;
         using ComputePipelineType = ComputePipeline;
         using DeviceType = Device;
-        using InputStateType = InputState;
         using PipelineLayoutType = PipelineLayout;
         using QueueType = Queue;
         using RenderPipelineType = RenderPipeline;
diff --git a/src/dawn_native/opengl/InputStateGL.cpp b/src/dawn_native/opengl/InputStateGL.cpp
deleted file mode 100644
index f9052e2..0000000
--- a/src/dawn_native/opengl/InputStateGL.cpp
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "dawn_native/opengl/InputStateGL.h"
-
-#include "common/Assert.h"
-
-namespace dawn_native { namespace opengl {
-
-    InputState::InputState(InputStateBuilder* builder) : InputStateBase(builder) {
-        glGenVertexArrays(1, &mVertexArrayObject);
-        glBindVertexArray(mVertexArrayObject);
-        auto& attributesSetMask = GetAttributesSetMask();
-        for (uint32_t location = 0; location < attributesSetMask.size(); ++location) {
-            if (!attributesSetMask[location]) {
-                continue;
-            }
-            auto attribute = GetAttribute(location);
-            glEnableVertexAttribArray(location);
-
-            attributesUsingInput[attribute.inputSlot][location] = true;
-            auto input = GetInput(attribute.inputSlot);
-
-            if (input.stride == 0) {
-                // Emulate a stride of zero (constant vertex attribute) by
-                // setting the attribute instance divisor to a huge number.
-                glVertexAttribDivisor(location, 0xffffffff);
-            } else {
-                switch (input.stepMode) {
-                    case dawn::InputStepMode::Vertex:
-                        break;
-                    case dawn::InputStepMode::Instance:
-                        glVertexAttribDivisor(location, 1);
-                        break;
-                    default:
-                        UNREACHABLE();
-                }
-            }
-        }
-    }
-
-    std::bitset<kMaxVertexAttributes> InputState::GetAttributesUsingInput(uint32_t slot) const {
-        return attributesUsingInput[slot];
-    }
-
-    GLuint InputState::GetVAO() {
-        return mVertexArrayObject;
-    }
-
-}}  // namespace dawn_native::opengl
diff --git a/src/dawn_native/opengl/InputStateGL.h b/src/dawn_native/opengl/InputStateGL.h
deleted file mode 100644
index dbdd84b..0000000
--- a/src/dawn_native/opengl/InputStateGL.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2017 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef DAWNNATIVE_OPENGL_INPUTSTATEGL_H_
-#define DAWNNATIVE_OPENGL_INPUTSTATEGL_H_
-
-#include "dawn_native/InputState.h"
-
-#include "glad/glad.h"
-
-namespace dawn_native { namespace opengl {
-
-    class Device;
-
-    class InputState : public InputStateBase {
-      public:
-        InputState(InputStateBuilder* builder);
-
-        std::bitset<kMaxVertexAttributes> GetAttributesUsingInput(uint32_t slot) const;
-        GLuint GetVAO();
-
-      private:
-        GLuint mVertexArrayObject;
-        std::array<std::bitset<kMaxVertexAttributes>, kMaxVertexInputs> attributesUsingInput;
-    };
-
-}}  // namespace dawn_native::opengl
-
-#endif  // DAWNNATIVE_OPENGL_INPUTSTATEGL_H_
diff --git a/src/dawn_native/opengl/RenderPipelineGL.cpp b/src/dawn_native/opengl/RenderPipelineGL.cpp
index 830fa40..991b022 100644
--- a/src/dawn_native/opengl/RenderPipelineGL.cpp
+++ b/src/dawn_native/opengl/RenderPipelineGL.cpp
@@ -16,13 +16,13 @@
 
 #include "dawn_native/opengl/DeviceGL.h"
 #include "dawn_native/opengl/Forward.h"
-#include "dawn_native/opengl/InputStateGL.h"
 #include "dawn_native/opengl/PersistentPipelineStateGL.h"
 #include "dawn_native/opengl/UtilsGL.h"
 
 namespace dawn_native { namespace opengl {
 
     namespace {
+
         GLenum GLPrimitiveTopology(dawn::PrimitiveTopology primitiveTopology) {
             switch (primitiveTopology) {
                 case dawn::PrimitiveTopology::PointList:
@@ -175,23 +175,62 @@
 
     RenderPipeline::RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor)
         : RenderPipelineBase(device, descriptor),
+          mVertexArrayObject(0),
           mGlPrimitiveTopology(GLPrimitiveTopology(GetPrimitiveTopology())) {
         PerStage<const ShaderModule*> modules(nullptr);
         modules[dawn::ShaderStage::Vertex] = ToBackend(descriptor->vertexStage->module);
         modules[dawn::ShaderStage::Fragment] = ToBackend(descriptor->fragmentStage->module);
 
         PipelineGL::Initialize(ToBackend(GetLayout()), modules);
+        CreateVAOForInputState(descriptor->inputState);
+    }
+
+    RenderPipeline::~RenderPipeline() {
+        glDeleteVertexArrays(1, &mVertexArrayObject);
+        glBindVertexArray(0);
     }
 
     GLenum RenderPipeline::GetGLPrimitiveTopology() const {
         return mGlPrimitiveTopology;
     }
 
+    void RenderPipeline::CreateVAOForInputState(const InputStateDescriptor* inputState) {
+        glGenVertexArrays(1, &mVertexArrayObject);
+        glBindVertexArray(mVertexArrayObject);
+        auto& attributesSetMask = GetAttributesSetMask();
+        for (uint32_t location = 0; location < attributesSetMask.size(); ++location) {
+            if (!attributesSetMask[location]) {
+                continue;
+            }
+            auto attribute = GetAttribute(location);
+            glEnableVertexAttribArray(location);
+
+            attributesUsingInput[attribute.inputSlot][location] = true;
+            auto input = GetInput(attribute.inputSlot);
+
+            if (input.stride == 0) {
+                // Emulate a stride of zero (constant vertex attribute) by
+                // setting the attribute instance divisor to a huge number.
+                glVertexAttribDivisor(location, 0xffffffff);
+            } else {
+                switch (input.stepMode) {
+                    case dawn::InputStepMode::Vertex:
+                        break;
+                    case dawn::InputStepMode::Instance:
+                        glVertexAttribDivisor(location, 1);
+                        break;
+                    default:
+                        UNREACHABLE();
+                }
+            }
+        }
+    }
+
     void RenderPipeline::ApplyNow(PersistentPipelineState& persistentPipelineState) {
         PipelineGL::ApplyNow();
 
-        auto inputState = ToBackend(GetInputState());
-        glBindVertexArray(inputState->GetVAO());
+        ASSERT(mVertexArrayObject);
+        glBindVertexArray(mVertexArrayObject);
 
         ApplyDepthStencilState(GetDepthStencilStateDescriptor(), &persistentPipelineState);
 
diff --git a/src/dawn_native/opengl/RenderPipelineGL.h b/src/dawn_native/opengl/RenderPipelineGL.h
index 28458e0..dc51f6f 100644
--- a/src/dawn_native/opengl/RenderPipelineGL.h
+++ b/src/dawn_native/opengl/RenderPipelineGL.h
@@ -31,12 +31,17 @@
     class RenderPipeline : public RenderPipelineBase, public PipelineGL {
       public:
         RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor);
+        ~RenderPipeline();
 
         GLenum GetGLPrimitiveTopology() const;
 
         void ApplyNow(PersistentPipelineState& persistentPipelineState);
 
       private:
+        void CreateVAOForInputState(const InputStateDescriptor* inputState);
+
+        // TODO(yunchao.he@intel.com): vao need to be deduplicated between pipelines.
+        GLuint mVertexArrayObject;
         GLenum mGlPrimitiveTopology;
     };
 
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index aad6e2f..f2ab5a5 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -27,7 +27,6 @@
 #include "dawn_native/vulkan/CommandBufferVk.h"
 #include "dawn_native/vulkan/ComputePipelineVk.h"
 #include "dawn_native/vulkan/FencedDeleter.h"
-#include "dawn_native/vulkan/InputStateVk.h"
 #include "dawn_native/vulkan/PipelineLayoutVk.h"
 #include "dawn_native/vulkan/QueueVk.h"
 #include "dawn_native/vulkan/RenderPassCache.h"
@@ -152,9 +151,6 @@
         const ComputePipelineDescriptor* descriptor) {
         return new ComputePipeline(this, descriptor);
     }
-    InputStateBase* Device::CreateInputState(InputStateBuilder* builder) {
-        return new InputState(builder);
-    }
     ResultOrError<PipelineLayoutBase*> Device::CreatePipelineLayoutImpl(
         const PipelineLayoutDescriptor* descriptor) {
         return new PipelineLayout(this, descriptor);
diff --git a/src/dawn_native/vulkan/DeviceVk.h b/src/dawn_native/vulkan/DeviceVk.h
index 95cdc54..c59883e 100644
--- a/src/dawn_native/vulkan/DeviceVk.h
+++ b/src/dawn_native/vulkan/DeviceVk.h
@@ -65,7 +65,6 @@
 
         // Dawn API
         CommandBufferBase* CreateCommandBuffer(CommandEncoderBase* encoder) override;
-        InputStateBase* CreateInputState(InputStateBuilder* builder) override;
 
         Serial GetCompletedCommandSerial() const final override;
         Serial GetLastSubmittedCommandSerial() const final override;
diff --git a/src/dawn_native/vulkan/Forward.h b/src/dawn_native/vulkan/Forward.h
index 99cc23a..344678a 100644
--- a/src/dawn_native/vulkan/Forward.h
+++ b/src/dawn_native/vulkan/Forward.h
@@ -26,7 +26,6 @@
     class CommandBuffer;
     class ComputePipeline;
     class Device;
-    class InputState;
     class PipelineLayout;
     class Queue;
     class RenderPipeline;
@@ -45,7 +44,6 @@
         using CommandBufferType = CommandBuffer;
         using ComputePipelineType = ComputePipeline;
         using DeviceType = Device;
-        using InputStateType = InputState;
         using PipelineLayoutType = PipelineLayout;
         using QueueType = Queue;
         using RenderPipelineType = RenderPipeline;
diff --git a/src/dawn_native/vulkan/InputStateVk.cpp b/src/dawn_native/vulkan/InputStateVk.cpp
deleted file mode 100644
index 3e704ae..0000000
--- a/src/dawn_native/vulkan/InputStateVk.cpp
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2018 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "dawn_native/vulkan/InputStateVk.h"
-
-#include "common/BitSetIterator.h"
-
-namespace dawn_native { namespace vulkan {
-
-    namespace {
-
-        VkVertexInputRate VulkanInputRate(dawn::InputStepMode stepMode) {
-            switch (stepMode) {
-                case dawn::InputStepMode::Vertex:
-                    return VK_VERTEX_INPUT_RATE_VERTEX;
-                case dawn::InputStepMode::Instance:
-                    return VK_VERTEX_INPUT_RATE_INSTANCE;
-                default:
-                    UNREACHABLE();
-            }
-        }
-
-        VkFormat VulkanVertexFormat(dawn::VertexFormat format) {
-            switch (format) {
-                case dawn::VertexFormat::UChar2:
-                    return VK_FORMAT_R8G8_UINT;
-                case dawn::VertexFormat::UChar4:
-                    return VK_FORMAT_R8G8B8A8_UINT;
-                case dawn::VertexFormat::Char2:
-                    return VK_FORMAT_R8G8_SINT;
-                case dawn::VertexFormat::Char4:
-                    return VK_FORMAT_R8G8B8A8_SINT;
-                case dawn::VertexFormat::UChar2Norm:
-                    return VK_FORMAT_R8G8_UNORM;
-                case dawn::VertexFormat::UChar4Norm:
-                    return VK_FORMAT_R8G8B8A8_UNORM;
-                case dawn::VertexFormat::Char2Norm:
-                    return VK_FORMAT_R8G8_SNORM;
-                case dawn::VertexFormat::Char4Norm:
-                    return VK_FORMAT_R8G8B8A8_SNORM;
-                case dawn::VertexFormat::UShort2:
-                    return VK_FORMAT_R16G16_UINT;
-                case dawn::VertexFormat::UShort4:
-                    return VK_FORMAT_R16G16B16A16_UINT;
-                case dawn::VertexFormat::Short2:
-                    return VK_FORMAT_R16G16_SINT;
-                case dawn::VertexFormat::Short4:
-                    return VK_FORMAT_R16G16B16A16_SINT;
-                case dawn::VertexFormat::UShort2Norm:
-                    return VK_FORMAT_R16G16_UNORM;
-                case dawn::VertexFormat::UShort4Norm:
-                    return VK_FORMAT_R16G16B16A16_UNORM;
-                case dawn::VertexFormat::Short2Norm:
-                    return VK_FORMAT_R16G16_SNORM;
-                case dawn::VertexFormat::Short4Norm:
-                    return VK_FORMAT_R16G16B16A16_SNORM;
-                case dawn::VertexFormat::Half2:
-                    return VK_FORMAT_R16G16_SFLOAT;
-                case dawn::VertexFormat::Half4:
-                    return VK_FORMAT_R16G16B16A16_SFLOAT;
-                case dawn::VertexFormat::Float:
-                    return VK_FORMAT_R32_SFLOAT;
-                case dawn::VertexFormat::Float2:
-                    return VK_FORMAT_R32G32_SFLOAT;
-                case dawn::VertexFormat::Float3:
-                    return VK_FORMAT_R32G32B32_SFLOAT;
-                case dawn::VertexFormat::Float4:
-                    return VK_FORMAT_R32G32B32A32_SFLOAT;
-                case dawn::VertexFormat::UInt:
-                    return VK_FORMAT_R32_UINT;
-                case dawn::VertexFormat::UInt2:
-                    return VK_FORMAT_R32G32_UINT;
-                case dawn::VertexFormat::UInt3:
-                    return VK_FORMAT_R32G32B32_UINT;
-                case dawn::VertexFormat::UInt4:
-                    return VK_FORMAT_R32G32B32A32_UINT;
-                case dawn::VertexFormat::Int:
-                    return VK_FORMAT_R32_SINT;
-                case dawn::VertexFormat::Int2:
-                    return VK_FORMAT_R32G32_SINT;
-                case dawn::VertexFormat::Int3:
-                    return VK_FORMAT_R32G32B32_SINT;
-                case dawn::VertexFormat::Int4:
-                    return VK_FORMAT_R32G32B32A32_SINT;
-                default:
-                    UNREACHABLE();
-            }
-        }
-
-    }  // anonymous namespace
-
-    InputState::InputState(InputStateBuilder* builder) : InputStateBase(builder) {
-        // Fill in the "binding info" that will be chained in the create info
-        uint32_t bindingCount = 0;
-        for (uint32_t i : IterateBitSet(GetInputsSetMask())) {
-            const auto& bindingInfo = GetInput(i);
-
-            auto& bindingDesc = mBindings[bindingCount];
-            bindingDesc.binding = i;
-            bindingDesc.stride = bindingInfo.stride;
-            bindingDesc.inputRate = VulkanInputRate(bindingInfo.stepMode);
-
-            bindingCount++;
-        }
-
-        // Fill in the "attribute info" that will be chained in the create info
-        uint32_t attributeCount = 0;
-        for (uint32_t i : IterateBitSet(GetAttributesSetMask())) {
-            const auto& attributeInfo = GetAttribute(i);
-
-            auto& attributeDesc = mAttributes[attributeCount];
-            attributeDesc.location = i;
-            attributeDesc.binding = attributeInfo.inputSlot;
-            attributeDesc.format = VulkanVertexFormat(attributeInfo.format);
-            attributeDesc.offset = attributeInfo.offset;
-
-            attributeCount++;
-        }
-
-        // Build the create info
-        mCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO;
-        mCreateInfo.pNext = nullptr;
-        mCreateInfo.flags = 0;
-        mCreateInfo.vertexBindingDescriptionCount = bindingCount;
-        mCreateInfo.pVertexBindingDescriptions = mBindings.data();
-        mCreateInfo.vertexAttributeDescriptionCount = attributeCount;
-        mCreateInfo.pVertexAttributeDescriptions = mAttributes.data();
-    }
-
-    const VkPipelineVertexInputStateCreateInfo* InputState::GetCreateInfo() const {
-        return &mCreateInfo;
-    }
-
-}}  // namespace dawn_native::vulkan
diff --git a/src/dawn_native/vulkan/InputStateVk.h b/src/dawn_native/vulkan/InputStateVk.h
deleted file mode 100644
index b44c08a..0000000
--- a/src/dawn_native/vulkan/InputStateVk.h
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2018 The Dawn Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef DAWNNATIVE_VULKAN_INPUTSTATEVK_H_
-#define DAWNNATIVE_VULKAN_INPUTSTATEVK_H_
-
-#include "dawn_native/InputState.h"
-
-#include "common/vulkan_platform.h"
-
-namespace dawn_native { namespace vulkan {
-
-    class Device;
-
-    // Pre-computes the input state configuration to give to a graphics pipeline create info.
-    class InputState : public InputStateBase {
-      public:
-        InputState(InputStateBuilder* builder);
-
-        const VkPipelineVertexInputStateCreateInfo* GetCreateInfo() const;
-
-      private:
-        VkPipelineVertexInputStateCreateInfo mCreateInfo;
-        std::array<VkVertexInputBindingDescription, kMaxVertexInputs> mBindings;
-        std::array<VkVertexInputAttributeDescription, kMaxVertexAttributes> mAttributes;
-    };
-
-}}  // namespace dawn_native::vulkan
-
-#endif  // DAWNNATIVE_VULKAN_INPUTSTATEVK_H_
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp
index 81b73b5..28ef9ad 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp
@@ -16,7 +16,6 @@
 
 #include "dawn_native/vulkan/DeviceVk.h"
 #include "dawn_native/vulkan/FencedDeleter.h"
-#include "dawn_native/vulkan/InputStateVk.h"
 #include "dawn_native/vulkan/PipelineLayoutVk.h"
 #include "dawn_native/vulkan/RenderPassCache.h"
 #include "dawn_native/vulkan/ShaderModuleVk.h"
@@ -26,6 +25,84 @@
 
     namespace {
 
+        VkVertexInputRate VulkanInputRate(dawn::InputStepMode stepMode) {
+            switch (stepMode) {
+                case dawn::InputStepMode::Vertex:
+                    return VK_VERTEX_INPUT_RATE_VERTEX;
+                case dawn::InputStepMode::Instance:
+                    return VK_VERTEX_INPUT_RATE_INSTANCE;
+                default:
+                    UNREACHABLE();
+            }
+        }
+
+        VkFormat VulkanVertexFormat(dawn::VertexFormat format) {
+            switch (format) {
+                case dawn::VertexFormat::UChar2:
+                    return VK_FORMAT_R8G8_UINT;
+                case dawn::VertexFormat::UChar4:
+                    return VK_FORMAT_R8G8B8A8_UINT;
+                case dawn::VertexFormat::Char2:
+                    return VK_FORMAT_R8G8_SINT;
+                case dawn::VertexFormat::Char4:
+                    return VK_FORMAT_R8G8B8A8_SINT;
+                case dawn::VertexFormat::UChar2Norm:
+                    return VK_FORMAT_R8G8_UNORM;
+                case dawn::VertexFormat::UChar4Norm:
+                    return VK_FORMAT_R8G8B8A8_UNORM;
+                case dawn::VertexFormat::Char2Norm:
+                    return VK_FORMAT_R8G8_SNORM;
+                case dawn::VertexFormat::Char4Norm:
+                    return VK_FORMAT_R8G8B8A8_SNORM;
+                case dawn::VertexFormat::UShort2:
+                    return VK_FORMAT_R16G16_UINT;
+                case dawn::VertexFormat::UShort4:
+                    return VK_FORMAT_R16G16B16A16_UINT;
+                case dawn::VertexFormat::Short2:
+                    return VK_FORMAT_R16G16_SINT;
+                case dawn::VertexFormat::Short4:
+                    return VK_FORMAT_R16G16B16A16_SINT;
+                case dawn::VertexFormat::UShort2Norm:
+                    return VK_FORMAT_R16G16_UNORM;
+                case dawn::VertexFormat::UShort4Norm:
+                    return VK_FORMAT_R16G16B16A16_UNORM;
+                case dawn::VertexFormat::Short2Norm:
+                    return VK_FORMAT_R16G16_SNORM;
+                case dawn::VertexFormat::Short4Norm:
+                    return VK_FORMAT_R16G16B16A16_SNORM;
+                case dawn::VertexFormat::Half2:
+                    return VK_FORMAT_R16G16_SFLOAT;
+                case dawn::VertexFormat::Half4:
+                    return VK_FORMAT_R16G16B16A16_SFLOAT;
+                case dawn::VertexFormat::Float:
+                    return VK_FORMAT_R32_SFLOAT;
+                case dawn::VertexFormat::Float2:
+                    return VK_FORMAT_R32G32_SFLOAT;
+                case dawn::VertexFormat::Float3:
+                    return VK_FORMAT_R32G32B32_SFLOAT;
+                case dawn::VertexFormat::Float4:
+                    return VK_FORMAT_R32G32B32A32_SFLOAT;
+                case dawn::VertexFormat::UInt:
+                    return VK_FORMAT_R32_UINT;
+                case dawn::VertexFormat::UInt2:
+                    return VK_FORMAT_R32G32_UINT;
+                case dawn::VertexFormat::UInt3:
+                    return VK_FORMAT_R32G32B32_UINT;
+                case dawn::VertexFormat::UInt4:
+                    return VK_FORMAT_R32G32B32A32_UINT;
+                case dawn::VertexFormat::Int:
+                    return VK_FORMAT_R32_SINT;
+                case dawn::VertexFormat::Int2:
+                    return VK_FORMAT_R32G32_SINT;
+                case dawn::VertexFormat::Int3:
+                    return VK_FORMAT_R32G32B32_SINT;
+                case dawn::VertexFormat::Int4:
+                    return VK_FORMAT_R32G32B32A32_SINT;
+                default:
+                    UNREACHABLE();
+            }
+        }
+
         VkPrimitiveTopology VulkanPrimitiveTopology(dawn::PrimitiveTopology topology) {
             switch (topology) {
                 case dawn::PrimitiveTopology::PointList:
@@ -218,6 +295,12 @@
             shaderStages[1].pName = descriptor->fragmentStage->entryPoint;
         }
 
+        std::array<VkVertexInputBindingDescription, kMaxVertexInputs> mBindings;
+        std::array<VkVertexInputAttributeDescription, kMaxVertexAttributes> mAttributes;
+        const InputStateDescriptor* inputState = GetInputStateDescriptor();
+        VkPipelineVertexInputStateCreateInfo inputStateCreateInfo =
+            ComputeInputStateDesc(inputState, &mBindings, &mAttributes);
+
         VkPipelineInputAssemblyStateCreateInfo inputAssembly;
         inputAssembly.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
         inputAssembly.pNext = nullptr;
@@ -341,7 +424,7 @@
         createInfo.flags = 0;
         createInfo.stageCount = 2;
         createInfo.pStages = shaderStages;
-        createInfo.pVertexInputState = ToBackend(GetInputState())->GetCreateInfo();
+        createInfo.pVertexInputState = &inputStateCreateInfo;
         createInfo.pInputAssemblyState = &inputAssembly;
         createInfo.pTessellationState = nullptr;
         createInfo.pViewportState = &viewport;
@@ -362,6 +445,49 @@
         }
     }
 
+    VkPipelineVertexInputStateCreateInfo RenderPipeline::ComputeInputStateDesc(
+        const InputStateDescriptor* inputState,
+        std::array<VkVertexInputBindingDescription, kMaxVertexInputs>* mBindings,
+        std::array<VkVertexInputAttributeDescription, kMaxVertexAttributes>* mAttributes) {
+        // Fill in the "binding info" that will be chained in the create info
+        uint32_t bindingCount = 0;
+        for (uint32_t i : IterateBitSet(GetInputsSetMask())) {
+            const auto& bindingInfo = GetInput(i);
+
+            auto& bindingDesc = (*mBindings)[bindingCount];
+            bindingDesc.binding = i;
+            bindingDesc.stride = bindingInfo.stride;
+            bindingDesc.inputRate = VulkanInputRate(bindingInfo.stepMode);
+
+            bindingCount++;
+        }
+
+        // Fill in the "attribute info" that will be chained in the create info
+        uint32_t attributeCount = 0;
+        for (uint32_t i : IterateBitSet(GetAttributesSetMask())) {
+            const auto& attributeInfo = GetAttribute(i);
+
+            auto& attributeDesc = (*mAttributes)[attributeCount];
+            attributeDesc.location = i;
+            attributeDesc.binding = attributeInfo.inputSlot;
+            attributeDesc.format = VulkanVertexFormat(attributeInfo.format);
+            attributeDesc.offset = attributeInfo.offset;
+
+            attributeCount++;
+        }
+
+        // Build the create info
+        VkPipelineVertexInputStateCreateInfo mCreateInfo;
+        mCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO;
+        mCreateInfo.pNext = nullptr;
+        mCreateInfo.flags = 0;
+        mCreateInfo.vertexBindingDescriptionCount = bindingCount;
+        mCreateInfo.pVertexBindingDescriptions = &(*mBindings)[0];
+        mCreateInfo.vertexAttributeDescriptionCount = attributeCount;
+        mCreateInfo.pVertexAttributeDescriptions = &(*mAttributes)[0];
+        return mCreateInfo;
+    }
+
     RenderPipeline::~RenderPipeline() {
         if (mHandle != VK_NULL_HANDLE) {
             ToBackend(GetDevice())->GetFencedDeleter()->DeleteWhenUnused(mHandle);
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.h b/src/dawn_native/vulkan/RenderPipelineVk.h
index 744772d..5d58fa7 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.h
+++ b/src/dawn_native/vulkan/RenderPipelineVk.h
@@ -31,6 +31,11 @@
         VkPipeline GetHandle() const;
 
       private:
+        VkPipelineVertexInputStateCreateInfo ComputeInputStateDesc(
+            const InputStateDescriptor* inputState,
+            std::array<VkVertexInputBindingDescription, kMaxVertexInputs>* mBindings,
+            std::array<VkVertexInputAttributeDescription, kMaxVertexAttributes>* mAttributes);
+
         VkPipeline mHandle = VK_NULL_HANDLE;
     };
 
diff --git a/src/tests/end2end/DestroyBufferTests.cpp b/src/tests/end2end/DestroyBufferTests.cpp
index 6c2560a..6012536 100644
--- a/src/tests/end2end/DestroyBufferTests.cpp
+++ b/src/tests/end2end/DestroyBufferTests.cpp
@@ -36,9 +36,6 @@
         attribute.offset = 0;

         attribute.format = dawn::VertexFormat::Float4;

 

-        dawn::InputState inputState =

-            device.CreateInputStateBuilder().SetInput(&input).SetAttribute(&attribute).GetResult();

-

         dawn::ShaderModule vsModule =

             utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(

               #version 450

@@ -60,7 +57,10 @@
         descriptor.cFragmentStage.module = fsModule;

         descriptor.primitiveTopology = dawn::PrimitiveTopology::TriangleStrip;

         descriptor.indexFormat = dawn::IndexFormat::Uint32;

-        descriptor.inputState = inputState;

+        descriptor.cInputState.numInputs = 1;

+        descriptor.cInputState.inputs = &input;

+        descriptor.cInputState.numAttributes = 1;

+        descriptor.cInputState.attributes = &attribute;

         descriptor.cColorStates[0]->format = renderPass.colorFormat;

 

         pipeline = device.CreateRenderPipeline(&descriptor);

@@ -135,4 +135,4 @@
     EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, 1, 3);

 }

 

-DAWN_INSTANTIATE_TEST(DestroyBufferTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
\ No newline at end of file
+DAWN_INSTANTIATE_TEST(DestroyBufferTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);

diff --git a/src/tests/end2end/DrawIndexedTests.cpp b/src/tests/end2end/DrawIndexedTests.cpp
index 8548d99..7eca1e1 100644
--- a/src/tests/end2end/DrawIndexedTests.cpp
+++ b/src/tests/end2end/DrawIndexedTests.cpp
@@ -37,18 +37,13 @@
             attribute.offset = 0;
             attribute.format = dawn::VertexFormat::Float4;
 
-            dawn::InputState inputState = device.CreateInputStateBuilder()
-                                              .SetInput(&input)
-                                              .SetAttribute(&attribute)
-                                              .GetResult();
-
-            dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
+            dawn::ShaderModule vsModule =
+                utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
                 #version 450
                 layout(location = 0) in vec4 pos;
                 void main() {
                     gl_Position = pos;
-                })"
-            );
+                })");
 
             dawn::ShaderModule fsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
                 #version 450
@@ -63,7 +58,10 @@
             descriptor.cFragmentStage.module = fsModule;
             descriptor.primitiveTopology = dawn::PrimitiveTopology::TriangleStrip;
             descriptor.indexFormat = dawn::IndexFormat::Uint32;
-            descriptor.inputState = inputState;
+            descriptor.cInputState.numInputs = 1;
+            descriptor.cInputState.inputs = &input;
+            descriptor.cInputState.numAttributes = 1;
+            descriptor.cInputState.attributes = &attribute;
             descriptor.cColorStates[0]->format = renderPass.colorFormat;
 
             pipeline = device.CreateRenderPipeline(&descriptor);
diff --git a/src/tests/end2end/DrawTests.cpp b/src/tests/end2end/DrawTests.cpp
index 073ebd3..f7292ea 100644
--- a/src/tests/end2end/DrawTests.cpp
+++ b/src/tests/end2end/DrawTests.cpp
@@ -37,9 +37,6 @@
         attribute.offset = 0;
         attribute.format = dawn::VertexFormat::Float4;
 
-        dawn::InputState inputState =
-            device.CreateInputStateBuilder().SetInput(&input).SetAttribute(&attribute).GetResult();
-
         dawn::ShaderModule vsModule =
             utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
                 #version 450
@@ -61,7 +58,10 @@
         descriptor.cFragmentStage.module = fsModule;
         descriptor.primitiveTopology = dawn::PrimitiveTopology::TriangleStrip;
         descriptor.indexFormat = dawn::IndexFormat::Uint32;
-        descriptor.inputState = inputState;
+        descriptor.cInputState.numInputs = 1;
+        descriptor.cInputState.inputs = &input;
+        descriptor.cInputState.numAttributes = 1;
+        descriptor.cInputState.attributes = &attribute;
         descriptor.cColorStates[0]->format = renderPass.colorFormat;
 
         pipeline = device.CreateRenderPipeline(&descriptor);
diff --git a/src/tests/end2end/IndexFormatTests.cpp b/src/tests/end2end/IndexFormatTests.cpp
index a62ee6e..3ba2325 100644
--- a/src/tests/end2end/IndexFormatTests.cpp
+++ b/src/tests/end2end/IndexFormatTests.cpp
@@ -42,11 +42,6 @@
             attribute.offset = 0;
             attribute.format = dawn::VertexFormat::Float4;
 
-            dawn::InputState inputState = device.CreateInputStateBuilder()
-                                              .SetInput(&input)
-                                              .SetAttribute(&attribute)
-                                              .GetResult();
-
             dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"(
                 #version 450
                 layout(location = 0) in vec4 pos;
@@ -68,7 +63,10 @@
             descriptor.cFragmentStage.module = fsModule;
             descriptor.primitiveTopology = dawn::PrimitiveTopology::TriangleStrip;
             descriptor.indexFormat = format;
-            descriptor.inputState = inputState;
+            descriptor.cInputState.numInputs = 1;
+            descriptor.cInputState.inputs = &input;
+            descriptor.cInputState.numAttributes = 1;
+            descriptor.cInputState.attributes = &attribute;
             descriptor.cColorStates[0]->format = renderPass.colorFormat;
 
             return device.CreateRenderPipeline(&descriptor);
diff --git a/src/tests/end2end/InputStateTests.cpp b/src/tests/end2end/InputStateTests.cpp
index c205f5f..78c14e3 100644
--- a/src/tests/end2end/InputStateTests.cpp
+++ b/src/tests/end2end/InputStateTests.cpp
@@ -64,7 +64,9 @@
             VertexFormat format;
             InputStepMode step;
         };
-        dawn::RenderPipeline MakeTestPipeline(const dawn::InputState& inputState, int multiplier, std::vector<ShaderTestSpec> testSpec) {
+        dawn::RenderPipeline MakeTestPipeline(const dawn::InputStateDescriptor& inputState,
+                                              int multiplier,
+                                              std::vector<ShaderTestSpec> testSpec) {
             std::ostringstream vs;
             vs << "#version 450\n";
 
@@ -124,7 +126,7 @@
             utils::ComboRenderPipelineDescriptor descriptor(device);
             descriptor.cVertexStage.module = vsModule;
             descriptor.cFragmentStage.module = fsModule;
-            descriptor.inputState = inputState;
+            descriptor.inputState = &inputState;
             descriptor.cColorStates[0]->format = renderPass.colorFormat;
 
             return device.CreateRenderPipeline(&descriptor);
@@ -141,28 +143,30 @@
             uint32_t offset;
             VertexFormat format;
         };
-        dawn::InputState MakeInputState(std::vector<InputSpec> inputs, std::vector<AttributeSpec> attributes) {
-            dawn::InputStateBuilder builder = device.CreateInputStateBuilder();
-
+        dawn::InputStateDescriptor MakeInputState(std::vector<InputSpec> inputs,
+                                                  std::vector<AttributeSpec> attributes) {
+            dawn::InputStateDescriptor inputState;
+            uint32_t numInputs = 0;
             for (const auto& input : inputs) {
-                dawn::VertexInputDescriptor descriptor;
-                descriptor.inputSlot = input.slot;
-                descriptor.stride = input.stride;
-                descriptor.stepMode = input.step;
-                builder.SetInput(&descriptor);
+                vertexInputs[numInputs].inputSlot = input.slot;
+                vertexInputs[numInputs].stride = input.stride;
+                vertexInputs[numInputs].stepMode = input.step;
+                numInputs++;
             }
 
+            uint32_t numAttributes = 0;
             for (const auto& attribute : attributes) {
-                dawn::VertexAttributeDescriptor descriptor;
-                descriptor.shaderLocation = attribute.location;
-                descriptor.inputSlot = attribute.slot;
-                descriptor.offset = attribute.offset;
-                descriptor.format = attribute.format;
-
-                builder.SetAttribute(&descriptor);
+                vertexAttributes[numAttributes].shaderLocation = attribute.location;
+                vertexAttributes[numAttributes].inputSlot = attribute.slot;
+                vertexAttributes[numAttributes].offset = attribute.offset;
+                vertexAttributes[numAttributes].format = attribute.format;
+                numAttributes++;
             }
-
-            return builder.GetResult();
+            inputState.numInputs = numInputs;
+            inputState.inputs = vertexInputs;
+            inputState.numAttributes = numAttributes;
+            inputState.attributes = vertexAttributes;
+            return inputState;
         }
 
         template<typename T>
@@ -214,16 +218,14 @@
         }
 
         utils::BasicRenderPass renderPass;
+        dawn::VertexAttributeDescriptor vertexAttributes[kMaxVertexAttributes];
+        dawn::VertexInputDescriptor vertexInputs[kMaxVertexInputs];
 };
 
 // Test compilation and usage of the fixture :)
 TEST_P(InputStateTest, Basic) {
-    dawn::InputState inputState = MakeInputState({
-            {0, 4 * sizeof(float), InputStepMode::Vertex}
-        }, {
-            {0, 0, 0, VertexFormat::Float4}
-        }
-    );
+    dawn::InputStateDescriptor inputState = MakeInputState(
+        {{0, 4 * sizeof(float), InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 1, {
         {0, VertexFormat::Float4, InputStepMode::Vertex}
     });
@@ -241,12 +243,8 @@
     // This test was failing only on AMD but the OpenGL backend doesn't gather PCI info yet.
     DAWN_SKIP_TEST_IF(IsLinux() && IsOpenGL());
 
-    dawn::InputState inputState = MakeInputState({
-            {0, 0, InputStepMode::Vertex}
-        }, {
-            {0, 0, 0, VertexFormat::Float4}
-        }
-    );
+    dawn::InputStateDescriptor inputState =
+        MakeInputState({{0, 0, InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 0, {
         {0, VertexFormat::Float4, InputStepMode::Vertex}
     });
@@ -264,12 +262,8 @@
 
     // R32F case
     {
-        dawn::InputState inputState = MakeInputState({
-                {0, 0, InputStepMode::Vertex}
-            }, {
-                {0, 0, 0, VertexFormat::Float}
-            }
-        );
+        dawn::InputStateDescriptor inputState =
+            MakeInputState({{0, 0, InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float}});
         dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 0, {
             {0, VertexFormat::Float, InputStepMode::Vertex}
         });
@@ -281,12 +275,8 @@
     }
     // RG32F case
     {
-        dawn::InputState inputState = MakeInputState({
-                {0, 0, InputStepMode::Vertex}
-            }, {
-                {0, 0, 0, VertexFormat::Float2}
-            }
-        );
+        dawn::InputStateDescriptor inputState =
+            MakeInputState({{0, 0, InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float2}});
         dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 0, {
             {0, VertexFormat::Float2, InputStepMode::Vertex}
         });
@@ -298,12 +288,8 @@
     }
     // RGB32F case
     {
-        dawn::InputState inputState = MakeInputState({
-                {0, 0, InputStepMode::Vertex}
-            }, {
-                {0, 0, 0, VertexFormat::Float3}
-            }
-        );
+        dawn::InputStateDescriptor inputState =
+            MakeInputState({{0, 0, InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float3}});
         dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 0, {
             {0, VertexFormat::Float3, InputStepMode::Vertex}
         });
@@ -320,12 +306,8 @@
     // This test was failing only on AMD but the OpenGL backend doesn't gather PCI info yet.
     DAWN_SKIP_TEST_IF(IsLinux() && IsOpenGL());
 
-    dawn::InputState inputState = MakeInputState({
-            {0, 8 * sizeof(float), InputStepMode::Vertex}
-        }, {
-            {0, 0, 0, VertexFormat::Float4}
-        }
-    );
+    dawn::InputStateDescriptor inputState = MakeInputState(
+        {{0, 8 * sizeof(float), InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 1, {
         {0, VertexFormat::Float4, InputStepMode::Vertex}
     });
@@ -340,13 +322,9 @@
 
 // Test two attributes at an offset, vertex version
 TEST_P(InputStateTest, TwoAttributesAtAnOffsetVertex) {
-    dawn::InputState inputState = MakeInputState({
-            {0, 8 * sizeof(float), InputStepMode::Vertex}
-        }, {
-            {0, 0, 0, VertexFormat::Float4},
-            {1, 0, 4  * sizeof(float), VertexFormat::Float4}
-        }
-    );
+    dawn::InputStateDescriptor inputState = MakeInputState(
+        {{0, 8 * sizeof(float), InputStepMode::Vertex}},
+        {{0, 0, 0, VertexFormat::Float4}, {1, 0, 4 * sizeof(float), VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 1, {
         {0, VertexFormat::Float4, InputStepMode::Vertex}
     });
@@ -361,13 +339,9 @@
 
 // Test two attributes at an offset, instance version
 TEST_P(InputStateTest, TwoAttributesAtAnOffsetInstance) {
-    dawn::InputState inputState = MakeInputState({
-            {0, 8 * sizeof(float), InputStepMode::Instance}
-        }, {
-            {0, 0, 0, VertexFormat::Float4},
-            {1, 0, 4  * sizeof(float), VertexFormat::Float4}
-        }
-    );
+    dawn::InputStateDescriptor inputState = MakeInputState(
+        {{0, 8 * sizeof(float), InputStepMode::Instance}},
+        {{0, 0, 0, VertexFormat::Float4}, {1, 0, 4 * sizeof(float), VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 1, {
         {0, VertexFormat::Float4, InputStepMode::Instance}
     });
@@ -382,12 +356,8 @@
 
 // Test a pure-instance input state
 TEST_P(InputStateTest, PureInstance) {
-    dawn::InputState inputState = MakeInputState({
-            {0, 4 * sizeof(float), InputStepMode::Instance}
-        }, {
-            {0, 0, 0, VertexFormat::Float4}
-        }
-    );
+    dawn::InputStateDescriptor inputState = MakeInputState(
+        {{0, 4 * sizeof(float), InputStepMode::Instance}}, {{0, 0, 0, VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 1, {
         {0, VertexFormat::Float4, InputStepMode::Instance}
     });
@@ -404,16 +374,15 @@
 // Test with mixed everything, vertex vs. instance, different stride and offsets
 // different attribute types
 TEST_P(InputStateTest, MixedEverything) {
-    dawn::InputState inputState = MakeInputState({
+    dawn::InputStateDescriptor inputState = MakeInputState(
+        {
             {0, 12 * sizeof(float), InputStepMode::Vertex},
             {1, 10 * sizeof(float), InputStepMode::Instance},
-        }, {
-            {0, 0, 0, VertexFormat::Float},
-            {1, 0, 6  * sizeof(float), VertexFormat::Float2},
-            {2, 1, 0, VertexFormat::Float3},
-            {3, 1, 5  * sizeof(float), VertexFormat::Float4}
-        }
-    );
+        },
+        {{0, 0, 0, VertexFormat::Float},
+         {1, 0, 6 * sizeof(float), VertexFormat::Float2},
+         {2, 1, 0, VertexFormat::Float3},
+         {3, 1, 5 * sizeof(float), VertexFormat::Float4}});
     dawn::RenderPipeline pipeline = MakeTestPipeline(inputState, 1, {
         {0, VertexFormat::Float, InputStepMode::Vertex},
         {1, VertexFormat::Float2, InputStepMode::Vertex},
@@ -439,9 +408,8 @@
 // Test input state is unaffected by unused vertex slot
 TEST_P(InputStateTest, UnusedVertexSlot) {
     // Instance input state, using slot 1
-    dawn::InputState instanceInputState =
-        MakeInputState({{1, 4 * sizeof(float), InputStepMode::Instance}},
-                       {{0, 1, 0, VertexFormat::Float4}});
+    dawn::InputStateDescriptor instanceInputState = MakeInputState(
+        {{1, 4 * sizeof(float), InputStepMode::Instance}}, {{0, 1, 0, VertexFormat::Float4}});
     dawn::RenderPipeline instancePipeline = MakeTestPipeline(
         instanceInputState, 1, {{0, VertexFormat::Float4, InputStepMode::Instance}});
 
@@ -477,16 +445,14 @@
 // SetVertexBuffers should be reapplied when the input state changes.
 TEST_P(InputStateTest, MultiplePipelinesMixedInputState) {
     // Basic input state, using slot 0
-    dawn::InputState vertexInputState =
-        MakeInputState({{0, 4 * sizeof(float), InputStepMode::Vertex}},
-                       {{0, 0, 0, VertexFormat::Float4}});
+    dawn::InputStateDescriptor vertexInputState = MakeInputState(
+        {{0, 4 * sizeof(float), InputStepMode::Vertex}}, {{0, 0, 0, VertexFormat::Float4}});
     dawn::RenderPipeline vertexPipeline = MakeTestPipeline(
         vertexInputState, 1, {{0, VertexFormat::Float4, InputStepMode::Vertex}});
 
     // Instance input state, using slot 1
-    dawn::InputState instanceInputState =
-        MakeInputState({{1, 4 * sizeof(float), InputStepMode::Instance}},
-                       {{0, 1, 0, VertexFormat::Float4}});
+    dawn::InputStateDescriptor instanceInputState = MakeInputState(
+        {{1, 4 * sizeof(float), InputStepMode::Instance}}, {{0, 1, 0, VertexFormat::Float4}});
     dawn::RenderPipeline instancePipeline = MakeTestPipeline(
         instanceInputState, 1, {{0, VertexFormat::Float4, InputStepMode::Instance}});
 
diff --git a/src/tests/end2end/PrimitiveTopologyTests.cpp b/src/tests/end2end/PrimitiveTopologyTests.cpp
index 2831c42..6c1d73a 100644
--- a/src/tests/end2end/PrimitiveTopologyTests.cpp
+++ b/src/tests/end2end/PrimitiveTopologyTests.cpp
@@ -165,22 +165,6 @@
                     fragColor = vec4(0.0, 1.0, 0.0, 1.0);
                 })");
 
-            dawn::VertexAttributeDescriptor attribute;
-            attribute.shaderLocation = 0;
-            attribute.inputSlot = 0;
-            attribute.offset = 0;
-            attribute.format = dawn::VertexFormat::Float4;
-
-            dawn::VertexInputDescriptor input;
-            input.inputSlot = 0;
-            input.stride = 4 * sizeof(float);
-            input.stepMode = dawn::InputStepMode::Vertex;
-
-            inputState = device.CreateInputStateBuilder()
-                             .SetAttribute(&attribute)
-                             .SetInput(&input)
-                             .GetResult();
-
             vertexBuffer = utils::CreateBufferFromData(device, kVertices, sizeof(kVertices), dawn::BufferUsageBit::Vertex);
         }
 
@@ -197,12 +181,25 @@
 
         // Draw the vertices with the given primitive topology and check the pixel values of the test locations
         void DoTest(dawn::PrimitiveTopology primitiveTopology, const std::vector<LocationSpec> &locationSpecs) {
+            dawn::VertexAttributeDescriptor attribute;
+            attribute.shaderLocation = 0;
+            attribute.inputSlot = 0;
+            attribute.offset = 0;
+            attribute.format = dawn::VertexFormat::Float4;
+
+            dawn::VertexInputDescriptor input;
+            input.inputSlot = 0;
+            input.stride = 4 * sizeof(float);
+            input.stepMode = dawn::InputStepMode::Vertex;
 
             utils::ComboRenderPipelineDescriptor descriptor(device);
             descriptor.cVertexStage.module = vsModule;
             descriptor.cFragmentStage.module = fsModule;
             descriptor.primitiveTopology = primitiveTopology;
-            descriptor.inputState = inputState;
+            descriptor.cInputState.numInputs = 1;
+            descriptor.cInputState.inputs = &input;
+            descriptor.cInputState.numAttributes = 1;
+            descriptor.cInputState.attributes = &attribute;
             descriptor.cColorStates[0]->format = renderPass.colorFormat;
 
             dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&descriptor);
@@ -234,7 +231,6 @@
         utils::BasicRenderPass renderPass;
         dawn::ShaderModule vsModule;
         dawn::ShaderModule fsModule;
-        dawn::InputState inputState;
         dawn::Buffer vertexBuffer;
 };
 
diff --git a/src/tests/unittests/validation/InputStateValidationTests.cpp b/src/tests/unittests/validation/InputStateValidationTests.cpp
index d0bfce3..bcc35fd 100644
--- a/src/tests/unittests/validation/InputStateValidationTests.cpp
+++ b/src/tests/unittests/validation/InputStateValidationTests.cpp
@@ -25,9 +25,13 @@
 
 class InputStateTest : public ValidationTest {
     protected:
-        void CreatePipeline(bool success, const dawn::InputState& inputState, std::string vertexSource) {
-            dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str());
-            dawn::ShaderModule fsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
+      void CreatePipeline(bool success,
+                          const dawn::InputStateDescriptor* state,
+                          std::string vertexSource) {
+          dawn::ShaderModule vsModule =
+              utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str());
+          dawn::ShaderModule fsModule =
+              utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
                 #version 450
                 layout(location = 0) out vec4 fragColor;
                 void main() {
@@ -35,26 +39,25 @@
                 }
             )");
 
-            utils::ComboRenderPipelineDescriptor descriptor(device);
-            descriptor.cVertexStage.module = vsModule;
-            descriptor.cFragmentStage.module = fsModule;
-            descriptor.inputState = inputState;
-            descriptor.cColorStates[0]->format = dawn::TextureFormat::R8G8B8A8Unorm;
+          utils::ComboRenderPipelineDescriptor descriptor(device);
+          descriptor.cVertexStage.module = vsModule;
+          descriptor.cFragmentStage.module = fsModule;
+          if (state) {
+              descriptor.inputState = state;
+          }
+          descriptor.cColorStates[0]->format = dawn::TextureFormat::R8G8B8A8Unorm;
 
-            if (!success) {
-                ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
-            } else {
-                device.CreateRenderPipeline(&descriptor);
-            }
-        }
+          if (!success) {
+              ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+          } else {
+              device.CreateRenderPipeline(&descriptor);
+          }
+      }
 };
 
 // Check an empty input state is valid
 TEST_F(InputStateTest, EmptyIsOk) {
-    dawn::InputState state = AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .GetResult();
-
-    CreatePipeline(true, state, R"(
+    CreatePipeline(true, nullptr, R"(
         #version 450
         void main() {
             gl_Position = vec4(0.0);
@@ -64,31 +67,30 @@
 
 // Check validation that pipeline vertex inputs are backed by attributes in the input state
 TEST_F(InputStateTest, PipelineCompatibility) {
-    dawn::VertexAttributeDescriptor attribute1;
-    attribute1.shaderLocation = 0;
-    attribute1.inputSlot = 0;
-    attribute1.offset = 0;
-    attribute1.format = dawn::VertexFormat::Float;
+    dawn::VertexAttributeDescriptor attribute[2];
+    attribute[0].shaderLocation = 0;
+    attribute[0].inputSlot = 0;
+    attribute[0].offset = 0;
+    attribute[0].format = dawn::VertexFormat::Float;
 
-    dawn::VertexAttributeDescriptor attribute2;
-    attribute2.shaderLocation = 1;
-    attribute2.inputSlot = 0;
-    attribute2.offset = sizeof(float);
-    attribute2.format = dawn::VertexFormat::Float;
+    attribute[1].shaderLocation = 1;
+    attribute[1].inputSlot = 0;
+    attribute[1].offset = sizeof(float);
+    attribute[1].format = dawn::VertexFormat::Float;
 
     dawn::VertexInputDescriptor input;
     input.inputSlot = 0;
     input.stride = 2 * sizeof(float);
     input.stepMode = dawn::InputStepMode::Vertex;
 
-    dawn::InputState state = AssertWillBeSuccess(device.CreateInputStateBuilder())
-                                 .SetInput(&input)
-                                 .SetAttribute(&attribute1)
-                                 .SetAttribute(&attribute2)
-                                 .GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &input;
+    state.numAttributes = 2;
+    state.attributes = attribute;
 
     // Control case: pipeline with one input per attribute
-    CreatePipeline(true, state, R"(
+    CreatePipeline(true, &state, R"(
         #version 450
         layout(location = 0) in vec4 a;
         layout(location = 1) in vec4 b;
@@ -98,7 +100,7 @@
     )");
 
     // Check it is valid for the pipeline to use a subset of the InputState
-    CreatePipeline(true, state, R"(
+    CreatePipeline(true, &state, R"(
         #version 450
         layout(location = 0) in vec4 a;
         void main() {
@@ -107,7 +109,7 @@
     )");
 
     // Check for an error when the pipeline uses an attribute not in the input state
-    CreatePipeline(false, state, R"(
+    CreatePipeline(false, &state, R"(
         #version 450
         layout(location = 2) in vec4 a;
         void main() {
@@ -119,7 +121,17 @@
 // Test that a stride of 0 is valid
 TEST_F(InputStateTest, StrideZero) {
     // Works ok without attributes
-    AssertWillBeSuccess(device.CreateInputStateBuilder()).SetInput(&kBaseInput).GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 0;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Works ok with attributes at a large-ish offset
     dawn::VertexAttributeDescriptor attribute;
@@ -128,22 +140,42 @@
     attribute.offset = 128;
     attribute.format = dawn::VertexFormat::Float;
 
-    AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Test that we cannot set an already set input
 TEST_F(InputStateTest, AlreadySetInput) {
     // Control case
-    AssertWillBeSuccess(device.CreateInputStateBuilder()).SetInput(&kBaseInput).GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 0;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Oh no, input 0 is set twice
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetInput(&kBaseInput)
-        .GetResult();
+    dawn::VertexInputDescriptor vertexInput[2] = {kBaseInput, kBaseInput};
+    state.numInputs = 2;
+    state.inputs = vertexInput;
+
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check out of bounds condition on input slot
@@ -154,11 +186,26 @@
     input.stride = 0;
     input.stepMode = dawn::InputStepMode::Vertex;
 
-    AssertWillBeSuccess(device.CreateInputStateBuilder()).SetInput(&input).GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &input;
+    state.numAttributes = 0;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Test input slot OOB
     input.inputSlot = kMaxVertexInputs;
-    AssertWillBeError(device.CreateInputStateBuilder()).SetInput(&input).GetResult();
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check out of bounds condition on input stride
@@ -168,11 +215,27 @@
     input.inputSlot = 0;
     input.stride = kMaxVertexInputStride;
     input.stepMode = dawn::InputStepMode::Vertex;
-    AssertWillBeSuccess(device.CreateInputStateBuilder()).SetInput(&input).GetResult();
+
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &input;
+    state.numAttributes = 0;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Test input stride OOB
     input.stride = kMaxVertexInputStride + 1;
-    AssertWillBeError(device.CreateInputStateBuilder()).SetInput(&input).GetResult();
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Test that we cannot set an already set attribute
@@ -184,17 +247,30 @@
     attribute.offset = 0;
     attribute.format = dawn::VertexFormat::Float;
 
-    AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Oh no, attribute 0 is set twice
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .SetAttribute(&attribute)
-        .GetResult();
+    dawn::VertexAttributeDescriptor vertexAttribute[2] = {attribute, attribute};
+    state.numAttributes = 2;
+    state.attributes = vertexAttribute;
+
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check out of bounds condition on attribute shader location
@@ -206,17 +282,27 @@
     attribute.offset = 0;
     attribute.format = dawn::VertexFormat::Float;
 
-    AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Test attribute location OOB
     attribute.shaderLocation = kMaxVertexAttributes;
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check attribute offset out of bounds
@@ -227,17 +313,28 @@
     attribute.inputSlot = 0;
     attribute.offset = kMaxVertexAttributeEnd - sizeof(dawn::VertexFormat::Float);
     attribute.format = dawn::VertexFormat::Float;
-    AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Test attribute offset out of bounds
     attribute.offset = kMaxVertexAttributeEnd - 1;
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check attribute offset overflow
@@ -247,10 +344,19 @@
     attribute.inputSlot = 0;
     attribute.offset = std::numeric_limits<uint32_t>::max();
     attribute.format = dawn::VertexFormat::Float;
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check that all attributes must be backed by an input
@@ -262,17 +368,27 @@
     attribute.offset = 0;
     attribute.format = dawn::VertexFormat::Float;
 
-    AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Attribute 0 uses input 1 which doesn't exist
     attribute.inputSlot = 1;
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
 
 // Check OOB checks for an attribute's input
@@ -284,15 +400,25 @@
     attribute.offset = 0;
     attribute.format = dawn::VertexFormat::Float;
 
-    AssertWillBeSuccess(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    dawn::InputStateDescriptor state;
+    state.numInputs = 1;
+    state.inputs = &kBaseInput;
+    state.numAttributes = 1;
+    state.attributes = &attribute;
+
+    CreatePipeline(true, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 
     // Could crash if we didn't check for OOB
     attribute.inputSlot = 1000000;
-    AssertWillBeError(device.CreateInputStateBuilder())
-        .SetInput(&kBaseInput)
-        .SetAttribute(&attribute)
-        .GetResult();
+    CreatePipeline(false, &state, R"(
+        #version 450
+        void main() {
+            gl_Position = vec4(0.0);
+        }
+    )");
 }
diff --git a/src/tests/unittests/validation/VertexBufferValidationTests.cpp b/src/tests/unittests/validation/VertexBufferValidationTests.cpp
index 06a0825..29c1c16 100644
--- a/src/tests/unittests/validation/VertexBufferValidationTests.cpp
+++ b/src/tests/unittests/validation/VertexBufferValidationTests.cpp
@@ -67,8 +67,12 @@
             return utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vs.str().c_str());
         }
 
-        dawn::InputState MakeInputState(unsigned int numInputs) {
-            auto builder = device.CreateInputStateBuilder();
+        dawn::RenderPipeline MakeRenderPipeline(const dawn::ShaderModule& vsModule,
+                                                unsigned int numInputs) {
+            utils::ComboRenderPipelineDescriptor descriptor(device);
+            descriptor.cVertexStage.module = vsModule;
+            descriptor.cFragmentStage.module = fsModule;
+
             dawn::VertexAttributeDescriptor attribute;
             attribute.offset = 0;
             attribute.format = dawn::VertexFormat::Float3;
@@ -77,22 +81,19 @@
             input.stride = 0;
             input.stepMode = dawn::InputStepMode::Vertex;
 
+            dawn::VertexInputDescriptor vertexInputs[kMaxVertexInputs];
+            dawn::VertexAttributeDescriptor vertexAttributes[kMaxVertexAttributes];
             for (unsigned int i = 0; i < numInputs; ++i) {
                 attribute.shaderLocation = i;
                 attribute.inputSlot = i;
                 input.inputSlot = i;
-                builder.SetAttribute(&attribute);
-                builder.SetInput(&input);
+                vertexInputs[i] = input;
+                vertexAttributes[i] = attribute;
             }
-            return builder.GetResult();
-        }
-
-        dawn::RenderPipeline MakeRenderPipeline(const dawn::ShaderModule& vsModule, const dawn::InputState& inputState) {
-
-            utils::ComboRenderPipelineDescriptor descriptor(device);
-            descriptor.cVertexStage.module = vsModule;
-            descriptor.cFragmentStage.module = fsModule;
-            descriptor.inputState = inputState;
+            descriptor.cInputState.numInputs = numInputs;
+            descriptor.cInputState.inputs = vertexInputs;
+            descriptor.cInputState.numAttributes = numInputs;
+            descriptor.cInputState.attributes = vertexAttributes;
 
             return device.CreateRenderPipeline(&descriptor);
         }
@@ -105,11 +106,8 @@
     auto vsModule2 = MakeVertexShader(2);
     auto vsModule1 = MakeVertexShader(1);
 
-    auto inputState2 = MakeInputState(2);
-    auto inputState1 = MakeInputState(1);
-
-    auto pipeline2 = MakeRenderPipeline(vsModule2, inputState2);
-    auto pipeline1 = MakeRenderPipeline(vsModule1, inputState1);
+    auto pipeline2 = MakeRenderPipeline(vsModule2, 2);
+    auto pipeline1 = MakeRenderPipeline(vsModule1, 1);
 
     auto vertexBuffers = MakeVertexBuffers<2>();
     uint32_t offsets[] = { 0, 0 };
@@ -143,11 +141,8 @@
     auto vsModule2 = MakeVertexShader(2);
     auto vsModule1 = MakeVertexShader(1);
 
-    auto inputState2 = MakeInputState(2);
-    auto inputState1 = MakeInputState(1);
-
-    auto pipeline2 = MakeRenderPipeline(vsModule2, inputState2);
-    auto pipeline1 = MakeRenderPipeline(vsModule1, inputState1);
+    auto pipeline2 = MakeRenderPipeline(vsModule2, 2);
+    auto pipeline1 = MakeRenderPipeline(vsModule1, 1);
 
     auto vertexBuffers = MakeVertexBuffers<2>();
     uint32_t offsets[] = { 0, 0 };
diff --git a/src/tests/unittests/wire/WireArgumentTests.cpp b/src/tests/unittests/wire/WireArgumentTests.cpp
index 8576202..8a98d56 100644
--- a/src/tests/unittests/wire/WireArgumentTests.cpp
+++ b/src/tests/unittests/wire/WireArgumentTests.cpp
@@ -101,15 +101,12 @@
     colorStateDescriptor.colorWriteMask = DAWN_COLOR_WRITE_MASK_ALL;
 
     // Create the input state
-    DawnInputStateBuilder inputStateBuilder = dawnDeviceCreateInputStateBuilder(device);
-    DawnInputStateBuilder apiInputStateBuilder = api.GetNewInputStateBuilder();
-    EXPECT_CALL(api, DeviceCreateInputStateBuilder(apiDevice))
-        .WillOnce(Return(apiInputStateBuilder));
-
-    DawnInputState inputState = dawnInputStateBuilderGetResult(inputStateBuilder);
-    DawnInputState apiInputState = api.GetNewInputState();
-    EXPECT_CALL(api, InputStateBuilderGetResult(apiInputStateBuilder))
-        .WillOnce(Return(apiInputState));
+    DawnInputStateDescriptor inputState;
+    inputState.nextInChain = nullptr;
+    inputState.numInputs = 0;
+    inputState.inputs = nullptr;
+    inputState.numAttributes = 0;
+    inputState.attributes = nullptr;
 
     // Create the depth-stencil state
     DawnStencilStateFaceDescriptor stencilFace;
@@ -159,7 +156,7 @@
 
     pipelineDescriptor.sampleCount = 1;
     pipelineDescriptor.layout = layout;
-    pipelineDescriptor.inputState = inputState;
+    pipelineDescriptor.inputState = &inputState;
     pipelineDescriptor.indexFormat = DAWN_INDEX_FORMAT_UINT32;
     pipelineDescriptor.primitiveTopology = DAWN_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
     pipelineDescriptor.depthStencilState = &depthStencilState;
@@ -172,8 +169,6 @@
                     })))
         .WillOnce(Return(nullptr));
     EXPECT_CALL(api, ShaderModuleRelease(apiVsModule));
-    EXPECT_CALL(api, InputStateBuilderRelease(apiInputStateBuilder));
-    EXPECT_CALL(api, InputStateRelease(apiInputState));
     EXPECT_CALL(api, PipelineLayoutRelease(apiLayout));
 
     FlushClient();
diff --git a/src/tests/unittests/wire/WireOptionalTests.cpp b/src/tests/unittests/wire/WireOptionalTests.cpp
index 6b702a7..24d781d 100644
--- a/src/tests/unittests/wire/WireOptionalTests.cpp
+++ b/src/tests/unittests/wire/WireOptionalTests.cpp
@@ -86,15 +86,12 @@
     colorStateDescriptor.colorWriteMask = DAWN_COLOR_WRITE_MASK_ALL;
 
     // Create the input state
-    DawnInputStateBuilder inputStateBuilder = dawnDeviceCreateInputStateBuilder(device);
-    DawnInputStateBuilder apiInputStateBuilder = api.GetNewInputStateBuilder();
-    EXPECT_CALL(api, DeviceCreateInputStateBuilder(apiDevice))
-        .WillOnce(Return(apiInputStateBuilder));
-
-    DawnInputState inputState = dawnInputStateBuilderGetResult(inputStateBuilder);
-    DawnInputState apiInputState = api.GetNewInputState();
-    EXPECT_CALL(api, InputStateBuilderGetResult(apiInputStateBuilder))
-        .WillOnce(Return(apiInputState));
+    DawnInputStateDescriptor inputState;
+    inputState.nextInChain = nullptr;
+    inputState.numInputs = 0;
+    inputState.inputs = nullptr;
+    inputState.numAttributes = 0;
+    inputState.attributes = nullptr;
 
     // Create the depth-stencil state
     DawnStencilStateFaceDescriptor stencilFace;
@@ -144,7 +141,7 @@
 
     pipelineDescriptor.sampleCount = 1;
     pipelineDescriptor.layout = layout;
-    pipelineDescriptor.inputState = inputState;
+    pipelineDescriptor.inputState = &inputState;
     pipelineDescriptor.indexFormat = DAWN_INDEX_FORMAT_UINT32;
     pipelineDescriptor.primitiveTopology = DAWN_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
 
@@ -191,8 +188,6 @@
         .WillOnce(Return(nullptr));
 
     EXPECT_CALL(api, ShaderModuleRelease(apiVsModule));
-    EXPECT_CALL(api, InputStateBuilderRelease(apiInputStateBuilder));
-    EXPECT_CALL(api, InputStateRelease(apiInputState));
     EXPECT_CALL(api, PipelineLayoutRelease(apiLayout));
 
     FlushClient();
diff --git a/src/utils/ComboRenderPipelineDescriptor.cpp b/src/utils/ComboRenderPipelineDescriptor.cpp
index 03da2e7..f7bf479 100644
--- a/src/utils/ComboRenderPipelineDescriptor.cpp
+++ b/src/utils/ComboRenderPipelineDescriptor.cpp
@@ -37,6 +37,15 @@
             cFragmentStage.entryPoint = "main";
         }
 
+        // Set defaults for the input state descriptors.
+        {
+            descriptor->inputState = &cInputState;
+            cInputState.numInputs = 0;
+            cInputState.inputs = nullptr;
+            cInputState.numAttributes = 0;
+            cInputState.attributes = nullptr;
+        }
+
         // Set defaults for the color state descriptors.
         {
             descriptor->colorStateCount = 1;
@@ -75,7 +84,6 @@
             descriptor->depthStencilState = nullptr;
         }
 
-        descriptor->inputState = device.CreateInputStateBuilder().GetResult();
         descriptor->layout = utils::MakeBasicPipelineLayout(device, nullptr);
     }
 
diff --git a/src/utils/ComboRenderPipelineDescriptor.h b/src/utils/ComboRenderPipelineDescriptor.h
index d5e50e9..f76d7a5 100644
--- a/src/utils/ComboRenderPipelineDescriptor.h
+++ b/src/utils/ComboRenderPipelineDescriptor.h
@@ -30,6 +30,7 @@
         dawn::PipelineStageDescriptor cVertexStage;

         dawn::PipelineStageDescriptor cFragmentStage;

 

+        dawn::InputStateDescriptor cInputState;

         std::array<dawn::ColorStateDescriptor*, kMaxColorAttachments> cColorStates;

         dawn::DepthStencilStateDescriptor cDepthStencilState;