Add MultiDrawIndirect Vulkan backend + Validation + end2end_tests

This change includes the vulkan backend for MultiDrawIndirect. There are some preparations for the Metal backend.
Validation is implemented to work with Vulkan. Needs some more work, when D3D12 is implemented.
The end2end tests are derived from DrawIndirect tests, so they are very similar.

Change-Id: I93667798537a529e963ebcb3d8b3d269039501ab
Bug: 356461286
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/201354
Commit-Queue: Srijan Dhungana <srijan.dhungana6@gmail.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Auto-Submit: Srijan Dhungana <srijan.dhungana6@gmail.com>
diff --git a/src/dawn/common/Math.h b/src/dawn/common/Math.h
index 7ef28b5..27909fb 100644
--- a/src/dawn/common/Math.h
+++ b/src/dawn/common/Math.h
@@ -78,6 +78,14 @@
     return (value + (alignmentT - 1)) & ~(alignmentT - 1);
 }
 
+template <typename T>
+T AlignDown(T value, size_t alignment) {
+    DAWN_ASSERT(IsPowerOfTwo(alignment));
+    DAWN_ASSERT(alignment != 0);
+    T alignmentT = static_cast<T>(alignment);
+    return value & ~(alignmentT - 1);
+}
+
 template <typename T, size_t Alignment>
 constexpr size_t AlignSizeof() {
     static_assert(Alignment != 0 && (Alignment & (Alignment - 1)) == 0,
diff --git a/src/dawn/native/CommandBufferStateTracker.cpp b/src/dawn/native/CommandBufferStateTracker.cpp
index 07d6d7a..04500a9 100644
--- a/src/dawn/native/CommandBufferStateTracker.cpp
+++ b/src/dawn/native/CommandBufferStateTracker.cpp
@@ -555,7 +555,7 @@
         }
     }
 
-    if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) {
+    if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && IndexBufferSet()) {
         RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
         if (!IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) ||
             mIndexFormat == lastRenderPipeline->GetStripIndexFormat()) {
@@ -572,7 +572,7 @@
     DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_PIPELINE], "No pipeline set.");
 
     if (aspects[VALIDATION_ASPECT_INDEX_BUFFER]) {
-        DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set.");
+        DAWN_INVALID_IF(!IndexBufferSet(), "Index buffer was not set.");
 
         RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
         wgpu::IndexFormat pipelineIndexFormat = lastRenderPipeline->GetStripIndexFormat();
@@ -754,10 +754,11 @@
     mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS);
 }
 
-void CommandBufferStateTracker::SetIndexBuffer(wgpu::IndexFormat format,
+void CommandBufferStateTracker::SetIndexBuffer(BufferBase* buffer,
+                                               wgpu::IndexFormat format,
                                                uint64_t offset,
                                                uint64_t size) {
-    mIndexBufferSet = true;
+    mIndexBuffer = buffer;
     mIndexFormat = format;
     mIndexBufferSize = size;
     mIndexBufferOffset = offset;
@@ -798,6 +799,10 @@
     return mLastPipeline != nullptr;
 }
 
+bool CommandBufferStateTracker::IndexBufferSet() const {
+    return mIndexBuffer != nullptr;
+}
+
 RenderPipelineBase* CommandBufferStateTracker::GetRenderPipeline() const {
     DAWN_ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::RenderPipeline);
     return static_cast<RenderPipelineBase*>(mLastPipeline);
@@ -812,6 +817,10 @@
     return mLastPipelineLayout;
 }
 
+BufferBase* CommandBufferStateTracker::GetIndexBuffer() const {
+    return mIndexBuffer;
+}
+
 wgpu::IndexFormat CommandBufferStateTracker::GetIndexFormat() const {
     return mIndexFormat;
 }
diff --git a/src/dawn/native/CommandBufferStateTracker.h b/src/dawn/native/CommandBufferStateTracker.h
index 8476184..66d07cc 100644
--- a/src/dawn/native/CommandBufferStateTracker.h
+++ b/src/dawn/native/CommandBufferStateTracker.h
@@ -68,7 +68,10 @@
                       BindGroupBase* bindgroup,
                       uint32_t dynamicOffsetCount,
                       const uint32_t* dynamicOffsets);
-    void SetIndexBuffer(wgpu::IndexFormat format, uint64_t offset, uint64_t size);
+    void SetIndexBuffer(BufferBase* buffer,
+                        wgpu::IndexFormat format,
+                        uint64_t offset,
+                        uint64_t size);
     void UnsetVertexBuffer(VertexBufferSlot slot);
     void SetVertexBuffer(VertexBufferSlot slot, uint64_t size);
     void End();
@@ -79,9 +82,11 @@
     BindGroupBase* GetBindGroup(BindGroupIndex index) const;
     const std::vector<uint32_t>& GetDynamicOffsets(BindGroupIndex index) const;
     bool HasPipeline() const;
+    bool IndexBufferSet() const;
     RenderPipelineBase* GetRenderPipeline() const;
     ComputePipelineBase* GetComputePipeline() const;
     PipelineLayoutBase* GetPipelineLayout() const;
+    BufferBase* GetIndexBuffer() const;
     wgpu::IndexFormat GetIndexFormat() const;
     uint64_t GetIndexBufferSize() const;
     uint64_t GetIndexBufferOffset() const;
@@ -98,10 +103,10 @@
     VertexBufferMask mVertexBuffersUsed;
     PerVertexBuffer<uint64_t> mVertexBufferSizes = {};
 
-    bool mIndexBufferSet = false;
     wgpu::IndexFormat mIndexFormat;
     uint64_t mIndexBufferSize = 0;
     uint64_t mIndexBufferOffset = 0;
+    RAW_PTR_EXCLUSION BufferBase* mIndexBuffer = nullptr;
 
     // RAW_PTR_EXCLUSION: These pointers are very hot in command recording code and point at
     // various objects referenced by the object graph of the CommandBuffer so they cannot be
diff --git a/src/dawn/native/IndirectDrawMetadata.cpp b/src/dawn/native/IndirectDrawMetadata.cpp
index 1ef6980..04f652d 100644
--- a/src/dawn/native/IndirectDrawMetadata.cpp
+++ b/src/dawn/native/IndirectDrawMetadata.cpp
@@ -154,6 +154,11 @@
     return &mIndexedIndirectBufferValidationInfo;
 }
 
+const std::vector<IndirectDrawMetadata::IndirectMultiDraw>&
+IndirectDrawMetadata::GetIndirectMultiDraws() const {
+    return mMultiDraws;
+}
+
 void IndirectDrawMetadata::AddBundle(RenderBundleBase* bundle) {
     auto [_, inserted] = mAddedBundles.insert(bundle);
     if (!inserted) {
@@ -237,6 +242,27 @@
     mIndexedIndirectBufferValidationInfo.clear();
 }
 
+void IndirectDrawMetadata::AddMultiDrawIndirect(MultiDrawIndirectCmd* cmd) {
+    IndirectMultiDraw multiDraw;
+    multiDraw.type = DrawType::NonIndexed;
+    multiDraw.cmd = cmd;
+    mMultiDraws.push_back(multiDraw);
+}
+
+void IndirectDrawMetadata::AddMultiDrawIndexedIndirect(BufferBase* indexBuffer,
+                                                       wgpu::IndexFormat indexFormat,
+                                                       uint64_t indexBufferSize,
+                                                       uint64_t indexBufferOffset,
+                                                       MultiDrawIndexedIndirectCmd* cmd) {
+    IndirectMultiDraw multiDraw;
+    multiDraw.type = DrawType::Indexed;
+    multiDraw.cmd = cmd;
+    multiDraw.indexBufferSize = indexBufferSize;
+    multiDraw.indexFormat = indexFormat;
+
+    mMultiDraws.push_back(multiDraw);
+}
+
 bool IndirectDrawMetadata::IndexedIndirectConfig::operator<(
     const IndexedIndirectConfig& other) const {
     return std::tie(inputIndirectBufferPtr, duplicateBaseVertexInstance, drawType) <
diff --git a/src/dawn/native/IndirectDrawMetadata.h b/src/dawn/native/IndirectDrawMetadata.h
index d86cb9b..657dd08 100644
--- a/src/dawn/native/IndirectDrawMetadata.h
+++ b/src/dawn/native/IndirectDrawMetadata.h
@@ -56,13 +56,20 @@
 // commands.
 class IndirectDrawMetadata : public NonCopyable {
   public:
+    enum class DrawType : uint8_t {
+        NonIndexed,
+        Indexed,
+    };
+
     struct IndirectDraw {
         uint64_t inputBufferOffset;
         uint64_t numIndexBufferElements;
         uint64_t indexBufferOffsetInElements;
-        // This is a pointer to the command that should be populated with the validated
-        // indirect scratch buffer. It is only valid up until the encoded command buffer
-        // is submitted.
+        // When validation is enabled, the original indirect buffer is validated and copied to a new
+        // indirect buffer containing only valid commands. The pointer to the command allocated in
+        // the command allocator is used to swap the indirect buffer for the validated one before
+        // the backend processes the command. Valid until the backend has processed the
+        // commands.
         raw_ptr<DrawIndirectCmd> cmd;
     };
 
@@ -72,6 +79,20 @@
         std::vector<IndirectDraw> draws;
     };
 
+    struct IndirectMultiDraw {
+        DrawType type;
+
+        uint64_t indexBufferSize;
+        wgpu::IndexFormat indexFormat;
+
+        // When validation is enabled, the original indirect buffer is validated and copied to a new
+        // indirect buffer containing only valid commands. The pointer to the command allocated in
+        // the command allocator is used to swap the indirect buffer for the validated one before
+        // the backend processes the command. Valid until the backend has processed the
+        // commands.
+        raw_ptr<MultiDrawIndirectCmd> cmd;
+    };
+
     // Tracks information about every draw call in this render pass which uses the same indirect
     // buffer and the same-sized index buffer. Calls are grouped by indirect offset ranges so
     // that validation work can be chunked efficiently if necessary.
@@ -111,10 +132,6 @@
         std::vector<IndirectValidationBatch> mBatches;
     };
 
-    enum class DrawType {
-        NonIndexed,
-        Indexed,
-    };
     struct IndexedIndirectConfig {
         uintptr_t inputIndirectBufferPtr;
         bool duplicateBaseVertexInstance;
@@ -149,12 +166,24 @@
                          bool duplicateBaseVertexInstance,
                          DrawIndirectCmd* cmd);
 
+    void AddMultiDrawIndirect(MultiDrawIndirectCmd* cmd);
+
+    void AddMultiDrawIndexedIndirect(BufferBase* indexBuffer,
+                                     wgpu::IndexFormat indexFormat,
+                                     uint64_t indexBufferSize,
+                                     uint64_t indexBufferOffset,
+                                     MultiDrawIndexedIndirectCmd* cmd);
+
     void ClearIndexedIndirectBufferValidationInfo();
 
+    const std::vector<IndirectMultiDraw>& GetIndirectMultiDraws() const;
+
   private:
     IndexedIndirectBufferValidationInfoMap mIndexedIndirectBufferValidationInfo;
     absl::flat_hash_set<RenderBundleBase*> mAddedBundles;
 
+    std::vector<IndirectMultiDraw> mMultiDraws;
+
     uint64_t mMaxBatchOffsetRange;
     uint32_t mMaxDrawCallsPerBatch;
 };
