Make monolithic PipelineCache thread safe

When vulkan_monolithic_pipeline_cache is enabled there is a single
VkPipelineCache+PipelineCache object owned by each device. That
PipelineCache can be used by multiple worker threads at the same time so
it needs to be thread safe.

Change mMonolithicPipelineCache creation so it happens during device
initialization. This avoids problems around multiple threads trying to
create it lazily. This might decrease the chance of having previous
pipeline cache data loaded from disk into BlobCache in time for
creation.

Also make PipelineCacheBase+PipelineCache more thread safe. Most member
variables are already set at creation time and then never modified.
mNeedsStore is an exception so use an atomic there. mStoredDataSize is
another member variable that is modified after creation but
SerializeToBlobImpl() should only be called from a single thread.

Bug: 370343334
Change-Id: I1815c0866e24ca40e8544a2948bb814824e611df
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/214736
Reviewed-by: Loko Kung <lokokung@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Kyle Charbonneau <kylechar@google.com>
diff --git a/src/dawn/native/PipelineCache.cpp b/src/dawn/native/PipelineCache.cpp
index 9a0bba1..ac6e5b8 100644
--- a/src/dawn/native/PipelineCache.cpp
+++ b/src/dawn/native/PipelineCache.cpp
@@ -27,6 +27,8 @@
 
 #include "dawn/native/PipelineCache.h"
 
