Implement RenderPassDescriptor.maxDrawCount

Bug: dawn:1465
Change-Id: I6b0aab25ec7a48a6521e4e7f52e42d6890ae2013
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94821
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Takahiro <hogehoge@gachapin.jp>
diff --git a/dawn.json b/dawn.json
index 9c3f328..ec3cb28 100644
--- a/dawn.json
+++ b/dawn.json
@@ -1886,6 +1886,13 @@
             {"name": "timestamp writes", "type": "render pass timestamp write", "annotation": "const*", "length": "timestamp write count"}
         ]
     },
+    "render pass descriptor max draw count": {
+        "category": "structure",
+        "chained": "in",
+        "members": [
+            {"name": "max draw count", "type": "uint64_t", "default": 50000000}
+        ]
+    },
     "render pass encoder": {
         "category": "object",
         "methods": [
@@ -2482,6 +2489,7 @@
             {"value": 12, "name": "external texture binding entry", "tags": ["dawn"]},
             {"value": 13, "name": "external texture binding layout", "tags": ["dawn"]},
             {"value": 14, "name": "surface descriptor from windows swap chain panel", "tags": ["dawn"]},
+            {"value": 15, "name": "render pass descriptor max draw count"},
             {"value": 1000, "name": "dawn texture internal usage descriptor", "tags": ["dawn"]},
             {"value": 1001, "name": "primitive depth clamping state", "tags": ["dawn", "emscripten"]},
             {"value": 1002, "name": "dawn toggles device descriptor", "tags": ["dawn", "native"]},
diff --git a/src/dawn/native/CommandEncoder.cpp b/src/dawn/native/CommandEncoder.cpp
index b8d549a..6f32d92 100644
--- a/src/dawn/native/CommandEncoder.cpp
+++ b/src/dawn/native/CommandEncoder.cpp
@@ -444,6 +444,9 @@
                                         uint32_t* height,
                                         uint32_t* sampleCount,
                                         UsageValidationMode usageValidationMode) {
+    DAWN_TRY(ValidateSingleSType(descriptor->nextInChain,
+                                 wgpu::SType::RenderPassDescriptorMaxDrawCount));
+
     uint32_t maxColorAttachments = device->GetLimits().v1.maxColorAttachments;
     DAWN_INVALID_IF(
         descriptor->colorAttachmentCount > maxColorAttachments,
diff --git a/src/dawn/native/RenderBundle.cpp b/src/dawn/native/RenderBundle.cpp
index 2781983d..d2e3d69 100644
--- a/src/dawn/native/RenderBundle.cpp
+++ b/src/dawn/native/RenderBundle.cpp
@@ -37,6 +37,7 @@
       mAttachmentState(std::move(attachmentState)),
       mDepthReadOnly(depthReadOnly),
       mStencilReadOnly(stencilReadOnly),
+      mDrawCount(encoder->GetDrawCount()),
       mResourceUsage(std::move(resourceUsage)) {
     TrackInDevice();
 }
@@ -80,6 +81,11 @@
     return mStencilReadOnly;
 }
 
+uint64_t RenderBundleBase::GetDrawCount() const {
+    ASSERT(!IsError());
+    return mDrawCount;
+}
+
 const RenderPassResourceUsage& RenderBundleBase::GetResourceUsage() const {
     ASSERT(!IsError());
     return mResourceUsage;
diff --git a/src/dawn/native/RenderBundle.h b/src/dawn/native/RenderBundle.h
index 9297e01..f86eb5f 100644
--- a/src/dawn/native/RenderBundle.h
+++ b/src/dawn/native/RenderBundle.h
@@ -52,6 +52,7 @@
     const AttachmentState* GetAttachmentState() const;
     bool IsDepthReadOnly() const;
     bool IsStencilReadOnly() const;
+    uint64_t GetDrawCount() const;
     const RenderPassResourceUsage& GetResourceUsage() const;
     const IndirectDrawMetadata& GetIndirectDrawMetadata();
 
@@ -65,6 +66,7 @@
     Ref<AttachmentState> mAttachmentState;
     bool mDepthReadOnly;
     bool mStencilReadOnly;
+    uint64_t mDrawCount;
     RenderPassResourceUsage mResourceUsage;
 };
 
diff --git a/src/dawn/native/RenderEncoderBase.cpp b/src/dawn/native/RenderEncoderBase.cpp
index 242e47b..36aa9e1 100644
--- a/src/dawn/native/RenderEncoderBase.cpp
+++ b/src/dawn/native/RenderEncoderBase.cpp
@@ -75,6 +75,11 @@
     return mStencilReadOnly;
 }
 
+uint64_t RenderEncoderBase::GetDrawCount() const {
+    ASSERT(!IsError());
+    return mDrawCount;
+}
+
 Ref<AttachmentState> RenderEncoderBase::AcquireAttachmentState() {
     return std::move(mAttachmentState);
 }
@@ -104,6 +109,8 @@
             draw->firstVertex = firstVertex;
             draw->firstInstance = firstInstance;
 
+            mDrawCount++;
+
             return {};
         },
         "encoding %s.Draw(%u, %u, %u, %u).", this, vertexCount, instanceCount, firstVertex,
@@ -141,6 +148,8 @@
             draw->baseVertex = baseVertex;
             draw->firstInstance = firstInstance;
 
+            mDrawCount++;
+
             return {};
         },
         "encoding %s.DrawIndexed(%u, %u, %u, %i, %u).", this, indexCount, instanceCount, firstIndex,
@@ -191,6 +200,8 @@
             // backend.
             mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
 
+            mDrawCount++;
+
             return {};
         },
         "encoding %s.DrawIndirect(%s, %u).", this, indirectBuffer, indirectOffset);
@@ -243,6 +254,8 @@
             // backend.
             mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
 
+            mDrawCount++;
+
             return {};
         },
         "encoding %s.DrawIndexedIndirect(%s, %u).", this, indirectBuffer, indirectOffset);
diff --git a/src/dawn/native/RenderEncoderBase.h b/src/dawn/native/RenderEncoderBase.h
index 0bdcc4d..0cb675a 100644
--- a/src/dawn/native/RenderEncoderBase.h
+++ b/src/dawn/native/RenderEncoderBase.h
@@ -62,6 +62,7 @@
     const AttachmentState* GetAttachmentState() const;
     bool IsDepthReadOnly() const;
     bool IsStencilReadOnly() const;
+    uint64_t GetDrawCount() const;
     Ref<AttachmentState> AcquireAttachmentState();
 
   protected:
@@ -74,6 +75,8 @@
     RenderPassResourceUsageTracker mUsageTracker;
     IndirectDrawMetadata mIndirectDrawMetadata;
 
+    uint64_t mDrawCount = 0;
+
   private:
     Ref<AttachmentState> mAttachmentState;
     const bool mDisableBaseVertex;
diff --git a/src/dawn/native/RenderPassEncoder.cpp b/src/dawn/native/RenderPassEncoder.cpp
index 716ce97..9994cf5 100644
--- a/src/dawn/native/RenderPassEncoder.cpp
+++ b/src/dawn/native/RenderPassEncoder.cpp
@@ -20,6 +20,7 @@
 
 #include "dawn/common/Constants.h"
 #include "dawn/native/Buffer.h"
+#include "dawn/native/ChainUtils_autogen.h"
 #include "dawn/native/CommandEncoder.h"
 #include "dawn/native/CommandValidation.h"
 #include "dawn/native/Commands.h"
@@ -72,6 +73,11 @@
       mOcclusionQuerySet(descriptor->occlusionQuerySet),
       mTimestampWritesAtEnd(std::move(timestampWritesAtEnd)) {
     mUsageTracker = std::move(usageTracker);
+    const RenderPassDescriptorMaxDrawCount* maxDrawCountInfo = nullptr;
+    FindInChain(descriptor->nextInChain, &maxDrawCountInfo);
+    if (maxDrawCountInfo) {
+        mMaxDrawCount = maxDrawCountInfo->maxDrawCount;
+    }
     TrackInDevice();
 }
 
@@ -140,6 +146,10 @@
                     mOcclusionQueryActive,
                     "Render pass %s ended with incomplete occlusion query index %u of %s.", this,
                     mCurrentOcclusionQueryIndex, mOcclusionQuerySet.Get());
+
+                DAWN_INVALID_IF(mDrawCount > mMaxDrawCount,
+                                "The drawCount (%u) of %s is greater than the maxDrawCount (%u).",
+                                mDrawCount, this, mMaxDrawCount);
             }
 
             EndRenderPassCmd* cmd = allocator->Allocate<EndRenderPassCmd>(Command::EndRenderPass);
@@ -320,6 +330,8 @@
                 if (IsValidationEnabled()) {
                     mIndirectDrawMetadata.AddBundle(renderBundles[i]);
                 }
+
+                mDrawCount += bundles[i]->GetDrawCount();
             }
 
             return {};
diff --git a/src/dawn/native/RenderPassEncoder.h b/src/dawn/native/RenderPassEncoder.h
index ad4c130..45714d0 100644
--- a/src/dawn/native/RenderPassEncoder.h
+++ b/src/dawn/native/RenderPassEncoder.h
@@ -97,6 +97,9 @@
     uint32_t mCurrentOcclusionQueryIndex = 0;
     bool mOcclusionQueryActive = false;
 
+    // This is the hardcoded value in the WebGPU spec.
+    uint64_t mMaxDrawCount = 50000000;
+
     std::vector<TimestampWrite> mTimestampWritesAtEnd;
 };
 
diff --git a/src/dawn/tests/unittests/validation/RenderPassDescriptorValidationTests.cpp b/src/dawn/tests/unittests/validation/RenderPassDescriptorValidationTests.cpp
index 452a222..6a85a46 100644
--- a/src/dawn/tests/unittests/validation/RenderPassDescriptorValidationTests.cpp
+++ b/src/dawn/tests/unittests/validation/RenderPassDescriptorValidationTests.cpp
@@ -18,6 +18,8 @@
 
 #include "dawn/common/Constants.h"
 