diff --git a/src/dawn/native/IndirectDrawValidationEncoder.cpp b/src/dawn/native/IndirectDrawValidationEncoder.cpp
index c37535b..182fbc6 100644
--- a/src/dawn/native/IndirectDrawValidationEncoder.cpp
+++ b/src/dawn/native/IndirectDrawValidationEncoder.cpp
@@ -45,6 +45,7 @@
 #include "dawn/native/Device.h"
 #include "dawn/native/InternalPipelineStore.h"
 #include "dawn/native/Queue.h"
+#include "dawn/native/RenderPipeline.h"
 #include "dawn/native/utils/WGPUHelpers.h"
 #include "partition_alloc/pointers/raw_ptr.h"
 
@@ -54,12 +55,13 @@
 // NOTE: This must match the workgroup_size attribute on the compute entry point below.
 constexpr uint64_t kWorkgroupSize = 64;
 
-// Bitmasks for BatchInfo::flags
+// Bitmasks for BatchInfo::flags and MultiDrawConstants::flags
 constexpr uint32_t kDuplicateBaseVertexInstance = 1;
 constexpr uint32_t kIndexedDraw = 2;
 constexpr uint32_t kValidationEnabled = 4;
 constexpr uint32_t kIndirectFirstInstanceEnabled = 8;
 constexpr uint32_t kUseFirstIndexToEmulateIndexBufferOffset = 16;
+constexpr uint32_t kIndirectDrawCountBuffer = 32;
 
 // Equivalent to the IndirectDraw struct defined in the shader below.
 struct IndirectDraw {
@@ -77,6 +79,16 @@
     uint32_t flags;
 };
 
+// Equivalent to MultiDrawConstants struct defined in the shader below.
+struct MultiDrawConstants {
+    uint32_t maxDrawCount;
+    uint32_t indirectOffsetInElements;
+    uint32_t drawCountOffsetInElements;
+    uint32_t numIndexBufferElementsLow;
+    uint32_t numIndexBufferElementsHigh;
+    uint32_t flags;
+};
+
 // The size, in bytes, of the IndirectDraw struct defined in the shader below.
 constexpr uint32_t kIndirectDrawByteSize = sizeof(uint32_t) * 4;
 
@@ -84,17 +96,30 @@
 // various failure modes.
 static const char sRenderValidationShaderSource[] = R"(
 
+            const kWorkgroupSize = 64u;
+
             const kNumDrawIndirectParams = 4u;
+            const kNumDrawIndexedIndirectParams = 5u;
 
             const kIndexCountEntry = 0u;
             const kFirstIndexEntry = 2u;
 
-            // Bitmasks for BatchInfo::flags
+            // Bitmasks for BatchInfo::flags and MultiDrawConstants::flags
             const kDuplicateBaseVertexInstance = 1u;
             const kIndexedDraw = 2u;
             const kValidationEnabled = 4u;
             const kIndirectFirstInstanceEnabled = 8u;
             const kUseFirstIndexToEmulateIndexBufferOffset = 16u;
+            const kIndirectDrawCountBuffer = 32u; // if set, drawCount is read from a buffer
+
+            struct MultiDrawConstants {
+                maxDrawCount: u32,
+                indirectOffsetInElements: u32,
+                drawCountOffsetInElements: u32,
+                numIndexBufferElementsLow: u32,
+                numIndexBufferElementsHigh: u32,
+                flags : u32,
+            }
 
             struct IndirectDraw {
                 indirectOffset: u32,
@@ -113,39 +138,44 @@
                 data: array<u32>,
             }
 
+            // We have two entry points, which use different descriptors at binding 0.
+            // Even though they are overlapping, we only use one for each entry point.
             @group(0) @binding(0) var<storage, read> batch: BatchInfo;
+            @group(0) @binding(0) var<storage, read> drawConstants: MultiDrawConstants;
             @group(0) @binding(1) var<storage, read_write> inputParams: IndirectParams;
             @group(0) @binding(2) var<storage, read_write> outputParams: IndirectParams;
+            // Although the drawCountBuffer only has a u32 value, it is stored in a buffer
+            // to allow for offsetting the buffer in the shader.
+            @group(0) @binding(3) var<storage, read_write> indirectDrawCount : IndirectParams;
 
