Reuse one uniform buffer for ApplyClearWithDraw

Also format the shaders used by ApplyClearWithDraw

Bug: chromium:341129591
Change-Id: I697398653253d1b2e6b409faa99c475e498af1a3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/188843
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Commit-Queue: Peng Huang <penghuang@chromium.org>
Auto-Submit: Peng Huang <penghuang@chromium.org>
diff --git a/src/dawn/native/ApplyClearColorValueWithDrawHelper.cpp b/src/dawn/native/ApplyClearColorValueWithDrawHelper.cpp
index db0deff..3edd384 100644
--- a/src/dawn/native/ApplyClearColorValueWithDrawHelper.cpp
+++ b/src/dawn/native/ApplyClearColorValueWithDrawHelper.cpp
@@ -28,15 +28,15 @@
 #include "dawn/native/ApplyClearColorValueWithDrawHelper.h"
 
 #include <limits>
-#include <optional>
 #include <string>
 #include <utility>
-#include <vector>
 
 #include "dawn/common/Enumerator.h"
 #include "dawn/common/Range.h"
 #include "dawn/native/BindGroup.h"
 #include "dawn/native/BindGroupLayout.h"
+#include "dawn/native/Buffer.h"
+#include "dawn/native/CommandEncoder.h"
 #include "dawn/native/Device.h"
 #include "dawn/native/InternalPipelineStore.h"
 #include "dawn/native/ObjectContentHasher.h"
@@ -52,15 +52,12 @@
 // General helper functions and data structures for applying clear values with draw
 static const char kVSSource[] = R"(
 @vertex
-fn main(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4f {
+fn main(@builtin(vertex_index) vertexIndex : u32) -> @builtin(position) vec4f {
     var pos = array(
         vec2f(-1.0, -1.0),
-        vec2f( 1.0, -1.0),
-        vec2f(-1.0,  1.0),
-        vec2f(-1.0,  1.0),
-        vec2f( 1.0, -1.0),
-        vec2f( 1.0,  1.0));
-        return vec4f(pos[VertexIndex], 0.0, 1.0);
+        vec2f( 3.0, -1.0),
+        vec2f(-1.0,  3.0));
+    return vec4f(pos[vertexIndex], 0.0, 1.0);
 })";
 
 const char* GetTextureComponentTypeString(DeviceBase* device, wgpu::TextureFormat format) {
@@ -97,11 +94,11 @@
         const char* type = GetTextureComponentTypeString(device, currentFormat);
 
         outputColorDeclarationStream
-            << absl::StrFormat("@location(%u) output%u : vec4<%s>,\n", i, i, type);
+            << absl::StrFormat("    @location(%u) output%u : vec4<%s>,\n", i, i, type);
         clearValueUniformBufferDeclarationStream
-            << absl::StrFormat("color%u : vec4<%s>,\n", i, type);
-        assignOutputColorStream << absl::StrFormat("outputColor.output%u = clearColors.color%u;\n",
-                                                   i, i);
+            << absl::StrFormat("    color%u : vec4<%s>,\n", i, type);
+        assignOutputColorStream << absl::StrFormat(
+            "    outputColor.output%u = clearColors.color%u;\n", i, i);
     }
     outputColorDeclarationStream << "}" << std::endl;
     clearValueUniformBufferDeclarationStream << "}" << std::endl;
@@ -116,7 +113,7 @@
     var outputColor : OutputColor;
 )" << assignOutputColorStream.str()
                          << R"(
-return outputColor;
+    return outputColor;
 })";
     return fragmentShaderStream.str();
 }
