Remove ConcurrentTransformedShaderModuleCache

ConcurrentTransformedShaderModuleCache held both compiled SPIR-V and
VkShaderModule. There is already caching of SPIR-V in BlobCache. The
transformation from SPIR-V to VkShaderModule is inexpensive. This extra
layer of caching isn't really necessary so remove it.

Render/compute pipeline now call vkDestroyShaderModule() after
compilation is complete. This also removes the need for FencedDeleter to
handle that.

Bug: 411152029
Change-Id: Ia75f313849a5bade092f2a99d608cceb76a1b3c4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/239415
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Kyle Charbonneau <kylechar@google.com>
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp
index 595d877..c8260b0 100644
--- a/src/dawn/native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -115,8 +115,7 @@
 
     if (buildCacheKey) {
         // Record cache key information now since the createInfo is not stored.
-        StreamIn(&mCacheKey, createInfo, layout,
-                 stream::Iterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount));
+        StreamIn(&mCacheKey, createInfo, layout, moduleAndSpirv.spirv);
     }
 
     // Try to see if we have anything in the blob cache.
@@ -140,6 +139,8 @@
 
     SetLabelImpl();
 
+    device->fn.DestroyShaderModule(device->GetVkDevice(), moduleAndSpirv.module, nullptr);
+
     return {};
 }
 
diff --git a/src/dawn/native/vulkan/FencedDeleter.cpp b/src/dawn/native/vulkan/FencedDeleter.cpp
index fd8280b..5cb2366 100644
--- a/src/dawn/native/vulkan/FencedDeleter.cpp
+++ b/src/dawn/native/vulkan/FencedDeleter.cpp
@@ -51,7 +51,6 @@
     DAWN_ASSERT(mSamplerYcbcrConversionsToDelete.Empty());
     DAWN_ASSERT(mSamplersToDelete.Empty());
     DAWN_ASSERT(mSemaphoresToDelete.Empty());
-    DAWN_ASSERT(mShaderModulesToDelete.Empty());
     DAWN_ASSERT(mSurfacesToDelete.Empty());
     DAWN_ASSERT(mSwapChainsToDelete.Empty());
 }
@@ -112,10 +111,6 @@
     mSemaphoresToDelete.Enqueue(semaphore, GetCurrentDeletionSerial());
 }
 
-void FencedDeleter::DeleteWhenUnused(VkShaderModule module) {
-    mShaderModulesToDelete.Enqueue(module, GetCurrentDeletionSerial());
-}
-
 void FencedDeleter::DeleteWhenUnused(VkSurfaceKHR surface) {
     mSurfacesToDelete.Enqueue(surface, GetCurrentDeletionSerial());
 }
@@ -146,7 +141,6 @@
     GetLastSubmitted(mSamplerYcbcrConversionsToDelete);
     GetLastSubmitted(mSamplersToDelete);
     GetLastSubmitted(mSemaphoresToDelete);
-    GetLastSubmitted(mShaderModulesToDelete);
     GetLastSubmitted(mSurfacesToDelete);
     GetLastSubmitted(mSwapChainsToDelete);
 
@@ -202,11 +196,6 @@
     }
     mImageViewsToDelete.ClearUpTo(completedSerial);
 
-    for (VkShaderModule module : mShaderModulesToDelete.IterateUpTo(completedSerial)) {
-        mDevice->fn.DestroyShaderModule(vkDevice, module, nullptr);
-    }
-    mShaderModulesToDelete.ClearUpTo(completedSerial);
-
     for (VkPipeline pipeline : mPipelinesToDelete.IterateUpTo(completedSerial)) {
         mDevice->fn.DestroyPipeline(vkDevice, pipeline, nullptr);
     }
diff --git a/src/dawn/native/vulkan/FencedDeleter.h b/src/dawn/native/vulkan/FencedDeleter.h
index f1c624d..1eb36c6 100644
--- a/src/dawn/native/vulkan/FencedDeleter.h
+++ b/src/dawn/native/vulkan/FencedDeleter.h
@@ -56,7 +56,6 @@
     void DeleteWhenUnused(VkSamplerYcbcrConversion samplerYcbcrConversion);
     void DeleteWhenUnused(VkSampler sampler);
     void DeleteWhenUnused(VkSemaphore semaphore);
-    void DeleteWhenUnused(VkShaderModule module);
     void DeleteWhenUnused(VkSurfaceKHR surface);
     void DeleteWhenUnused(VkSwapchainKHR swapChain);
 
@@ -84,7 +83,6 @@
     SerialQueue<ExecutionSerial, VkSamplerYcbcrConversion> mSamplerYcbcrConversionsToDelete;
     SerialQueue<ExecutionSerial, VkSampler> mSamplersToDelete;
     SerialQueue<ExecutionSerial, VkSemaphore> mSemaphoresToDelete;
-    SerialQueue<ExecutionSerial, VkShaderModule> mShaderModulesToDelete;
     SerialQueue<ExecutionSerial, VkSurfaceKHR> mSurfacesToDelete;
     SerialQueue<ExecutionSerial, VkSwapchainKHR> mSwapChainsToDelete;
 };
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp
index c5b3245..dbe5a15 100644
--- a/src/dawn/native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -382,7 +382,7 @@
         mHasInputAttachment = mHasInputAttachment || moduleAndSpirv.hasInputAttachment;
         if (buildCacheKey) {
             // Record cache key for each shader since it will become inaccessible later on.
-            StreamIn(&mCacheKey, stream::Iterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount));
+            StreamIn(&mCacheKey, moduleAndSpirv.spirv);
         }
 
         VkPipelineShaderStageCreateInfo* shaderStage = &shaderStages[stageCount];