-            fn numIndirectParamsPerDrawCallInput() -> u32 {
-                var numParams = kNumDrawIndirectParams;
+            fn numIndirectParamsPerDrawCallInput(flags : u32) -> u32 {
                 // Indexed Draw has an extra parameter (firstIndex)
-                if (bool(batch.flags & kIndexedDraw)) {
-                    numParams = numParams + 1u;
+                if (bool(flags & kIndexedDraw)) {
+                    return kNumDrawIndexedIndirectParams;
                 }
-                return numParams;
+                return kNumDrawIndirectParams;
             }
 
-            fn numIndirectParamsPerDrawCallOutput() -> u32 {
-                var numParams = numIndirectParamsPerDrawCallInput();
+            fn numIndirectParamsPerDrawCallOutput(flags : u32) -> u32 {
+                var numParams = numIndirectParamsPerDrawCallInput(flags);
                 // 2 extra parameter for duplicated first/baseVertex and firstInstance
-                if (bool(batch.flags & kDuplicateBaseVertexInstance)) {
+                if (bool(flags & kDuplicateBaseVertexInstance)) {
                     numParams = numParams + 2u;
                 }
                 return numParams;
             }
 
-            fn fail(drawIndex: u32) {
-                let numParams = numIndirectParamsPerDrawCallOutput();
+            fn fail(drawIndex: u32, flags : u32) {
+                let numParams = numIndirectParamsPerDrawCallOutput(flags);
                 let index = drawIndex * numParams;
                 for(var i = 0u; i < numParams; i = i + 1u) {
                     outputParams.data[index + i] = 0u;
                 }
             }
 
-            fn set_pass(drawIndex: u32) {
-                let numInputParams = numIndirectParamsPerDrawCallInput();
-                var outIndex = drawIndex * numIndirectParamsPerDrawCallOutput();
+            fn set_pass_single(drawIndex: u32) {
+                let numInputParams = numIndirectParamsPerDrawCallInput(batch.flags);
+                var outIndex = drawIndex * numIndirectParamsPerDrawCallOutput(batch.flags);
                 let inIndex = batch.draws[drawIndex].indirectOffset;
 
                 // The first 2 parameter is reserved for the duplicated first/baseVertex and firstInstance
@@ -168,29 +198,40 @@
                 }
             }
 
-            @compute @workgroup_size(64, 1, 1)
-            fn main(@builtin(global_invocation_id) id : vec3u) {
+            fn set_pass_multi(drawIndex: u32) {
+                let numInputParams = numIndirectParamsPerDrawCallInput(drawConstants.flags);
+                var outIndex = drawIndex * numIndirectParamsPerDrawCallOutput(drawConstants.flags);
+                let inIndex = drawIndex * numInputParams;
+                let inputOffset = drawConstants.indirectOffsetInElements;
+
+                for(var i = 0u; i < numInputParams; i = i + 1u) {
+                    outputParams.data[outIndex + i] = inputParams.data[inputOffset + inIndex + i];
+                }
+            }
+
+            @compute @workgroup_size(kWorkgroupSize, 1, 1)
+            fn validate_single_draw(@builtin(global_invocation_id) id : vec3u) {
                 if (id.x >= batch.numDraws) {
                     return;
                 }
 
                 if(!bool(batch.flags & kValidationEnabled)) {
-                    set_pass(id.x);
+                    set_pass_single(id.x);
                     return;
                 }
 
                 let inputIndex = batch.draws[id.x].indirectOffset;
                 if(!bool(batch.flags & kIndirectFirstInstanceEnabled)) {
                     // firstInstance is always the last parameter
-                    let firstInstance = inputParams.data[inputIndex + numIndirectParamsPerDrawCallInput() - 1u];
+                    let firstInstance = inputParams.data[inputIndex + numIndirectParamsPerDrawCallInput(batch.flags) - 1u];
                     if (firstInstance != 0u) {
-                        fail(id.x);
+                        fail(id.x, batch.flags);
                         return;
                     }
                 }
 
                 if (!bool(batch.flags & kIndexedDraw)) {
-                    set_pass(id.x);
+                    set_pass_single(id.x);
                     return;
                 }
 
@@ -199,7 +240,7 @@
                 if (numIndexBufferElementsHigh >= 2u) {
                     // firstIndex and indexCount are both u32. The maximum possible sum of these
                     // values is 0x1fffffffe, which is less than 0x200000000. Nothing to validate.
-                    set_pass(id.x);
+                    set_pass_single(id.x);
                     return;
                 }
 
@@ -208,7 +249,7 @@
                 let firstIndex = inputParams.data[inputIndex + kFirstIndexEntry];
                 if (numIndexBufferElementsHigh == 0u &&
                     numIndexBufferElementsLow < firstIndex) {
-                    fail(id.x);
+                    fail(id.x, batch.flags);
                     return;
                 }
 
@@ -217,65 +258,141 @@
                 let maxIndexCount = numIndexBufferElementsLow - firstIndex;
                 let indexCount = inputParams.data[inputIndex + kIndexCountEntry];
                 if (indexCount > maxIndexCount) {
-                    fail(id.x);
+                    fail(id.x, batch.flags);
                     return;
                 }
-                set_pass(id.x);
+                set_pass_single(id.x);
             }
+
+           @compute @workgroup_size(kWorkgroupSize, 1, 1)
+            fn validate_multi_draw(@builtin(global_invocation_id) id : vec3u) {
+
+                var drawCount = drawConstants.maxDrawCount;
+
+                var drawCountOffset = drawConstants.drawCountOffsetInElements;
+
+                if(bool(drawConstants.flags & kIndirectDrawCountBuffer)) {
+                    let drawCountInBuffer = indirectDrawCount.data[drawCountOffset];
+                    drawCount = min(drawCountInBuffer, drawCount);
+                }
+
+                if (id.x >= drawCount) {
+                    return;
+                }
+
+                let numIndexBufferElementsHigh = drawConstants.numIndexBufferElementsHigh;
+
+                if (numIndexBufferElementsHigh >= 2u) {
+                    // firstIndex and indexCount are both u32. The maximum possible sum of these
+                    // values is 0x1fffffffe, which is less than 0x200000000. Nothing to validate.
+                    set_pass_multi(id.x);
+                    return;
+                }
+
+                let numIndexBufferElementsLow = drawConstants.numIndexBufferElementsLow;
+                let inputOffset = drawConstants.indirectOffsetInElements;
+                let firstIndex = inputParams.data[inputOffset + id.x * numIndirectParamsPerDrawCallInput(drawConstants.flags) + kFirstIndexEntry];
+                if (numIndexBufferElementsHigh == 0u &&
+                    numIndexBufferElementsLow < firstIndex) {
+                    fail(id.x, drawConstants.flags);
+                    return;
+                }
+
+                // Note that this subtraction may underflow, but only when
+                // numIndexBufferElementsHigh is 1u. The result is still correct in that case.
+                let maxIndexCount = numIndexBufferElementsLow - firstIndex;
+                let indexCount = inputParams.data[inputOffset + id.x * numIndirectParamsPerDrawCallInput(drawConstants.flags) + kIndexCountEntry];
+                if (indexCount > maxIndexCount) {
+                    fail(id.x, drawConstants.flags);
+                    return;
+                }
+                set_pass_multi(id.x);
+
+            }
+
+
         )";
 
-ResultOrError<ComputePipelineBase*> GetOrCreateRenderValidationPipeline(DeviceBase* device) {
+ResultOrError<dawn::Ref<ComputePipelineBase>> CreateRenderValidationPipelines(
+    DeviceBase* device,
+    const char* entryPoint,
+    std::initializer_list<dawn::native::utils::BindingLayoutEntryInitializationHelper> entries) {
     InternalPipelineStore* store = device->GetInternalPipelineStore();
 
-    if (store->renderValidationPipeline == nullptr) {
-        // If we need to apply the index buffer offset to the first index then
-        // we can't handle buffers larger than 4gig otherwise we'll overflow first_index
-        // which is a 32bit value.
-        //
-        // When a buffer is less than 4gig the largest index buffer offset you can pass to
-        // SetIndexBuffer is 0xffff_fffe. Otherwise you'll get a validation error. This
-        // is converted to count of indices and so at most 0x7fff_ffff.
-        //
-        // The largest valid first_index would be 0x7fff_ffff. Anything larger will fail
-        // the validation used in this compute shader and the validated indirect buffer
-        // will have 0,0,0,0,0.
-        //
-        // Adding 0x7fff_ffff + 0x7fff_ffff does not overflow so as long as we keep
-        // maxBufferSize < 4gig we're safe.
-        DAWN_ASSERT(!device->ShouldApplyIndexBufferOffsetToFirstIndex() ||
-                    device->GetLimits().v1.maxBufferSize < 0x100000000u);
+    // If we need to apply the index buffer offset to the first index then
+    // we can't handle buffers larger than 4gig otherwise we'll overflow first_index
+    // which is a 32bit value.
+    //
+    // When a buffer is less than 4gig the largest index buffer offset you can pass to
+    // SetIndexBuffer is 0xffff_fffe. Otherwise you'll get a validation error. This
+    // is converted to count of indices and so at most 0x7fff_ffff.
+    //
+    // The largest valid first_index would be 0x7fff_ffff. Anything larger will fail
+    // the validation used in this compute shader and the validated indirect buffer
+    // will have 0,0,0,0,0.
+    //
+    // Adding 0x7fff_ffff + 0x7fff_ffff does not overflow so as long as we keep
+    // maxBufferSize < 4gig we're safe.
+    DAWN_ASSERT(!device->ShouldApplyIndexBufferOffsetToFirstIndex() ||
+                device->GetLimits().v1.maxBufferSize < 0x1'0000'0000u);
 
-        // Create compute shader module if not cached before.
-        if (store->renderValidationShader == nullptr) {
-            DAWN_TRY_ASSIGN(store->renderValidationShader,
-                            utils::CreateShaderModule(device, sRenderValidationShaderSource));
-        }
+    // Create compute shader module if not cached before.
+    if (store->indirectDrawValidationShader == nullptr) {
+        DAWN_TRY_ASSIGN(store->indirectDrawValidationShader,
+                        utils::CreateShaderModule(device, sRenderValidationShaderSource));
+    }
 
-        Ref<BindGroupLayoutBase> bindGroupLayout;
+    Ref<BindGroupLayoutBase> bindGroupLayout;
+    DAWN_TRY_ASSIGN(bindGroupLayout, utils::MakeBindGroupLayout(device, entries,
+                                                                /* allowInternalBinding */ true));
+
+    Ref<PipelineLayoutBase> pipelineLayout;
+    DAWN_TRY_ASSIGN(pipelineLayout, utils::MakeBasicPipelineLayout(device, bindGroupLayout));
+
+    ComputePipelineDescriptor computePipelineDescriptor = {};
+    computePipelineDescriptor.layout = pipelineLayout.Get();
+    computePipelineDescriptor.compute.module = store->indirectDrawValidationShader.Get();
+    computePipelineDescriptor.compute.entryPoint = entryPoint;
+
+    dawn::Ref<ComputePipelineBase> pipeline;
+    DAWN_TRY_ASSIGN(pipeline, device->CreateComputePipeline(&computePipelineDescriptor));
+
+    return pipeline;
+}
+
+ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDrawValidationPipeline(DeviceBase* device) {
+    InternalPipelineStore* store = device->GetInternalPipelineStore();
+
+    if (store->indirectDrawValidationPipeline == nullptr) {
         DAWN_TRY_ASSIGN(
-            bindGroupLayout,
-            utils::MakeBindGroupLayout(
-                device,
+            store->indirectDrawValidationPipeline,
+            CreateRenderValidationPipelines(
+                device, "validate_single_draw",
                 {
                     {0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::ReadOnlyStorage},
                     {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding},
                     {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage},
-                },
-                /* allowInternalBinding */ true));
-
-        Ref<PipelineLayoutBase> pipelineLayout;
-        DAWN_TRY_ASSIGN(pipelineLayout, utils::MakeBasicPipelineLayout(device, bindGroupLayout));
-
-        ComputePipelineDescriptor computePipelineDescriptor = {};
-        computePipelineDescriptor.layout = pipelineLayout.Get();
-        computePipelineDescriptor.compute.module = store->renderValidationShader.Get();
-        computePipelineDescriptor.compute.entryPoint = "main";
-
-        DAWN_TRY_ASSIGN(store->renderValidationPipeline,
-                        device->CreateComputePipeline(&computePipelineDescriptor));
+                }));
     }
+    return store->indirectDrawValidationPipeline.Get();
+}
 
-    return store->renderValidationPipeline.Get();
+ResultOrError<ComputePipelineBase*> GetOrCreateMultiDrawValidationPipeline(DeviceBase* device) {
+    InternalPipelineStore* store = device->GetInternalPipelineStore();
+
+    if (store->multiDrawValidationPipeline == nullptr) {
+        DAWN_TRY_ASSIGN(
+            store->multiDrawValidationPipeline,
+            CreateRenderValidationPipelines(
+                device, "validate_multi_draw",
+                {
+                    {0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::ReadOnlyStorage},
+                    {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding},
+                    {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage},
+                    {3, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding},
+                }));
+    }
+    return store->multiDrawValidationPipeline.Get();
 }
 
 size_t GetBatchDataSize(uint32_t numDraws) {
@@ -335,7 +452,12 @@
     std::vector<Pass> passes;
     IndirectDrawMetadata::IndexedIndirectBufferValidationInfoMap& bufferInfoMap =
         *indirectDrawMetadata->GetIndexedIndirectBufferValidationInfo();
-    if (bufferInfoMap.empty()) {
+
+    const std::vector<IndirectDrawMetadata::IndirectMultiDraw>& multiDraws =
+        indirectDrawMetadata->GetIndirectMultiDraws();
+
+    // Nothing to validate.
+    if (bufferInfoMap.empty() && multiDraws.empty()) {
         return {};
     }
 
@@ -422,6 +544,34 @@
         }
     }
 
+    // Multi draw output params are stored after the single draw output params, so we need to
+    // track the offset of the multi draw output params.
+    outputParamsSize = Align(outputParamsSize, minStorageBufferOffsetAlignment);
+    const uint64_t multiDrawOutputParamsOffset = outputParamsSize;
+
+    uint64_t outputParamsSizeForMultiDraw = 0;
+    // Calculate size of output params for multi draws
+    for (auto& draw : multiDraws) {
+        // Don't need to validate non-indexed draws.
+        if (draw.type == IndirectDrawMetadata::DrawType::NonIndexed) {
+            continue;
+        }
+        outputParamsSizeForMultiDraw += draw.cmd->maxDrawCount * kDrawIndexedIndirectSize;
+
+        if (outputParamsSizeForMultiDraw > maxStorageBufferBindingSize) {
+            return DAWN_INTERNAL_ERROR("Too many multiDrawIndexedIndirect calls to validate");
+        }
+    }
+
+    outputParamsSize += outputParamsSizeForMultiDraw;
+
+    // If there are no output params to validate, we can skip the rest of the encoding.
+    // The above .empty() checks are not sufficient because there might exist non-indexed multi
+    // draws, which don't need validation.
+    if (outputParamsSize == 0) {
+        return {};
+    }
+
     auto* const store = device->GetInternalPipelineStore();
     ScratchBuffer& outputParamsBuffer = store->scratchIndirectStorage;
     ScratchBuffer& batchDataBuffer = store->scratchStorage;
@@ -430,9 +580,13 @@
     for (const Pass& pass : passes) {
         requiredBatchDataBufferSize = std::max(requiredBatchDataBufferSize, pass.batchDataSize);
     }
-    DAWN_TRY(batchDataBuffer.EnsureCapacity(requiredBatchDataBufferSize));
+    // Needs to at least be able to store a MultiDrawConstants struct for the multi draw validation.
+    requiredBatchDataBufferSize =
+        std::max(requiredBatchDataBufferSize, static_cast<uint64_t>(sizeof(MultiDrawConstants)));
 
     DAWN_TRY(outputParamsBuffer.EnsureCapacity(outputParamsSize));
+    DAWN_TRY(batchDataBuffer.EnsureCapacity(requiredBatchDataBufferSize));
+
     // We swap the indirect buffer used so we need to explicitly add the usage.
     usageTracker->BufferUsedAs(outputParamsBuffer.GetBuffer(), wgpu::BufferUsage::Indirect);
 
@@ -478,65 +632,185 @@
             }
         }
     }
+    if (!passes.empty()) {
+        ComputePipelineBase* pipeline;
+        DAWN_TRY_ASSIGN(pipeline, GetOrCreateIndirectDrawValidationPipeline(device));
 
-    ComputePipelineBase* pipeline;
-    DAWN_TRY_ASSIGN(pipeline, GetOrCreateRenderValidationPipeline(device));
+        Ref<BindGroupLayoutBase> layout;
+        DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0));
 
-    Ref<BindGroupLayoutBase> layout;
-    DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0));
+        BindGroupEntry bindings[3];
+        BindGroupEntry& bufferDataBinding = bindings[0];
+        bufferDataBinding.binding = 0;
+        bufferDataBinding.buffer = batchDataBuffer.GetBuffer();
 
-    BindGroupEntry bindings[3];
-    BindGroupEntry& bufferDataBinding = bindings[0];
-    bufferDataBinding.binding = 0;
-    bufferDataBinding.buffer = batchDataBuffer.GetBuffer();
+        BindGroupEntry& inputIndirectBinding = bindings[1];
+        inputIndirectBinding.binding = 1;
 
