[dawn][metal] ArgBufs: call useResource for all resources in the usage scope

This is required for Metal to manage both residency and hazards for us,
since it can't see what resources are used in the argument buffers.
It's similar to other backends, but a bit simpler because we're still
letting Metal do automatic hazard tracking for us).

Metal Shader Validation catches this issue, so verified this fix using
both Animometer and ComputeBoids, with:
  MTL_SHADER_VALIDATION=1 MTL_SHADER_VALIDATION_REPORT_TO_STDERR=1
and:
  --enable-toggles=metal_use_argument_buffers --enable-backend-validation=full

Fixed: 477317116
Change-Id: I9cab58582ff08d52966ac06ba86df33399d4aa8c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/286136
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn/native/metal/BackendMTL.mm b/src/dawn/native/metal/BackendMTL.mm
index 78ea734..73fd94a 100644
--- a/src/dawn/native/metal/BackendMTL.mm
+++ b/src/dawn/native/metal/BackendMTL.mm
@@ -46,7 +46,7 @@
         return true;
     }
 
-    // Sometime validation layer can be enabled eternally via xcode or command line.
+    // Validation layer can also be enabled externally via Xcode or command line.
     if (GetEnvironmentVar("METAL_DEVICE_WRAPPER_TYPE").first == "1" ||
         GetEnvironmentVar("MTL_DEBUG_LAYER").first == "1") {
         return true;
diff --git a/src/dawn/native/metal/BindGroupLayoutMTL.mm b/src/dawn/native/metal/BindGroupLayoutMTL.mm
index 21b5905..c6d5814 100644
--- a/src/dawn/native/metal/BindGroupLayoutMTL.mm
+++ b/src/dawn/native/metal/BindGroupLayoutMTL.mm
@@ -75,7 +75,7 @@
                 DAWN_CHECK(false);
             },
             [&](const TextureBindingInfo&) { desc.dataType = MTLDataTypeTexture; },
-            [&](const StorageTextureBindingInfo&) { DAWN_CHECK(false); },
+            [&](const StorageTextureBindingInfo&) { desc.dataType = MTLDataTypeTexture; },
             [&](const TexelBufferBindingInfo&) { DAWN_CHECK(false); },
             [](const InputAttachmentBindingInfo&) { DAWN_CHECK(false); },
             [](const ExternalTextureBindingInfo&) { DAWN_CHECK(false); });
diff --git a/src/dawn/native/metal/CommandBufferMTL.h b/src/dawn/native/metal/CommandBufferMTL.h
index acbc0f1..a4a6878 100644
--- a/src/dawn/native/metal/CommandBufferMTL.h
+++ b/src/dawn/native/metal/CommandBufferMTL.h
@@ -78,7 +78,8 @@
     using CommandBufferBase::CommandBufferBase;
 
     MaybeError EncodeComputePass(CommandRecordingContext* commandContext,
-                                 BeginComputePassCmd* computePassCmd);
+                                 BeginComputePassCmd* computePassCmd,
+                                 const ComputePassResourceUsage& resourceUsage);
 
     // Empty occlusion queries aren't filled to zero on Apple GPUs. This set is used to
     // track which results should be explicitly zero'ed as a workaround. Use of empty queries
diff --git a/src/dawn/native/metal/CommandBufferMTL.mm b/src/dawn/native/metal/CommandBufferMTL.mm
index 4eb5c5e..83437c3 100644
--- a/src/dawn/native/metal/CommandBufferMTL.mm
+++ b/src/dawn/native/metal/CommandBufferMTL.mm
@@ -28,6 +28,7 @@
 #include "dawn/native/metal/CommandBufferMTL.h"
 
 #include "absl/container/flat_hash_map.h"
+#include "dawn/common/Assert.h"
 #include "dawn/common/MatchVariant.h"
 #include "dawn/common/Range.h"
 #include "dawn/native/BindGroupTracker.h"
@@ -36,6 +37,7 @@
 #include "dawn/native/DynamicUploader.h"
 #include "dawn/native/ExternalTexture.h"
 #include "dawn/native/ImmediateConstantsTracker.h"
+#include "dawn/native/PassResourceUsage.h"
 #include "dawn/native/Queue.h"
 #include "dawn/native/RenderBundle.h"
 #include "dawn/native/metal/BindGroupLayoutMTL.h"
@@ -732,9 +734,7 @@
             auto HandleTextureBinding = [&]() {
                 auto textureView = ToBackend(group->GetBindingAsTextureView(bindingIndex));
                 id<MTLTexture> texture = textureView->GetMTLTexture();
-                if (mUseArgumentBuffers) {
-                    // TODO(crbug.com/477317116): Need to make texture resident.
-                } else {
+                if (!mUseArgumentBuffers) {
                     if (hasVertStage &&
                         mBoundTextures[SingleShaderStage::Vertex][vertIndex] != texture) {
                         mBoundTextures[SingleShaderStage::Vertex][vertIndex] = texture;
@@ -774,9 +774,7 @@
                     }
 
                     const id<MTLBuffer> buffer = ToBackend(binding.buffer)->GetMTLBuffer();
-                    if (mUseArgumentBuffers) {
-                        // TODO(crbug.com/477317116): Need to make buffer resident.
-                    } else {
+                    if (!mUseArgumentBuffers) {
                         NSUInteger offset = binding.offset;
 
                         // TODO(crbug.com/dawn/854): Record bound buffer status to use
@@ -809,9 +807,7 @@
                 [&](const SamplerBindingInfo&) {
                     auto sampler = ToBackend(group->GetBindingAsSampler(bindingIndex));
                     id<MTLSamplerState> samplerState = sampler->GetMTLSamplerState();
-                    if (mUseArgumentBuffers) {
-                        // (Note useResource is not needed; MTLSamplerState is not a MTLResource.)
-                    } else {
+                    if (!mUseArgumentBuffers) {
                         if (hasVertStage &&
                             mBoundSamplers[SingleShaderStage::Vertex][vertIndex] != samplerState) {
                             mBoundSamplers[SingleShaderStage::Vertex][vertIndex] = samplerState;
@@ -839,7 +835,7 @@
                     DAWN_CHECK(false);
                 },
                 [&](const TextureBindingInfo&) { HandleTextureBinding(); },
-                [&](const StorageTextureBindingInfo&) { HandleTextureBinding(); },
+                [&](const StorageTextureBindingInfo& info) { HandleTextureBinding(); },
                 [&](const TexelBufferBindingInfo&) {
                     // Metal does not support texel buffers.
                     // TODO(crbug/382544164): Prototype texel buffer feature
@@ -1096,17 +1092,17 @@
             case Command::BeginComputePass: {
                 BeginComputePassCmd* cmd = mCommands.NextCommand<BeginComputePassCmd>();
 
-                for (TextureBase* texture :
-                     GetResourceUsages().computePasses[nextComputePassNumber].referencedTextures) {
+                const ComputePassResourceUsage& resourceUsage =
+                    GetResourceUsages().computePasses[nextComputePassNumber];
+                for (TextureBase* texture : resourceUsage.referencedTextures) {
                     ToBackend(texture)->SynchronizeTextureBeforeUse(commandContext);
                 }
-                for (const SyncScopeResourceUsage& scope :
-                     GetResourceUsages().computePasses[nextComputePassNumber].dispatchUsages) {
+                for (const SyncScopeResourceUsage& scope : resourceUsage.dispatchUsages) {
                     DAWN_TRY(LazyClearSyncScope(scope, commandContext));
                 }
                 commandContext->EndBlit();
 
-                DAWN_TRY(EncodeComputePass(commandContext, cmd));
+                DAWN_TRY(EncodeComputePass(commandContext, cmd, resourceUsage));
 
                 nextComputePassNumber++;
                 break;
@@ -1115,13 +1111,12 @@
             case Command::BeginRenderPass: {
                 BeginRenderPassCmd* cmd = mCommands.NextCommand<BeginRenderPassCmd>();
 
-                for (TextureBase* texture :
-                     this->GetResourceUsages().renderPasses[nextRenderPassNumber].textures) {
+                const RenderPassResourceUsage& resourceUsage =
+                    GetResourceUsages().renderPasses[nextRenderPassNumber];
+                for (TextureBase* texture : resourceUsage.textures) {
                     ToBackend(texture)->SynchronizeTextureBeforeUse(commandContext);
                 }
-                for (ExternalTextureBase* externalTexture : this->GetResourceUsages()
-                                                                .renderPasses[nextRenderPassNumber]
-                                                                .externalTextures) {
+                for (ExternalTextureBase* externalTexture : resourceUsage.externalTextures) {
                     for (auto& view : externalTexture->GetTextureViews()) {
                         if (view.Get()) {
                             Texture* texture = ToBackend(view->GetTexture());
@@ -1129,8 +1124,7 @@
                         }
                     }
                 }
-                DAWN_TRY(LazyClearSyncScope(GetResourceUsages().renderPasses[nextRenderPassNumber],
-                                            commandContext));
+                DAWN_TRY(LazyClearSyncScope(resourceUsage, commandContext));
                 commandContext->EndBlit();
 
                 // Before beginning, we encode a compute pass that converts multi draws into an ICB
@@ -1166,7 +1160,8 @@
 
                 EmptyOcclusionQueries emptyOcclusionQueries;
                 DAWN_TRY(EncodeMetalRenderPass(
-                    device, commandContext, descriptor.Get(), cmd->width, cmd->height,
+                    device, commandContext, &resourceUsage, descriptor.Get(), cmd->width,
+                    cmd->height,
                     [&](id<MTLRenderCommandEncoder> encoder,
                         BeginRenderPassCmd* cmd) -> MaybeError {
                         return this->EncodeRenderPass(
@@ -1545,7 +1540,9 @@
 }
 
 MaybeError CommandBuffer::EncodeComputePass(CommandRecordingContext* commandContext,
-                                            BeginComputePassCmd* computePassCmd) {
+                                            BeginComputePassCmd* computePassCmd,
+                                            const ComputePassResourceUsage& resourceUsage) {
+    uint64_t currentDispatch = 0;
     ComputePipeline* lastPipeline = nullptr;
     StorageBufferLengthTracker storageBufferLengths = {};
     BindGroupTracker bindGroups(&storageBufferLengths,
@@ -1616,9 +1613,12 @@
                 bindGroups.Apply(encoder);
                 storageBufferLengths.Apply(lastPipeline);
                 immediates.Apply(encoder, &storageBufferLengths);
+                MetalComputePassMakeResourcesResident(
+                    GetDevice(), encoder, resourceUsage.dispatchUsages[currentDispatch]);
 
                 [encoder dispatchThreadgroups:MTLSizeMake(dispatch->x, dispatch->y, dispatch->z)
                         threadsPerThreadgroup:lastPipeline->GetLocalWorkGroupSize()];
+                currentDispatch++;
                 break;
             }
 
@@ -1628,6 +1628,8 @@
                 bindGroups.Apply(encoder);
                 storageBufferLengths.Apply(lastPipeline);
                 immediates.Apply(encoder, &storageBufferLengths);
+                MetalComputePassMakeResourcesResident(
+                    GetDevice(), encoder, resourceUsage.dispatchUsages[currentDispatch]);
 
                 Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
                 buffer->TrackUsage();
@@ -1636,6 +1638,7 @@
                     dispatchThreadgroupsWithIndirectBuffer:indirectBuffer
                                       indirectBufferOffset:dispatch->indirectOffset
                                      threadsPerThreadgroup:lastPipeline->GetLocalWorkGroupSize()];
+                currentDispatch++;
                 break;
             }
 
diff --git a/src/dawn/native/metal/UtilsMetal.h b/src/dawn/native/metal/UtilsMetal.h
index b5969ab..0315fee 100644
--- a/src/dawn/native/metal/UtilsMetal.h
+++ b/src/dawn/native/metal/UtilsMetal.h
@@ -32,6 +32,7 @@
 
 #include "absl/container/inlined_vector.h"
 #include "dawn/common/NSRef.h"
+#include "dawn/native/PassResourceUsage.h"
 #include "dawn/native/dawn_platform.h"
 #include "dawn/native/metal/DeviceMTL.h"
 #include "dawn/native/metal/ShaderModuleMTL.h"
@@ -124,12 +125,17 @@
     std::function<MaybeError(id<MTLRenderCommandEncoder>, BeginRenderPassCmd* renderPassCmd)>;
 MaybeError EncodeMetalRenderPass(Device* device,
                                  CommandRecordingContext* commandContext,
+                                 const RenderPassResourceUsage* resourceUsage,
                                  MTLRenderPassDescriptor* mtlRenderPass,
                                  uint32_t width,
                                  uint32_t height,
                                  EncodeInsideRenderPass encodeInside,
                                  BeginRenderPassCmd* renderPassCmd = nullptr);
 
+void MetalComputePassMakeResourcesResident(DeviceBase* device,
+                                           id<MTLComputeCommandEncoder> encoder,
+                                           const SyncScopeResourceUsage& resourceUsage);
+
 id<MTLTexture> CreateTextureMtlForPlane(MTLTextureUsage mtlUsage,
                                         const Format& format,
                                         size_t plane,
diff --git a/src/dawn/native/metal/UtilsMetal.mm b/src/dawn/native/metal/UtilsMetal.mm
index d63edf5..5a84fdb 100644
--- a/src/dawn/native/metal/UtilsMetal.mm
+++ b/src/dawn/native/metal/UtilsMetal.mm
@@ -26,15 +26,55 @@
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 #include "dawn/native/metal/UtilsMetal.h"
+#include <Metal/Metal.h>
 
 #include "dawn/common/Assert.h"
+#include "dawn/common/Math.h"
+#include "dawn/native/Buffer.h"
 #include "dawn/native/CommandBuffer.h"
+#include "dawn/native/EnumMaskIterator.h"
 #include "dawn/native/Pipeline.h"
 #include "dawn/native/ShaderModule.h"
+#include "dawn/native/dawn_platform.h"
+#include "dawn/native/metal/BufferMTL.h"
 
 namespace dawn::native::metal {
 
 namespace {
+
+MTLResourceUsage ToMTLResourceUsage(wgpu::BufferUsage usage) {
+    if (IsSubset(usage, kReadOnlyBufferUsages)) {
+        return MTLResourceUsageRead;
+    } else {
+        // Technically some of these usages could be write-only, but we can't tell from here.
+        // Also it might not be safe to tell Metal those are write-only.
+        return MTLResourceUsageRead | MTLResourceUsageWrite;
+    }
+}
+
+MTLResourceUsage ToMTLResourceUsage(wgpu::TextureUsage usage) {
+    if (IsSubset(usage, kReadOnlyTextureUsages)) {
+        return MTLResourceUsageRead;
+    } else {
+        // Technically some of these usages could be write-only, but we can't tell from here.
+        // Also it might not be safe to tell Metal those are write-only.
+        return MTLResourceUsageRead | MTLResourceUsageWrite;
+    }
+}
+
+MTLRenderStages ToMTLRenderStages(wgpu::ShaderStage visibility) {
+    // Note wgpu::ShaderStage::Compute is intentionally ignored here. It may be present in the
+    // visibility (which comes from the bind group layout) but it's not relevant here.
+    MTLRenderStages stages = 0;
+    if (visibility & wgpu::ShaderStage::Vertex) {
+        stages |= MTLRenderStageVertex;
+    }
+    if (visibility & wgpu::ShaderStage::Fragment) {
+        stages |= MTLRenderStageFragment;
+    }
+    return stages;
+}
+
 // A helper struct to track state while doing workarounds for Metal render passes. It
 // contains a temporary texture and information about the attachment it replaces.
 // Helper methods encode copies between the two textures.
@@ -154,6 +194,76 @@
     commandContext->EndRender();
 }
 
+// Overloads with matching signatures to make it simpler to call these in the templated function.
+void MakeResourceResident(id<MTLComputeCommandEncoder> encoder,
+                          id<MTLResource> resource,
+                          MTLResourceUsage usage,
+                          wgpu::ShaderStage stages) {
+    // This is a compute encoder. Skip any resources that can't possibly be visible to it.
+    if (stages & (wgpu::ShaderStage::Compute)) {
+        [encoder useResource:resource usage:usage];
+    }
+}
+void MakeResourceResident(id<MTLRenderCommandEncoder> encoder,
+                          id<MTLResource> resource,
+                          MTLResourceUsage usage,
+                          wgpu::ShaderStage stages) {
+    // This is a render encoder. Skip any resources that can't possibly be visible to it.
+    if (stages & (wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment)) {
+        [encoder useResource:resource usage:usage stages:ToMTLRenderStages(stages)];
+    }
+}
+
+// Templated over MTLComputeCommandEncoder/MTLRenderCommandEncoder.
+template <typename T>
+concept MTLEncoderType = std::is_same_v<T, id<MTLComputeCommandEncoder>> ||
+                         std::is_same_v<T, id<MTLRenderCommandEncoder>>;
+template <MTLEncoderType Encoder>
+void MakeResourcesResident(Encoder encoder, const SyncScopeResourceUsage& resourceUsage) {
+    for (size_t i = 0; i < resourceUsage.buffers.size(); ++i) {
+        id<MTLBuffer> buffer = ToBackend(resourceUsage.buffers[i])->GetMTLBuffer();
+        const auto& info = resourceUsage.bufferSyncInfos[i];
+
+        if (info.shaderStages == wgpu::ShaderStage::None) {
+            // This resource is not passed in an argument buffer, it's only used for something else
+            // (like an index buffer) that gets passed to Metal on the API side.
+            continue;
+        }
+
+        MakeResourceResident(encoder, buffer, ToMTLResourceUsage(info.usage), info.shaderStages);
+    }
+
+    for (size_t i = 0; i < resourceUsage.textures.size(); ++i) {
+        Texture* texture = ToBackend(resourceUsage.textures[i]);
+
+        // Collect all the aspects/usages/stages used for any subresource.
+        Aspect aspects{};
+        wgpu::TextureUsage usages{};
+        wgpu::ShaderStage stages{};
+        resourceUsage.textureSyncInfos[i].Iterate(
+            [&](const SubresourceRange& range, const TextureSyncInfo& syncInfo) {
+                aspects |= range.aspects;
+                usages |= syncInfo.usage;
+                stages |= syncInfo.shaderStages;
+            });
+
+        if (stages == wgpu::ShaderStage::None) {
+            // This resource is not passed in an argument buffer, it's only used for something else
+            // (like a render attachment) that gets passed to Metal on the API side.
+            continue;
+        }
+
+        // There are at most three planes. Call useResource for each plane that is used.
+        const Aspect kAspectsCorrespondingToPlane0{~(Aspect::Plane1 | Aspect::Plane2)};
+        for (Aspect plane : {kAspectsCorrespondingToPlane0, Aspect::Plane1, Aspect::Plane2}) {
+            if (aspects & plane) {
+                MakeResourceResident(encoder, texture->GetMTLTexture(plane),
+                                     ToMTLResourceUsage(usages), stages);
+            }
+        }
+    }
+}
+
 }  // anonymous namespace
 
 MTLPixelFormat MetalPixelFormat(const DeviceBase* device, wgpu::TextureFormat format) {
@@ -668,6 +778,7 @@
 
 MaybeError EncodeMetalRenderPass(Device* device,
                                  CommandRecordingContext* commandContext,
+                                 const RenderPassResourceUsage* resourceUsage,
                                  MTLRenderPassDescriptor* mtlRenderPass,
                                  uint32_t width,
                                  uint32_t height,
@@ -735,8 +846,8 @@
         }
 
         if (workaroundUsed) {
-            DAWN_TRY(EncodeMetalRenderPass(device, commandContext, mtlRenderPass, width, height,
-                                           std::move(encodeInside), renderPassCmd));
+            DAWN_TRY(EncodeMetalRenderPass(device, commandContext, nullptr, mtlRenderPass, width,
+                                           height, std::move(encodeInside), renderPassCmd));
 
             for (uint32_t i = 0; i < kMaxColorAttachments; ++i) {
                 if (originalAttachments[i].texture == nullptr) {
@@ -771,8 +882,8 @@
 
         // If we found a store + MSAA resolve we need to resolve in a different render pass.
         if (hasStoreAndMSAAResolve) {
-            DAWN_TRY(EncodeMetalRenderPass(device, commandContext, mtlRenderPass, width, height,
-                                           std::move(encodeInside), renderPassCmd));
+            DAWN_TRY(EncodeMetalRenderPass(device, commandContext, nullptr, mtlRenderPass, width,
+                                           height, std::move(encodeInside), renderPassCmd));
 
             ResolveInAnotherRenderPass(commandContext, mtlRenderPass, resolveTextures);
             return {};
@@ -781,17 +892,29 @@
 
     // No (more) workarounds needed! We can finally encode the actual render pass.
     commandContext->EndBlit();
-    DAWN_TRY(encodeInside(commandContext->BeginRender(mtlRenderPass), renderPassCmd));
+    auto renderCommandEncoder = commandContext->BeginRender(mtlRenderPass);
+    if (resourceUsage != nullptr && device->IsToggleEnabled(Toggle::MetalUseArgumentBuffers)) {
+        MakeResourcesResident(renderCommandEncoder, *resourceUsage);
+    }
+    DAWN_TRY(encodeInside(renderCommandEncoder, renderPassCmd));
     commandContext->EndRender();
     return {};
 }
 
+void MetalComputePassMakeResourcesResident(DeviceBase* device,
+                                           id<MTLComputeCommandEncoder> encoder,
+                                           const SyncScopeResourceUsage& resourceUsage) {
+    if (device->IsToggleEnabled(Toggle::MetalUseArgumentBuffers)) {
+        MakeResourcesResident(encoder, resourceUsage);
+    }
+}
+
 MaybeError EncodeEmptyMetalRenderPass(Device* device,
                                       CommandRecordingContext* commandContext,
                                       MTLRenderPassDescriptor* mtlRenderPass,
                                       Extent3D size) {
     return EncodeMetalRenderPass(
-        device, commandContext, mtlRenderPass, size.width, size.height,
+        device, commandContext, nullptr, mtlRenderPass, size.width, size.height,
         [&](id<MTLRenderCommandEncoder>, BeginRenderPassCmd*) -> MaybeError { return {}; });
 }