Add setIndexBufferWithFormat method

First step of a multi-part change to bring the setIndexBuffer
method up-to-date with the current WebGPU spec. This change
preserves the previous setIndexBuffer semantics for backwards
compatibility until developers have been notified and given
a grace period to transition to the new signature.

BUG=dawn:502
Change-Id: Ia8c665639494d244f52296ceadaedb320fa6c985
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27182
Commit-Queue: Brandon Jones <bajones@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/dawn.json b/dawn.json
index 48aa2fa..e9cf0dd 100644
--- a/dawn.json
+++ b/dawn.json
@@ -779,8 +779,9 @@
     "index format": {
         "category": "enum",
         "values": [
-            {"value": 0, "name": "uint16"},
-            {"value": 1, "name": "uint32"}
+            {"value": 0, "name": "undefined", "jsrepr": "undefined"},
+            {"value": 1, "name": "uint16"},
+            {"value": 2, "name": "uint32"}
         ]
     },
     "instance": {
@@ -823,7 +824,7 @@
         "category": "structure",
         "extensible": true,
         "members": [
-            {"name": "index format", "type": "index format", "default": "uint32"},
+            {"name": "index format", "type": "index format", "default": "undefined"},
             {"name": "vertex buffer count", "type": "uint32_t", "default": 0},
             {"name": "vertex buffers", "type": "vertex buffer layout descriptor", "annotation": "const*", "length": "vertex buffer count"}
         ]
@@ -1087,6 +1088,15 @@
                 ]
             },
             {
+                "name": "set index buffer with format",
+                "args": [
+                    {"name": "buffer", "type": "buffer"},
+                    {"name": "format", "type": "index format"},
+                    {"name": "offset", "type": "uint64_t", "default": "0"},
+                    {"name": "size", "type": "uint64_t", "default": "0"}
+                ]
+            },
+            {
                 "name": "finish",
                 "returns": "render bundle",
                 "args": [
@@ -1277,6 +1287,15 @@
                 ]
             },
             {
+                "name": "set index buffer with format",
+                "args": [
+                    {"name": "buffer", "type": "buffer"},
+                    {"name": "format", "type": "index format"},
+                    {"name": "offset", "type": "uint64_t", "default": "0"},
+                    {"name": "size", "type": "uint64_t", "default": "0"}
+                ]
+            },
+            {
                 "name": "write timestamp",
                 "args": [
                     {"name": "query set", "type": "query set"},
diff --git a/examples/CppHelloTriangle.cpp b/examples/CppHelloTriangle.cpp
index 378afa8..c3548ab 100644
--- a/examples/CppHelloTriangle.cpp
+++ b/examples/CppHelloTriangle.cpp
@@ -164,7 +164,7 @@
         pass.SetPipeline(pipeline);
         pass.SetBindGroup(0, bindGroup);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
         pass.DrawIndexed(3);
         pass.EndPass();
     }
diff --git a/examples/CubeReflection.cpp b/examples/CubeReflection.cpp
index 4ff18e0..52d32b5 100644
--- a/examples/CubeReflection.cpp
+++ b/examples/CubeReflection.cpp
@@ -251,7 +251,7 @@
         pass.SetPipeline(pipeline);
         pass.SetBindGroup(0, bindGroup[0]);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
         pass.DrawIndexed(36);
 
         pass.SetStencilReference(0x1);
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 7c4a327..3e806a1 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -61,7 +61,8 @@
         1 << VALIDATION_ASPECT_VERTEX_BUFFERS | 1 << VALIDATION_ASPECT_INDEX_BUFFER;
 
     static constexpr CommandBufferStateTracker::ValidationAspects kLazyAspects =
-        1 << VALIDATION_ASPECT_BIND_GROUPS | 1 << VALIDATION_ASPECT_VERTEX_BUFFERS;
+        1 << VALIDATION_ASPECT_BIND_GROUPS | 1 << VALIDATION_ASPECT_VERTEX_BUFFERS |
+        1 << VALIDATION_ASPECT_INDEX_BUFFER;
 
     MaybeError CommandBufferStateTracker::ValidateCanDispatch() {
         return ValidateOperation(kDispatchAspects);
@@ -124,6 +125,23 @@
                 mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
             }
         }
+
+        if (aspects[VALIDATION_ASPECT_INDEX_BUFFER]) {
+            if (mIndexBufferSet) {
+                wgpu::IndexFormat pipelineIndexFormat =
+                    mLastRenderPipeline->GetVertexStateDescriptor()->indexFormat;
+                if (mIndexFormat != wgpu::IndexFormat::Undefined) {
+                    if (!mLastRenderPipeline->IsStripPrimitiveTopology() ||
+                        mIndexFormat == pipelineIndexFormat) {
+                        mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
+                    }
+                } else if (pipelineIndexFormat != wgpu::IndexFormat::Undefined) {
+                    // TODO(crbug.com/dawn/502): Deprecated path. Remove once setIndexFormat always
+                    // requires an index format.
+                    mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
+                }
+            }
+        }
     }
 
     MaybeError CommandBufferStateTracker::CheckMissingAspects(ValidationAspects aspects) {
@@ -132,7 +150,29 @@
         }
 
         if (aspects[VALIDATION_ASPECT_INDEX_BUFFER]) {
-            return DAWN_VALIDATION_ERROR("Missing index buffer");
+            wgpu::IndexFormat pipelineIndexFormat =
+                mLastRenderPipeline->GetVertexStateDescriptor()->indexFormat;
+            if (!mIndexBufferSet) {
+                return DAWN_VALIDATION_ERROR("Missing index buffer");
+            } else if (mIndexFormat != wgpu::IndexFormat::Undefined &&
+                mLastRenderPipeline->IsStripPrimitiveTopology() &&
+                mIndexFormat != pipelineIndexFormat) {
+                return DAWN_VALIDATION_ERROR(
+                    "Pipeline strip index format does not match index buffer format");
+            } else if (mIndexFormat == wgpu::IndexFormat::Undefined &&
+                       pipelineIndexFormat == wgpu::IndexFormat::Undefined) {
+                // TODO(crbug.com/dawn/502): Deprecated path. Remove once setIndexFormat always
+                // requires an index format.
+                return DAWN_VALIDATION_ERROR(
+                    "Index format must be specified on the pipeline or in setIndexBuffer");
+            }
+
+            // The chunk of code above should be similar to the one in |RecomputeLazyAspects|.
+            // It returns the first invalid state found. We shouldn't be able to reach this line
+            // because to have invalid aspects one of the above conditions must have failed earlier.
+            // If this is reached, make sure lazy aspects and the error checks above are consistent.
+            UNREACHABLE();
+            return DAWN_VALIDATION_ERROR("Index buffer invalid");
         }
 
         if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
@@ -185,8 +225,9 @@
         mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS);
     }
 