-    BindGroupEntry& inputIndirectBinding = bindings[1];
-    inputIndirectBinding.binding = 1;
+        BindGroupEntry& outputParamsBinding = bindings[2];
+        outputParamsBinding.binding = 2;
+        outputParamsBinding.buffer = outputParamsBuffer.GetBuffer();
 
-    BindGroupEntry& outputParamsBinding = bindings[2];
-    outputParamsBinding.binding = 2;
-    outputParamsBinding.buffer = outputParamsBuffer.GetBuffer();
+        BindGroupDescriptor bindGroupDescriptor = {};
+        bindGroupDescriptor.layout = layout.Get();
+        bindGroupDescriptor.entryCount = 3;
+        bindGroupDescriptor.entries = bindings;
 
-    BindGroupDescriptor bindGroupDescriptor = {};
-    bindGroupDescriptor.layout = layout.Get();
-    bindGroupDescriptor.entryCount = 3;
-    bindGroupDescriptor.entries = bindings;
+        // Finally, we can now encode our validation and duplication passes. Each pass first does
+        // two WriteBuffer to get batch and pass data over to the GPU, followed by a single compute
+        // pass. The compute pass encodes a separate SetBindGroup and Dispatch command for each
+        // batch.
+        for (const Pass& pass : passes) {
+            commandEncoder->APIWriteBuffer(batchDataBuffer.GetBuffer(), 0,
+                                           static_cast<const uint8_t*>(pass.batchData.get()),
+                                           pass.batchDataSize);
 
-    // Finally, we can now encode our validation and duplication passes. Each pass first does a
-    // two WriteBuffer to get batch and pass data over to the GPU, followed by a single compute
-    // pass. The compute pass encodes a separate SetBindGroup and Dispatch command for each
-    // batch.
-    for (const Pass& pass : passes) {
-        commandEncoder->APIWriteBuffer(batchDataBuffer.GetBuffer(), 0,
-                                       static_cast<const uint8_t*>(pass.batchData.get()),
-                                       pass.batchDataSize);
+            Ref<ComputePassEncoder> passEncoder = commandEncoder->BeginComputePass();
+            passEncoder->APISetPipeline(pipeline);
 
-        Ref<ComputePassEncoder> passEncoder = commandEncoder->BeginComputePass();
-        passEncoder->APISetPipeline(pipeline);
+            inputIndirectBinding.buffer = pass.inputIndirectBuffer;
 
-        inputIndirectBinding.buffer = pass.inputIndirectBuffer;
+            for (const Batch& batch : pass.batches) {
+                bufferDataBinding.offset = batch.dataBufferOffset;
+                bufferDataBinding.size = batch.dataSize;
+                inputIndirectBinding.offset = batch.inputIndirectOffset;
+                inputIndirectBinding.size = batch.inputIndirectSize;
+                outputParamsBinding.offset = batch.outputParamsOffset;
+                outputParamsBinding.size = batch.outputParamsSize;
 
-        for (const Batch& batch : pass.batches) {
-            bufferDataBinding.offset = batch.dataBufferOffset;
-            bufferDataBinding.size = batch.dataSize;
-            inputIndirectBinding.offset = batch.inputIndirectOffset;
-            inputIndirectBinding.size = batch.inputIndirectSize;
-            outputParamsBinding.offset = batch.outputParamsOffset;
-            outputParamsBinding.size = batch.outputParamsSize;
+                Ref<BindGroupBase> bindGroup;
+                DAWN_TRY_ASSIGN(bindGroup, device->CreateBindGroup(&bindGroupDescriptor));
+
+                const uint32_t numDrawsRoundedUp =
+                    (batch.batchInfo->numDraws + kWorkgroupSize - 1) / kWorkgroupSize;
+                passEncoder->APISetBindGroup(0, bindGroup.Get());
+                passEncoder->APIDispatchWorkgroups(numDrawsRoundedUp);
+            }
+
+            passEncoder->APIEnd();
+        }
+    }
+    if (!multiDraws.empty()) {
+        ScratchBuffer& drawConstantsBuffer = store->scratchStorage;
+
+        ComputePipelineBase* pipeline;
+        DAWN_TRY_ASSIGN(pipeline, GetOrCreateMultiDrawValidationPipeline(device));
+
+        Ref<BindGroupLayoutBase> layout;
+        DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0));
+
+        BindGroupEntry bindings[4];
+
+        BindGroupEntry& drawConstantsBinding = bindings[0];
+        drawConstantsBinding.binding = 0;
+        drawConstantsBinding.buffer = drawConstantsBuffer.GetBuffer();
+
+        BindGroupEntry& inputIndirectBinding = bindings[1];
+        inputIndirectBinding.binding = 1;
+
+        BindGroupEntry& outputParamsBinding = bindings[2];
+        outputParamsBinding.binding = 2;
+        outputParamsBinding.buffer = outputParamsBuffer.GetBuffer();
+
+        BindGroupEntry& drawCountBinding = bindings[3];
+        drawCountBinding.binding = 3;
+
+        BindGroupDescriptor bindGroupDescriptor = {};
+        bindGroupDescriptor.layout = layout.Get();
+        bindGroupDescriptor.entryCount = 4;
+        bindGroupDescriptor.entries = bindings;
+
+        // Start of the region for multi draw output params.
+        uint64_t outputOffset = multiDrawOutputParamsOffset;
+
+        for (auto& draw : multiDraws) {
+            if (draw.type == IndirectDrawMetadata::DrawType::NonIndexed) {
+                continue;
+            }
+
+            const size_t formatSize = IndexFormatSize(draw.indexFormat);
+            uint64_t numIndexBufferElements = draw.indexBufferSize / formatSize;
+
+            // Same struct for both indexed and non-indexed draws.
+            MultiDrawIndirectCmd* cmd = draw.cmd;
+
+            // Align the output offset to the minStorageBufferOffsetAlignment.
+
+            MultiDrawConstants drawConstants;
+            drawConstants.maxDrawCount = draw.cmd->maxDrawCount;
+            // We need to pass the remaining offset in elements after aligning to the
+            // minStorageBufferOffsetAlignment. See comment below.
+            drawConstants.indirectOffsetInElements = static_cast<uint32_t>(
+                (cmd->indirectOffset % minStorageBufferOffsetAlignment) / sizeof(uint32_t));
+            drawConstants.drawCountOffsetInElements = static_cast<uint32_t>(
+                (cmd->drawCountOffset % minStorageBufferOffsetAlignment) / sizeof(uint32_t));
+            drawConstants.numIndexBufferElementsLow =
+                static_cast<uint32_t>(numIndexBufferElements & 0xFFFFFFFF);
+            drawConstants.numIndexBufferElementsHigh =
+                static_cast<uint32_t>((numIndexBufferElements >> 32) & 0xFFFFFFFF);
+            drawConstants.flags = kIndexedDraw;
+            if (cmd->drawCountBuffer != nullptr) {
+                drawConstants.flags |= kIndirectDrawCountBuffer;
+            }
+
+            inputIndirectBinding.buffer = cmd->indirectBuffer.Get();
+            // We can't use the offset directly because the indirect offset is guaranteed to
+            // be aligned to 4 bytes, but when binding the buffer alignment requirement is
+            // minStorageBufferOffsetAlignment. Instead we align the offset to the
+            // minStorageBufferOffsetAlignment. Then pass the remaining offset in elements.
+            inputIndirectBinding.offset =
+                AlignDown(cmd->indirectOffset, minStorageBufferOffsetAlignment);
+
+            outputParamsBinding.buffer = outputParamsBuffer.GetBuffer();
+            outputParamsBinding.offset = outputOffset;
+
+            if (cmd->drawCountBuffer != nullptr) {
+                // If the drawCountBuffer is set, we need to bind it to the bind group.
+                // The drawCountBuffer is used to read the drawCount for the multi draw call.
+                // If the drawCount exceeds the maxDrawCount, it will be clamped to maxDrawCount.
+                drawCountBinding.buffer = cmd->drawCountBuffer.Get();
+                drawCountBinding.offset =
+                    AlignDown(cmd->drawCountOffset, minStorageBufferOffsetAlignment);
+            } else {
+                // This is an unused binding.
+                // Bind group entry for the drawCountBuffer is not needed however we need to bind
+                // something else than nullptr to the bind group entry to avoid validation errors.
+                // This buffer is never used in the shader, since there is a flag
+                // (kIndirectDrawCountBuffer) to check if the drawCountBuffer is set.
+                drawCountBinding.buffer = cmd->indirectBuffer.Get();
+                drawCountBinding.offset = 0;
+            }
 
             Ref<BindGroupBase> bindGroup;
             DAWN_TRY_ASSIGN(bindGroup, device->CreateBindGroup(&bindGroupDescriptor));
 
-            const uint32_t numDrawsRoundedUp =
-                (batch.batchInfo->numDraws + kWorkgroupSize - 1) / kWorkgroupSize;
-            passEncoder->APISetBindGroup(0, bindGroup.Get());
-            passEncoder->APIDispatchWorkgroups(numDrawsRoundedUp);
-        }
+            commandEncoder->APIWriteBuffer(drawConstantsBuffer.GetBuffer(), 0,
+                                           reinterpret_cast<const uint8_t*>(&drawConstants),
+                                           sizeof(MultiDrawConstants));
 
-        passEncoder->APIEnd();
+            Ref<ComputePassEncoder> passEncoder = commandEncoder->BeginComputePass();
+            passEncoder->APISetPipeline(pipeline);
+            passEncoder->APISetBindGroup(0, bindGroup.Get());
+
+            // TODO(crbug.com/356461286): After maxDrawCount has a limit we can
+            // dispatch exact number of workgroups without worrying about overflow:
+            // uint32_t workgroupCount = (cmd->maxDrawCount + kWorkgroupSize - 1u) / kWorkgroupSize;
+            uint32_t workgroupCount = cmd->maxDrawCount / kWorkgroupSize;
+            // Integer division rounds down so adding 1 if there is a remainder.
+            workgroupCount += cmd->maxDrawCount % kWorkgroupSize == 0 ? 0 : 1;
+            passEncoder->APIDispatchWorkgroups(workgroupCount);
+            passEncoder->APIEnd();
+
+            // Update the draw command to use the validated indirect buffer.
+            // The drawCountBuffer doesn't need to be updated because if it exceeds the maxDrawCount
+            // it will be clamped to maxDrawCount.
+            cmd->indirectBuffer = outputParamsBuffer.GetBuffer();
+            cmd->indirectOffset = outputOffset;
+
+            outputOffset += cmd->maxDrawCount * kDrawIndexedIndirectSize;
+        }
     }
 
     return {};
 }
-
 }  // namespace dawn::native
diff --git a/src/dawn/native/InternalPipelineStore.cpp b/src/dawn/native/InternalPipelineStore.cpp
index 58d1d4d..c5a78a4 100644
--- a/src/dawn/native/InternalPipelineStore.cpp
+++ b/src/dawn/native/InternalPipelineStore.cpp
@@ -41,6 +41,9 @@
     : scratchStorage(device, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Storage),
       scratchIndirectStorage(
           device,
+          wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect | wgpu::BufferUsage::Storage),
+      scratchMultiDrawStorage(
+          device,
           wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect | wgpu::BufferUsage::Storage) {}
 
 InternalPipelineStore::~InternalPipelineStore() = default;
