Update RenderPipelineBase to stop depending on deprecated struct types

Bug: dawn:642
Change-Id: Ibc9d8f87735864dcafb3ec68013e4590602af855
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/45360
Commit-Queue: Brandon Jones <bajones@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 00c8fd0..7b6d367 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -127,10 +127,8 @@
         }
 
         if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) {
-            wgpu::IndexFormat pipelineIndexFormat =
-                mLastRenderPipeline->GetVertexStateDescriptor()->indexFormat;
             if (!IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) ||
-                mIndexFormat == pipelineIndexFormat) {
+                mIndexFormat == mLastRenderPipeline->GetStripIndexFormat()) {
                 mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
             }
         }
@@ -142,8 +140,7 @@
         }
 
         if (aspects[VALIDATION_ASPECT_INDEX_BUFFER]) {
-            wgpu::IndexFormat pipelineIndexFormat =
-                mLastRenderPipeline->GetVertexStateDescriptor()->indexFormat;
+            wgpu::IndexFormat pipelineIndexFormat = mLastRenderPipeline->GetStripIndexFormat();
             if (!mIndexBufferSet) {
                 return DAWN_VALIDATION_ERROR("Missing index buffer");
             } else if (IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) &&
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 6cfaccf..2e41888 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -302,15 +302,15 @@
         return {};
     }
 