-    void CommandBufferStateTracker::SetIndexBuffer() {
-        mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
+    void CommandBufferStateTracker::SetIndexBuffer(wgpu::IndexFormat format) {
+        mIndexBufferSet = true;
+        mIndexFormat = format;
     }
 
     void CommandBufferStateTracker::SetVertexBuffer(uint32_t slot) {
diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h
index 39d32fd..67645e4 100644
--- a/src/dawn_native/CommandBufferStateTracker.h
+++ b/src/dawn_native/CommandBufferStateTracker.h
@@ -38,7 +38,7 @@
         void SetComputePipeline(ComputePipelineBase* pipeline);
         void SetRenderPipeline(RenderPipelineBase* pipeline);
         void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup);
-        void SetIndexBuffer();
+        void SetIndexBuffer(wgpu::IndexFormat format);
         void SetVertexBuffer(uint32_t slot);
 
         static constexpr size_t kNumAspects = 4;
@@ -55,6 +55,8 @@
 
         ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {};
         std::bitset<kMaxVertexBuffers> mVertexBufferSlotsUsed;
+        bool mIndexBufferSet = false;
+        wgpu::IndexFormat mIndexFormat;
 
         PipelineLayoutBase* mLastPipelineLayout = nullptr;
         RenderPipelineBase* mLastRenderPipeline = nullptr;
diff --git a/src/dawn_native/CommandValidation.cpp b/src/dawn_native/CommandValidation.cpp
index 497a8de..1184dc3 100644
--- a/src/dawn_native/CommandValidation.cpp
+++ b/src/dawn_native/CommandValidation.cpp
@@ -102,8 +102,8 @@
                 }
 
                 case Command::SetIndexBuffer: {
-                    commands->NextCommand<SetIndexBufferCmd>();
-                    commandBufferState->SetIndexBuffer();
+                    SetIndexBufferCmd* cmd = commands->NextCommand<SetIndexBufferCmd>();
+                    commandBufferState->SetIndexBuffer(cmd->format);
                     break;
                 }
 
diff --git a/src/dawn_native/Commands.h b/src/dawn_native/Commands.h
index 0948678..0d9f57c 100644
--- a/src/dawn_native/Commands.h
+++ b/src/dawn_native/Commands.h
@@ -228,6 +228,7 @@
 
     struct SetIndexBufferCmd {
         Ref<BufferBase> buffer;
+        wgpu::IndexFormat format;
         uint64_t offset;
         uint64_t size;
     };
diff --git a/src/dawn_native/RenderEncoderBase.cpp b/src/dawn_native/RenderEncoderBase.cpp
index 7285fba..e662309 100644
--- a/src/dawn_native/RenderEncoderBase.cpp
+++ b/src/dawn_native/RenderEncoderBase.cpp
@@ -136,6 +136,20 @@
     }
 
     void RenderEncoderBase::SetIndexBuffer(BufferBase* buffer, uint64_t offset, uint64_t size) {
+        GetDevice()->EmitDeprecationWarning(
+            "RenderEncoderBase::SetIndexBuffer is deprecated. Use RenderEncoderBase::SetIndexBufferWithFormat instead");
+
+        SetIndexBufferCommon(buffer, wgpu::IndexFormat::Undefined, offset, size, false);
+    }
+
+    void RenderEncoderBase::SetIndexBufferWithFormat(BufferBase* buffer, wgpu::IndexFormat format,
+                                                     uint64_t offset, uint64_t size) {
+        SetIndexBufferCommon(buffer, format, offset, size, true);
+    }
+
+    void RenderEncoderBase::SetIndexBufferCommon(BufferBase* buffer, wgpu::IndexFormat format,
+                                                 uint64_t offset, uint64_t size,
+                                                 bool requireFormat) {
         mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
             DAWN_TRY(GetDevice()->ValidateObject(buffer));
 
@@ -153,9 +167,16 @@
                 }
             }
 
+            if (requireFormat && format == wgpu::IndexFormat::Undefined) {
+                return DAWN_VALIDATION_ERROR("Index format must be specified");
+            } else if (!requireFormat) {
+                ASSERT(format == wgpu::IndexFormat::Undefined);
+            }
+
             SetIndexBufferCmd* cmd =
                 allocator->Allocate<SetIndexBufferCmd>(Command::SetIndexBuffer);
             cmd->buffer = buffer;
+            cmd->format = format;
             cmd->offset = offset;
             cmd->size = size;
 
diff --git a/src/dawn_native/RenderEncoderBase.h b/src/dawn_native/RenderEncoderBase.h
index a4f3b9f..b3f8543 100644
--- a/src/dawn_native/RenderEncoderBase.h
+++ b/src/dawn_native/RenderEncoderBase.h
@@ -41,12 +41,17 @@
 
         void SetVertexBuffer(uint32_t slot, BufferBase* buffer, uint64_t offset, uint64_t size);
         void SetIndexBuffer(BufferBase* buffer, uint64_t offset, uint64_t size);
+        void SetIndexBufferWithFormat(BufferBase* buffer, wgpu::IndexFormat format, uint64_t offset,
+                                      uint64_t size);
 
       protected:
         // Construct an "error" render encoder base.
         RenderEncoderBase(DeviceBase* device, EncodingContext* encodingContext, ErrorTag errorTag);
 
       private:
+        void SetIndexBufferCommon(BufferBase* buffer, wgpu::IndexFormat format, uint64_t offset,
+                                  uint64_t size, bool requireFormat);
+
         const bool mDisableBaseVertex;
         const bool mDisableBaseInstance;
     };
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index 87c9f62..83d9b11 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -86,12 +86,22 @@
 
         MaybeError ValidateVertexStateDescriptor(
             const VertexStateDescriptor* descriptor,
+            wgpu::PrimitiveTopology primitiveTopology,
             std::bitset<kMaxVertexAttributes>* attributesSetMask) {
             if (descriptor->nextInChain != nullptr) {
                 return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
             }
             DAWN_TRY(ValidateIndexFormat(descriptor->indexFormat));
 
+            // Pipeline descriptors using strip topologies must not have an undefined index format.
+            if (descriptor->indexFormat == wgpu::IndexFormat::Undefined) {
+                if (primitiveTopology == wgpu::PrimitiveTopology::LineStrip ||
+                    primitiveTopology == wgpu::PrimitiveTopology::TriangleStrip) {
+                    return DAWN_VALIDATION_ERROR(
+                        "indexFormat must not be undefined when using strip primitive topologies");
+                }
+            }
+
             if (descriptor->vertexBufferCount > kMaxVertexBuffers) {
                 return DAWN_VALIDATION_ERROR("Vertex buffer count exceeds maximum");
             }
@@ -321,12 +331,14 @@
             return DAWN_VALIDATION_ERROR("Null fragment stage is not supported (yet)");
         }
 
+        DAWN_TRY(ValidatePrimitiveTopology(descriptor->primitiveTopology));
+
         std::bitset<kMaxVertexAttributes> attributesSetMask;
         if (descriptor->vertexState) {
-            DAWN_TRY(ValidateVertexStateDescriptor(descriptor->vertexState, &attributesSetMask));
+            DAWN_TRY(ValidateVertexStateDescriptor(
+                descriptor->vertexState, descriptor->primitiveTopology, &attributesSetMask));
         }
 