diff --git a/src/dawn/native/InternalPipelineStore.h b/src/dawn/native/InternalPipelineStore.h
index 527cc35..40b78dc 100644
--- a/src/dawn/native/InternalPipelineStore.h
+++ b/src/dawn/native/InternalPipelineStore.h
@@ -73,8 +73,13 @@
     // buffer for indirect dispatch or draw calls.
     ScratchBuffer scratchIndirectStorage;
 
-    Ref<ComputePipelineBase> renderValidationPipeline;
-    Ref<ShaderModuleBase> renderValidationShader;
+    // A render pass can have both DrawIndirect and MultiDrawIndirect calls.
+    // We need a separate buffer to store the validated multiDrawCommands.
+    ScratchBuffer scratchMultiDrawStorage;
+
+    Ref<ShaderModuleBase> indirectDrawValidationShader;
+    Ref<ComputePipelineBase> indirectDrawValidationPipeline;
+    Ref<ComputePipelineBase> multiDrawValidationPipeline;
     Ref<ComputePipelineBase> dispatchIndirectValidationPipeline;
 
     Ref<RenderPipelineBase> blitRG8ToDepth16UnormPipeline;
diff --git a/src/dawn/native/RenderEncoderBase.cpp b/src/dawn/native/RenderEncoderBase.cpp
index abe87e1..1995db6 100644
--- a/src/dawn/native/RenderEncoderBase.cpp
+++ b/src/dawn/native/RenderEncoderBase.cpp
@@ -381,10 +381,15 @@
             cmd->drawCountBuffer = drawCountBuffer;
             cmd->drawCountOffset = drawCountBufferOffset;
 
+            mIndirectDrawMetadata.AddMultiDrawIndirect(cmd);
+
             // TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
             // validation, but it will unecessarily transition to indirectBuffer usage in the
             // backend.
             mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
+            if (drawCountBuffer != nullptr) {
+                mUsageTracker.BufferUsedAs(drawCountBuffer, wgpu::BufferUsage::Indirect);
+            }
 
             mDrawCount += maxDrawCount;
 
@@ -465,10 +470,18 @@
             cmd->drawCountBuffer = drawCountBuffer;
             cmd->drawCountOffset = drawCountBufferOffset;
 
+            mIndirectDrawMetadata.AddMultiDrawIndexedIndirect(
+                mCommandBufferState.GetIndexBuffer(), mCommandBufferState.GetIndexFormat(),
+                mCommandBufferState.GetIndexBufferSize(),
+                mCommandBufferState.GetIndexBufferOffset(), cmd);
+
             // TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
             // validation, but it will unecessarily transition to indirectBuffer usage in the
             // backend.
             mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
+            if (drawCountBuffer != nullptr) {
+                mUsageTracker.BufferUsedAs(drawCountBuffer, wgpu::BufferUsage::Indirect);
+            }
 
             mDrawCount += maxDrawCount;
 
@@ -555,7 +568,7 @@
                 }
             }
 
-            mCommandBufferState.SetIndexBuffer(format, offset, size);
+            mCommandBufferState.SetIndexBuffer(buffer, format, offset, size);
 
             SetIndexBufferCmd* cmd =
                 allocator->Allocate<SetIndexBufferCmd>(Command::SetIndexBuffer);
diff --git a/src/dawn/native/ScratchBuffer.cpp b/src/dawn/native/ScratchBuffer.cpp
index 4fb1f22..3697dfd 100644
--- a/src/dawn/native/ScratchBuffer.cpp
+++ b/src/dawn/native/ScratchBuffer.cpp
@@ -41,6 +41,7 @@
 }
 
 MaybeError ScratchBuffer::EnsureCapacity(uint64_t capacity) {
+    DAWN_ASSERT(capacity > 0);
     if (!mBuffer.Get() || mBuffer->GetSize() < capacity) {
         BufferDescriptor descriptor;
         descriptor.size = capacity;
diff --git a/src/dawn/native/vulkan/CommandBufferVk.cpp b/src/dawn/native/vulkan/CommandBufferVk.cpp
index 9f7dd8b..c3eade0 100644
--- a/src/dawn/native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn/native/vulkan/CommandBufferVk.cpp
@@ -1247,6 +1247,57 @@
                 break;
             }
 
+            case Command::MultiDrawIndirect: {
+                MultiDrawIndirectCmd* cmd = iter->NextCommand<MultiDrawIndirectCmd>();
+
+                Buffer* indirectBuffer = ToBackend(cmd->indirectBuffer.Get());
+                DAWN_ASSERT(indirectBuffer != nullptr);
+
+                // Count buffer is optional
+                Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
+
+                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+
+                if (countBuffer == nullptr) {
+                    device->fn.CmdDrawIndirect(commands, indirectBuffer->GetHandle(),
+                                               static_cast<VkDeviceSize>(cmd->indirectOffset),
+                                               cmd->maxDrawCount, kDrawIndirectSize);
+                } else {
+                    device->fn.CmdDrawIndirectCountKHR(
+                        commands, indirectBuffer->GetHandle(),
+                        static_cast<VkDeviceSize>(cmd->indirectOffset), countBuffer->GetHandle(),
+                        static_cast<VkDeviceSize>(cmd->drawCountOffset), cmd->maxDrawCount,
+                        kDrawIndirectSize);
+                }
+                break;
+            }
+            case Command::MultiDrawIndexedIndirect: {
+                MultiDrawIndexedIndirectCmd* cmd = iter->NextCommand<MultiDrawIndexedIndirectCmd>();
+
+                Buffer* indirectBuffer = ToBackend(cmd->indirectBuffer.Get());
+                DAWN_ASSERT(indirectBuffer != nullptr);
+
+                // Count buffer is optional
+                Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
+
+                descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+
+                if (countBuffer == nullptr) {
+                    device->fn.CmdDrawIndexedIndirect(
+                        commands, indirectBuffer->GetHandle(),
+                        static_cast<VkDeviceSize>(cmd->indirectOffset), cmd->maxDrawCount,
+                        kDrawIndexedIndirectSize);
+                } else {
+                    device->fn.CmdDrawIndexedIndirectCountKHR(
+                        commands, indirectBuffer->GetHandle(),
+                        static_cast<VkDeviceSize>(cmd->indirectOffset), countBuffer->GetHandle(),
+                        static_cast<VkDeviceSize>(cmd->drawCountOffset), cmd->maxDrawCount,
+                        kDrawIndexedIndirectSize);
+                }
+
+                break;
+            }
+
             case Command::InsertDebugMarker: {
                 if (device->GetGlobalInfo().HasExt(InstanceExt::DebugUtils)) {
                     InsertDebugMarkerCmd* cmd = iter->NextCommand<InsertDebugMarkerCmd>();
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 072dde8..c01d64e 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -536,6 +536,12 @@
                           VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SAMPLER_YCBCR_CONVERSION_FEATURES);
     }
 
+    if (HasFeature(Feature::MultiDrawIndirect)) {
+        DAWN_ASSERT(usedKnobs.HasExt(DeviceExt::DrawIndirectCount) &&
+                    mDeviceInfo.features.multiDrawIndirect == VK_TRUE);
+        usedKnobs.features.multiDrawIndirect = VK_TRUE;
+    }
+
     // Find a universal queue family
     {
         // Note that GRAPHICS and COMPUTE imply TRANSFER so we don't need to check for it.
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
index df2f473..4d060ae 100644
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
@@ -261,6 +261,11 @@
         }
     }
 