@@ -631,6 +631,10 @@
 
     SetLabelImpl();
 
+    for (uint32_t i = 0; i < stageCount; ++i) {
+        device->fn.DestroyShaderModule(device->GetVkDevice(), shaderStages[i].module, nullptr);
+    }
+
     return {};
 }
 
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index bfcb45a..e2414c8 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -99,73 +99,6 @@
     return hash;
 }
 
-class ShaderModule::ConcurrentTransformedShaderModuleCache {
-  public:
-    explicit ConcurrentTransformedShaderModuleCache(Device* device) : mDevice(device) {}
-
-    ~ConcurrentTransformedShaderModuleCache() {
-        std::lock_guard<std::mutex> lock(mMutex);
-
-        for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) {
-            mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.vkModule);
-        }
-    }
-
-    std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key) {
-        std::lock_guard<std::mutex> lock(mMutex);
-
-        auto iter = mTransformedShaderModuleCache.find(key);
-        if (iter != mTransformedShaderModuleCache.end()) {
-            return iter->second.AsRefs();
-        }
-        return {};
-    }
-    ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key,
-                            VkShaderModule module,
-                            CompiledSpirv compilation,
-                            bool hasInputAttachment) {
-        DAWN_ASSERT(module != VK_NULL_HANDLE);
-        std::lock_guard<std::mutex> lock(mMutex);
-
-        auto iter = mTransformedShaderModuleCache.find(key);
-        if (iter == mTransformedShaderModuleCache.end()) {
-            bool added = false;
-            std::tie(iter, added) = mTransformedShaderModuleCache.emplace(
-                key, Entry{module, std::move(compilation.spirv), hasInputAttachment});
-            DAWN_ASSERT(added);
-        } else {
-            // No need to use FencedDeleter since this shader module was just created and does
-            // not need to wait for queue operations to complete.
-            // Also, use of fenced deleter here is not thread safe.
-            mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr);
-        }
-        return iter->second.AsRefs();
-    }
-
-  private:
-    struct Entry {
-        VkShaderModule vkModule;
-        std::vector<uint32_t> spirv;
-        bool hasInputAttachment;
-
-        ModuleAndSpirv AsRefs() const {
-            return {
-                vkModule,
-                spirv.data(),
-                spirv.size(),
-                hasInputAttachment,
-            };
-        }
-    };
-
-    raw_ptr<Device> mDevice;
-    std::mutex mMutex;
-    absl::flat_hash_map<TransformedShaderModuleCacheKey,
-                        Entry,
-                        TransformedShaderModuleCacheKeyHashFunc>
-        mTransformedShaderModuleCache;
-};
-
 // static
 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
     Device* device,
@@ -181,9 +114,7 @@
 ShaderModule::ShaderModule(Device* device,
                            const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                            std::vector<tint::wgsl::Extension> internalExtensions)
-    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)),
-      mTransformedShaderModuleCache(
-          std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {}
+    : ShaderModuleBase(device, descriptor, std::move(internalExtensions)) {}
 
 MaybeError ShaderModule::Initialize(
     ShaderModuleParseResult* parseResult,
@@ -193,8 +124,6 @@
 
 void ShaderModule::DestroyImpl() {
     ShaderModuleBase::DestroyImpl();
-    // Remove reference to internal cache to trigger cleanup.
-    mTransformedShaderModuleCache = nullptr;
 }
 
 ShaderModule::~ShaderModule() = default;
@@ -236,10 +165,6 @@
     auto cacheKey = TransformedShaderModuleCacheKey{reinterpret_cast<uintptr_t>(layout),
                                                     programmableStage.entryPoint.c_str(),
                                                     programmableStage.constants, emitPointSize};
-    auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
-    if (handleAndSpirv.has_value()) {
-        return std::move(*handleAndSpirv);
-    }
 
 #if TINT_BUILD_SPV_WRITER
     // Creation of module and spirv is deferred to this point when using tint generator
@@ -507,8 +432,10 @@
         // Set the label on `newHandle` now, and not on `moduleAndSpirv.module` later
         // since `moduleAndSpirv.module` may be in use by multiple threads.
         SetDebugName(ToBackend(GetDevice()), newHandle, "Dawn_ShaderModule", GetLabel());
-        moduleAndSpirv = mTransformedShaderModuleCache->AddOrGet(
-            cacheKey, newHandle, compilation.Acquire(), hasInputAttachment);
+
+        moduleAndSpirv.module = newHandle;
+        moduleAndSpirv.spirv = std::move(compilation->spirv);
+        moduleAndSpirv.hasInputAttachment = hasInputAttachment;
     }
 
     return std::move(moduleAndSpirv);
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index cc2d5fc..ee04246 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -67,8 +67,7 @@
   public:
     struct ModuleAndSpirv {
         VkShaderModule module;
-        const uint32_t* spirv;
-        size_t wordCount;
+        std::vector<uint32_t> spirv;
         bool hasInputAttachment;
     };
 
@@ -79,6 +78,7 @@
         ShaderModuleParseResult* parseResult,
         std::unique_ptr<OwnedCompilationMessages>* compilationMessages);
 
+    // Caller is responsible for destroying the `VkShaderModule` returned.
     ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
                                                     const ProgrammableStage& programmableStage,
                                                     const PipelineLayout* layout,
@@ -93,10 +93,6 @@
     MaybeError Initialize(ShaderModuleParseResult* parseResult,
                           std::unique_ptr<OwnedCompilationMessages>* compilationMessages);
     void DestroyImpl() override;
-
-    // New handles created by GetHandleAndSpirv at pipeline creation time.
-    class ConcurrentTransformedShaderModuleCache;
-    std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
 };
 
 }  // namespace vulkan