-        DAWN_TRY(ValidatePrimitiveTopology(descriptor->primitiveTopology));
         DAWN_TRY(ValidateProgrammableStageDescriptor(
             device, &descriptor->vertexStage, descriptor->layout, SingleShaderStage::Vertex));
         DAWN_TRY(ValidateProgrammableStageDescriptor(
@@ -531,6 +543,12 @@
         return mPrimitiveTopology;
     }
 
+    bool RenderPipelineBase::IsStripPrimitiveTopology() const {
+        ASSERT(!IsError());
+        return mPrimitiveTopology == wgpu::PrimitiveTopology::LineStrip ||
+               mPrimitiveTopology == wgpu::PrimitiveTopology::TriangleStrip;
+    }
+
     wgpu::CullMode RenderPipelineBase::GetCullMode() const {
         ASSERT(!IsError());
         return mRasterizationState.cullMode;
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index bdc3af0..5c328f9 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -68,6 +68,7 @@
         const ColorStateDescriptor* GetColorStateDescriptor(uint32_t attachmentSlot) const;
         const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor() const;
         wgpu::PrimitiveTopology GetPrimitiveTopology() const;
+        bool IsStripPrimitiveTopology() const;
         wgpu::CullMode GetCullMode() const;
         wgpu::FrontFace GetFrontFace() const;
 
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 8ca8fc2..8e8cdd7 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -48,6 +48,8 @@
 
         DXGI_FORMAT DXGIIndexFormat(wgpu::IndexFormat format) {
             switch (format) {
+                case wgpu::IndexFormat::Undefined:
+                    return DXGI_FORMAT_UNKNOWN;
                 case wgpu::IndexFormat::Uint16:
                     return DXGI_FORMAT_R16_UINT;
                 case wgpu::IndexFormat::Uint32:
@@ -467,9 +469,11 @@
 
         class IndexBufferTracker {
           public:
-            void OnSetIndexBuffer(Buffer* buffer, uint64_t offset, uint64_t size) {
+            void OnSetIndexBuffer(Buffer* buffer, wgpu::IndexFormat format, uint64_t offset,
+                                  uint64_t size) {
                 mD3D12BufferView.BufferLocation = buffer->GetVA() + offset;
                 mD3D12BufferView.SizeInBytes = size;
+                mBufferIndexFormat = DXGIIndexFormat(format);
 
                 // We don't need to dirty the state unless BufferLocation or SizeInBytes
                 // change, but most of the time this will always be the case.
@@ -477,20 +481,26 @@
             }
 
             void OnSetPipeline(const RenderPipelineBase* pipeline) {
-                mD3D12BufferView.Format =
+                mPipelineIndexFormat =
                     DXGIIndexFormat(pipeline->GetVertexStateDescriptor()->indexFormat);
             }
 
             void Apply(ID3D12GraphicsCommandList* commandList) {
-                if (mD3D12BufferView.Format == mLastAppliedIndexFormat) {
-                    return;
+                DXGI_FORMAT newIndexFormat = mBufferIndexFormat;
+                if (newIndexFormat == DXGI_FORMAT_UNKNOWN) {
+                    newIndexFormat = mPipelineIndexFormat;
                 }
 
-                commandList->IASetIndexBuffer(&mD3D12BufferView);
-                mLastAppliedIndexFormat = mD3D12BufferView.Format;
+                if (newIndexFormat != mLastAppliedIndexFormat) {
+                    mD3D12BufferView.Format = newIndexFormat;
+                    commandList->IASetIndexBuffer(&mD3D12BufferView);
+                    mLastAppliedIndexFormat = newIndexFormat;
+                }
             }
 
           private:
+            DXGI_FORMAT mBufferIndexFormat = DXGI_FORMAT_UNKNOWN;
+            DXGI_FORMAT mPipelineIndexFormat = DXGI_FORMAT_UNKNOWN;
             DXGI_FORMAT mLastAppliedIndexFormat = DXGI_FORMAT_UNKNOWN;
             D3D12_INDEX_BUFFER_VIEW mD3D12BufferView = {};
         };
@@ -1285,8 +1295,8 @@
                 case Command::SetIndexBuffer: {
                     SetIndexBufferCmd* cmd = iter->NextCommand<SetIndexBufferCmd>();
 
-                    indexBufferTracker.OnSetIndexBuffer(ToBackend(cmd->buffer.Get()), cmd->offset,
-                                                        cmd->size);
+                    indexBufferTracker.OnSetIndexBuffer(ToBackend(cmd->buffer.Get()), cmd->format,
+                                                        cmd->offset, cmd->size);
                     break;
                 }
 
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 351090a..7891c7f 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -41,6 +41,17 @@
             MTLStoreActionStoreAndMultisampleResolve;
 #pragma clang diagnostic pop
 
+        MTLIndexType MTLIndexFormat(wgpu::IndexFormat format) {
+            switch (format) {
+                case wgpu::IndexFormat::Uint16:
+                    return MTLIndexTypeUInt16;
+                case wgpu::IndexFormat::Uint32:
+                    return MTLIndexTypeUInt32;
+                default:
+                    UNREACHABLE();
+            }
+        }
+
         // Creates an autoreleased MTLRenderPassDescriptor matching desc
         MTLRenderPassDescriptor* CreateMTLRenderPassDescriptor(BeginRenderPassCmd* renderPass) {
             MTLRenderPassDescriptor* descriptor = [MTLRenderPassDescriptor renderPassDescriptor];
@@ -979,6 +990,7 @@
         RenderPipeline* lastPipeline = nullptr;
         id<MTLBuffer> indexBuffer = nil;
         uint32_t indexBufferBaseOffset = 0;
+        wgpu::IndexFormat indexBufferFormat = wgpu::IndexFormat::Undefined;
         StorageBufferLengthTracker storageBufferLengths = {};
         VertexBufferTracker vertexBuffers(&storageBufferLengths);
         BindGroupTracker bindGroups(&storageBufferLengths);
@@ -1015,13 +1027,20 @@
 
                 case Command::DrawIndexed: {
                     DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
-                    size_t formatSize =
-                        IndexFormatSize(lastPipeline->GetVertexStateDescriptor()->indexFormat);
 
                     vertexBuffers.Apply(encoder, lastPipeline, enableVertexPulling);
                     bindGroups.Apply(encoder);
                     storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
 
+                    // If a index format was specified in setIndexBuffer always use it.
+                    wgpu::IndexFormat indexFormat = indexBufferFormat;
+                    if (indexFormat == wgpu::IndexFormat::Undefined) {
+                        // Otherwise use the pipeline's index format.
+                        // TODO(crbug.com/dawn/502): This path is deprecated.
+                        indexFormat = lastPipeline->GetVertexStateDescriptor()->indexFormat;
+                    }
+                    size_t formatSize = IndexFormatSize(indexFormat);
+
                     // The index and instance count must be non-zero, otherwise no-op
                     if (draw->indexCount != 0 && draw->instanceCount != 0) {
                         // MTLFeatureSet_iOS_GPUFamily3_v1 does not support baseInstance and
@@ -1029,7 +1048,7 @@
                         if (draw->baseVertex == 0 && draw->firstInstance == 0) {
                             [encoder drawIndexedPrimitives:lastPipeline->GetMTLPrimitiveTopology()
                                                 indexCount:draw->indexCount
-                                                 indexType:lastPipeline->GetMTLIndexType()
+                                                 indexType:MTLIndexFormat(indexFormat)
                                                indexBuffer:indexBuffer
                                          indexBufferOffset:indexBufferBaseOffset +
                                                            draw->firstIndex * formatSize
@@ -1037,7 +1056,7 @@
                         } else {
                             [encoder drawIndexedPrimitives:lastPipeline->GetMTLPrimitiveTopology()
                                                 indexCount:draw->indexCount
-                                                 indexType:lastPipeline->GetMTLIndexType()
+                                                 indexType:MTLIndexFormat(indexFormat)
                                                indexBuffer:indexBuffer
                                          indexBufferOffset:indexBufferBaseOffset +
                                                            draw->firstIndex * formatSize
@@ -1071,10 +1090,18 @@
                     bindGroups.Apply(encoder);
                     storageBufferLengths.Apply(encoder, lastPipeline, enableVertexPulling);
 
+                    // If a index format was specified in setIndexBuffer always use it.
+                    wgpu::IndexFormat indexFormat = indexBufferFormat;
+                    if (indexFormat == wgpu::IndexFormat::Undefined) {
+                        // Otherwise use the pipeline's index format.
+                        // TODO(crbug.com/dawn/502): This path is deprecated.
+                        indexFormat = lastPipeline->GetVertexStateDescriptor()->indexFormat;
+                    }
+
                     Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
                     id<MTLBuffer> indirectBuffer = buffer->GetMTLBuffer();
                     [encoder drawIndexedPrimitives:lastPipeline->GetMTLPrimitiveTopology()
-                                         indexType:lastPipeline->GetMTLIndexType()
+                                         indexType:MTLIndexFormat(indexFormat)
                                        indexBuffer:indexBuffer
                                  indexBufferOffset:indexBufferBaseOffset
                                     indirectBuffer:indirectBuffer
@@ -1142,6 +1169,9 @@
                     auto b = ToBackend(cmd->buffer.Get());
                     indexBuffer = b->GetMTLBuffer();
                     indexBufferBaseOffset = cmd->offset;
+                    // TODO(crbug.com/dawn/502): Once setIndexBuffer is required to specify an
+                    // index buffer format store as an MTLIndexType.
+                    indexBufferFormat = cmd->format;
                     break;
                 }
 
diff --git a/src/dawn_native/metal/RenderPipelineMTL.h b/src/dawn_native/metal/RenderPipelineMTL.h
index e27e1c4..4d8656b 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.h
+++ b/src/dawn_native/metal/RenderPipelineMTL.h
@@ -28,7 +28,6 @@
         static ResultOrError<RenderPipeline*> Create(Device* device,
                                                      const RenderPipelineDescriptor* descriptor);
 
-        MTLIndexType GetMTLIndexType() const;
         MTLPrimitiveType GetMTLPrimitiveTopology() const;
         MTLWinding GetMTLFrontFace() const;
         MTLCullMode GetMTLCullMode() const;
@@ -50,7 +49,6 @@
 
         MTLVertexDescriptor* MakeVertexDesc();
 
-        MTLIndexType mMtlIndexType;
         MTLPrimitiveType mMtlPrimitiveTopology;
         MTLWinding mMtlFrontFace;
         MTLCullMode mMtlCullMode;
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index 1e9efe1..5823ffe 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -126,15 +126,6 @@
             }
         }
 
-        MTLIndexType MTLIndexFormat(wgpu::IndexFormat format) {
-            switch (format) {
-                case wgpu::IndexFormat::Uint16:
-                    return MTLIndexTypeUInt16;
-                case wgpu::IndexFormat::Uint32:
-                    return MTLIndexTypeUInt32;
-            }
-        }
-
         MTLBlendFactor MetalBlendFactor(wgpu::BlendFactor factor, bool alpha) {
             switch (factor) {
                 case wgpu::BlendFactor::Zero:
@@ -321,7 +312,6 @@
     }
 
     MaybeError RenderPipeline::Initialize(const RenderPipelineDescriptor* descriptor) {
-        mMtlIndexType = MTLIndexFormat(GetVertexStateDescriptor()->indexFormat);
         mMtlPrimitiveTopology = MTLPrimitiveTopology(GetPrimitiveTopology());
         mMtlFrontFace = MTLFrontFace(GetFrontFace());
         mMtlCullMode = ToMTLCullMode(GetCullMode());
@@ -420,10 +410,6 @@
         [mMtlDepthStencilState release];
     }
 
-    MTLIndexType RenderPipeline::GetMTLIndexType() const {
-        return mMtlIndexType;
-    }
-
     MTLPrimitiveType RenderPipeline::GetMTLPrimitiveTopology() const {
         return mMtlPrimitiveTopology;
     }
diff --git a/src/dawn_native/opengl/CommandBufferGL.cpp b/src/dawn_native/opengl/CommandBufferGL.cpp
index be6b58d..e9e98b1 100644
--- a/src/dawn_native/opengl/CommandBufferGL.cpp
+++ b/src/dawn_native/opengl/CommandBufferGL.cpp
@@ -982,6 +982,7 @@
 
         RenderPipeline* lastPipeline = nullptr;
         uint64_t indexBufferBaseOffset = 0;
+        wgpu::IndexFormat indexBufferFormat;
 
         VertexStateBufferBindingTracker vertexStateBufferBindingTracker;
         BindGroupTracker bindGroupTracker = {};
@@ -1011,14 +1012,19 @@
                     vertexStateBufferBindingTracker.Apply(gl);
                     bindGroupTracker.Apply(gl);
 
-                    wgpu::IndexFormat indexFormat =
-                        lastPipeline->GetVertexStateDescriptor()->indexFormat;
+                    // If a index format was specified in setIndexBuffer always use it.
+                    wgpu::IndexFormat indexFormat = indexBufferFormat;
+                    if (indexFormat == wgpu::IndexFormat::Undefined) {
+                        // Otherwise use the pipeline's index format.
+                        // TODO(crbug.com/dawn/502): This path is deprecated.
+                        indexFormat = lastPipeline->GetVertexStateDescriptor()->indexFormat;
+                    }
                     size_t formatSize = IndexFormatSize(indexFormat);
-                    GLenum formatType = IndexFormatType(indexFormat);
 
                     if (draw->firstInstance > 0) {
                         gl.DrawElementsInstancedBaseVertexBaseInstance(
-                            lastPipeline->GetGLPrimitiveTopology(), draw->indexCount, formatType,
+                            lastPipeline->GetGLPrimitiveTopology(), draw->indexCount,
+                            IndexFormatType(indexFormat),
                             reinterpret_cast<void*>(draw->firstIndex * formatSize +
                                                     indexBufferBaseOffset),
                             draw->instanceCount, draw->baseVertex, draw->firstInstance);
@@ -1027,7 +1033,7 @@
                         if (draw->baseVertex != 0) {
                             gl.DrawElementsInstancedBaseVertex(
                                 lastPipeline->GetGLPrimitiveTopology(), draw->indexCount,
-                                formatType,
+                                IndexFormatType(indexFormat),
                                 reinterpret_cast<void*>(draw->firstIndex * formatSize +
                                                         indexBufferBaseOffset),
                                 draw->instanceCount, draw->baseVertex);
@@ -1035,7 +1041,7 @@
                             // This branch is only needed on OpenGL < 3.2; ES < 3.2
                             gl.DrawElementsInstanced(
                                 lastPipeline->GetGLPrimitiveTopology(), draw->indexCount,
-                                formatType,
+                                IndexFormatType(indexFormat),
                                 reinterpret_cast<void*>(draw->firstIndex * formatSize +
                                                         indexBufferBaseOffset),
                                 draw->instanceCount);
@@ -1064,16 +1070,20 @@
                     vertexStateBufferBindingTracker.Apply(gl);
                     bindGroupTracker.Apply(gl);
 
-                    wgpu::IndexFormat indexFormat =
-                        lastPipeline->GetVertexStateDescriptor()->indexFormat;
-                    GLenum formatType = IndexFormatType(indexFormat);
-
                     uint64_t indirectBufferOffset = draw->indirectOffset;
                     Buffer* indirectBuffer = ToBackend(draw->indirectBuffer.Get());
 
+                    // If a index format was specified in setIndexBuffer always use it.
+                    wgpu::IndexFormat indexFormat = indexBufferFormat;
+                    if (indexFormat == wgpu::IndexFormat::Undefined) {
+                        // Otherwise use the pipeline's index format.
+                        // TODO(crbug.com/dawn/502): This path is deprecated.
+                        indexFormat = lastPipeline->GetVertexStateDescriptor()->indexFormat;
+                    }
+
                     gl.BindBuffer(GL_DRAW_INDIRECT_BUFFER, indirectBuffer->GetHandle());
                     gl.DrawElementsIndirect(
-                        lastPipeline->GetGLPrimitiveTopology(), formatType,
+                        lastPipeline->GetGLPrimitiveTopology(), IndexFormatType(indexFormat),
                         reinterpret_cast<void*>(static_cast<intptr_t>(indirectBufferOffset)));
                     break;
                 }
@@ -1110,6 +1120,9 @@
 
                 case Command::SetIndexBuffer: {
                     SetIndexBufferCmd* cmd = iter->NextCommand<SetIndexBufferCmd>();
+                    // TODO(crbug.com/dawn/502): Once setIndexBuffer is required to specify an
+                    // index buffer format store as an GLenum.
+                    indexBufferFormat = cmd->format;
                     indexBufferBaseOffset = cmd->offset;
                     vertexStateBufferBindingTracker.OnSetIndexBuffer(cmd->buffer.Get());
                     break;
diff --git a/src/dawn_native/vulkan/CommandBufferVk.cpp b/src/dawn_native/vulkan/CommandBufferVk.cpp
index ce86612..1cf5028 100644
--- a/src/dawn_native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn_native/vulkan/CommandBufferVk.cpp
@@ -185,6 +185,41 @@
             }
         };
 
+        class IndexBufferTracker {
+          public:
+            void OnSetIndexBuffer(VkBuffer buffer, wgpu::IndexFormat format, VkDeviceSize offset) {
+                mIndexBuffer = buffer;
+                mOffset = offset;
+                mBufferIndexFormat = format;
+
+                mLastAppliedIndexFormat = wgpu::IndexFormat::Undefined;
+            }
+
+            void OnSetPipeline(RenderPipeline* pipeline) {
+                mPipelineIndexFormat = pipeline->GetVertexStateDescriptor()->indexFormat;
+            }
+
+            void Apply(Device* device, VkCommandBuffer commands) {
+                wgpu::IndexFormat newIndexFormat = mBufferIndexFormat;
+                if (newIndexFormat == wgpu::IndexFormat::Undefined) {
+                    newIndexFormat = mPipelineIndexFormat;
+                }
+
+                if (newIndexFormat != mLastAppliedIndexFormat) {
+                    device->fn.CmdBindIndexBuffer(commands, mIndexBuffer, mOffset,
+                                                  VulkanIndexType(newIndexFormat));
+                    mLastAppliedIndexFormat = newIndexFormat;
+                }
+            }
+
+          private:
+            wgpu::IndexFormat mBufferIndexFormat = wgpu::IndexFormat::Undefined;
+            wgpu::IndexFormat mPipelineIndexFormat = wgpu::IndexFormat::Undefined;
+            wgpu::IndexFormat mLastAppliedIndexFormat = wgpu::IndexFormat::Undefined;
+            VkBuffer mIndexBuffer = VK_NULL_HANDLE;
+            VkDeviceSize mOffset;
+        };
+
         MaybeError RecordBeginRenderPass(CommandRecordingContext* recordingContext,
                                          Device* device,
                                          BeginRenderPassCmd* renderPass) {
@@ -799,6 +834,7 @@
         }
 
         RenderDescriptorSetTracker descriptorSets = {};
+        IndexBufferTracker indexBufferTracker = {};
         RenderPipeline* lastPipeline = nullptr;
 
         auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
@@ -816,6 +852,7 @@
                     DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
                     descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                    indexBufferTracker.Apply(device, commands);
                     device->fn.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
                                               draw->firstIndex, draw->baseVertex,
                                               draw->firstInstance);
@@ -838,6 +875,7 @@
                     VkBuffer indirectBuffer = ToBackend(draw->indirectBuffer)->GetHandle();
 
                     descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                    indexBufferTracker.Apply(device, commands);
                     device->fn.CmdDrawIndexedIndirect(
                         commands, indirectBuffer, static_cast<VkDeviceSize>(draw->indirectOffset),
                         1, 0);
@@ -911,13 +949,8 @@
                     SetIndexBufferCmd* cmd = iter->NextCommand<SetIndexBufferCmd>();
                     VkBuffer indexBuffer = ToBackend(cmd->buffer)->GetHandle();
 
-                    // TODO(cwallez@chromium.org): get the index type from the last render pipeline
-                    // and rebind if needed on pipeline change
-                    ASSERT(lastPipeline != nullptr);
-                    VkIndexType indexType =
-                        VulkanIndexType(lastPipeline->GetVertexStateDescriptor()->indexFormat);
-                    device->fn.CmdBindIndexBuffer(
-                        commands, indexBuffer, static_cast<VkDeviceSize>(cmd->offset), indexType);
+                    indexBufferTracker.OnSetIndexBuffer(indexBuffer, cmd->format,
+                                                        static_cast<VkDeviceSize>(cmd->offset));
                     break;
                 }
 
@@ -930,6 +963,7 @@
                     lastPipeline = pipeline;
 
                     descriptorSets.OnSetPipeline(pipeline);
+                    indexBufferTracker.OnSetPipeline(pipeline);
                     break;
                 }
 
diff --git a/src/tests/end2end/BufferZeroInitTests.cpp b/src/tests/end2end/BufferZeroInitTests.cpp
index 8a15bfe..d18d1cf 100644
--- a/src/tests/end2end/BufferZeroInitTests.cpp
+++ b/src/tests/end2end/BufferZeroInitTests.cpp
@@ -313,7 +313,8 @@
 
         // Bind the buffer with offset == indexBufferOffset and size sizeof(uint32_t) as the index
         // buffer.
-        renderPass.SetIndexBuffer(indexBuffer, indexBufferOffset, sizeof(uint32_t));
+        renderPass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint16,
+                                            indexBufferOffset, sizeof(uint32_t));
         renderPass.DrawIndexed(1);
         renderPass.EndPass();
 
@@ -392,7 +393,7 @@
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
         wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
         renderPass.SetPipeline(renderPipeline);
-        renderPass.SetIndexBuffer(indexBuffer);
+        renderPass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint16);
         renderPass.DrawIndexedIndirect(indirectBuffer, indirectBufferOffset);
         renderPass.EndPass();
 
diff --git a/src/tests/end2end/DeprecatedAPITests.cpp b/src/tests/end2end/DeprecatedAPITests.cpp
index 1fd9d1a..b24af76 100644
--- a/src/tests/end2end/DeprecatedAPITests.cpp
+++ b/src/tests/end2end/DeprecatedAPITests.cpp
@@ -82,3 +82,108 @@
 
     wgpu::Extent3D copySize = {1, 1, 1};
 };
+
+constexpr uint32_t kRTSize = 400;
+
+class SetIndexBufferDeprecationTests : public DeprecationTests {
+  protected:
+    void SetUp() override {
+        DeprecationTests::SetUp();
+
+        renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+    }
+
+    utils::BasicRenderPass renderPass;
+
+    wgpu::RenderPipeline MakeTestPipeline(wgpu::IndexFormat format) {
+        wgpu::ShaderModule vsModule =
+            utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, R"(
+                #version 450
+                layout(location = 0) in vec4 pos;
+                void main() {
+                    gl_Position = pos;
+                })");
+
+        wgpu::ShaderModule fsModule =
+            utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, R"(
+                #version 450
+                layout(location = 0) out vec4 fragColor;
+                void main() {
+                    fragColor = vec4(0.0, 1.0, 0.0, 1.0);
+                })");
+
+        utils::ComboRenderPipelineDescriptor descriptor(device);
+        descriptor.vertexStage.module = vsModule;
+        descriptor.cFragmentStage.module = fsModule;
+        descriptor.primitiveTopology = wgpu::PrimitiveTopology::TriangleStrip;
+        descriptor.cVertexState.indexFormat = format;
+        descriptor.cVertexState.vertexBufferCount = 1;
+        descriptor.cVertexState.cVertexBuffers[0].arrayStride = 4 * sizeof(float);
+        descriptor.cVertexState.cVertexBuffers[0].attributeCount = 1;
+        descriptor.cVertexState.cAttributes[0].format = wgpu::VertexFormat::Float4;
+        descriptor.cColorStates[0].format = renderPass.colorFormat;
+
+        return device.CreateRenderPipeline(&descriptor);
+    }
+};
+
+// Test that the Uint32 index format is correctly interpreted
+TEST_P(SetIndexBufferDeprecationTests, Uint32) {
+    wgpu::RenderPipeline pipeline = MakeTestPipeline(wgpu::IndexFormat::Uint32);
+
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData<float>(
+        device, wgpu::BufferUsage::Vertex,
+        {-1.0f, -1.0f, 0.0f, 1.0f,  // Note Vertices[0] = Vertices[1]
+         -1.0f, -1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 1.0f});
+    // If this is interpreted as Uint16, then it would be 0, 1, 0, ... and would draw nothing.
+    wgpu::Buffer indexBuffer =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Index, {1, 2, 3});
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    {
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetPipeline(pipeline);
+        pass.SetVertexBuffer(0, vertexBuffer);
+        EXPECT_DEPRECATION_WARNING(pass.SetIndexBuffer(indexBuffer));
+        pass.DrawIndexed(3);
+        pass.EndPass();
+    }
+
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(RGBA8::kGreen, renderPass.color, 100, 300);
+}
+
+// Test that the Uint16 index format is correctly interpreted
+TEST_P(SetIndexBufferDeprecationTests, Uint16) {
+    wgpu::RenderPipeline pipeline = MakeTestPipeline(wgpu::IndexFormat::Uint16);
+
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData<float>(
+        device, wgpu::BufferUsage::Vertex,
+        {-1.0f, -1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 1.0f});
+    // If this is interpreted as uint32, it will have index 1 and 2 be both 0 and render nothing
+    wgpu::Buffer indexBuffer =
+        utils::CreateBufferFromData<uint16_t>(device, wgpu::BufferUsage::Index, {1, 2, 0, 0, 0, 0});
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    {
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetPipeline(pipeline);
+        pass.SetVertexBuffer(0, vertexBuffer);
+        EXPECT_DEPRECATION_WARNING(pass.SetIndexBuffer(indexBuffer));
+        pass.DrawIndexed(3);
+        pass.EndPass();
+    }
+
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(RGBA8::kGreen, renderPass.color, 100, 300);
+}
+
+DAWN_INSTANTIATE_TEST(SetIndexBufferDeprecationTests,
+                      D3D12Backend(),
+                      MetalBackend(),
+                      OpenGLBackend(),
+                      VulkanBackend());
diff --git a/src/tests/end2end/DrawIndexedIndirectTests.cpp b/src/tests/end2end/DrawIndexedIndirectTests.cpp
index a8e8c12..fc96092 100644
--- a/src/tests/end2end/DrawIndexedIndirectTests.cpp
+++ b/src/tests/end2end/DrawIndexedIndirectTests.cpp
@@ -88,7 +88,7 @@
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
             pass.SetPipeline(pipeline);
             pass.SetVertexBuffer(0, vertexBuffer);
-            pass.SetIndexBuffer(indexBuffer, indexOffset);
+            pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32, indexOffset);
             pass.DrawIndexedIndirect(indirectBuffer, indirectOffset);
             pass.EndPass();
         }