+    if (mDeviceInfo.HasExt(DeviceExt::DrawIndirectCount) &&
+        mDeviceInfo.features.multiDrawIndirect == VK_TRUE) {
+        EnableFeature(Feature::MultiDrawIndirect);
+    }
+
     // unclippedDepth=true translates to depthClamp=true, which implicitly disables clipping.
     if (mDeviceInfo.features.depthClamp == VK_TRUE) {
         EnableFeature(Feature::DepthClipControl);
diff --git a/src/dawn/native/vulkan/VulkanExtensions.cpp b/src/dawn/native/vulkan/VulkanExtensions.cpp
index 818eb2f..a41f882 100644
--- a/src/dawn/native/vulkan/VulkanExtensions.cpp
+++ b/src/dawn/native/vulkan/VulkanExtensions.cpp
@@ -169,6 +169,7 @@
     {DeviceExt::ShaderFloat16Int8, "VK_KHR_shader_float16_int8", VulkanVersion_1_2},
     {DeviceExt::ShaderSubgroupExtendedTypes, "VK_KHR_shader_subgroup_extended_types",
      VulkanVersion_1_2},
+    {DeviceExt::DrawIndirectCount, "VK_KHR_draw_indirect_count", NeverPromoted},
 
     {DeviceExt::ShaderIntegerDotProduct, "VK_KHR_shader_integer_dot_product", VulkanVersion_1_3},
     {DeviceExt::ZeroInitializeWorkgroupMemory, "VK_KHR_zero_initialize_workgroup_memory",
@@ -235,6 +236,7 @@
             case DeviceExt::Maintenance2:
             case DeviceExt::ImageFormatList:
             case DeviceExt::StorageBufferStorageClass:
+            case DeviceExt::DrawIndirectCount:
                 hasDependencies = true;
                 break;
 
diff --git a/src/dawn/native/vulkan/VulkanExtensions.h b/src/dawn/native/vulkan/VulkanExtensions.h
index 0061809..c536ff8 100644
--- a/src/dawn/native/vulkan/VulkanExtensions.h
+++ b/src/dawn/native/vulkan/VulkanExtensions.h
@@ -107,6 +107,7 @@
     ImageFormatList,
     ShaderFloat16Int8,
     ShaderSubgroupExtendedTypes,
+    DrawIndirectCount,
 
     // Promoted to 1.3
     ShaderIntegerDotProduct,
diff --git a/src/dawn/native/vulkan/VulkanFunctions.cpp b/src/dawn/native/vulkan/VulkanFunctions.cpp
index b5b4ce5..4f21fea 100644
--- a/src/dawn/native/vulkan/VulkanFunctions.cpp
+++ b/src/dawn/native/vulkan/VulkanFunctions.cpp
@@ -396,6 +396,11 @@
         GET_DEVICE_PROC(DestroySamplerYcbcrConversion);
     }
 
+    if (deviceInfo.HasExt(DeviceExt::DrawIndirectCount)) {
+        GET_DEVICE_PROC(CmdDrawIndirectCountKHR);
+        GET_DEVICE_PROC(CmdDrawIndexedIndirectCountKHR);
+    }
+
 #if VK_USE_PLATFORM_FUCHSIA
     if (deviceInfo.HasExt(DeviceExt::ExternalMemoryZirconHandle)) {
         GET_DEVICE_PROC(GetMemoryZirconHandleFUCHSIA);
diff --git a/src/dawn/native/vulkan/VulkanFunctions.h b/src/dawn/native/vulkan/VulkanFunctions.h
index 61d2ba8..2aef174 100644
--- a/src/dawn/native/vulkan/VulkanFunctions.h
+++ b/src/dawn/native/vulkan/VulkanFunctions.h
@@ -350,6 +350,10 @@
     VkFn<PFN_vkAcquireNextImageKHR> AcquireNextImageKHR = nullptr;
     VkFn<PFN_vkQueuePresentKHR> QueuePresentKHR = nullptr;
 
+    // VK_KHR_draw_indirect_count
+    VkFn<PFN_vkCmdDrawIndirectCount> CmdDrawIndirectCountKHR = nullptr;
+    VkFn<PFN_vkCmdDrawIndexedIndirectCount> CmdDrawIndexedIndirectCountKHR = nullptr;
+
 #if VK_USE_PLATFORM_FUCHSIA
     // VK_FUCHSIA_external_memory
     VkFn<PFN_vkGetMemoryZirconHandleFUCHSIA> GetMemoryZirconHandleFUCHSIA = nullptr;
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 2e22feb..957178b 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -630,6 +630,8 @@
     "end2end/MaxLimitTests.cpp",
     "end2end/MemoryAllocationStressTests.cpp",
     "end2end/MemoryHeapPropertiesTests.cpp",
+    "end2end/MultiDrawIndexedIndirectTests.cpp",
+    "end2end/MultiDrawIndirectTests.cpp",
     "end2end/MultisampledRenderingTests.cpp",
     "end2end/MultisampledSamplingTests.cpp",
     "end2end/MultithreadTests.cpp",
diff --git a/src/dawn/tests/end2end/MultiDrawIndexedIndirectTests.cpp b/src/dawn/tests/end2end/MultiDrawIndexedIndirectTests.cpp
new file mode 100644
index 0000000..50ebc68
--- /dev/null
+++ b/src/dawn/tests/end2end/MultiDrawIndexedIndirectTests.cpp
@@ -0,0 +1,436 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <vector>
+
+#include "dawn/tests/DawnTest.h"
+#include "dawn/utils/ComboRenderBundleEncoderDescriptor.h"
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
+#include "dawn/utils/WGPUHelpers.h"
+
+namespace dawn {
+namespace {
+
+constexpr uint32_t kRTSize = 4;
+
+class MultiDrawIndexedIndirectTest : public DawnTest {
+  protected:
+    wgpu::RequiredLimits GetRequiredLimits(const wgpu::SupportedLimits& supported) override {
+        // Force larger limits, that might reach into the upper 32 bits of the 64bit limit values,
+        // to help detect integer arithmetic bugs like overflows and truncations.
+        wgpu::RequiredLimits required = {};
+        required.limits = supported.limits;
+        return required;
+    }
+
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        if (!SupportsFeatures({wgpu::FeatureName::MultiDrawIndirect})) {
+            return {};
+        }
+        return {wgpu::FeatureName::MultiDrawIndirect};
+    }
+
+    void SetUp() override {
+        DawnTest::SetUp();
+        DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::MultiDrawIndirect));
+
+        renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+            @vertex
+            fn main(@location(0) pos : vec4f) -> @builtin(position) vec4f {
+                return pos;
+            })");
+
+        wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
+            @fragment fn main() -> @location(0) vec4f {
+                return vec4f(0.0, 1.0, 0.0, 1.0);
+            })");
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = vsModule;
+        descriptor.cFragment.module = fsModule;
+        descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleStrip;
+        descriptor.primitive.stripIndexFormat = wgpu::IndexFormat::Uint32;
+        descriptor.vertex.bufferCount = 1;
+        descriptor.cBuffers[0].arrayStride = 4 * sizeof(float);
+        descriptor.cBuffers[0].attributeCount = 1;
+        descriptor.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
+        descriptor.cTargets[0].format = renderPass.colorFormat;
+
+        pipeline = device.CreateRenderPipeline(&descriptor);
+
+        vertexBuffer = utils::CreateBufferFromData<float>(
+            device, wgpu::BufferUsage::Vertex,
+            {// First quad: the first 3 vertices represent the bottom left triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, -1.0f, 0.0f, 1.0f, 1.0f, 1.0f,
+             0.0f, 1.0f,
+
+             // Second quad: the first 3 vertices represent the top right triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, -1.0f, -1.0f,
+             0.0f, 1.0f});
+    }
+
+    utils::BasicRenderPass renderPass;
+    wgpu::RenderPipeline pipeline;
+    wgpu::Buffer vertexBuffer;
+
+    wgpu::Buffer CreateIndirectBuffer(std::initializer_list<uint32_t> indirectParamList) {
+        return utils::CreateBufferFromData<uint32_t>(
+            device, wgpu::BufferUsage::Indirect | wgpu::BufferUsage::Storage, indirectParamList);
+    }
+
+    wgpu::Buffer CreateIndexBuffer(std::initializer_list<uint32_t> indexList) {
+        return utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Index, indexList);
+    }
+
+    wgpu::CommandBuffer EncodeDrawCommands(std::initializer_list<uint32_t> bufferList,
+                                           wgpu::Buffer indexBuffer,
+                                           uint64_t indexOffset,
+                                           uint64_t indirectOffset,
+                                           uint32_t maxDrawCount,
+                                           wgpu::Buffer drawCountBuffer = nullptr,
+                                           uint64_t drawCountOffset = 0) {
+        wgpu::Buffer indirectBuffer = CreateIndirectBuffer(bufferList);
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        {
+            wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+            pass.SetPipeline(pipeline);
+            pass.SetVertexBuffer(0, vertexBuffer);
+            pass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32, indexOffset);
+            pass.MultiDrawIndexedIndirect(indirectBuffer, indirectOffset, maxDrawCount,
+                                          drawCountBuffer, drawCountOffset);
+            pass.End();
+        }
+
+        return encoder.Finish();
+    }
+
+    void TestDraw(wgpu::CommandBuffer commands,
+                  utils::RGBA8 bottomLeftExpected,
+                  utils::RGBA8 topRightExpected) {
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(bottomLeftExpected, renderPass.color, 1, 3);
+        EXPECT_PIXEL_RGBA8_EQ(topRightExpected, renderPass.color, 3, 1);
+    }
+
+    void Test(std::initializer_list<uint32_t> bufferList,
+              uint64_t indexOffset,
+              uint64_t indirectOffset,
+              uint32_t maxDrawCount,
+              utils::RGBA8 bottomLeftExpected,
+              utils::RGBA8 topRightExpected) {
+        wgpu::Buffer indexBuffer =
+            CreateIndexBuffer({0, 1, 2, 0, 3, 1,
+                               // The indices below are added to test negative baseVertex
+                               0 + 4, 1 + 4, 2 + 4, 0 + 4, 3 + 4, 1 + 4});
+        TestDraw(
+            EncodeDrawCommands(bufferList, indexBuffer, indexOffset, indirectOffset, maxDrawCount),
+            bottomLeftExpected, topRightExpected);
+    }
+};
+
+// The most basic DrawIndexed triangle draw.
+TEST_P(MultiDrawIndexedIndirectTest, Uint32) {
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    // Test a draw with no indices.
+    Test({0, 0, 0, 0, 0}, 0, 0, 1, notFilled, notFilled);
+
+    // Test a draw with only the first 3 indices of the first quad (bottom left triangle)
+    Test({3, 1, 0, 0, 0}, 0, 0, 1, filled, notFilled);
+
+    // Test a draw with only the last 3 indices of the first quad (top right triangle)
+    Test({3, 1, 3, 0, 0}, 0, 0, 1, notFilled, filled);
+
+    // Test a draw with all 6 indices (both triangles).
+    Test({6, 1, 0, 0, 0}, 0, 0, 1, filled, filled);
+}
+
+// Test the parameter 'baseVertex' of DrawIndexed() works.
+TEST_P(MultiDrawIndexedIndirectTest, BaseVertex) {
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    // Test a draw with only the first 3 indices of the second quad (top right triangle)
+    Test({3, 1, 0, 4, 0}, 0, 0, 1, notFilled, filled);
+
+    // Test a draw with only the last 3 indices of the second quad (bottom left triangle)
+    Test({3, 1, 3, 4, 0}, 0, 0, 1, filled, notFilled);
+
+    const int negFour = -4;
+    uint32_t unsignedNegFour;
+    std::memcpy(&unsignedNegFour, &negFour, sizeof(int));
+
+    // Test negative baseVertex
+    // Test a draw with only the first 3 indices of the first quad (bottom left triangle)
+    Test({3, 1, 0, unsignedNegFour, 0}, 6 * sizeof(uint32_t), 0, 1, filled, notFilled);
+
+    // Test a draw with only the last 3 indices of the first quad (top right triangle)
+    Test({3, 1, 3, unsignedNegFour, 0}, 6 * sizeof(uint32_t), 0, 1, notFilled, filled);
+
+    // Test a draw with only the last 3 indices of the first quad (top right triangle) and offset
+    Test({0, 3, 1, 3, unsignedNegFour, 0}, 6 * sizeof(uint32_t), 1 * sizeof(uint32_t), 1, notFilled,
+         filled);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, IndirectOffset) {
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    // Test an offset draw call, with indirect buffer containing 2 calls:
+    // 1) first 3 indices of the second quad (top right triangle)
+    // 2) last 3 indices of the second quad
+
+    // Test #1 (no offset)
+    Test({3, 1, 0, 4, 0, 3, 1, 3, 4, 0}, 0, 0, 1, notFilled, filled);
+
+    // Offset to draw #2
+    Test({3, 1, 0, 4, 0, 3, 1, 3, 4, 0}, 0, 5 * sizeof(uint32_t), 1, filled, notFilled);
+}
+
+// The basic triangle draw with various drawCount.
+TEST_P(MultiDrawIndexedIndirectTest, DrawCount) {
+    // TODO(crbug.com/356461286): NVIDIA Drivers for Vulkan Linux are drawing more than specified.
+    DAWN_SUPPRESS_TEST_IF(IsLinux() && IsNvidia() && IsVulkan());
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1});
+    // Create a drawCount buffer with various values.
+    wgpu::Buffer drawCountBuffer = CreateIndirectBuffer({0, 1, 2});
+    // Test a draw with drawCount = 0, which should not draw anything.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2,
+                                drawCountBuffer, 0),
+             notFilled, notFilled);
+    // Test a draw with drawCount = 1 where drawCount < maxDrawCount.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2,
+                                drawCountBuffer, 4),
+             filled, notFilled);
+    // Test a draw with drawCount = 2 where drawCount = maxDrawCount.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2,
+                                drawCountBuffer, 8),
+             filled, filled);
+    // Test a draw with drawCount = 2 where drawCount > maxDrawCount.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 1,
+                                drawCountBuffer, 8),
+             filled, notFilled);
+    // Test a draw without drawCount buffer. Should draw maxDrawCount times.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2), filled,
+             filled);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateWithOffsets) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1, 0, 1, 2});
+
+    // Test that validation properly accounts for index buffer offset.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0}, indexBuffer, 6 * sizeof(uint32_t), 0, 1), filled,
+             notFilled);
+    TestDraw(EncodeDrawCommands({4, 1, 0, 0, 0}, indexBuffer, 6 * sizeof(uint32_t), 0, 1),
+             notFilled, notFilled);
+    TestDraw(EncodeDrawCommands({3, 1, 4, 0, 0}, indexBuffer, 3 * sizeof(uint32_t), 0, 1),
+             notFilled, notFilled);
+
+    // Test that validation properly accounts for indirect buffer offset.
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 1000, 1, 0, 0, 0}, indexBuffer, 0,
+                                4 * sizeof(uint32_t), 1),
+             notFilled, notFilled);
+    TestDraw(EncodeDrawCommands({3, 1, 0, 0, 0, 1000, 1, 0, 0, 0}, indexBuffer, 0, 0, 1), filled,
+             notFilled);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateMultiplePasses) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1, 0, 1, 2});
+
+    // Test validation with multiple passes in a row. Namely this is exercising that scratch buffer
+    // data for use with a previous pass's validation commands is not overwritten before it can be
+    // used.
+    TestDraw(EncodeDrawCommands({10, 1, 0, 0, 0, 3, 1, 9, 0, 0}, indexBuffer, 0, 0, 2), notFilled,
+             notFilled);
+    TestDraw(EncodeDrawCommands({3, 1, 6, 0, 0, 3, 1, 0, 0, 0}, indexBuffer, 0, 0, 2), filled,
+             notFilled);
+    TestDraw(EncodeDrawCommands({4, 1, 6, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2), notFilled,
+             filled);
+    TestDraw(EncodeDrawCommands({6, 1, 6, 0, 0, 6, 1, 6, 0, 0}, indexBuffer, 0, 0, 2), notFilled,
+             notFilled);
+    TestDraw(EncodeDrawCommands({3, 1, 3, 0, 0}, indexBuffer, 0, 0, 1), notFilled, filled);
+    TestDraw(EncodeDrawCommands({6, 1, 0, 0, 0}, indexBuffer, 0, 0, 1), filled, filled);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateEncodeMultipleThenSubmitInOrder) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1, 0, 1, 2});
+
+    wgpu::CommandBuffer commands[6];
+    commands[0] = EncodeDrawCommands({10, 1, 0, 0, 0, 3, 1, 9, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[1] = EncodeDrawCommands({3, 1, 6, 0, 0, 3, 1, 0, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[2] = EncodeDrawCommands({4, 1, 6, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[3] = EncodeDrawCommands({6, 1, 6, 0, 0, 6, 1, 6, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[4] = EncodeDrawCommands({3, 1, 3, 0, 0}, indexBuffer, 0, 0, 1);
+    commands[5] = EncodeDrawCommands({6, 1, 0, 0, 0}, indexBuffer, 0, 0, 1);
+
+    TestDraw(commands[0], notFilled, notFilled);
+    TestDraw(commands[1], filled, notFilled);
+    TestDraw(commands[2], notFilled, filled);
+    TestDraw(commands[3], notFilled, notFilled);
+    TestDraw(commands[4], notFilled, filled);
+    TestDraw(commands[5], filled, filled);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateEncodeMultipleThenSubmitOutOfOrder) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1, 0, 1, 2});
+
+    wgpu::CommandBuffer commands[6];
+    commands[0] = EncodeDrawCommands({10, 1, 0, 0, 0, 3, 1, 9, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[1] = EncodeDrawCommands({3, 1, 6, 0, 0, 3, 1, 0, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[2] = EncodeDrawCommands({4, 1, 6, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[3] = EncodeDrawCommands({6, 1, 6, 0, 0, 6, 1, 6, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[4] = EncodeDrawCommands({3, 1, 3, 0, 0}, indexBuffer, 0, 0, 1);
+    commands[5] = EncodeDrawCommands({6, 1, 0, 0, 0}, indexBuffer, 0, 0, 1);
+
+    TestDraw(commands[0], notFilled, notFilled);
+    TestDraw(commands[4], notFilled, filled);
+    TestDraw(commands[2], notFilled, filled);
+    TestDraw(commands[5], filled, filled);
+    TestDraw(commands[1], filled, notFilled);
+    TestDraw(commands[3], notFilled, notFilled);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateEncodeMultipleThenSubmitAtOnce) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1, 0, 1, 2});
+
+    wgpu::CommandBuffer commands[5];
+    commands[0] = EncodeDrawCommands({10, 1, 0, 0, 0, 3, 1, 9, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[1] = EncodeDrawCommands({3, 1, 6, 0, 0, 3, 1, 0, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[2] = EncodeDrawCommands({4, 1, 6, 0, 0, 3, 1, 3, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[3] = EncodeDrawCommands({6, 1, 6, 0, 0, 6, 1, 6, 0, 0}, indexBuffer, 0, 0, 2);
+    commands[4] = EncodeDrawCommands({3, 1, 3, 0, 0}, indexBuffer, 0, 0, 1);
+
+    queue.Submit(5, commands);
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, 1, 3);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, 3, 1);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateMultiDrawMixed) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+    utils::RGBA8 notFilled(0, 0, 0, 0);
+
+    // Test that the multi draw does not affect the single draw and vice versa.
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1});
+
+    wgpu::Buffer indirectBuffer =
+        CreateIndirectBuffer({3, 1, 6, 0, 0, 0, 0, 0, 0, 0, 3, 1, 3, 0, 0});
+
+    wgpu::Buffer drawCountBuffer = CreateIndirectBuffer({0, 1, 2});
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+    pass.SetPipeline(pipeline);
+    pass.SetVertexBuffer(0, vertexBuffer);
+    pass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32, 0);
+    pass.MultiDrawIndirect(indirectBuffer, 0, 2, nullptr);
+    pass.MultiDrawIndexedIndirect(indirectBuffer, 20, 2, drawCountBuffer, 8);
+    pass.End();
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(notFilled, renderPass.color, 1, 3);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, 3, 1);
+}
+
+TEST_P(MultiDrawIndexedIndirectTest, ValidateMultiAndSingleDrawsInSingleRenderPass) {
+    // It doesn't make sense to test invalid inputs when validation is disabled.
+    DAWN_SUPPRESS_TEST_IF(HasToggleEnabled("skip_validation"));
+
+    utils::RGBA8 filled(0, 255, 0, 255);
+
+    // Test that the multi draw does not affect the single draw and vice versa.
+
+    wgpu::Buffer indexBuffer = CreateIndexBuffer({0, 1, 2, 0, 3, 1});
+
+    wgpu::Buffer indirectBuffer =
+        CreateIndirectBuffer({3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 3, 0, 0});
+
+    wgpu::Buffer drawCountBuffer = CreateIndirectBuffer({0, 1, 2});
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+    pass.SetPipeline(pipeline);
+    pass.SetVertexBuffer(0, vertexBuffer);
+    pass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32, 0);
+    pass.DrawIndexedIndirect(indirectBuffer, 0);  // draw the first triangle
+    pass.DrawIndirect(indirectBuffer, 16);        // no draw
+    pass.MultiDrawIndexedIndirect(indirectBuffer, 20, 2, drawCountBuffer, 8);  // draw the second
+    pass.MultiDrawIndirect(indirectBuffer, 28, 2, nullptr);                    // no draw
+    pass.End();
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, 1, 3);
+    EXPECT_PIXEL_RGBA8_EQ(filled, renderPass.color, 3, 1);
+}
+
+DAWN_INSTANTIATE_TEST(MultiDrawIndexedIndirectTest, VulkanBackend());
+
+}  // anonymous namespace
+}  // namespace dawn
diff --git a/src/dawn/tests/end2end/MultiDrawIndirectTests.cpp b/src/dawn/tests/end2end/MultiDrawIndirectTests.cpp
new file mode 100644
index 0000000..b127fc7
--- /dev/null
+++ b/src/dawn/tests/end2end/MultiDrawIndirectTests.cpp
@@ -0,0 +1,336 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <iostream>
+#include <vector>
+
+#include "dawn/tests/DawnTest.h"
+
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
+#include "dawn/utils/WGPUHelpers.h"
+
+namespace dawn {
+namespace {
+
+constexpr uint32_t kRTSize = 4;
+constexpr utils::RGBA8 filled(0, 255, 0, 255);
+constexpr utils::RGBA8 notFilled(0, 0, 0, 0);
+
+class MultiDrawIndirectTest : public DawnTest {
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        if (!SupportsFeatures({wgpu::FeatureName::MultiDrawIndirect})) {
+            return {};
+        }
+        return {wgpu::FeatureName::MultiDrawIndirect};
+    }
+
+    void SetUp() override {
+        DawnTest::SetUp();
+        DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::MultiDrawIndirect));
+
+        renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+            @vertex
+            fn main(@location(0) pos : vec4f) -> @builtin(position) vec4f {
+                return pos;
+            })");
+
+        wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
+            @fragment fn main() -> @location(0) vec4f {
+                return vec4f(0.0, 1.0, 0.0, 1.0);
+            })");
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = vsModule;
+        descriptor.cFragment.module = fsModule;
+        descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleStrip;
+        descriptor.primitive.stripIndexFormat = wgpu::IndexFormat::Uint32;
+        descriptor.vertex.bufferCount = 1;
+        descriptor.cBuffers[0].arrayStride = 4 * sizeof(float);
+        descriptor.cBuffers[0].attributeCount = 1;
+        descriptor.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
+        descriptor.cTargets[0].format = renderPass.colorFormat;
+
+        pipeline = device.CreateRenderPipeline(&descriptor);
+
+        vertexBuffer = utils::CreateBufferFromData<float>(
+            device, wgpu::BufferUsage::Vertex,
+            {// The bottom left triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, -1.0f, 0.0f, 1.0f,
+
+             // The top right triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f});
+    }
+
+    utils::BasicRenderPass renderPass;
+    wgpu::RenderPipeline pipeline;
+    wgpu::Buffer vertexBuffer;
+
+    void Test(std::initializer_list<uint32_t> bufferList,
+              uint64_t indirectOffset,
+              uint32_t maxDrawCount,
+              utils::RGBA8 bottomLeftExpected,
+              utils::RGBA8 topRightExpected,
+              wgpu::Buffer drawCountBuffer = nullptr,
+              uint64_t drawCountOffset = 0) {
+        wgpu::Buffer indirectBuffer =
+            utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, bufferList);
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        {
+            wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+            pass.SetPipeline(pipeline);
+            pass.SetVertexBuffer(0, vertexBuffer);
+            pass.MultiDrawIndirect(indirectBuffer, indirectOffset, maxDrawCount, drawCountBuffer,
+                                   drawCountOffset);
+            pass.End();
+        }
+
+        wgpu::CommandBuffer commands = encoder.Finish();
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(bottomLeftExpected, renderPass.color, 1, 3);
+        EXPECT_PIXEL_RGBA8_EQ(topRightExpected, renderPass.color, 3, 1);
+    }
+};
+
+// The basic triangle draw.
+TEST_P(MultiDrawIndirectTest, Uint32) {
+    // Test a draw with no indices.
+    Test({0, 0, 0, 0}, 0, 1, notFilled, notFilled);
+
+    // Test a draw with only the first 3 indices (bottom left triangle)
+    Test({3, 1, 0, 0}, 0, 1, filled, notFilled);
+
+    // Test a draw with only the last 3 indices (top right triangle)
+    Test({3, 1, 3, 0}, 0, 1, notFilled, filled);
+
+    // Test a draw with all 6 indices (both triangles)
+    Test({6, 1, 0, 0}, 0, 1, filled, filled);
+
+    // Test a draw with 2 draw commands (both triangles)
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 2, filled, filled);
+}
+
+// The basic triangle draw with various drawCount.
+TEST_P(MultiDrawIndirectTest, DrawCount) {
+    // TODO(crbug.com/356461286): NVIDIA Drivers for Vulkan Linux are drawing more than
+    // maxDrawCount.
+    DAWN_SUPPRESS_TEST_IF(IsLinux() && IsNvidia() && IsVulkan());
+
+    wgpu::Buffer drawBuffer =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {0, 1, 2});
+    // Test a draw with drawCount = 0.
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 2, notFilled, notFilled, drawBuffer, 0);
+    // Test a draw with drawCount < maxDrawCount.
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 2, filled, notFilled, drawBuffer, 4);
+    // Test a draw with drawCount > maxDrawCount.
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 1, filled, notFilled, drawBuffer, 8);
+    // Test a draw with drawCount = maxDrawCount.
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 2, filled, filled, drawBuffer, 8);
+    // Test a draw without drawCount buffer. It should be treated as drawCount = maxDrawCount.
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 2, filled, filled);
+}
+
+// Test with both indirect draw and multi draw.
+TEST_P(MultiDrawIndirectTest, IndirectOffset) {
+    // Test #1 (no offset)
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 0, 1, filled, notFilled);
+
+    // Offset to draw #2
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, kDrawIndirectSize, 1, notFilled, filled);
+}
+
+DAWN_INSTANTIATE_TEST(MultiDrawIndirectTest, VulkanBackend());
+
+class MultiDrawIndirectUsingFirstVertexTest : public DawnTest {
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        if (!SupportsFeatures({wgpu::FeatureName::MultiDrawIndirect})) {
+            return {};
+        }
+        return {wgpu::FeatureName::MultiDrawIndirect};
+    }
+    virtual void SetupShaderModule() {
+        vsModule = utils::CreateShaderModule(device, R"(
+            struct VertexInput {
+                @builtin(vertex_index) id : u32,
+                @location(0) pos: vec4f,
+            };
+            @group(0) @binding(0) var<uniform> offset: array<vec4f, 2>;
+            @vertex
+            fn main(input: VertexInput) -> @builtin(position) vec4f {
+                return input.pos + offset[input.id / 3u];
+            })");
+        fsModule = utils::CreateShaderModule(device, R"(
+            @fragment fn main() -> @location(0) vec4f {
+                return vec4f(0.0, 1.0, 0.0, 1.0);
+            })");
+    }
+    void GeneralSetup() {
+        renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+        SetupShaderModule();
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = vsModule;
+        descriptor.cFragment.module = fsModule;
+        descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleStrip;
+        descriptor.primitive.stripIndexFormat = wgpu::IndexFormat::Uint32;
+        descriptor.vertex.bufferCount = 1;
+        descriptor.cBuffers[0].arrayStride = 4 * sizeof(float);
+        descriptor.cBuffers[0].attributeCount = 1;
+        descriptor.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
+        descriptor.cTargets[0].format = renderPass.colorFormat;
+
+        pipeline = device.CreateRenderPipeline(&descriptor);
+
+        // Offset to the vertices, that needs correcting by the calibration offset from uniform
+        // buffer referenced by instance index to get filled triangle on screen.
+        constexpr float calibration = 99.0f;
+        vertexBuffer = utils::CreateBufferFromData<float>(
+            device, wgpu::BufferUsage::Vertex,
+            {// The bottom left triangle
+             -1.0f - calibration, 1.0f, 0.0f, 1.0f, 1.0f - calibration, -1.0f, 0.0f, 1.0f,
+             -1.0f - calibration, -1.0f, 0.0f, 1.0f,
+             // The top right triangle
+             -1.0f - calibration, 1.0f, 0.0f, 1.0f, 1.0f - calibration, -1.0f, 0.0f, 1.0f,
+             1.0f - calibration, 1.0f, 0.0f, 1.0f});
+        // Providing calibration vec4f offset values
+        wgpu::Buffer uniformBuffer =
+            utils::CreateBufferFromData<float>(device, wgpu::BufferUsage::Uniform,
+                                               {
+                                                   // Bad calibration at [0]
+                                                   0.0,
+                                                   0.0,
+                                                   0.0,
+                                                   0.0,
+                                                   // Good calibration at [1]
+                                                   calibration,
+                                                   0.0,
+                                                   0.0,
+                                                   0.0,
+                                               });
+        bindGroup =
+            utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, uniformBuffer}});
+    }
+    void SetUp() override {
+        DawnTest::SetUp();
+        DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::MultiDrawIndirect));
+        GeneralSetup();
+    }
+    utils::BasicRenderPass renderPass;
+    wgpu::RenderPipeline pipeline;
+    wgpu::Buffer vertexBuffer;
+    wgpu::BindGroup bindGroup;
+    wgpu::ShaderModule vsModule;
+    wgpu::ShaderModule fsModule;
+    // Test two DrawIndirect calls with different indirect offsets within one pass.
+    void Test(std::initializer_list<uint32_t> bufferList,
+              uint32_t maxDrawCount,
+              utils::RGBA8 bottomLeftExpected,
+              utils::RGBA8 topRightExpected) {
+        wgpu::Buffer indirectBuffer =
+            utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, bufferList);
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        {
+            wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+            pass.SetPipeline(pipeline);
+            pass.SetVertexBuffer(0, vertexBuffer);
+            pass.SetBindGroup(0, bindGroup);
+            pass.MultiDrawIndirect(indirectBuffer, 0, maxDrawCount, nullptr, 0);
+            pass.End();
+        }
+        wgpu::CommandBuffer commands = encoder.Finish();
+        queue.Submit(1, &commands);
+        EXPECT_PIXEL_RGBA8_EQ(bottomLeftExpected, renderPass.color, 1, 3);
+        EXPECT_PIXEL_RGBA8_EQ(topRightExpected, renderPass.color, 3, 1);
+    }
+};
+
+TEST_P(MultiDrawIndirectUsingFirstVertexTest, IndirectOffset) {
+    // Test an offset draw call, with indirect buffer containing 2 calls:
+    // 1) only the first 3 indices (bottom left triangle)
+    // 2) only the last 3 indices (top right triangle)
+    // #2 draw has the correct offset applied by vertex index
+    Test({3, 1, 0, 0, 3, 1, 3, 0}, 2, notFilled, filled);
+}
+
+DAWN_INSTANTIATE_TEST(MultiDrawIndirectUsingFirstVertexTest, VulkanBackend());
+
+class MultiDrawIndirectUsingInstanceIndexTest : public MultiDrawIndirectUsingFirstVertexTest {
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        if (!SupportsFeatures({wgpu::FeatureName::MultiDrawIndirect})) {
+            return {};
+        }
+        return {wgpu::FeatureName::MultiDrawIndirect};
+    }
+
+    void SetupShaderModule() override {
+        vsModule = utils::CreateShaderModule(device, R"(
+            struct VertexInput {
+                @builtin(instance_index) id : u32,
+                @location(0) pos: vec4f,
+            };
+
+            @group(0) @binding(0) var<uniform> offset: array<vec4f, 2>;
+
+            @vertex
+            fn main(input: VertexInput) -> @builtin(position) vec4f {
+                return input.pos + offset[input.id];
+            })");
+
+        fsModule = utils::CreateShaderModule(device, R"(
+            @fragment fn main() -> @location(0) vec4f {
+                return vec4f(0.0, 1.0, 0.0, 1.0);
+            })");
+    }
+
+    void SetUp() override {
+        DawnTest::SetUp();
+        DAWN_TEST_UNSUPPORTED_IF(!device.HasFeature(wgpu::FeatureName::MultiDrawIndirect));
+        GeneralSetup();
+    }
+};
+
+TEST_P(MultiDrawIndirectUsingInstanceIndexTest, IndirectOffset) {
+    // Test an offset draw call, with indirect buffer containing 2 calls:
+    // 1) only the first 3 indices (bottom left triangle)
+    // 2) only the last 3 indices (top right triangle)
+
+    // Test 1: #1 draw has the correct calibration referenced by instance index
+    Test({3, 1, 0, 1, 3, 1, 3, 0}, 2, filled, notFilled);
+
+    // Test 2: #2 draw has the correct offset applied by instance index
+    Test({3, 1, 0, 0, 3, 1, 3, 1}, 2, notFilled, filled);
+}
+
+DAWN_INSTANTIATE_TEST(MultiDrawIndirectUsingInstanceIndexTest, VulkanBackend());
+
+}  // namespace
+}  // namespace dawn
diff --git a/src/dawn/tests/unittests/MathTests.cpp b/src/dawn/tests/unittests/MathTests.cpp
index 5ae7fa7..dd1d87c 100644
--- a/src/dawn/tests/unittests/MathTests.cpp
+++ b/src/dawn/tests/unittests/MathTests.cpp
@@ -198,6 +198,36 @@
     ASSERT_EQ(Align(static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF), 1), 0xFFFFFFFFFFFFFFFFull);
 }
 