@@ -190,7 +187,7 @@
 }
 
 ResultOrError<Ref<BufferBase>> CreateUniformBufferWithClearValues(
-    DeviceBase* device,
+    CommandEncoder* encoder,
     const RenderPassDescriptor* renderPassDescriptor,
     const KeyOfApplyClearColorValueWithDrawPipelines& key) {
     auto colorAttachments = ityp::SpanFromUntyped<ColorAttachmentIndex>(
@@ -235,13 +232,12 @@
 
     DAWN_ASSERT(offset > 0);
 
-    Ref<BufferBase> outputBuffer;
-    DAWN_TRY_ASSIGN(
-        outputBuffer,
-        utils::CreateBufferFromData(device, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
-                                    clearValues.data(), offset));
+    Ref<BufferBase> buffer;
+    DAWN_TRY_ASSIGN(buffer, encoder->GetDevice()->GetOrCreateTemporaryUniformBuffer(offset));
+    buffer->SetLabel("Internal_UniformClearValues");
+    encoder->APIWriteBuffer(buffer.Get(), 0, clearValues.data(), offset);
 
-    return std::move(outputBuffer);
+    return std::move(buffer);
 }
 
 bool NeedsBigIntClear(const RenderPassColorAttachment& colorAttachmentInfo) {
@@ -295,34 +291,34 @@
     return true;
 }
 
-std::optional<KeyOfApplyClearColorValueWithDrawPipelines>
-GetKeyOfApplyClearColorValueWithDrawPipelines(const DeviceBase* device,
-                                              const RenderPassDescriptor* renderPassDescriptor) {
+bool GetKeyOfApplyClearColorValueWithDrawPipelines(
+    const DeviceBase* device,
+    const RenderPassDescriptor* renderPassDescriptor,
+    KeyOfApplyClearColorValueWithDrawPipelines* key) {
     bool clearWithDraw = device->IsToggleEnabled(Toggle::ClearColorWithDraw);
     bool clearWithDrawForBigInt =
         device->IsToggleEnabled(Toggle::ApplyClearBigIntegerColorValueWithDraw);
 
     if (!clearWithDraw && !clearWithDrawForBigInt) {
-        return std::nullopt;
+        return false;
     }
 
-    KeyOfApplyClearColorValueWithDrawPipelines key;
-    key.colorAttachmentCount = renderPassDescriptor->colorAttachmentCount;
+    key->colorAttachmentCount = renderPassDescriptor->colorAttachmentCount;
 
     auto colorAttachments = ityp::SpanFromUntyped<ColorAttachmentIndex>(
         renderPassDescriptor->colorAttachments, renderPassDescriptor->colorAttachmentCount);
 
-    key.colorTargetFormats.fill(wgpu::TextureFormat::Undefined);
+    key->colorTargetFormats.fill(wgpu::TextureFormat::Undefined);
     for (auto [i, attachment] : Enumerate(colorAttachments)) {
         if (attachment.view == nullptr) {
             continue;
         }
 
-        key.colorTargetFormats[i] = attachment.view->GetFormat().format;
-        if (key.sampleCount == 0) {
-            key.sampleCount = attachment.view->GetTexture()->GetSampleCount();
+        key->colorTargetFormats[i] = attachment.view->GetFormat().format;
+        if (key->sampleCount == 0) {
+            key->sampleCount = attachment.view->GetTexture()->GetSampleCount();
         } else {
-            DAWN_ASSERT(key.sampleCount == attachment.view->GetTexture()->GetSampleCount());
+            DAWN_ASSERT(key->sampleCount == attachment.view->GetTexture()->GetSampleCount());
         }
 
         if (attachment.loadOp != wgpu::LoadOp::Clear) {
@@ -330,21 +326,21 @@
         }
 
         if (clearWithDraw || (clearWithDrawForBigInt && NeedsBigIntClear(attachment))) {
-            key.colorTargetsToApplyClearColorValue.set(i);
+            key->colorTargetsToApplyClearColorValue.set(i);
         }
     }
 
     if (renderPassDescriptor->depthStencilAttachment &&
         renderPassDescriptor->depthStencilAttachment->view != nullptr) {
-        key.depthStencilFormat =
+        key->depthStencilFormat =
             renderPassDescriptor->depthStencilAttachment->view->GetFormat().format;
     }
 
-    if (key.colorTargetsToApplyClearColorValue.none()) {
-        return std::nullopt;
+    if (key->colorTargetsToApplyClearColorValue.none()) {
+        return false;
     }
 
-    return key;
+    return true;
 }
 
 }  // namespace
@@ -389,33 +385,44 @@
            key1.depthStencilFormat == key2.depthStencilFormat;
 }
 
-MaybeError ApplyClearWithDraw(RenderPassEncoder* renderPassEncoder,
-                              const RenderPassDescriptor* renderPassDescriptor) {
-    DeviceBase* device = renderPassEncoder->GetDevice();
-    std::optional<KeyOfApplyClearColorValueWithDrawPipelines> key =
-        GetKeyOfApplyClearColorValueWithDrawPipelines(device, renderPassDescriptor);
-    if (!key.has_value()) {
+ClearWithDrawHelper::ClearWithDrawHelper() = default;
+ClearWithDrawHelper::~ClearWithDrawHelper() = default;
+
+MaybeError ClearWithDrawHelper::Initialize(CommandEncoder* encoder,
+                                           const RenderPassDescriptor* renderPassDescriptor) {
+    DeviceBase* device = encoder->GetDevice();
+    mShouldRun = GetKeyOfApplyClearColorValueWithDrawPipelines(device, renderPassDescriptor, &mKey);
+    if (!mShouldRun) {
         return {};
     }
 
+    DAWN_TRY_ASSIGN(mUniformBufferWithClearColorValues,
+                    CreateUniformBufferWithClearValues(encoder, renderPassDescriptor, mKey));
+
+    return {};
+}
+
+MaybeError ClearWithDrawHelper::Apply(RenderPassEncoder* renderPassEncoder) {
+    if (!mShouldRun) {
+        return {};
+    }
+
+    DeviceBase* device = renderPassEncoder->GetDevice();
+
     RenderPipelineBase* pipeline = nullptr;
-    DAWN_TRY_ASSIGN(pipeline, GetOrCreateApplyClearValueWithDrawPipeline(device, key.value()));
+    DAWN_TRY_ASSIGN(pipeline, GetOrCreateApplyClearValueWithDrawPipeline(device, mKey));
 
     Ref<BindGroupLayoutBase> layout;
     DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0));
 
-    Ref<BufferBase> uniformBufferWithClearColorValues;
-    DAWN_TRY_ASSIGN(uniformBufferWithClearColorValues,
-                    CreateUniformBufferWithClearValues(device, renderPassDescriptor, key.value()));
-
     Ref<BindGroupBase> bindGroup;
     DAWN_TRY_ASSIGN(bindGroup,
-                    utils::MakeBindGroup(device, layout, {{0, uniformBufferWithClearColorValues}},
+                    utils::MakeBindGroup(device, layout, {{0, mUniformBufferWithClearColorValues}},
                                          UsageValidationMode::Internal));
 
     renderPassEncoder->APISetBindGroup(0, bindGroup.Get());
     renderPassEncoder->APISetPipeline(pipeline);
-    renderPassEncoder->APIDraw(6);
+    renderPassEncoder->APIDraw(3);
 
     return {};
 }
diff --git a/src/dawn/native/ApplyClearColorValueWithDrawHelper.h b/src/dawn/native/ApplyClearColorValueWithDrawHelper.h
index 102de12..67e7fe6 100644
--- a/src/dawn/native/ApplyClearColorValueWithDrawHelper.h
+++ b/src/dawn/native/ApplyClearColorValueWithDrawHelper.h
@@ -28,9 +28,8 @@
 #ifndef SRC_DAWN_NATIVE_APPLYCLEARVALUEWITHDRAWHELPER_H_
 #define SRC_DAWN_NATIVE_APPLYCLEARVALUEWITHDRAWHELPER_H_
 
-#include <bitset>
-
 #include "absl/container/flat_hash_map.h"
+#include "dawn/common/Ref.h"
 #include "dawn/common/ityp_array.h"
 #include "dawn/common/ityp_bitset.h"
 #include "dawn/native/Error.h"
@@ -38,6 +37,7 @@
 
 namespace dawn::native {
 class BufferBase;
+class CommandEncoder;
 class RenderPassEncoder;
 struct RenderPassDescriptor;
 
@@ -62,8 +62,20 @@
                         KeyOfApplyClearColorValueWithDrawPipelinesHashFunc,
                         KeyOfApplyClearColorValueWithDrawPipelinesEqualityFunc>;
 
-MaybeError ApplyClearWithDraw(RenderPassEncoder* renderPassEncoder,
-                              const RenderPassDescriptor* renderPassDescriptor);
+class ClearWithDrawHelper {
+  public:
+    ClearWithDrawHelper();
+    ~ClearWithDrawHelper();
+
+    MaybeError Initialize(CommandEncoder* encoder,
+                          const RenderPassDescriptor* renderPassDescriptor);
+    MaybeError Apply(RenderPassEncoder* renderPassEncoder);
+
+  private:
+    bool mShouldRun = false;
+    KeyOfApplyClearColorValueWithDrawPipelines mKey;
+    Ref<BufferBase> mUniformBufferWithClearColorValues;
+};
 
 }  // namespace dawn::native
 
diff --git a/src/dawn/native/CommandEncoder.cpp b/src/dawn/native/CommandEncoder.cpp
index 6a3492a..f4e3bc6 100644
--- a/src/dawn/native/CommandEncoder.cpp
+++ b/src/dawn/native/CommandEncoder.cpp
@@ -1206,6 +1206,7 @@
     };
 
     UnpackedPtr<RenderPassDescriptor> descriptor;
+    ClearWithDrawHelper clearWithDrawHelper;
     bool success = mEncodingContext.TryEncode(
         this,
         [&](CommandAllocator* allocator) -> MaybeError {
@@ -1215,6 +1216,8 @@
 
             DAWN_ASSERT(validationState.IsValidState());
 
+            DAWN_TRY(clearWithDrawHelper.Initialize(this, *descriptor));
+
             mEncodingContext.WillBeginRenderPass();
             BeginRenderPassCmd* cmd =
                 allocator->Allocate<BeginRenderPassCmd>(Command::BeginRenderPass);
@@ -1407,10 +1410,10 @@
             if (validationState.WillExpandResolveTexture()) {
                 DAWN_TRY(ApplyExpandResolveTextureLoadOp(device, passEncoder.Get(), *descriptor));
             }
-            // ApplyClearWithDraw() applies clear with draw if clear_color_with_draw or
+            // clearWithDrawHelper.Apply() applies clear with draw if clear_color_with_draw or
             // apply_clear_big_integer_color_value_with_draw toggle is enabled, and the render pass
             // attachments need to be cleared.
-            DAWN_TRY(ApplyClearWithDraw(passEncoder.Get(), *descriptor));
+            DAWN_TRY(clearWithDrawHelper.Apply(passEncoder.Get()));
 
             return {};
         }();
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index eeae081..756fbdb 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -644,6 +644,7 @@
     mEmptyPipelineLayout = nullptr;
     mInternalPipelineStore = nullptr;
     mExternalTexturePlaceholderView = nullptr;
+    mTemporaryUniformBuffer = nullptr;
 
     // Note: mQueue is not released here since the application may still get it after calling
     // Destroy() via APIGetQueue.
@@ -2459,6 +2460,18 @@
     });
 }
 