diff --git a/src/tests/end2end/DrawIndexedTests.cpp b/src/tests/end2end/DrawIndexedTests.cpp
index f980af3..07b2602 100644
--- a/src/tests/end2end/DrawIndexedTests.cpp
+++ b/src/tests/end2end/DrawIndexedTests.cpp
@@ -88,7 +88,7 @@
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
             pass.SetPipeline(pipeline);
             pass.SetVertexBuffer(0, vertexBuffer);
-            pass.SetIndexBuffer(indexBuffer, bufferOffset);
+            pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32, bufferOffset);
             pass.DrawIndexed(indexCount, instanceCount, firstIndex, baseVertex, firstInstance);
             pass.EndPass();
         }
diff --git a/src/tests/end2end/GpuMemorySynchronizationTests.cpp b/src/tests/end2end/GpuMemorySynchronizationTests.cpp
index 4cf64d0..f340e88 100644
--- a/src/tests/end2end/GpuMemorySynchronizationTests.cpp
+++ b/src/tests/end2end/GpuMemorySynchronizationTests.cpp
@@ -552,7 +552,7 @@
     wgpu::RenderPassEncoder pass1 = encoder.BeginRenderPass(&renderPass.renderPassInfo);
     pass1.SetPipeline(rp);
     pass1.SetVertexBuffer(0, vertexBuffer);
-    pass1.SetIndexBuffer(indexBuffer, 0);
+    pass1.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32, 0);
     pass1.SetBindGroup(0, bindGroup1);
     pass1.DrawIndexed(6);
     pass1.EndPass();
@@ -676,7 +676,7 @@
     wgpu::RenderPassEncoder pass1 = encoder.BeginRenderPass(&renderPass.renderPassInfo);
     pass1.SetPipeline(rp);
     pass1.SetVertexBuffer(0, buffer);
-    pass1.SetIndexBuffer(buffer, offsetof(Data, indices));
+    pass1.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, offsetof(Data, indices));
     pass1.SetBindGroup(0, bindGroup1);
     pass1.DrawIndexed(6);
     pass1.EndPass();
diff --git a/src/tests/end2end/IndexFormatTests.cpp b/src/tests/end2end/IndexFormatTests.cpp
index 554c674..cdd9765 100644
--- a/src/tests/end2end/IndexFormatTests.cpp
+++ b/src/tests/end2end/IndexFormatTests.cpp
@@ -30,7 +30,8 @@
 
     utils::BasicRenderPass renderPass;
 
-    wgpu::RenderPipeline MakeTestPipeline(wgpu::IndexFormat format) {
+    wgpu::RenderPipeline MakeTestPipeline(wgpu::IndexFormat format,
+        wgpu::PrimitiveTopology primitiveTopology = wgpu::PrimitiveTopology::TriangleStrip) {
         wgpu::ShaderModule vsModule =
             utils::CreateShaderModule(device, utils::SingleShaderStage::Vertex, R"(
                 #version 450
@@ -50,7 +51,7 @@
         utils::ComboRenderPipelineDescriptor descriptor(device);
         descriptor.vertexStage.module = vsModule;
         descriptor.cFragmentStage.module = fsModule;
-        descriptor.primitiveTopology = wgpu::PrimitiveTopology::TriangleStrip;
+        descriptor.primitiveTopology = primitiveTopology;
         descriptor.cVertexState.indexFormat = format;
         descriptor.cVertexState.vertexBufferCount = 1;
         descriptor.cVertexState.cVertexBuffers[0].arrayStride = 4 * sizeof(float);
@@ -79,7 +80,7 @@
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
         pass.SetPipeline(pipeline);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
         pass.DrawIndexed(3);
         pass.EndPass();
     }
@@ -106,7 +107,7 @@
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
         pass.SetPipeline(pipeline);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint16);
         pass.DrawIndexed(3);
         pass.EndPass();
     }
@@ -156,7 +157,7 @@
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
         pass.SetPipeline(pipeline);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
         pass.DrawIndexed(7);
         pass.EndPass();
     }
@@ -198,7 +199,7 @@
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
         pass.SetPipeline(pipeline);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint16);
         pass.DrawIndexed(7);
         pass.EndPass();
     }
@@ -215,8 +216,6 @@
 // prevent a case in D3D12 where the index format would be captured from the last
 // pipeline on SetIndexBuffer.
 TEST_P(IndexFormatTest, ChangePipelineAfterSetIndexBuffer) {
-    DAWN_SKIP_TEST_IF(IsD3D12() || IsVulkan());
-
     wgpu::RenderPipeline pipeline32 = MakeTestPipeline(wgpu::IndexFormat::Uint32);
     wgpu::RenderPipeline pipeline16 = MakeTestPipeline(wgpu::IndexFormat::Uint16);
 
@@ -233,7 +232,7 @@
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
         pass.SetPipeline(pipeline16);
         pass.SetVertexBuffer(0, vertexBuffer);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
         pass.SetPipeline(pipeline32);
         pass.DrawIndexed(3);
         pass.EndPass();
@@ -250,7 +249,7 @@
 // because it needs to be done lazily (to query the format from the last pipeline).
 // TODO(cwallez@chromium.org): This is currently disallowed by the validation but
 // we want to support eventually.
-TEST_P(IndexFormatTest, DISABLED_SetIndexBufferBeforeSetPipeline) {
+TEST_P(IndexFormatTest, SetIndexBufferBeforeSetPipeline) {
     wgpu::RenderPipeline pipeline = MakeTestPipeline(wgpu::IndexFormat::Uint32);
 
     wgpu::Buffer vertexBuffer = utils::CreateBufferFromData<float>(
@@ -262,7 +261,7 @@
     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
     {
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
-        pass.SetIndexBuffer(indexBuffer);
+        pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
         pass.SetPipeline(pipeline);
         pass.SetVertexBuffer(0, vertexBuffer);
         pass.DrawIndexed(3);
@@ -275,6 +274,51 @@
     EXPECT_PIXEL_RGBA8_EQ(RGBA8(0, 255, 0, 255), renderPass.color, 100, 300);
 }
 
+// Test that index buffers of multiple formats can be used with a pipeline that
+// doesn't use strip primitive topology.
+TEST_P(IndexFormatTest, SetIndexBufferDifferentFormats) {
+    wgpu::RenderPipeline pipeline = MakeTestPipeline(wgpu::IndexFormat::Undefined,
+                                                     wgpu::PrimitiveTopology::TriangleList);
+
+    wgpu::Buffer vertexBuffer = utils::CreateBufferFromData<float>(
+        device, wgpu::BufferUsage::Vertex,
+        {-1.0f, -1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 1.0f});
+    wgpu::Buffer indexBuffer32 =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Index, {0, 1, 2});
+    wgpu::Buffer indexBuffer16 =
+        utils::CreateBufferFromData<uint16_t>(device, wgpu::BufferUsage::Index, {0, 1, 2, 0});
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    {
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetIndexBufferWithFormat(indexBuffer32, wgpu::IndexFormat::Uint32);
+        pass.SetPipeline(pipeline);
+        pass.SetVertexBuffer(0, vertexBuffer);
+        pass.DrawIndexed(3);
+        pass.EndPass();
+    }
+
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(RGBA8(0, 255, 0, 255), renderPass.color, 100, 300);
+
+    encoder = device.CreateCommandEncoder();
+    {
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetIndexBufferWithFormat(indexBuffer16, wgpu::IndexFormat::Uint16);
+        pass.SetPipeline(pipeline);
+        pass.SetVertexBuffer(0, vertexBuffer);
+        pass.DrawIndexed(3);
+        pass.EndPass();
+    }
+
+    commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(RGBA8(0, 255, 0, 255), renderPass.color, 100, 300);
+}
+
 DAWN_INSTANTIATE_TEST(IndexFormatTest,
                       D3D12Backend(),
                       MetalBackend(),
diff --git a/src/tests/unittests/validation/DrawIndirectValidationTests.cpp b/src/tests/unittests/validation/DrawIndirectValidationTests.cpp
index 4b30651..e54c7d0 100644
--- a/src/tests/unittests/validation/DrawIndirectValidationTests.cpp
+++ b/src/tests/unittests/validation/DrawIndirectValidationTests.cpp
@@ -84,7 +84,7 @@
             uint32_t zeros[100] = {};
             wgpu::Buffer indexBuffer =
                 utils::CreateBufferFromData(device, zeros, sizeof(zeros), wgpu::BufferUsage::Index);
-            pass.SetIndexBuffer(indexBuffer);
+            pass.SetIndexBufferWithFormat(indexBuffer, wgpu::IndexFormat::Uint32);
             pass.DrawIndexedIndirect(indirectBuffer, indirectOffset);
         } else {
             pass.DrawIndirect(indirectBuffer, indirectOffset);
diff --git a/src/tests/unittests/validation/IndexBufferValidationTests.cpp b/src/tests/unittests/validation/IndexBufferValidationTests.cpp
index 32ea153..b2b68b5 100644
--- a/src/tests/unittests/validation/IndexBufferValidationTests.cpp
+++ b/src/tests/unittests/validation/IndexBufferValidationTests.cpp
@@ -31,13 +31,13 @@
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass);
         // Explicit size
-        pass.SetIndexBuffer(buffer, 0, 256);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 0, 256);
         // Implicit size
-        pass.SetIndexBuffer(buffer, 0, 0);
-        pass.SetIndexBuffer(buffer, 256 - 4, 0);
-        pass.SetIndexBuffer(buffer, 4, 0);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 0, 0);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 256 - 4, 0);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 4, 0);
         // Implicit size of zero