+// Tests for AlignDown
+TEST(Math, AlignDown) {
+    // 0 aligns to 0
+    ASSERT_EQ(AlignDown(0u, 4), 0u);
+    ASSERT_EQ(AlignDown(0u, 256), 0u);
+    ASSERT_EQ(AlignDown(0u, 512), 0u);
+
+    // Multiples align to self
+    ASSERT_EQ(AlignDown(8u, 8), 8u);
+    ASSERT_EQ(AlignDown(16u, 8), 16u);
+    ASSERT_EQ(AlignDown(24u, 8), 24u);
+    ASSERT_EQ(AlignDown(256u, 256), 256u);
+    ASSERT_EQ(AlignDown(512u, 256), 512u);
+    ASSERT_EQ(AlignDown(768u, 256), 768u);
+
+    // Alignment with 1 is self
+    for (uint32_t i = 0; i < 128; ++i) {
+        ASSERT_EQ(AlignDown(i, 1), i);
+    }
+
+    // Everything in the range (align, 2*align - 1) aligns down to align
+    for (uint32_t i = 1; i < 64; ++i) {
+        ASSERT_EQ(AlignDown(64 + i, 64), 64u);
+    }
+
+    // Test extrema
+    ASSERT_EQ(AlignDown(static_cast<uint64_t>(0xFFFFFFFF), 4), 0xFFFFFFFC);
+    ASSERT_EQ(AlignDown(static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF), 1), 0xFFFFFFFFFFFFFFFFull);
+}
+
 TEST(Math, AlignSizeof) {
     // Basic types should align to self if alignment is a divisor.
     ASSERT_EQ((AlignSizeof<uint8_t, 1>()), 1u);