-    bool StencilTestEnabled(const DepthStencilStateDescriptor* mDepthStencilState) {
-        return mDepthStencilState->stencilBack.compare != wgpu::CompareFunction::Always ||
-               mDepthStencilState->stencilBack.failOp != wgpu::StencilOperation::Keep ||
-               mDepthStencilState->stencilBack.depthFailOp != wgpu::StencilOperation::Keep ||
-               mDepthStencilState->stencilBack.passOp != wgpu::StencilOperation::Keep ||
-               mDepthStencilState->stencilFront.compare != wgpu::CompareFunction::Always ||
-               mDepthStencilState->stencilFront.failOp != wgpu::StencilOperation::Keep ||
-               mDepthStencilState->stencilFront.depthFailOp != wgpu::StencilOperation::Keep ||
-               mDepthStencilState->stencilFront.passOp != wgpu::StencilOperation::Keep;
+    bool StencilTestEnabled(const DepthStencilState* mDepthStencil) {
+        return mDepthStencil->stencilBack.compare != wgpu::CompareFunction::Always ||
+               mDepthStencil->stencilBack.failOp != wgpu::StencilOperation::Keep ||
+               mDepthStencil->stencilBack.depthFailOp != wgpu::StencilOperation::Keep ||
+               mDepthStencil->stencilBack.passOp != wgpu::StencilOperation::Keep ||
+               mDepthStencil->stencilFront.compare != wgpu::CompareFunction::Always ||
+               mDepthStencil->stencilFront.failOp != wgpu::StencilOperation::Keep ||
+               mDepthStencil->stencilFront.depthFailOp != wgpu::StencilOperation::Keep ||
+               mDepthStencil->stencilFront.passOp != wgpu::StencilOperation::Keep;
     }
 
     bool BlendEnabled(const ColorStateDescriptor* mColorState) {
@@ -330,71 +330,104 @@
                        descriptor->layout,
                        {{SingleShaderStage::Vertex, &descriptor->vertexStage},
                         {SingleShaderStage::Fragment, descriptor->fragmentStage}}),
-          mAttachmentState(device->GetOrCreateAttachmentState(descriptor)),
-          mPrimitiveTopology(descriptor->primitiveTopology),
-          mSampleMask(descriptor->sampleMask),
-          mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled) {
+          mAttachmentState(device->GetOrCreateAttachmentState(descriptor)) {
+        mPrimitive.topology = descriptor->primitiveTopology;
+
+        mMultisample.count = descriptor->sampleCount;
+        mMultisample.mask = descriptor->sampleMask;
+        mMultisample.alphaToCoverageEnabled = descriptor->alphaToCoverageEnabled;
+
         if (descriptor->vertexState != nullptr) {
-            mVertexState = *descriptor->vertexState;
-        } else {
-            mVertexState = VertexStateDescriptor();
-        }
+            const VertexStateDescriptor& vertexState = *descriptor->vertexState;
+            mVertexBufferCount = vertexState.vertexBufferCount;
+            mPrimitive.stripIndexFormat = vertexState.indexFormat;
 
-        for (uint8_t slot = 0; slot < mVertexState.vertexBufferCount; ++slot) {
-            if (mVertexState.vertexBuffers[slot].attributeCount == 0) {
-                continue;
+            for (uint8_t slot = 0; slot < mVertexBufferCount; ++slot) {
+                if (vertexState.vertexBuffers[slot].attributeCount == 0) {
+                    continue;
+                }
+
+                VertexBufferSlot typedSlot(slot);
+
+                mVertexBufferSlotsUsed.set(typedSlot);
+                mVertexBufferInfos[typedSlot].arrayStride =
+                    vertexState.vertexBuffers[slot].arrayStride;
+                mVertexBufferInfos[typedSlot].stepMode = vertexState.vertexBuffers[slot].stepMode;
+
+                for (uint32_t i = 0; i < vertexState.vertexBuffers[slot].attributeCount; ++i) {
+                    VertexAttributeLocation location = VertexAttributeLocation(static_cast<uint8_t>(
+                        vertexState.vertexBuffers[slot].attributes[i].shaderLocation));
+                    mAttributeLocationsUsed.set(location);
+                    mAttributeInfos[location].shaderLocation = location;
+                    mAttributeInfos[location].vertexBufferSlot = typedSlot;
+                    mAttributeInfos[location].offset =
+                        vertexState.vertexBuffers[slot].attributes[i].offset;
+
+                    mAttributeInfos[location].format = dawn::NormalizeVertexFormat(
+                        vertexState.vertexBuffers[slot].attributes[i].format);
+                }
             }
-
-            VertexBufferSlot typedSlot(slot);
-
-            mVertexBufferSlotsUsed.set(typedSlot);
-            mVertexBufferInfos[typedSlot].arrayStride =
-                mVertexState.vertexBuffers[slot].arrayStride;
-            mVertexBufferInfos[typedSlot].stepMode = mVertexState.vertexBuffers[slot].stepMode;
-
-            for (uint32_t i = 0; i < mVertexState.vertexBuffers[slot].attributeCount; ++i) {
-                VertexAttributeLocation location = VertexAttributeLocation(static_cast<uint8_t>(
-                    mVertexState.vertexBuffers[slot].attributes[i].shaderLocation));
-                mAttributeLocationsUsed.set(location);
-                mAttributeInfos[location].shaderLocation = location;
-                mAttributeInfos[location].vertexBufferSlot = typedSlot;
-                mAttributeInfos[location].offset =
-                    mVertexState.vertexBuffers[slot].attributes[i].offset;
-
-                mAttributeInfos[location].format = dawn::NormalizeVertexFormat(
-                    mVertexState.vertexBuffers[slot].attributes[i].format);
-            }
-        }
-
-        if (descriptor->rasterizationState != nullptr) {
-            mRasterizationState = *descriptor->rasterizationState;
         } else {
-            mRasterizationState = RasterizationStateDescriptor();
+            mVertexBufferCount = 0;
+            mPrimitive.stripIndexFormat = wgpu::IndexFormat::Undefined;
         }
 
         if (mAttachmentState->HasDepthStencilAttachment()) {
-            mDepthStencilState = *descriptor->depthStencilState;
+            const DepthStencilStateDescriptor& depthStencil = *descriptor->depthStencilState;
+            mDepthStencil.format = depthStencil.format;
+            mDepthStencil.depthWriteEnabled = depthStencil.depthWriteEnabled;
+            mDepthStencil.depthCompare = depthStencil.depthCompare;
+            mDepthStencil.stencilBack = depthStencil.stencilBack;
+            mDepthStencil.stencilFront = depthStencil.stencilFront;
+            mDepthStencil.stencilReadMask = depthStencil.stencilReadMask;
+            mDepthStencil.stencilWriteMask = depthStencil.stencilWriteMask;
         } else {
             // These default values below are useful for backends to fill information.
             // The values indicate that depth and stencil test are disabled when backends
             // set their own depth stencil states/descriptors according to the values in
-            // mDepthStencilState.
-            mDepthStencilState.depthCompare = wgpu::CompareFunction::Always;
-            mDepthStencilState.depthWriteEnabled = false;
-            mDepthStencilState.stencilBack.compare = wgpu::CompareFunction::Always;
-            mDepthStencilState.stencilBack.failOp = wgpu::StencilOperation::Keep;
-            mDepthStencilState.stencilBack.depthFailOp = wgpu::StencilOperation::Keep;
-            mDepthStencilState.stencilBack.passOp = wgpu::StencilOperation::Keep;
-            mDepthStencilState.stencilFront.compare = wgpu::CompareFunction::Always;
-            mDepthStencilState.stencilFront.failOp = wgpu::StencilOperation::Keep;
-            mDepthStencilState.stencilFront.depthFailOp = wgpu::StencilOperation::Keep;
-            mDepthStencilState.stencilFront.passOp = wgpu::StencilOperation::Keep;
-            mDepthStencilState.stencilReadMask = 0xff;
-            mDepthStencilState.stencilWriteMask = 0xff;
+            // mDepthStencil.
+            mDepthStencil.format = wgpu::TextureFormat::Undefined;
+            mDepthStencil.depthWriteEnabled = false;
+            mDepthStencil.depthCompare = wgpu::CompareFunction::Always;
+            mDepthStencil.stencilBack.compare = wgpu::CompareFunction::Always;
+            mDepthStencil.stencilBack.failOp = wgpu::StencilOperation::Keep;
+            mDepthStencil.stencilBack.depthFailOp = wgpu::StencilOperation::Keep;
+            mDepthStencil.stencilBack.passOp = wgpu::StencilOperation::Keep;
+            mDepthStencil.stencilFront.compare = wgpu::CompareFunction::Always;
+            mDepthStencil.stencilFront.failOp = wgpu::StencilOperation::Keep;
+            mDepthStencil.stencilFront.depthFailOp = wgpu::StencilOperation::Keep;
+            mDepthStencil.stencilFront.passOp = wgpu::StencilOperation::Keep;
+            mDepthStencil.stencilReadMask = 0xff;
+            mDepthStencil.stencilWriteMask = 0xff;
+        }
+
+        if (descriptor->rasterizationState != nullptr) {
+            mPrimitive.frontFace = descriptor->rasterizationState->frontFace;
+            mPrimitive.cullMode = descriptor->rasterizationState->cullMode;
+            mDepthStencil.depthBias = descriptor->rasterizationState->depthBias;
+            mDepthStencil.depthBiasSlopeScale = descriptor->rasterizationState->depthBiasSlopeScale;
+            mDepthStencil.depthBiasClamp = descriptor->rasterizationState->depthBiasClamp;
+        } else {
+            mPrimitive.frontFace = wgpu::FrontFace::CCW;
+            mPrimitive.cullMode = wgpu::CullMode::None;
+            mDepthStencil.depthBias = 0;
+            mDepthStencil.depthBiasSlopeScale = 0.0f;
+            mDepthStencil.depthBiasClamp = 0.0f;
         }
 
         for (ColorAttachmentIndex i : IterateBitSet(mAttachmentState->GetColorAttachmentsMask())) {
-            mColorStates[i] = descriptor->colorStates[static_cast<uint8_t>(i)];
+            const ColorStateDescriptor* colorState =
+                &descriptor->colorStates[static_cast<uint8_t>(i)];
+            mTargets[i].format = colorState->format;
+            mTargets[i].writeMask = colorState->writeMask;
+
+            if (BlendEnabled(colorState)) {
+                mTargetBlend[i].color = colorState->colorBlend;
+                mTargetBlend[i].alpha = colorState->alphaBlend;
+                mTargets[i].blend = &mTargetBlend[i];
+            } else {
+                mTargets[i].blend = nullptr;
+            }
         }
 
         // TODO(cwallez@chromium.org): Check against the shader module that the correct color
@@ -416,11 +449,6 @@
         }
     }
 
-    const VertexStateDescriptor* RenderPipelineBase::GetVertexStateDescriptor() const {
-        ASSERT(!IsError());
-        return &mVertexState;
-    }
-
     const ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>&
     RenderPipelineBase::GetAttributeLocationsUsed() const {
         ASSERT(!IsError());
@@ -446,51 +474,61 @@
         return mVertexBufferInfos[slot];
     }
 
-    const ColorStateDescriptor* RenderPipelineBase::GetColorStateDescriptor(
-        ColorAttachmentIndex attachmentSlot) const {
+    uint32_t RenderPipelineBase::GetVertexBufferCount() const {
         ASSERT(!IsError());
-        ASSERT(attachmentSlot < mColorStates.size());
-        return &mColorStates[attachmentSlot];
+        return mVertexBufferCount;
     }
 
-    const DepthStencilStateDescriptor* RenderPipelineBase::GetDepthStencilStateDescriptor() const {
+    const ColorTargetState* RenderPipelineBase::GetColorTargetState(
+        ColorAttachmentIndex attachmentSlot) const {
         ASSERT(!IsError());
-        return &mDepthStencilState;
+        ASSERT(attachmentSlot < mTargets.size());
+        return &mTargets[attachmentSlot];
+    }
+
+    const DepthStencilState* RenderPipelineBase::GetDepthStencilState() const {
+        ASSERT(!IsError());
+        return &mDepthStencil;
     }
 
     wgpu::PrimitiveTopology RenderPipelineBase::GetPrimitiveTopology() const {
         ASSERT(!IsError());
-        return mPrimitiveTopology;
+        return mPrimitive.topology;
+    }
+
+    wgpu::IndexFormat RenderPipelineBase::GetStripIndexFormat() const {
+        ASSERT(!IsError());
+        return mPrimitive.stripIndexFormat;
     }
 
     wgpu::CullMode RenderPipelineBase::GetCullMode() const {
         ASSERT(!IsError());
-        return mRasterizationState.cullMode;
+        return mPrimitive.cullMode;
     }
 
     wgpu::FrontFace RenderPipelineBase::GetFrontFace() const {
         ASSERT(!IsError());
-        return mRasterizationState.frontFace;
+        return mPrimitive.frontFace;
     }
 
     bool RenderPipelineBase::IsDepthBiasEnabled() const {
         ASSERT(!IsError());
-        return mRasterizationState.depthBias != 0 || mRasterizationState.depthBiasSlopeScale != 0;
+        return mDepthStencil.depthBias != 0 || mDepthStencil.depthBiasSlopeScale != 0;
     }
 
     int32_t RenderPipelineBase::GetDepthBias() const {
         ASSERT(!IsError());
-        return mRasterizationState.depthBias;
+        return mDepthStencil.depthBias;
     }
 
     float RenderPipelineBase::GetDepthBiasSlopeScale() const {
         ASSERT(!IsError());
-        return mRasterizationState.depthBiasSlopeScale;
+        return mDepthStencil.depthBiasSlopeScale;
     }
 
     float RenderPipelineBase::GetDepthBiasClamp() const {
         ASSERT(!IsError());
-        return mRasterizationState.depthBiasClamp;
+        return mDepthStencil.depthBiasClamp;
     }
 
     ityp::bitset<ColorAttachmentIndex, kMaxColorAttachments>
@@ -507,13 +545,13 @@
     wgpu::TextureFormat RenderPipelineBase::GetColorAttachmentFormat(
         ColorAttachmentIndex attachment) const {
         ASSERT(!IsError());
-        return mColorStates[attachment].format;
+        return mTargets[attachment].format;
     }
 
     wgpu::TextureFormat RenderPipelineBase::GetDepthStencilFormat() const {
         ASSERT(!IsError());
         ASSERT(mAttachmentState->HasDepthStencilAttachment());
-        return mDepthStencilState.format;
+        return mDepthStencil.format;
     }
 
     uint32_t RenderPipelineBase::GetSampleCount() const {
@@ -523,12 +561,12 @@
 
     uint32_t RenderPipelineBase::GetSampleMask() const {
         ASSERT(!IsError());
-        return mSampleMask;
+        return mMultisample.mask;
     }
 
     bool RenderPipelineBase::IsAlphaToCoverageEnabled() const {
         ASSERT(!IsError());
-        return mAlphaToCoverageEnabled;
+        return mMultisample.alphaToCoverageEnabled;
     }
 
     const AttachmentState* RenderPipelineBase::GetAttachmentState() const {
@@ -549,22 +587,25 @@
 
         // Record attachments
         for (ColorAttachmentIndex i : IterateBitSet(mAttachmentState->GetColorAttachmentsMask())) {
-            const ColorStateDescriptor& desc = *GetColorStateDescriptor(i);
+            const ColorTargetState& desc = *GetColorTargetState(i);
             recorder.Record(desc.writeMask);
-            recorder.Record(desc.colorBlend.operation, desc.colorBlend.srcFactor,
-                            desc.colorBlend.dstFactor);
-            recorder.Record(desc.alphaBlend.operation, desc.alphaBlend.srcFactor,
-                            desc.alphaBlend.dstFactor);
+            if (desc.blend != nullptr) {
+                recorder.Record(desc.blend->color.operation, desc.blend->color.srcFactor,
+                                desc.blend->color.dstFactor);
+                recorder.Record(desc.blend->alpha.operation, desc.blend->alpha.srcFactor,
+                                desc.blend->alpha.dstFactor);
+            }
         }
 
         if (mAttachmentState->HasDepthStencilAttachment()) {
-            const DepthStencilStateDescriptor& desc = mDepthStencilState;
+            const DepthStencilState& desc = mDepthStencil;
             recorder.Record(desc.depthWriteEnabled, desc.depthCompare);
             recorder.Record(desc.stencilReadMask, desc.stencilWriteMask);
             recorder.Record(desc.stencilFront.compare, desc.stencilFront.failOp,
                             desc.stencilFront.depthFailOp, desc.stencilFront.passOp);
             recorder.Record(desc.stencilBack.compare, desc.stencilBack.failOp,
                             desc.stencilBack.depthFailOp, desc.stencilBack.passOp);
+            recorder.Record(desc.depthBias, desc.depthBiasSlopeScale, desc.depthBiasClamp);
         }
 
         // Record vertex state
@@ -580,17 +621,13 @@
             recorder.Record(desc.arrayStride, desc.stepMode);
         }
 
-        recorder.Record(mVertexState.indexFormat);
+        // Record primitive state
+        recorder.Record(mPrimitive.topology, mPrimitive.stripIndexFormat, mPrimitive.frontFace,
+                        mPrimitive.cullMode);
 
-        // Record rasterization state
-        {
-            const RasterizationStateDescriptor& desc = mRasterizationState;
-            recorder.Record(desc.frontFace, desc.cullMode);
-            recorder.Record(desc.depthBias, desc.depthBiasSlopeScale, desc.depthBiasClamp);
-        }
-
-        // Record other state
-        recorder.Record(mPrimitiveTopology, mSampleMask, mAlphaToCoverageEnabled);
+        // Record multisample state
+        // Sample count hashed as part of the attachment state
+        recorder.Record(mMultisample.mask, mMultisample.alphaToCoverageEnabled);
 
         return recorder.GetContentHash();
     }
@@ -610,44 +647,59 @@
 
         for (ColorAttachmentIndex i :
              IterateBitSet(a->mAttachmentState->GetColorAttachmentsMask())) {
-            const ColorStateDescriptor& descA = *a->GetColorStateDescriptor(i);
-            const ColorStateDescriptor& descB = *b->GetColorStateDescriptor(i);
+            const ColorTargetState& descA = *a->GetColorTargetState(i);
+            const ColorTargetState& descB = *b->GetColorTargetState(i);
             if (descA.writeMask != descB.writeMask) {
                 return false;
             }
-            if (descA.colorBlend.operation != descB.colorBlend.operation ||
-                descA.colorBlend.srcFactor != descB.colorBlend.srcFactor ||
-                descA.colorBlend.dstFactor != descB.colorBlend.dstFactor) {
+            if ((descA.blend == nullptr) != (descB.blend == nullptr)) {
                 return false;
             }
-            if (descA.alphaBlend.operation != descB.alphaBlend.operation ||
-                descA.alphaBlend.srcFactor != descB.alphaBlend.srcFactor ||
-                descA.alphaBlend.dstFactor != descB.alphaBlend.dstFactor) {
-                return false;
+            if (descA.blend != nullptr) {
+                if (descA.blend->color.operation != descB.blend->color.operation ||
+                    descA.blend->color.srcFactor != descB.blend->color.srcFactor ||
+                    descA.blend->color.dstFactor != descB.blend->color.dstFactor) {
+                    return false;
+                }
+                if (descA.blend->alpha.operation != descB.blend->alpha.operation ||
+                    descA.blend->alpha.srcFactor != descB.blend->alpha.srcFactor ||
+                    descA.blend->alpha.dstFactor != descB.blend->alpha.dstFactor) {
+                    return false;
+                }
             }
         }
 
+        // Check depth/stencil state
         if (a->mAttachmentState->HasDepthStencilAttachment()) {
-            const DepthStencilStateDescriptor& descA = a->mDepthStencilState;
-            const DepthStencilStateDescriptor& descB = b->mDepthStencilState;
-            if (descA.depthWriteEnabled != descB.depthWriteEnabled ||
-                descA.depthCompare != descB.depthCompare) {
+            const DepthStencilState& stateA = a->mDepthStencil;
+            const DepthStencilState& stateB = b->mDepthStencil;
+
+            ASSERT(!std::isnan(stateA.depthBiasSlopeScale));
+            ASSERT(!std::isnan(stateB.depthBiasSlopeScale));
+            ASSERT(!std::isnan(stateA.depthBiasClamp));
+            ASSERT(!std::isnan(stateB.depthBiasClamp));
+
+            if (stateA.depthWriteEnabled != stateB.depthWriteEnabled ||
+                stateA.depthCompare != stateB.depthCompare ||
+                stateA.depthBias != stateB.depthBias ||
+                stateA.depthBiasSlopeScale != stateB.depthBiasSlopeScale ||
+                stateA.depthBiasClamp != stateB.depthBiasClamp) {
                 return false;
             }
-            if (descA.stencilReadMask != descB.stencilReadMask ||
-                descA.stencilWriteMask != descB.stencilWriteMask) {
+            if (stateA.stencilFront.compare != stateB.stencilFront.compare ||
+                stateA.stencilFront.failOp != stateB.stencilFront.failOp ||
+                stateA.stencilFront.depthFailOp != stateB.stencilFront.depthFailOp ||
+                stateA.stencilFront.passOp != stateB.stencilFront.passOp) {
                 return false;
             }
-            if (descA.stencilFront.compare != descB.stencilFront.compare ||
-                descA.stencilFront.failOp != descB.stencilFront.failOp ||
-                descA.stencilFront.depthFailOp != descB.stencilFront.depthFailOp ||
-                descA.stencilFront.passOp != descB.stencilFront.passOp) {
+            if (stateA.stencilBack.compare != stateB.stencilBack.compare ||
+                stateA.stencilBack.failOp != stateB.stencilBack.failOp ||
+                stateA.stencilBack.depthFailOp != stateB.stencilBack.depthFailOp ||
+                stateA.stencilBack.passOp != stateB.stencilBack.passOp) {
                 return false;
             }
-            if (descA.stencilBack.compare != descB.stencilBack.compare ||
-                descA.stencilBack.failOp != descB.stencilBack.failOp ||
-                descA.stencilBack.depthFailOp != descB.stencilBack.depthFailOp ||
-                descA.stencilBack.passOp != descB.stencilBack.passOp) {
+            if (stateA.stencilReadMask != stateB.stencilReadMask ||
+                stateA.stencilWriteMask != stateB.stencilWriteMask) {
                 return false;
             }
         }
@@ -679,34 +731,26 @@
             }
         }
 
-        if (a->mVertexState.indexFormat != b->mVertexState.indexFormat) {
-            return false;
-        }
-
-        // Check rasterization state
+        // Check primitive state
         {
-            const RasterizationStateDescriptor& descA = a->mRasterizationState;
-            const RasterizationStateDescriptor& descB = b->mRasterizationState;
-            if (descA.frontFace != descB.frontFace || descA.cullMode != descB.cullMode) {
-                return false;
-            }
-
-            ASSERT(!std::isnan(descA.depthBiasSlopeScale));
-            ASSERT(!std::isnan(descB.depthBiasSlopeScale));
-            ASSERT(!std::isnan(descA.depthBiasClamp));
-            ASSERT(!std::isnan(descB.depthBiasClamp));
-
-            if (descA.depthBias != descB.depthBias ||
-                descA.depthBiasSlopeScale != descB.depthBiasSlopeScale ||
-                descA.depthBiasClamp != descB.depthBiasClamp) {
+            const PrimitiveState& stateA = a->mPrimitive;
+            const PrimitiveState& stateB = b->mPrimitive;
+            if (stateA.topology != stateB.topology ||
+                stateA.stripIndexFormat != stateB.stripIndexFormat ||
+                stateA.frontFace != stateB.frontFace || stateA.cullMode != stateB.cullMode) {
                 return false;
             }
         }
 
-        // Check other state
-        if (a->mPrimitiveTopology != b->mPrimitiveTopology || a->mSampleMask != b->mSampleMask ||
-            a->mAlphaToCoverageEnabled != b->mAlphaToCoverageEnabled) {
-            return false;
+        // Check multisample state
+        {
+            const MultisampleState& stateA = a->mMultisample;
+            const MultisampleState& stateB = b->mMultisample;
+            // Sample count already checked as part of the attachment state.
+            if (stateA.mask != stateB.mask ||
+                stateA.alphaToCoverageEnabled != stateB.alphaToCoverageEnabled) {
+                return false;
+            }
         }
 
         return true;
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index 7d24b3f..5259620 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -35,8 +35,7 @@
 
     bool IsStripPrimitiveTopology(wgpu::PrimitiveTopology primitiveTopology);
 
-    bool StencilTestEnabled(const DepthStencilStateDescriptor* mDepthStencilState);
-    bool BlendEnabled(const ColorStateDescriptor* mColorState);
+    bool StencilTestEnabled(const DepthStencilState* mDepthStencil);
 
     struct VertexAttributeInfo {
         wgpu::VertexFormat format;
@@ -57,17 +56,17 @@
 
         static RenderPipelineBase* MakeError(DeviceBase* device);
 
-        const VertexStateDescriptor* GetVertexStateDescriptor() const;
         const ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>&
         GetAttributeLocationsUsed() const;
         const VertexAttributeInfo& GetAttribute(VertexAttributeLocation location) const;
         const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& GetVertexBufferSlotsUsed() const;
         const VertexBufferInfo& GetVertexBuffer(VertexBufferSlot slot) const;
+        uint32_t GetVertexBufferCount() const;
 
-        const ColorStateDescriptor* GetColorStateDescriptor(
-            ColorAttachmentIndex attachmentSlot) const;
-        const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor() const;
+        const ColorTargetState* GetColorTargetState(ColorAttachmentIndex attachmentSlot) const;
+        const DepthStencilState* GetDepthStencilState() const;
         wgpu::PrimitiveTopology GetPrimitiveTopology() const;
+        wgpu::IndexFormat GetStripIndexFormat() const;
         wgpu::CullMode GetCullMode() const;
         wgpu::FrontFace GetFrontFace() const;
         bool IsDepthBiasEnabled() const;
@@ -96,7 +95,7 @@
         RenderPipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
 
         // Vertex state
-        VertexStateDescriptor mVertexState;
+        uint32_t mVertexBufferCount;
         ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> mAttributeLocationsUsed;
         ityp::array<VertexAttributeLocation, VertexAttributeInfo, kMaxVertexAttributes>
             mAttributeInfos;
@@ -105,14 +104,13 @@
 
         // Attachments
         Ref<AttachmentState> mAttachmentState;
-        DepthStencilStateDescriptor mDepthStencilState;
-        ityp::array<ColorAttachmentIndex, ColorStateDescriptor, kMaxColorAttachments> mColorStates;
+        ityp::array<ColorAttachmentIndex, ColorTargetState, kMaxColorAttachments> mTargets;
+        ityp::array<ColorAttachmentIndex, BlendState, kMaxColorAttachments> mTargetBlend;
 
         // Other state
-        wgpu::PrimitiveTopology mPrimitiveTopology;
-        RasterizationStateDescriptor mRasterizationState;
-        uint32_t mSampleMask;
-        bool mAlphaToCoverageEnabled;
+        PrimitiveState mPrimitive;
+        DepthStencilState mDepthStencil;
+        MultisampleState mMultisample;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index 3e39686..371e9a7 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -207,16 +207,18 @@
             return static_cast<uint8_t>(writeMask);
         }
 
-        D3D12_RENDER_TARGET_BLEND_DESC ComputeColorDesc(const ColorStateDescriptor* descriptor) {
+        D3D12_RENDER_TARGET_BLEND_DESC ComputeColorDesc(const ColorTargetState* state) {
             D3D12_RENDER_TARGET_BLEND_DESC blendDesc;
-            blendDesc.BlendEnable = BlendEnabled(descriptor);
-            blendDesc.SrcBlend = D3D12Blend(descriptor->colorBlend.srcFactor);
-            blendDesc.DestBlend = D3D12Blend(descriptor->colorBlend.dstFactor);
-            blendDesc.BlendOp = D3D12BlendOperation(descriptor->colorBlend.operation);
-            blendDesc.SrcBlendAlpha = D3D12Blend(descriptor->alphaBlend.srcFactor);
-            blendDesc.DestBlendAlpha = D3D12Blend(descriptor->alphaBlend.dstFactor);
-            blendDesc.BlendOpAlpha = D3D12BlendOperation(descriptor->alphaBlend.operation);
-            blendDesc.RenderTargetWriteMask = D3D12RenderTargetWriteMask(descriptor->writeMask);
+            blendDesc.BlendEnable = state->blend != nullptr;
+            if (blendDesc.BlendEnable) {
+                blendDesc.SrcBlend = D3D12Blend(state->blend->color.srcFactor);
+                blendDesc.DestBlend = D3D12Blend(state->blend->color.dstFactor);
+                blendDesc.BlendOp = D3D12BlendOperation(state->blend->color.operation);
+                blendDesc.SrcBlendAlpha = D3D12Blend(state->blend->alpha.srcFactor);
+                blendDesc.DestBlendAlpha = D3D12Blend(state->blend->alpha.dstFactor);
+                blendDesc.BlendOpAlpha = D3D12BlendOperation(state->blend->alpha.operation);
+            }
+            blendDesc.RenderTargetWriteMask = D3D12RenderTargetWriteMask(state->writeMask);
             blendDesc.LogicOpEnable = false;
             blendDesc.LogicOp = D3D12_LOGIC_OP_NOOP;
             return blendDesc;
@@ -254,8 +256,7 @@
             return desc;
         }
 
-        D3D12_DEPTH_STENCIL_DESC ComputeDepthStencilDesc(
-            const DepthStencilStateDescriptor* descriptor) {
+        D3D12_DEPTH_STENCIL_DESC ComputeDepthStencilDesc(const DepthStencilState* descriptor) {
             D3D12_DEPTH_STENCIL_DESC mDepthStencilDescriptor;
             mDepthStencilDescriptor.DepthEnable = TRUE;
             mDepthStencilDescriptor.DepthWriteMask = descriptor->depthWriteEnabled
@@ -347,8 +348,8 @@
             descriptorD3D12.InputLayout = ComputeInputLayout(&inputElementDescriptors);
         }
 
-        descriptorD3D12.IBStripCutValue = ComputeIndexBufferStripCutValue(
-            GetPrimitiveTopology(), GetVertexStateDescriptor()->indexFormat);
+        descriptorD3D12.IBStripCutValue =
+            ComputeIndexBufferStripCutValue(GetPrimitiveTopology(), GetStripIndexFormat());
 
         descriptorD3D12.RasterizerState.FillMode = D3D12_FILL_MODE_SOLID;
         descriptorD3D12.RasterizerState.CullMode = D3D12CullMode(GetCullMode());
@@ -372,15 +373,14 @@
             descriptorD3D12.RTVFormats[static_cast<uint8_t>(i)] =
                 D3D12TextureFormat(GetColorAttachmentFormat(i));
             descriptorD3D12.BlendState.RenderTarget[static_cast<uint8_t>(i)] =
-                ComputeColorDesc(GetColorStateDescriptor(i));
+                ComputeColorDesc(GetColorTargetState(i));
         }
         descriptorD3D12.NumRenderTargets = static_cast<uint32_t>(GetColorAttachmentsMask().count());
 
         descriptorD3D12.BlendState.AlphaToCoverageEnable = descriptor->alphaToCoverageEnabled;
         descriptorD3D12.BlendState.IndependentBlendEnable = TRUE;
 
-        descriptorD3D12.DepthStencilState =
-            ComputeDepthStencilDesc(GetDepthStencilStateDescriptor());
+        descriptorD3D12.DepthStencilState = ComputeDepthStencilDesc(GetDepthStencilState());
 
         descriptorD3D12.SampleMask = GetSampleMask();
         descriptorD3D12.PrimitiveTopologyType = D3D12PrimitiveTopologyType(GetPrimitiveTopology());
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index a8794d7..6c2d356 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -284,7 +284,7 @@
                                                ->GetBufferBindingCount(SingleShaderStage::Vertex);
 
                     if (enableVertexPulling) {
-                        bufferCount += pipeline->GetVertexStateDescriptor()->vertexBufferCount;
+                        bufferCount += pipeline->GetVertexBufferCount();
                     }
 
                     [render setVertexBytes:data[SingleShaderStage::Vertex].data()
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index bea3898..b2f6c8f 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -201,21 +201,23 @@
         }
 
         void ComputeBlendDesc(MTLRenderPipelineColorAttachmentDescriptor* attachment,
-                              const ColorStateDescriptor* descriptor,
+                              const ColorTargetState* state,
                               bool isDeclaredInFragmentShader) {
-            attachment.blendingEnabled = BlendEnabled(descriptor);
-            attachment.sourceRGBBlendFactor =
-                MetalBlendFactor(descriptor->colorBlend.srcFactor, false);
-            attachment.destinationRGBBlendFactor =
-                MetalBlendFactor(descriptor->colorBlend.dstFactor, false);
-            attachment.rgbBlendOperation = MetalBlendOperation(descriptor->colorBlend.operation);
-            attachment.sourceAlphaBlendFactor =
-                MetalBlendFactor(descriptor->alphaBlend.srcFactor, true);
-            attachment.destinationAlphaBlendFactor =
-                MetalBlendFactor(descriptor->alphaBlend.dstFactor, true);
-            attachment.alphaBlendOperation = MetalBlendOperation(descriptor->alphaBlend.operation);
+            attachment.blendingEnabled = state->blend != nullptr;
+            if (attachment.blendingEnabled) {
+                attachment.sourceRGBBlendFactor =
+                    MetalBlendFactor(state->blend->color.srcFactor, false);
+                attachment.destinationRGBBlendFactor =
+                    MetalBlendFactor(state->blend->color.dstFactor, false);
+                attachment.rgbBlendOperation = MetalBlendOperation(state->blend->color.operation);
+                attachment.sourceAlphaBlendFactor =
+                    MetalBlendFactor(state->blend->alpha.srcFactor, true);
+                attachment.destinationAlphaBlendFactor =
+                    MetalBlendFactor(state->blend->alpha.dstFactor, true);
+                attachment.alphaBlendOperation = MetalBlendOperation(state->blend->alpha.operation);
+            }
             attachment.writeMask =
-                MetalColorWriteMask(descriptor->writeMask, isDeclaredInFragmentShader);
+                MetalColorWriteMask(state->writeMask, isDeclaredInFragmentShader);
         }
 
         MTLStencilOperation MetalStencilOperation(wgpu::StencilOperation stencilOperation) {
@@ -239,8 +241,7 @@
             }
         }
 
-        NSRef<MTLDepthStencilDescriptor> MakeDepthStencilDesc(
-            const DepthStencilStateDescriptor* descriptor) {
+        NSRef<MTLDepthStencilDescriptor> MakeDepthStencilDesc(const DepthStencilState* descriptor) {
             NSRef<MTLDepthStencilDescriptor> mtlDepthStencilDescRef =
                 AcquireNSRef([MTLDepthStencilDescriptor new]);
             MTLDepthStencilDescriptor* mtlDepthStencilDescriptor = mtlDepthStencilDescRef.Get();
@@ -340,9 +341,17 @@
         ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
         const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
         ShaderModule::MetalFunctionData vertexData;
+
+        const VertexStateDescriptor* vertexStatePtr = descriptor->vertexState;
+        VertexStateDescriptor vertexState;
+        if (vertexStatePtr == nullptr) {
+            vertexState = {};
+            vertexStatePtr = &vertexState;
+        }
+
         DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
-                                              ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
-                                              this));
+                                              ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this,
+                                              vertexStatePtr));
 
         descriptorMTL.vertexFunction = vertexData.function.Get();
         if (vertexData.needsStorageBufferLength) {
@@ -379,7 +388,7 @@
         for (ColorAttachmentIndex i : IterateBitSet(GetColorAttachmentsMask())) {
             descriptorMTL.colorAttachments[static_cast<uint8_t>(i)].pixelFormat =
                 MetalPixelFormat(GetColorAttachmentFormat(i));
-            const ColorStateDescriptor* descriptor = GetColorStateDescriptor(i);
+            const ColorTargetState* descriptor = GetColorTargetState(i);
             ComputeBlendDesc(descriptorMTL.colorAttachments[static_cast<uint8_t>(i)], descriptor,
                              fragmentOutputsWritten[i]);
         }
@@ -403,7 +412,7 @@
         // call setDepthStencilState() for a given render pipeline in CommandEncoder, in order to
         // improve performance.
         NSRef<MTLDepthStencilDescriptor> depthStencilDesc =
-            MakeDepthStencilDesc(GetDepthStencilStateDescriptor());
+            MakeDepthStencilDesc(GetDepthStencilState());
         mMtlDepthStencilState =
             AcquireNSPRef([mtlDevice newDepthStencilStateWithDescriptor:depthStencilDesc.Get()]);
 
diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h
index 3d777e9..1a4f549 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.h
+++ b/src/dawn_native/metal/ShaderModuleMTL.h
@@ -47,7 +47,8 @@
                                   const PipelineLayout* layout,
                                   MetalFunctionData* out,
                                   uint32_t sampleMask = 0xFFFFFFFF,
-                                  const RenderPipeline* renderPipeline = nullptr);
+                                  const RenderPipeline* renderPipeline = nullptr,
+                                  const VertexStateDescriptor* vertexState = nullptr);
 
       private:
         ResultOrError<std::string> TranslateToMSLWithTint(const char* entryPointName,
@@ -55,6 +56,7 @@
                                                           const PipelineLayout* layout,
                                                           uint32_t sampleMask,
                                                           const RenderPipeline* renderPipeline,
+                                                          const VertexStateDescriptor* vertexState,
                                                           std::string* remappedEntryPointName,
                                                           bool* needsStorageBufferLength);
         ResultOrError<std::string> TranslateToMSLWithSPIRVCross(
@@ -63,6 +65,7 @@
             const PipelineLayout* layout,
             uint32_t sampleMask,
             const RenderPipeline* renderPipeline,
+            const VertexStateDescriptor* vertexState,
             std::string* remappedEntryPointName,
             bool* needsStorageBufferLength);
 
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index f678ebd..9cc493d7 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -56,6 +56,7 @@
         // TODO(crbug.com/tint/387): AND in a fixed sample mask in the shader.
         uint32_t sampleMask,
         const RenderPipeline* renderPipeline,
+        const VertexStateDescriptor* vertexState,
         std::string* remappedEntryPointName,
         bool* needsStorageBufferLength) {
         // TODO(crbug.com/tint/256): Set this accordingly if arrayLength(..) is used.
@@ -68,8 +69,7 @@
         if (stage == SingleShaderStage::Vertex &&
             GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
             transformManager.append(
-                MakeVertexPullingTransform(*renderPipeline->GetVertexStateDescriptor(),
-                                           entryPointName, kPullingBufferBindingSet));
+                MakeVertexPullingTransform(*vertexState, entryPointName, kPullingBufferBindingSet));
 
             for (VertexBufferSlot slot :
                  IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
@@ -118,6 +118,7 @@
         const PipelineLayout* layout,
         uint32_t sampleMask,
         const RenderPipeline* renderPipeline,
+        const VertexStateDescriptor* vertexState,
         std::string* remappedEntryPointName,
         bool* needsStorageBufferLength) {
         const std::vector<uint32_t>* spirv = &GetSpirv();
@@ -128,14 +129,12 @@
             stage == SingleShaderStage::Vertex) {
             if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
                 DAWN_TRY_ASSIGN(pullingSpirv,
-                                GeneratePullingSpirv(GetTintProgram(),
-                                                     *renderPipeline->GetVertexStateDescriptor(),
-                                                     entryPointName, kPullingBufferBindingSet));
+                                GeneratePullingSpirv(GetTintProgram(), *vertexState, entryPointName,
+                                                     kPullingBufferBindingSet));
             } else {
-                DAWN_TRY_ASSIGN(
-                    pullingSpirv,
-                    GeneratePullingSpirv(GetSpirv(), *renderPipeline->GetVertexStateDescriptor(),
-                                         entryPointName, kPullingBufferBindingSet));
+                DAWN_TRY_ASSIGN(pullingSpirv,
+                                GeneratePullingSpirv(GetSpirv(), *vertexState, entryPointName,
+                                                     kPullingBufferBindingSet));
             }
             spirv = &pullingSpirv;
         }
@@ -228,20 +227,29 @@
                                             const PipelineLayout* layout,
                                             ShaderModule::MetalFunctionData* out,
                                             uint32_t sampleMask,
-                                            const RenderPipeline* renderPipeline) {
+                                            const RenderPipeline* renderPipeline,
+                                            const VertexStateDescriptor* vertexState) {
         ASSERT(!IsError());
         ASSERT(out);
 
+        // Vertex stages must specify a renderPipeline and vertexState
+        if (stage == SingleShaderStage::Vertex) {
+            ASSERT(renderPipeline != nullptr);
+            ASSERT(vertexState != nullptr);
+        }
+
         std::string remappedEntryPointName;
         std::string msl;
         if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
-            DAWN_TRY_ASSIGN(msl, TranslateToMSLWithTint(entryPointName, stage, layout, sampleMask,
-                                                        renderPipeline, &remappedEntryPointName,
-                                                        &out->needsStorageBufferLength));
+            DAWN_TRY_ASSIGN(
+                msl, TranslateToMSLWithTint(entryPointName, stage, layout, sampleMask,
+                                            renderPipeline, vertexState, &remappedEntryPointName,
+                                            &out->needsStorageBufferLength));
         } else {
-            DAWN_TRY_ASSIGN(msl, TranslateToMSLWithSPIRVCross(
-                                     entryPointName, stage, layout, sampleMask, renderPipeline,
-                                     &remappedEntryPointName, &out->needsStorageBufferLength));
+            DAWN_TRY_ASSIGN(msl, TranslateToMSLWithSPIRVCross(entryPointName, stage, layout,
+                                                              sampleMask, renderPipeline,
+                                                              vertexState, &remappedEntryPointName,
+                                                              &out->needsStorageBufferLength));
         }
 
         // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
diff --git a/src/dawn_native/opengl/RenderPipelineGL.cpp b/src/dawn_native/opengl/RenderPipelineGL.cpp
index 9a5da5c..c5e9805 100644
--- a/src/dawn_native/opengl/RenderPipelineGL.cpp
+++ b/src/dawn_native/opengl/RenderPipelineGL.cpp
@@ -104,43 +104,42 @@
 
         void ApplyColorState(const OpenGLFunctions& gl,
                              ColorAttachmentIndex attachment,
-                             const ColorStateDescriptor* descriptor) {
+                             const ColorTargetState* state) {
             GLuint colorBuffer = static_cast<GLuint>(static_cast<uint8_t>(attachment));
-            if (BlendEnabled(descriptor)) {
+            if (state->blend != nullptr) {
                 gl.Enablei(GL_BLEND, colorBuffer);
-                gl.BlendEquationSeparatei(colorBuffer,
-                                          GLBlendMode(descriptor->colorBlend.operation),
-                                          GLBlendMode(descriptor->alphaBlend.operation));
+                gl.BlendEquationSeparatei(colorBuffer, GLBlendMode(state->blend->color.operation),
+                                          GLBlendMode(state->blend->alpha.operation));
                 gl.BlendFuncSeparatei(colorBuffer,
-                                      GLBlendFactor(descriptor->colorBlend.srcFactor, false),
-                                      GLBlendFactor(descriptor->colorBlend.dstFactor, false),
-                                      GLBlendFactor(descriptor->alphaBlend.srcFactor, true),
-                                      GLBlendFactor(descriptor->alphaBlend.dstFactor, true));
+                                      GLBlendFactor(state->blend->color.srcFactor, false),
+                                      GLBlendFactor(state->blend->color.dstFactor, false),
+                                      GLBlendFactor(state->blend->alpha.srcFactor, true),
+                                      GLBlendFactor(state->blend->alpha.dstFactor, true));
             } else {
                 gl.Disablei(GL_BLEND, colorBuffer);
             }
-            gl.ColorMaski(colorBuffer, descriptor->writeMask & wgpu::ColorWriteMask::Red,
-                          descriptor->writeMask & wgpu::ColorWriteMask::Green,
-                          descriptor->writeMask & wgpu::ColorWriteMask::Blue,
-                          descriptor->writeMask & wgpu::ColorWriteMask::Alpha);
+            gl.ColorMaski(colorBuffer, state->writeMask & wgpu::ColorWriteMask::Red,
+                          state->writeMask & wgpu::ColorWriteMask::Green,
+                          state->writeMask & wgpu::ColorWriteMask::Blue,
+                          state->writeMask & wgpu::ColorWriteMask::Alpha);
         }
 
-        void ApplyColorState(const OpenGLFunctions& gl, const ColorStateDescriptor* descriptor) {
-            if (BlendEnabled(descriptor)) {
+        void ApplyColorState(const OpenGLFunctions& gl, const ColorTargetState* state) {
+            if (state->blend != nullptr) {
                 gl.Enable(GL_BLEND);
-                gl.BlendEquationSeparate(GLBlendMode(descriptor->colorBlend.operation),
-                                         GLBlendMode(descriptor->alphaBlend.operation));
-                gl.BlendFuncSeparate(GLBlendFactor(descriptor->colorBlend.srcFactor, false),
-                                     GLBlendFactor(descriptor->colorBlend.dstFactor, false),
-                                     GLBlendFactor(descriptor->alphaBlend.srcFactor, true),
-                                     GLBlendFactor(descriptor->alphaBlend.dstFactor, true));
+                gl.BlendEquationSeparate(GLBlendMode(state->blend->color.operation),
+                                         GLBlendMode(state->blend->alpha.operation));
+                gl.BlendFuncSeparate(GLBlendFactor(state->blend->color.srcFactor, false),
+                                     GLBlendFactor(state->blend->color.dstFactor, false),
+                                     GLBlendFactor(state->blend->alpha.srcFactor, true),
+                                     GLBlendFactor(state->blend->alpha.dstFactor, true));
             } else {
                 gl.Disable(GL_BLEND);
             }
-            gl.ColorMask(descriptor->writeMask & wgpu::ColorWriteMask::Red,
-                         descriptor->writeMask & wgpu::ColorWriteMask::Green,
-                         descriptor->writeMask & wgpu::ColorWriteMask::Blue,
-                         descriptor->writeMask & wgpu::ColorWriteMask::Alpha);
+            gl.ColorMask(state->writeMask & wgpu::ColorWriteMask::Red,
+                         state->writeMask & wgpu::ColorWriteMask::Green,
+                         state->writeMask & wgpu::ColorWriteMask::Blue,
+                         state->writeMask & wgpu::ColorWriteMask::Alpha);
         }
 
         bool Equal(const BlendDescriptor& lhs, const BlendDescriptor& rhs) {
@@ -170,7 +169,7 @@
         }
 
         void ApplyDepthStencilState(const OpenGLFunctions& gl,
-                                    const DepthStencilStateDescriptor* descriptor,
+                                    const DepthStencilState* descriptor,
                                     PersistentPipelineState* persistentPipelineState) {
             // Depth writes only occur if depth is enabled
             if (descriptor->depthCompare == wgpu::CompareFunction::Always &&
@@ -278,7 +277,7 @@
 
         ApplyFrontFaceAndCulling(gl, GetFrontFace(), GetCullMode());
 
-        ApplyDepthStencilState(gl, GetDepthStencilStateDescriptor(), &persistentPipelineState);
+        ApplyDepthStencilState(gl, GetDepthStencilState(), &persistentPipelineState);
 
         gl.SampleMaski(0, GetSampleMask());
         if (IsAlphaToCoverageEnabled()) {
@@ -302,21 +301,26 @@
 
         if (!GetDevice()->IsToggleEnabled(Toggle::DisableIndexedDrawBuffers)) {
             for (ColorAttachmentIndex attachmentSlot : IterateBitSet(GetColorAttachmentsMask())) {
-                ApplyColorState(gl, attachmentSlot, GetColorStateDescriptor(attachmentSlot));
+                ApplyColorState(gl, attachmentSlot, GetColorTargetState(attachmentSlot));
             }
         } else {
-            const ColorStateDescriptor* prevDescriptor = nullptr;
+            const ColorTargetState* prevDescriptor = nullptr;
             for (ColorAttachmentIndex attachmentSlot : IterateBitSet(GetColorAttachmentsMask())) {
-                const ColorStateDescriptor* descriptor = GetColorStateDescriptor(attachmentSlot);
+                const ColorTargetState* descriptor = GetColorTargetState(attachmentSlot);
                 if (!prevDescriptor) {
                     ApplyColorState(gl, descriptor);
                     prevDescriptor = descriptor;
-                } else if (!Equal(descriptor->alphaBlend, prevDescriptor->alphaBlend) ||
-                           !Equal(descriptor->colorBlend, prevDescriptor->colorBlend) ||
-                           descriptor->writeMask != prevDescriptor->writeMask) {
-                    // TODO(crbug.com/dawn/582): Add validation to prevent this as it is not
-                    // supported on GLES < 3.2.
+                } else if ((descriptor->blend == nullptr) != (prevDescriptor->blend == nullptr)) {
+                    // TODO(crbug.com/dawn/582): GLES < 3.2 does not support different blend states
+                    // per color target. Add validation to prevent this as it is not.
                     ASSERT(false);
+                } else if (descriptor->blend != nullptr) {
+                    if (!Equal(descriptor->blend->alpha, prevDescriptor->blend->alpha) ||
+                        !Equal(descriptor->blend->color, prevDescriptor->blend->color) ||
+                        descriptor->writeMask != prevDescriptor->writeMask) {
+                        // TODO(crbug.com/dawn/582)
+                        ASSERT(false);
+                    }
                 }
             }
         }
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp
index 2d9c627..2438fde 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp
@@ -222,18 +222,29 @@
                                               : static_cast<VkColorComponentFlags>(0);
         }
 
-        VkPipelineColorBlendAttachmentState ComputeColorDesc(const ColorStateDescriptor* descriptor,
+        VkPipelineColorBlendAttachmentState ComputeColorDesc(const ColorTargetState* state,
                                                              bool isDeclaredInFragmentShader) {
             VkPipelineColorBlendAttachmentState attachment;
-            attachment.blendEnable = BlendEnabled(descriptor) ? VK_TRUE : VK_FALSE;
-            attachment.srcColorBlendFactor = VulkanBlendFactor(descriptor->colorBlend.srcFactor);
-            attachment.dstColorBlendFactor = VulkanBlendFactor(descriptor->colorBlend.dstFactor);
-            attachment.colorBlendOp = VulkanBlendOperation(descriptor->colorBlend.operation);
-            attachment.srcAlphaBlendFactor = VulkanBlendFactor(descriptor->alphaBlend.srcFactor);
-            attachment.dstAlphaBlendFactor = VulkanBlendFactor(descriptor->alphaBlend.dstFactor);
-            attachment.alphaBlendOp = VulkanBlendOperation(descriptor->alphaBlend.operation);
+            attachment.blendEnable = state->blend != nullptr ? VK_TRUE : VK_FALSE;
+            if (attachment.blendEnable) {
+                attachment.srcColorBlendFactor = VulkanBlendFactor(state->blend->color.srcFactor);
+                attachment.dstColorBlendFactor = VulkanBlendFactor(state->blend->color.dstFactor);
+                attachment.colorBlendOp = VulkanBlendOperation(state->blend->color.operation);
+                attachment.srcAlphaBlendFactor = VulkanBlendFactor(state->blend->alpha.srcFactor);
+                attachment.dstAlphaBlendFactor = VulkanBlendFactor(state->blend->alpha.dstFactor);
+                attachment.alphaBlendOp = VulkanBlendOperation(state->blend->alpha.operation);
+            } else {
+                // Swiftshader's Vulkan implementation appears to expect these values to be valid
+                // even when blending is not enabled.
+                attachment.srcColorBlendFactor = VK_BLEND_FACTOR_ONE;
+                attachment.dstColorBlendFactor = VK_BLEND_FACTOR_ZERO;
+                attachment.colorBlendOp = VK_BLEND_OP_ADD;
+                attachment.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
+                attachment.dstAlphaBlendFactor = VK_BLEND_FACTOR_ZERO;
+                attachment.alphaBlendOp = VK_BLEND_OP_ADD;
+            }
             attachment.colorWriteMask =
-                VulkanColorWriteMask(descriptor->writeMask, isDeclaredInFragmentShader);
+                VulkanColorWriteMask(state->writeMask, isDeclaredInFragmentShader);
             return attachment;
         }
 
@@ -259,7 +270,7 @@
         }
 
         VkPipelineDepthStencilStateCreateInfo ComputeDepthStencilDesc(
-            const DepthStencilStateDescriptor* descriptor) {
+            const DepthStencilState* descriptor) {
             VkPipelineDepthStencilStateCreateInfo depthStencilState;
             depthStencilState.sType = VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO;
             depthStencilState.pNext = nullptr;
@@ -404,7 +415,7 @@
         multisample.alphaToOneEnable = VK_FALSE;
 
         VkPipelineDepthStencilStateCreateInfo depthStencilState =
-            ComputeDepthStencilDesc(GetDepthStencilStateDescriptor());
+            ComputeDepthStencilDesc(GetDepthStencilState());
 
         // Initialize the "blend state info" that will be chained in the "create info" from the data
         // pre-computed in the ColorState
@@ -413,9 +424,8 @@
         const auto& fragmentOutputsWritten =
             GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputsWritten;
         for (ColorAttachmentIndex i : IterateBitSet(GetColorAttachmentsMask())) {
-            const ColorStateDescriptor* colorStateDescriptor = GetColorStateDescriptor(i);
-            colorBlendAttachments[i] =
-                ComputeColorDesc(colorStateDescriptor, fragmentOutputsWritten[i]);
+            const ColorTargetState* target = GetColorTargetState(i);
+            colorBlendAttachments[i] = ComputeColorDesc(target, fragmentOutputsWritten[i]);
         }
         VkPipelineColorBlendStateCreateInfo colorBlend;
         colorBlend.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO;