Workaround empty occlusion queries on Apple GPUs

Fixed: dawn:1707
Change-Id: I73f2b595c57830266e7d34ca9e483d803a00d331
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/125180
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn/native/Toggles.cpp b/src/dawn/native/Toggles.cpp
index ca22e07..120be2d 100644
--- a/src/dawn/native/Toggles.cpp
+++ b/src/dawn/native/Toggles.cpp
@@ -358,6 +358,12 @@
       "subresource are completely initialized, and StoreOp::Discard is always translated as a "
       "Store.",
       "https://crbug.com/dawn/838", ToggleStage::Device}},
+    {Toggle::MetalFillEmptyOcclusionQueriesWithZero,
+     {"metal_fill_empty_occlusion_queries_with_zero",
+      "Apple GPUs leave stale results in the visibility result buffer instead of writing zero if "
+      "an occlusion query is empty. Workaround this by explicitly filling it with zero if there "
+      "are no draw calls.",
+      "https://crbug.com/dawn/1707", ToggleStage::Device}},
     {Toggle::UseBlitForBufferToDepthTextureCopy,
      {"use_blit_for_buffer_to_depth_texture_copy",
       "Use a blit instead of a copy command to copy buffer data to the depth aspect of a "
diff --git a/src/dawn/native/Toggles.h b/src/dawn/native/Toggles.h
index be459a7..c77a08d 100644
--- a/src/dawn/native/Toggles.h
+++ b/src/dawn/native/Toggles.h
@@ -87,6 +87,7 @@
     MetalUseCombinedDepthStencilFormatForStencil8,
     MetalUseBothDepthAndStencilAttachmentsForCombinedDepthStencilFormats,
     MetalKeepMultisubresourceDepthStencilTexturesInitialized,
+    MetalFillEmptyOcclusionQueriesWithZero,
     UseBlitForBufferToDepthTextureCopy,
     UseBlitForBufferToStencilTextureCopy,
     UseBlitForDepthTextureToTextureCopyToNonzeroSubresource,
diff --git a/src/dawn/native/metal/BackendMTL.mm b/src/dawn/native/metal/BackendMTL.mm
index 2e9e5de..bc4fd4a 100644
--- a/src/dawn/native/metal/BackendMTL.mm
+++ b/src/dawn/native/metal/BackendMTL.mm
@@ -404,6 +404,10 @@
                                    true);
         }
 
+        if (gpu_info::IsApple(vendorId)) {
+            deviceToggles->Default(Toggle::MetalFillEmptyOcclusionQueriesWithZero, true);
+        }
+
         // Local testing shows the workaround is needed on AMD Radeon HD 8870M (gcn-1) MacOS 12.1;
         // not on AMD Radeon Pro 555 (gcn-4) MacOS 13.1.
         // Conservatively enable the workaround on AMD unless the system is MacOS 13.1+
diff --git a/src/dawn/native/metal/CommandBufferMTL.h b/src/dawn/native/metal/CommandBufferMTL.h
index 76499fe..2125e2a 100644
--- a/src/dawn/native/metal/CommandBufferMTL.h
+++ b/src/dawn/native/metal/CommandBufferMTL.h
@@ -15,6 +15,9 @@
 #ifndef SRC_DAWN_NATIVE_METAL_COMMANDBUFFERMTL_H_
 #define SRC_DAWN_NATIVE_METAL_COMMANDBUFFERMTL_H_
 
+#include <set>
+#include <utility>
+
 #include "dawn/native/CommandBuffer.h"
 #include "dawn/native/Commands.h"
 #include "dawn/native/Error.h"
@@ -32,6 +35,7 @@
 class CommandRecordingContext;
 class Device;
 class Texture;
+class QuerySet;
 
 void RecordCopyBufferToTexture(CommandRecordingContext* commandContext,
                                id<MTLBuffer> mtlBuffer,
@@ -60,8 +64,15 @@
 
     MaybeError EncodeComputePass(CommandRecordingContext* commandContext,
                                  BeginComputePassCmd* computePassCmd);
+
+    // Empty occlusion queries aren't filled to zero on Apple GPUs. This set is used to
+    // track which results should be explicitly zero'ed as a workaround. Use of empty queries
+    // *should* mostly be a degenerate scenario, so this std::set shouldn't be performance-critical.
+    // The set is passed as nullptr to `EncodeRenderPass` if the workaround is not in use.
+    using EmptyOcclusionQueries = std::set<std::pair<QuerySet*, uint32_t>>;
     MaybeError EncodeRenderPass(id<MTLRenderCommandEncoder> encoder,
-                                BeginRenderPassCmd* renderPassCmd);
+                                BeginRenderPassCmd* renderPassCmd,
+                                EmptyOcclusionQueries* emptyOcclusionQueries);
 };
 
 }  // namespace dawn::native::metal
diff --git a/src/dawn/native/metal/CommandBufferMTL.mm b/src/dawn/native/metal/CommandBufferMTL.mm
index 2812563..2ae9526 100644
--- a/src/dawn/native/metal/CommandBufferMTL.mm
+++ b/src/dawn/native/metal/CommandBufferMTL.mm
@@ -812,12 +812,25 @@
                 Device* device = ToBackend(GetDevice());
                 NSRef<MTLRenderPassDescriptor> descriptor = CreateMTLRenderPassDescriptor(
                     device, cmd, device->UseCounterSamplingAtStageBoundary());
+
+                EmptyOcclusionQueries emptyOcclusionQueries;
                 DAWN_TRY(EncodeMetalRenderPass(
                     device, commandContext, descriptor.Get(), cmd->width, cmd->height,
-                    [this](id<MTLRenderCommandEncoder> encoder, BeginRenderPassCmd* cmd)
-                        -> MaybeError { return this->EncodeRenderPass(encoder, cmd); },
+                    [&](id<MTLRenderCommandEncoder> encoder,
+                        BeginRenderPassCmd* cmd) -> MaybeError {
+                        return this->EncodeRenderPass(
+                            encoder, cmd,
+                            device->IsToggleEnabled(Toggle::MetalFillEmptyOcclusionQueriesWithZero)
+                                ? &emptyOcclusionQueries
+                                : nullptr);
+                    },
                     cmd));
-
+                for (const auto& [querySet, queryIndex] : emptyOcclusionQueries) {
+                    [commandContext->EnsureBlit()
+                        fillBuffer:querySet->GetVisibilityBuffer()
+                             range:NSMakeRange(queryIndex * sizeof(uint64_t), sizeof(uint64_t))
+                             value:0u];
+                }
                 nextRenderPassNumber++;
                 break;
             }
@@ -1347,7 +1360,8 @@
 }
 
 MaybeError CommandBuffer::EncodeRenderPass(id<MTLRenderCommandEncoder> encoder,
-                                           BeginRenderPassCmd* renderPassCmd) {
+                                           BeginRenderPassCmd* renderPassCmd,
+                                           EmptyOcclusionQueries* emptyOcclusionQueries) {
     bool enableVertexPulling = GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling);
     RenderPipeline* lastPipeline = nullptr;
     id<MTLBuffer> indexBuffer = nullptr;
@@ -1355,6 +1369,8 @@
     MTLIndexType indexBufferType;
     uint64_t indexFormatSize = 0;
 
+    bool didDrawInCurrentOcclusionQuery = false;
+
     StorageBufferLengthTracker storageBufferLengths = {};
     VertexBufferTracker vertexBuffers(&storageBufferLengths);
     BindGroupTracker bindGroups(&storageBufferLengths);
@@ -1390,12 +1406,14 @@
                                     vertexStart:draw->firstVertex
                                     vertexCount:draw->vertexCount
                                   instanceCount:draw->instanceCount];
+                        didDrawInCurrentOcclusionQuery = true;
                     } else {
                         [encoder drawPrimitives:lastPipeline->GetMTLPrimitiveTopology()
                                     vertexStart:draw->firstVertex
                                     vertexCount:draw->vertexCount
                                   instanceCount:draw->instanceCount
                                    baseInstance:draw->firstInstance];
+                        didDrawInCurrentOcclusionQuery = true;
                     }
                 }
                 break;
@@ -1420,6 +1438,7 @@
                                      indexBufferOffset:indexBufferBaseOffset +
                                                        draw->firstIndex * indexFormatSize
                                          instanceCount:draw->instanceCount];
+                        didDrawInCurrentOcclusionQuery = true;
                     } else {
                         [encoder drawIndexedPrimitives:lastPipeline->GetMTLPrimitiveTopology()
                                             indexCount:draw->indexCount
@@ -1430,6 +1449,7 @@
                                          instanceCount:draw->instanceCount
                                             baseVertex:draw->baseVertex
                                           baseInstance:draw->firstInstance];
+                        didDrawInCurrentOcclusionQuery = true;
                     }
                 }
                 break;
@@ -1448,6 +1468,7 @@
                 [encoder drawPrimitives:lastPipeline->GetMTLPrimitiveTopology()
                           indirectBuffer:indirectBuffer
                     indirectBufferOffset:draw->indirectOffset];
+                didDrawInCurrentOcclusionQuery = true;
                 break;
             }
 
@@ -1469,6 +1490,7 @@
                              indexBufferOffset:indexBufferBaseOffset
                                 indirectBuffer:indirectBuffer
                           indirectBufferOffset:draw->indirectOffset];
+                didDrawInCurrentOcclusionQuery = true;
                 break;
             }
 
@@ -1649,6 +1671,7 @@
 
                 [encoder setVisibilityResultMode:MTLVisibilityResultModeBoolean
                                           offset:cmd->queryIndex * sizeof(uint64_t)];
+                didDrawInCurrentOcclusionQuery = false;
                 break;
             }
 
@@ -1657,6 +1680,20 @@
 
                 [encoder setVisibilityResultMode:MTLVisibilityResultModeDisabled
                                           offset:cmd->queryIndex * sizeof(uint64_t)];
+                if (emptyOcclusionQueries) {
+                    // Empty occlusion queries aren't filled to zero on Apple GPUs.
+                    // Keep track of them so we can clear them if necessary.
+                    auto key = std::make_pair(ToBackend(renderPassCmd->occlusionQuerySet.Get()),
+                                              cmd->queryIndex);
+                    if (!didDrawInCurrentOcclusionQuery) {
+                        emptyOcclusionQueries->insert(std::move(key));
+                    } else {
+                        auto it = emptyOcclusionQueries->find(std::move(key));
+                        if (it != emptyOcclusionQueries->end()) {
+                            emptyOcclusionQueries->erase(it);
+                        }
+                    }
+                }
                 break;
             }
 
diff --git a/src/dawn/tests/end2end/QueryTests.cpp b/src/dawn/tests/end2end/QueryTests.cpp
index c5caf88..0635c75 100644
--- a/src/dawn/tests/end2end/QueryTests.cpp
+++ b/src/dawn/tests/end2end/QueryTests.cpp
@@ -366,8 +366,6 @@
 
 // Test setting an occlusion query to non-zero, then rewriting it without drawing, resolves to 0.
 TEST_P(OcclusionQueryTests, RewriteNoDrawToZero) {
-    // TODO(crbug.com/dawn/1707): The second query does not reset it to 0.
-    DAWN_SUPPRESS_TEST_IF(IsMacOS() && IsMetal() && IsApple());
     constexpr uint32_t kQueryCount = 1;
 
     wgpu::QuerySet querySet = CreateOcclusionQuerySet(kQueryCount);
@@ -406,8 +404,6 @@
 // Test setting an occlusion query to non-zero, then rewriting it without drawing, resolves to 0.
 // Do the two queries+resolves in separate submits.
 TEST_P(OcclusionQueryTests, RewriteNoDrawToZeroSeparateSubmit) {
-    // TODO(crbug.com/dawn/1707): The second query does not reset it to 0.
-    DAWN_SUPPRESS_TEST_IF(IsMacOS() && IsMetal() && IsApple());
     constexpr uint32_t kQueryCount = 1;
 
     wgpu::QuerySet querySet = CreateOcclusionQuerySet(kQueryCount);
@@ -445,6 +441,66 @@
                   new OcclusionExpectation(OcclusionExpectation::Result::Zero));
 }
 