+#include <atomic>
+
 namespace dawn::native {
 
 PipelineCacheBase::PipelineCacheBase(BlobCache* cache, const CacheKey& key, bool storeOnIdle)
@@ -62,7 +64,7 @@
     if (mStoreOnIdle) {
         // Assume pipeline cache was modified by compiling a pipeline. It will be stored in
         // BlobCache at some later point in StoreOnIdle() if necessary.
-        mNeedsStore = true;
+        mNeedsStore.store(true, std::memory_order_relaxed);
     } else {
         // TODO(dawn:549): Flush is currently synchronously happening on the same thread as pipeline
         // compilation, but it's perhaps deferrable.
@@ -75,8 +77,7 @@
 
 MaybeError PipelineCacheBase::StoreOnIdle() {
     DAWN_ASSERT(mStoreOnIdle);
-    if (mNeedsStore) {
-        mNeedsStore = false;
+    if (mNeedsStore.exchange(false, std::memory_order_relaxed)) {
         DAWN_TRY(Flush());
     }
     return {};
diff --git a/src/dawn/native/PipelineCache.h b/src/dawn/native/PipelineCache.h
index 5b9015c..b96d6fd 100644
--- a/src/dawn/native/PipelineCache.h
+++ b/src/dawn/native/PipelineCache.h
@@ -28,6 +28,8 @@
 #ifndef SRC_DAWN_NATIVE_PIPELINECACHE_H_
 #define SRC_DAWN_NATIVE_PIPELINECACHE_H_
 
+#include <atomic>
+
 #include "dawn/common/RefCounted.h"
 #include "dawn/native/BlobCache.h"
 #include "dawn/native/CacheKey.h"
@@ -73,12 +75,17 @@
     // The blob cache is owned by the Adapter and pipeline caches are owned/created by devices
     // or adapters. Since the device owns a reference to the Instance which owns the Adapter,
     // the blob cache is guaranteed to be valid throughout the lifetime of the object.
-    raw_ptr<BlobCache> mCache;
-    CacheKey mKey;
+    const raw_ptr<BlobCache> mCache;
+    const CacheKey mKey;
     const bool mStoreOnIdle;
     bool mInitialized = false;
     bool mCacheHit = false;
-    bool mNeedsStore = false;
+
+    // Multiple threads can be using the pipeline cache concurrently and
+    // modifying this variable. Loads and stores are done with relaxed ordering
+    // since we don't care so much about strict ordering just avoiding UB from
+    // concurrent read/writes.
+    std::atomic<bool> mNeedsStore = false;
 };
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 9abe3f7..d7eb276 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -164,6 +164,16 @@
 #endif
     }
 
+    if (IsToggleEnabled(Toggle::VulkanMonolithicPipelineCache)) {
+        CacheKey cacheKey = GetCacheKey();
+        // `pipelineCacheUUID` is supposed to change if anything in the driver changes such that
+        // the serialized VkPipelineCache is no longer valid.
+        auto& deviceProperties = GetDeviceInfo().properties;
+        StreamIn(&cacheKey, deviceProperties.pipelineCacheUUID);
+
+        mMonolithicPipelineCache = PipelineCache::CreateMonolithic(this, cacheKey);
+    }
+
     SetLabelImpl();
 
     ToBackend(GetPhysicalDevice())->GetVulkanInstance()->StartListeningForDeviceMessages(this);
@@ -236,16 +246,7 @@
     return TextureView::Create(texture, descriptor);
 }
 Ref<PipelineCacheBase> Device::GetOrCreatePipelineCacheImpl(const CacheKey& key) {
-    if (IsToggleEnabled(Toggle::VulkanMonolithicPipelineCache)) {
-        if (!mMonolithicPipelineCache) {
-            CacheKey cacheKey = GetCacheKey();
-            // `pipelineCacheUUID` is supposed to change if anything in the driver changes such that
-            // the serialized VkPipelineCache is no longer valid.
-            auto& deviceProperties = GetDeviceInfo().properties;
-            StreamIn(&cacheKey, deviceProperties.pipelineCacheUUID);
-
-            mMonolithicPipelineCache = PipelineCache::CreateMonolithic(this, cacheKey);
-        }
+    if (mMonolithicPipelineCache) {
         return mMonolithicPipelineCache;
     }
 
diff --git a/src/dawn/native/vulkan/PipelineCacheVk.cpp b/src/dawn/native/vulkan/PipelineCacheVk.cpp
index 9841077..da69502 100644
--- a/src/dawn/native/vulkan/PipelineCacheVk.cpp
+++ b/src/dawn/native/vulkan/PipelineCacheVk.cpp
@@ -39,7 +39,7 @@
 namespace dawn::native::vulkan {
 
 // static
-Ref<PipelineCache> PipelineCache::Create(DeviceBase* device, const CacheKey& key) {
+Ref<PipelineCache> PipelineCache::Create(Device* device, const CacheKey& key) {
     Ref<PipelineCache> cache =
         AcquireRef(new PipelineCache(device, key, /*isMonolithicCache=*/false));
     cache->Initialize();
@@ -47,29 +47,24 @@
 }
 
 // static
-Ref<PipelineCache> PipelineCache::CreateMonolithic(DeviceBase* device, const CacheKey& key) {
+Ref<PipelineCache> PipelineCache::CreateMonolithic(Device* device, const CacheKey& key) {
     Ref<PipelineCache> cache =
         AcquireRef(new PipelineCache(device, key, /*isMonolithicCache=*/true));
     cache->Initialize();
     return cache;
 }
 
-PipelineCache::PipelineCache(DeviceBase* device, const CacheKey& key, bool isMonolithicCache)
+PipelineCache::PipelineCache(Device* device, const CacheKey& key, bool isMonolithicCache)
     : PipelineCacheBase(device->GetBlobCache(), key, isMonolithicCache), mDevice(device) {}
 
 PipelineCache::~PipelineCache() {
     if (mHandle == VK_NULL_HANDLE) {
         return;
     }
-    Device* device = ToBackend(GetDevice());
-    device->fn.DestroyPipelineCache(device->GetVkDevice(), mHandle, nullptr);
+    mDevice->fn.DestroyPipelineCache(mDevice->GetVkDevice(), mHandle, nullptr);
     mHandle = VK_NULL_HANDLE;
 }
 
-DeviceBase* PipelineCache::GetDevice() const {
-    return mDevice;
-}
-
 VkPipelineCache PipelineCache::GetHandle() const {
     return mHandle;
 }
@@ -81,20 +76,19 @@
     }
 
     size_t bufferSize;
-    Device* device = ToBackend(GetDevice());
     DAWN_TRY(CheckVkSuccess(
-        device->fn.GetPipelineCacheData(device->GetVkDevice(), mHandle, &bufferSize, nullptr),
+        mDevice->fn.GetPipelineCacheData(mDevice->GetVkDevice(), mHandle, &bufferSize, nullptr),
         "GetPipelineCacheData"));
 
     if (bufferSize == 0 || bufferSize == mStoredDataSize) {
-        // If current VkPipelineCache data size is same as `mCachedDataSize` assume nothing has
+        // If current VkPipelineCache data size is same as `mStoredDataSize` assume nothing has
         // changed vs what is stored in the BlobCache.
         return {};
     }
     *blob = CreateBlob(bufferSize);
-    DAWN_TRY(CheckVkSuccess(
-        device->fn.GetPipelineCacheData(device->GetVkDevice(), mHandle, &bufferSize, blob->Data()),
-        "GetPipelineCacheData"));
+    DAWN_TRY(CheckVkSuccess(mDevice->fn.GetPipelineCacheData(mDevice->GetVkDevice(), mHandle,
+                                                             &bufferSize, blob->Data()),
+                            "GetPipelineCacheData"));
     mStoredDataSize = bufferSize;
 
     return {};
@@ -110,18 +104,17 @@
     createInfo.initialDataSize = blob.Size();
     createInfo.pInitialData = blob.Data();
 
-    Device* device = ToBackend(GetDevice());
     mHandle = VK_NULL_HANDLE;
 
     // Attempts to create the pipeline cache but does not bubble the error, instead only logging.
     // This should be fine because the handle will be left as null and pipeline creation should
     // continue as if there was no cache.
     MaybeError maybeError = CheckVkSuccess(
-        device->fn.CreatePipelineCache(device->GetVkDevice(), &createInfo, nullptr, &*mHandle),
+        mDevice->fn.CreatePipelineCache(mDevice->GetVkDevice(), &createInfo, nullptr, &*mHandle),
         "CreatePipelineCache");
     if (maybeError.IsError()) {
         std::unique_ptr<ErrorData> error = maybeError.AcquireError();
-        GetDevice()->EmitLog(WGPULoggingType_Info, error->GetFormattedMessage().c_str());
+        mDevice->EmitLog(WGPULoggingType_Info, error->GetFormattedMessage().c_str());
         return;
     }
 
diff --git a/src/dawn/native/vulkan/PipelineCacheVk.h b/src/dawn/native/vulkan/PipelineCacheVk.h
index 8a260f9..509b84c 100644
--- a/src/dawn/native/vulkan/PipelineCacheVk.h
+++ b/src/dawn/native/vulkan/PipelineCacheVk.h
@@ -34,33 +34,32 @@
 
 #include "dawn/common/vulkan_platform.h"
 
-namespace dawn::native {
-class DeviceBase;
-}
-
 namespace dawn::native::vulkan {
 
+class Device;
+
 class PipelineCache final : public PipelineCacheBase {
   public:
-    static Ref<PipelineCache> Create(DeviceBase* device, const CacheKey& key);
+    static Ref<PipelineCache> Create(Device* device, const CacheKey& key);
 
     // Creates a pipeline cache that is intended to be monolithic. The cache will only be serialized
     // and stored to BlobCache when StoreOnIdle() is called.
-    static Ref<PipelineCache> CreateMonolithic(DeviceBase* device, const CacheKey& key);
+    static Ref<PipelineCache> CreateMonolithic(Device* device, const CacheKey& key);
 
-    DeviceBase* GetDevice() const;
     VkPipelineCache GetHandle() const;
 
   private:
-    explicit PipelineCache(DeviceBase* device, const CacheKey& key, bool isMonolithicCache);
+    explicit PipelineCache(Device* device, const CacheKey& key, bool isMonolithicCache);
     ~PipelineCache() override;
 
     void Initialize();
     MaybeError SerializeToBlobImpl(Blob* blob) override;
 
-    raw_ptr<DeviceBase> mDevice;
+    const raw_ptr<Device> mDevice;
     VkPipelineCache mHandle = VK_NULL_HANDLE;
 
+    // Only a single thread should be inside SerializeToBlobImpl() at one time so this should never
+    // be accessed concurrently on multiple threads.
     size_t mStoredDataSize = 0;
 };