+#include "dawn/utils/ComboRenderBundleEncoderDescriptor.h"
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
 
 namespace {
@@ -544,6 +546,211 @@
     AssertBeginRenderPassError(&renderPass);
 }
 
+// drawCount must not exceed maxDrawCount
+TEST_F(RenderPassDescriptorValidationTest, MaxDrawCount) {
+    constexpr wgpu::TextureFormat kColorFormat = wgpu::TextureFormat::RGBA8Unorm;
+    constexpr uint64_t kMaxDrawCount = 16;
+
+    wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+        @vertex fn main() -> @builtin(position) vec4<f32> {
+            return vec4<f32>(0.0, 0.0, 0.0, 1.0);
+        })");
+
+    wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
+        @fragment fn main() -> @location(0) vec4<f32> {
+            return vec4<f32>(0.0, 1.0, 0.0, 1.0);
+        })");
+
+    utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+    pipelineDescriptor.vertex.module = vsModule;
+    pipelineDescriptor.cFragment.module = fsModule;
+    wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&pipelineDescriptor);
+
+    wgpu::TextureDescriptor colorTextureDescriptor;
+    colorTextureDescriptor.size = {1, 1};
+    colorTextureDescriptor.format = kColorFormat;
+    colorTextureDescriptor.usage = wgpu::TextureUsage::RenderAttachment;
+    wgpu::Texture colorTexture = device.CreateTexture(&colorTextureDescriptor);
+
+    utils::ComboRenderBundleEncoderDescriptor bundleEncoderDescriptor;
+    bundleEncoderDescriptor.colorFormatsCount = 1;
+    bundleEncoderDescriptor.cColorFormats[0] = kColorFormat;
+
+    wgpu::Buffer indexBuffer =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Index, {0, 1, 2});
+    wgpu::Buffer indirectBuffer =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {3, 1, 0, 0});
+    wgpu::Buffer indexedIndirectBuffer =
+        utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {3, 1, 0, 0, 0});
+
+    wgpu::RenderPassDescriptorMaxDrawCount maxDrawCount;
+    maxDrawCount.maxDrawCount = kMaxDrawCount;
+
+    // Valid. drawCount is less than the default maxDrawCount.
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.Draw(3);
+        }
+
+        renderPass.End();
+        encoder.Finish();
+    }
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+        renderPass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.DrawIndexed(3);
+        }
+
+        renderPass.End();
+        encoder.Finish();
+    }
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.DrawIndirect(indirectBuffer, 0);
+        }
+
+        renderPass.End();
+        encoder.Finish();
+    }
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+        renderPass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.DrawIndexedIndirect(indexedIndirectBuffer, 0);
+        }
+
+        renderPass.End();
+        encoder.Finish();
+    }
+
+    {
+        wgpu::RenderBundleEncoder renderBundleEncoder =
+            device.CreateRenderBundleEncoder(&bundleEncoderDescriptor);
+        renderBundleEncoder.SetPipeline(pipeline);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderBundleEncoder.Draw(3);
+        }
+
+        wgpu::RenderBundle renderBundle = renderBundleEncoder.Finish();
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.ExecuteBundles(1, &renderBundle);
+        renderPass.End();
+        encoder.Finish();
+    }
+
+    // Invalid. drawCount counts up with draw calls and
+    // it is greater than maxDrawCount.
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        renderPassDescriptor.nextInChain = &maxDrawCount;
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.Draw(3);
+        }
+
+        renderPass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        renderPassDescriptor.nextInChain = &maxDrawCount;
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+        renderPass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.DrawIndexed(3);
+        }
+
+        renderPass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        renderPassDescriptor.nextInChain = &maxDrawCount;
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.DrawIndirect(indirectBuffer, 0);
+        }
+
+        renderPass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        renderPassDescriptor.nextInChain = &maxDrawCount;
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.SetPipeline(pipeline);
+        renderPass.SetIndexBuffer(indexBuffer, wgpu::IndexFormat::Uint32);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderPass.DrawIndexedIndirect(indexedIndirectBuffer, 0);
+        }
+
+        renderPass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+
+    {
+        wgpu::RenderBundleEncoder renderBundleEncoder =
+            device.CreateRenderBundleEncoder(&bundleEncoderDescriptor);
+        renderBundleEncoder.SetPipeline(pipeline);
+
+        for (uint64_t i = 0; i <= kMaxDrawCount; i++) {
+            renderBundleEncoder.Draw(3);
+        }
+
+        wgpu::RenderBundle renderBundle = renderBundleEncoder.Finish();
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        utils::ComboRenderPassDescriptor renderPassDescriptor({colorTexture.CreateView()});
+        renderPassDescriptor.nextInChain = &maxDrawCount;
+        wgpu::RenderPassEncoder renderPass = encoder.BeginRenderPass(&renderPassDescriptor);
+        renderPass.ExecuteBundles(1, &renderBundle);
+        renderPass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+}
+
 class MultisampledRenderPassDescriptorValidationTest : public RenderPassDescriptorValidationTest {
   public:
     utils::ComboRenderPassDescriptor CreateMultisampledRenderPass() {