+ResultOrError<Ref<BufferBase>> DeviceBase::GetOrCreateTemporaryUniformBuffer(size_t size) {
+    if (!mTemporaryUniformBuffer || mTemporaryUniformBuffer->GetSize() != size) {
+        BufferDescriptor desc;
+        desc.label = "Internal_TemporaryUniform";
+        desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform;
+        desc.size = size;
+        DAWN_TRY_ASSIGN(mTemporaryUniformBuffer, CreateBuffer(&desc));
+    }
+
+    return mTemporaryUniformBuffer;
+}
+
 IgnoreLazyClearCountScope::IgnoreLazyClearCountScope(DeviceBase* device)
     : mDevice(device), mLazyClearCountForTesting(device->mLazyClearCountForTesting) {}
 
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index 8cef28e..9e62b3e 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -446,6 +446,8 @@
 
     void DumpMemoryStatistics(dawn::native::MemoryDump* dump) const;
 
+    ResultOrError<Ref<BufferBase>> GetOrCreateTemporaryUniformBuffer(size_t size);
+
   protected:
     // Constructor used only for mocking and testing.
     DeviceBase();
@@ -590,6 +592,7 @@
     tint::wgsl::AllowedFeatures mWGSLAllowedFeatures;
 
     std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
+    Ref<BufferBase> mTemporaryUniformBuffer;
 
     Ref<CallbackTaskManager> mCallbackTaskManager;
     std::unique_ptr<dawn::platform::WorkerTaskPool> mWorkerTaskPool;