+// Test that resetting an occlusion query to zero works when a draw is done where all primitives
+// fail the depth test.
+TEST_P(OcclusionQueryTests, RewriteToZeroWithDraw) {
+    constexpr uint32_t kQueryCount = 1;
+
+    utils::ComboRenderPipelineDescriptor descriptor;
+    descriptor.vertex.module = vsModule;
+    descriptor.cFragment.module = fsModule;
+
+    // Enable depth and stencil tests and set comparison tests to never pass.
+    wgpu::DepthStencilState* depthStencil = descriptor.EnableDepthStencil(kDepthStencilFormat);
+    depthStencil->depthCompare = wgpu::CompareFunction::Never;
+    depthStencil->stencilFront.compare = wgpu::CompareFunction::Never;
+    depthStencil->stencilBack.compare = wgpu::CompareFunction::Never;
+
+    wgpu::RenderPipeline renderPipeline = device.CreateRenderPipeline(&descriptor);
+
+    wgpu::Texture renderTarget = CreateRenderTexture(kColorFormat);
+    wgpu::TextureView renderTargetView = renderTarget.CreateView();
+
+    wgpu::Texture depthTexture = CreateRenderTexture(kDepthStencilFormat);
+    wgpu::TextureView depthTextureView = depthTexture.CreateView();
+
+    wgpu::QuerySet querySet = CreateOcclusionQuerySet(kQueryCount);
+    wgpu::Buffer destination = CreateResolveBuffer(kQueryCount * sizeof(uint64_t));
+    // Set all bits in buffer to check 0 is correctly written if there is no sample passed the
+    // occlusion testing
+    queue.WriteBuffer(destination, 0, &kSentinelValue, sizeof(kSentinelValue));
+
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    {
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+        renderPass.renderPassInfo.occlusionQuerySet = querySet;
+
+        wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+        pass.SetPipeline(pipeline);
+        pass.BeginOcclusionQuery(0);
+        pass.Draw(3);
+        pass.EndOcclusionQuery();
+        pass.End();
+    }
+    {
+        utils::ComboRenderPassDescriptor renderPass({renderTargetView}, depthTextureView);
+        renderPass.occlusionQuerySet = querySet;
+
+        wgpu::RenderPassEncoder rewritePass = encoder.BeginRenderPass(&renderPass);
+        rewritePass.SetPipeline(renderPipeline);
+        rewritePass.BeginOcclusionQuery(0);
+        rewritePass.Draw(3);
+        rewritePass.EndOcclusionQuery();
+        rewritePass.End();
+    }
+    encoder.ResolveQuerySet(querySet, 0, kQueryCount, destination, 0);
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_BUFFER(destination, 0, sizeof(uint64_t),
+                  new OcclusionExpectation(OcclusionExpectation::Result::Zero));
+}
+
 // Test resolving occlusion query to the destination buffer with offset
 TEST_P(OcclusionQueryTests, ResolveToBufferWithOffset) {
     constexpr uint32_t kQueryCount = 2;
@@ -1227,7 +1283,11 @@
     }
 }
 
-DAWN_INSTANTIATE_TEST(OcclusionQueryTests, D3D12Backend(), MetalBackend(), VulkanBackend());
+DAWN_INSTANTIATE_TEST(OcclusionQueryTests,
+                      D3D12Backend(),
+                      MetalBackend(),
+                      MetalBackend({"metal_fill_empty_occlusion_queries_with_zero"}),
+                      VulkanBackend());
 DAWN_INSTANTIATE_TEST(PipelineStatisticsQueryTests,
                       D3D12Backend(),
                       MetalBackend(),