-        pass.SetIndexBuffer(buffer, 256, 0);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 256, 0);
         pass.EndPass();
         encoder.Finish();
     }
@@ -46,7 +46,7 @@
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass);
-        pass.SetIndexBuffer(buffer, 4, 256);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 4, 256);
         pass.EndPass();
         ASSERT_DEVICE_ERROR(encoder.Finish());
     }
@@ -55,7 +55,7 @@
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
         wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass);
-        pass.SetIndexBuffer(buffer, 256 + 4, 0);
+        pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 256 + 4, 0);
         pass.EndPass();
         ASSERT_DEVICE_ERROR(encoder.Finish());
     }
@@ -68,27 +68,27 @@
     {
         wgpu::RenderBundleEncoder encoder = device.CreateRenderBundleEncoder(&renderBundleDesc);
         // Explicit size
-        encoder.SetIndexBuffer(buffer, 0, 256);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 0, 256);
         // Implicit size
-        encoder.SetIndexBuffer(buffer, 0, 0);
-        encoder.SetIndexBuffer(buffer, 256 - 4, 0);
-        encoder.SetIndexBuffer(buffer, 4, 0);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 0, 0);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 256 - 4, 0);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 4, 0);
         // Implicit size of zero
-        encoder.SetIndexBuffer(buffer, 256, 0);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 256, 0);
         encoder.Finish();
     }
 
     // Bad case, offset + size is larger than the buffer
     {
         wgpu::RenderBundleEncoder encoder = device.CreateRenderBundleEncoder(&renderBundleDesc);
-        encoder.SetIndexBuffer(buffer, 4, 256);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 4, 256);
         ASSERT_DEVICE_ERROR(encoder.Finish());
     }
 
     // Bad case, size is 0 but the offset is larger than the buffer
     {
         wgpu::RenderBundleEncoder encoder = device.CreateRenderBundleEncoder(&renderBundleDesc);
-        encoder.SetIndexBuffer(buffer, 256 + 4, 0);
+        encoder.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32, 256 + 4, 0);
         ASSERT_DEVICE_ERROR(encoder.Finish());
     }
 }
diff --git a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
index d0af33f..cc353c1 100644
--- a/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
+++ b/src/tests/unittests/validation/RenderPipelineValidationTests.cpp
@@ -526,3 +526,60 @@
         })");
     ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
 }
+
+// Tests that strip primitive topologies require an index format
+TEST_F(RenderPipelineValidationTest, StripIndexFormatRequired) {
+    constexpr uint32_t kNumStripType = 2u;
+    constexpr uint32_t kNumListType = 3u;
+    constexpr uint32_t kNumIndexFormat = 3u;
+
+    std::array<wgpu::PrimitiveTopology, kNumStripType> kStripTopologyTypes = {{
+        wgpu::PrimitiveTopology::LineStrip,
+        wgpu::PrimitiveTopology::TriangleStrip
+    }};
+
+    std::array<wgpu::PrimitiveTopology, kNumListType> kListTopologyTypes = {{
+        wgpu::PrimitiveTopology::PointList,
+        wgpu::PrimitiveTopology::LineList,
+        wgpu::PrimitiveTopology::TriangleList
+    }};
+
+    std::array<wgpu::IndexFormat, kNumIndexFormat> kIndexFormatTypes = {{
+        wgpu::IndexFormat::Undefined,
+        wgpu::IndexFormat::Uint16,
+        wgpu::IndexFormat::Uint32
+    }};
+
+    for (wgpu::PrimitiveTopology primitiveTopology : kStripTopologyTypes) {
+        for (wgpu::IndexFormat indexFormat : kIndexFormatTypes) {
+            utils::ComboRenderPipelineDescriptor descriptor(device);
+            descriptor.vertexStage.module = vsModule;
+            descriptor.cFragmentStage.module = fsModule;
+            descriptor.primitiveTopology = primitiveTopology;
+            descriptor.cVertexState.indexFormat = indexFormat;
+
+            if (indexFormat == wgpu::IndexFormat::Undefined) {
+                // Fail because the index format is undefined and the primitive
+                // topology is a strip type.
+                ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor));
+            } else {
+                // Succeeds because the index format is given.
+                device.CreateRenderPipeline(&descriptor);
+            }
+        }
+    }
+
+    for (wgpu::PrimitiveTopology primitiveTopology : kListTopologyTypes) {
+        for (wgpu::IndexFormat indexFormat : kIndexFormatTypes) {
+            utils::ComboRenderPipelineDescriptor descriptor(device);
+            descriptor.vertexStage.module = vsModule;
+            descriptor.cFragmentStage.module = fsModule;
+            descriptor.primitiveTopology = primitiveTopology;
+            descriptor.cVertexState.indexFormat = indexFormat;
+
+            // Succeeds even when the index format is undefined because the
+            // primitive topology isn't a strip type.
+            device.CreateRenderPipeline(&descriptor);
+        }
+    }
+}
diff --git a/src/tests/unittests/validation/ResourceUsageTrackingTests.cpp b/src/tests/unittests/validation/ResourceUsageTrackingTests.cpp
index 846a874..2bab964 100644
--- a/src/tests/unittests/validation/ResourceUsageTrackingTests.cpp
+++ b/src/tests/unittests/validation/ResourceUsageTrackingTests.cpp
@@ -91,7 +91,7 @@
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             DummyRenderPass dummyRenderPass(device);
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-            pass.SetIndexBuffer(buffer);
+            pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32);
             pass.SetVertexBuffer(0, buffer);
             pass.EndPass();
             encoder.Finish();
@@ -135,7 +135,7 @@
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             DummyRenderPass dummyRenderPass(device);
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-            pass.SetIndexBuffer(buffer);
+            pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32);
             pass.SetBindGroup(0, bg);
             pass.EndPass();
             ASSERT_DEVICE_ERROR(encoder.Finish());
@@ -258,12 +258,12 @@
             DummyRenderPass dummyRenderPass(device);
 
             wgpu::RenderPassEncoder pass0 = encoder.BeginRenderPass(&dummyRenderPass);
-            pass0.SetIndexBuffer(buffer0);
+            pass0.SetIndexBufferWithFormat(buffer0, wgpu::IndexFormat::Uint32);
             pass0.SetBindGroup(0, bg1);
             pass0.EndPass();
 
             wgpu::RenderPassEncoder pass1 = encoder.BeginRenderPass(&dummyRenderPass);
-            pass1.SetIndexBuffer(buffer1);
+            pass1.SetIndexBufferWithFormat(buffer1, wgpu::IndexFormat::Uint32);
             pass1.SetBindGroup(0, bg0);
             pass1.EndPass();
 
@@ -349,7 +349,7 @@
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
             pass.SetPipeline(rp);
 
-            pass.SetIndexBuffer(buffer);
+            pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32);
             pass.Draw(3);
 
             pass.SetBindGroup(0, bg);
@@ -414,7 +414,7 @@
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
             pass.SetPipeline(rp);
 
-            pass.SetIndexBuffer(buffer);
+            pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32);
             pass.SetBindGroup(0, writeBG);
             pass.Draw(3);
 
@@ -514,8 +514,8 @@
         {
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-            pass.SetIndexBuffer(buffer0);
-            pass.SetIndexBuffer(buffer1);
+            pass.SetIndexBufferWithFormat(buffer0, wgpu::IndexFormat::Uint32);
+            pass.SetIndexBufferWithFormat(buffer1, wgpu::IndexFormat::Uint32);
             pass.SetBindGroup(0, bg);
             pass.EndPass();
             ASSERT_DEVICE_ERROR(encoder.Finish());
@@ -526,8 +526,8 @@
         {
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-            pass.SetIndexBuffer(buffer1);
-            pass.SetIndexBuffer(buffer0);
+            pass.SetIndexBufferWithFormat(buffer1, wgpu::IndexFormat::Uint32);
+            pass.SetIndexBufferWithFormat(buffer0, wgpu::IndexFormat::Uint32);
             pass.SetBindGroup(0, bg);
             pass.EndPass();
             ASSERT_DEVICE_ERROR(encoder.Finish());
@@ -584,7 +584,7 @@
             {
                 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
                 wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-                pass.SetIndexBuffer(buffer0);
+                pass.SetIndexBufferWithFormat(buffer0, wgpu::IndexFormat::Uint32);
                 pass.SetBindGroup(0, bg0);
                 pass.SetBindGroup(0, bg1);
                 pass.EndPass();
@@ -596,7 +596,7 @@
             {
                 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
                 wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-                pass.SetIndexBuffer(buffer0);
+                pass.SetIndexBufferWithFormat(buffer0, wgpu::IndexFormat::Uint32);
                 pass.SetBindGroup(0, bg1);
                 pass.SetBindGroup(0, bg0);
                 pass.EndPass();
@@ -724,7 +724,7 @@
             wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
             DummyRenderPass dummyRenderPass(device);
             wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
-            pass.SetIndexBuffer(buffer);
+            pass.SetIndexBufferWithFormat(buffer, wgpu::IndexFormat::Uint32);
             pass.SetBindGroup(0, bg);
             pass.EndPass();
             ASSERT_DEVICE_ERROR(encoder.Finish());