Add compute pipeline cache key generation for Vulkan.

- Adds dependency to vulkan-tools for pNext chain helpers.
- Adds extra caching to vulkan shaders to keep the spirv in the in-memory cache as well.
- Adds pNext chain serializer infra for Vulkan.

Change-Id: Ibe73183fbff15f7310eaaeae92fbd622be1ac096
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/85022
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8e4d6da..6adb40c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -142,7 +142,8 @@
 set_if_not_defined(DAWN_SPIRV_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-headers/src" "Directory in which to find SPIRV-Headers")
 set_if_not_defined(DAWN_SPIRV_TOOLS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-tools/src" "Directory in which to find SPIRV-Tools")
 set_if_not_defined(DAWN_TINT_DIR "${Dawn_SOURCE_DIR}" "Directory in which to find Tint")
-set_if_not_defined(DAWN_VULKAN_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/vulkan-headers/src" "Directory in which to find Vulkan-Headers")
+set_if_not_defined(DAWN_VULKAN_DEPS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps" "Directory in which to find vulkan-deps")
+set_if_not_defined(DAWN_VULKAN_HEADERS_DIR "${DAWN_VULKAN_DEPS_DIR}/vulkan-headers/src" "Directory in which to find Vulkan-Headers")
 
 # Dependencies for DAWN_BUILD_NODE_BINDINGS
 set_if_not_defined(NODE_ADDON_API_DIR "${DAWN_THIRD_PARTY_DIR}/node-addon-api" "Directory in which to find node-addon-api")
diff --git a/scripts/dawn_overrides_with_defaults.gni b/scripts/dawn_overrides_with_defaults.gni
index 46f44ef..b4142ac 100644
--- a/scripts/dawn_overrides_with_defaults.gni
+++ b/scripts/dawn_overrides_with_defaults.gni
@@ -54,14 +54,17 @@
   dawn_swiftshader_dir = ""
 }
 
-if (!defined(dawn_vulkan_headers_dir)) {
-  dawn_vulkan_headers_dir = "//third_party/vulkan-deps/vulkan-headers/src"
+if (!defined(dawn_vulkan_deps_dir)) {
+  dawn_vulkan_deps_dir = "//third_party/vulkan-deps"
   if (dawn_standalone) {
-    dawn_vulkan_headers_dir =
-        "${dawn_root}/third_party/vulkan-deps/vulkan-headers/src"
+    dawn_vulkan_deps_dir = "${dawn_root}/third_party/vulkan-deps"
   }
 }
 
+if (!defined(dawn_vulkan_headers_dir)) {
+  dawn_vulkan_headers_dir = "${dawn_vulkan_deps_dir}/vulkan-headers/src"
+}
+
 if (!defined(dawn_vulkan_loader_dir)) {
   # Default to the Vulkan loader not being available except in standalone.
   dawn_vulkan_loader_dir = ""
@@ -70,6 +73,10 @@
   }
 }
 
+if (!defined(dawn_vulkan_tools_dir)) {
+  dawn_vulkan_tools_dir = "${dawn_vulkan_deps_dir}/vulkan-tools/src"
+}
+
 if (!defined(dawn_vulkan_validation_layers_dir)) {
   # Default to VVLs not being available.
   dawn_vulkan_validation_layers_dir = ""
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 4547138..e6d90dc 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -99,6 +99,11 @@
   }
 }
 
+# Config that adds include directory for vulkan-deps, specifically for Vulkan-Tools.
+config("vulkan_deps_include") {
+  include_dirs = [ "${dawn_vulkan_deps_dir}" ]
+}
+
 dawn_json_generator("utils_gen") {
   target = "native_utils"
   outputs = [
@@ -571,6 +576,8 @@
   }
 
   if (dawn_enable_vulkan) {
+    configs += [ ":vulkan_deps_include" ]
+    deps += [ "${dawn_vulkan_tools_dir}:vulkan_tools_headers" ]
     public_deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ]
     sources += [
       "vulkan/AdapterVk.cpp",
@@ -583,6 +590,8 @@
       "vulkan/BindGroupVk.h",
       "vulkan/BufferVk.cpp",
       "vulkan/BufferVk.h",
+      "vulkan/CacheKeyVk.cpp",
+      "vulkan/CacheKeyVk.h",
       "vulkan/CommandBufferVk.cpp",
       "vulkan/CommandBufferVk.h",
       "vulkan/CommandRecordingContext.h",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index 170ea70..e7cab14 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -455,6 +455,8 @@
         "vulkan/BindGroupVk.h"
         "vulkan/BufferVk.cpp"
         "vulkan/BufferVk.h"
+        "vulkan/CacheKeyVk.cpp"
+        "vulkan/CacheKeyVk.h"
         "vulkan/CommandBufferVk.cpp"
         "vulkan/CommandBufferVk.h"
         "vulkan/CommandRecordingContext.h"
@@ -510,6 +512,7 @@
     )
 
     target_link_libraries(dawn_native PUBLIC Vulkan-Headers)
+    target_include_directories(dawn_native PRIVATE ${DAWN_VULKAN_DEPS_DIR})
 
     if (UNIX AND NOT APPLE)
         target_sources(dawn_native PRIVATE
diff --git a/src/dawn/native/CacheKey.cpp b/src/dawn/native/CacheKey.cpp
index 3495577..dea67f8 100644
--- a/src/dawn/native/CacheKey.cpp
+++ b/src/dawn/native/CacheKey.cpp
@@ -14,8 +14,19 @@
 
 #include "dawn/native/CacheKey.h"
 
+#include <iomanip>
+
 namespace dawn::native {
 
+    std::ostream& operator<<(std::ostream& os, const CacheKey& key) {
+        os << std::hex;
+        for (const int b : key) {
+            os << std::setfill('0') << std::setw(2) << b << " ";
+        }
+        os << std::dec;
+        return os;
+    }
+
     template <>
     void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) {
         key->Record(static_cast<size_t>(t.length()));
diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h
index ce21f6d..e97e770 100644
--- a/src/dawn/native/CacheKey.h
+++ b/src/dawn/native/CacheKey.h
@@ -15,17 +15,20 @@
 #ifndef DAWNNATIVE_CACHE_KEY_H_
 #define DAWNNATIVE_CACHE_KEY_H_
 
+#include <iostream>
 #include <limits>
 #include <string>
 #include <type_traits>
 #include <vector>
 
-#include "dawn/common/Assert.h"
-
 namespace dawn::native {
 
-    // Forward declare CacheKey class because of co-dependency.
+    // Forward declare classes because of co-dependency.
     class CacheKey;
+    class CachedObject;
+
+    // Stream operator for CacheKey for debugging.
+    std::ostream& operator<<(std::ostream& os, const CacheKey& key);
 
     // Overridable serializer struct that should be implemented for cache key serializable
     // types/classes.
@@ -82,7 +85,31 @@
         }
     };
 
-    // Specialized overload for string literals. Note we drop the null-terminator.
+    // Specialized overload for enums.
+    template <typename T>
+    class CacheKeySerializer<T, std::enable_if_t<std::is_enum_v<T>>> {
+      public:
+        static void Serialize(CacheKey* key, const T t) {
+            CacheKeySerializer<std::underlying_type_t<T>>::Serialize(
+                key, static_cast<std::underlying_type_t<T>>(t));
+        }
+    };
+
+    // Specialized overload for pointers. Since we are serializing for a cache key, we always
+    // serialize via value, not by pointer. To handle nullptr scenarios, we always serialize whether
+    // the pointer was nullptr followed by the contents if applicable.
+    template <typename T>
+    class CacheKeySerializer<T, std::enable_if_t<std::is_pointer_v<T>>> {
+      public:
+        static void Serialize(CacheKey* key, const T t) {
+            key->Record(t == nullptr);
+            if (t != nullptr) {
+                CacheKeySerializer<std::remove_cv_t<std::remove_pointer_t<T>>>::Serialize(key, *t);
+            }
+        }
+    };
+
+    // Specialized overload for string literals.
     template <size_t N>
     class CacheKeySerializer<char[N]> {
       public:
@@ -93,6 +120,15 @@
         }
     };
 
+    // Specialized overload for CachedObjects.
+    template <typename T>
+    class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>> {
+      public:
+        static void Serialize(CacheKey* key, const T& t) {
+            key->Record(t.GetCacheKey());
+        }
+    };
+
 }  // namespace dawn::native
 
 #endif  // DAWNNATIVE_CACHE_KEY_H_
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
index 8ed4340..c377c96 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.cpp
@@ -16,6 +16,7 @@
 
 #include "dawn/common/BitSetIterator.h"
 #include "dawn/common/ityp_vector.h"
+#include "dawn/native/CacheKey.h"
 #include "dawn/native/vulkan/BindGroupVk.h"
 #include "dawn/native/vulkan/DescriptorSetAllocator.h"
 #include "dawn/native/vulkan/DeviceVk.h"
@@ -115,6 +116,9 @@
         createInfo.bindingCount = static_cast<uint32_t>(bindings.size());
         createInfo.pBindings = bindings.data();
 
+        // Record cache key information now since the createInfo is not stored.
+        GetCacheKey()->Record(createInfo);
+
         Device* device = ToBackend(GetDevice());
         DAWN_TRY(CheckVkSuccess(device->fn.CreateDescriptorSetLayout(
                                     device->GetVkDevice(), &createInfo, nullptr, &*mHandle),
diff --git a/src/dawn/native/vulkan/BindGroupLayoutVk.h b/src/dawn/native/vulkan/BindGroupLayoutVk.h
index 558ff7f..d8adedc 100644
--- a/src/dawn/native/vulkan/BindGroupLayoutVk.h
+++ b/src/dawn/native/vulkan/BindGroupLayoutVk.h
@@ -22,6 +22,10 @@
 
 #include <vector>
 
+namespace dawn::native {
+    class CacheKey;
+}  // namespace dawn::native
+
 namespace dawn::native::vulkan {
 
     class BindGroup;
diff --git a/src/dawn/native/vulkan/CacheKeyVk.cpp b/src/dawn/native/vulkan/CacheKeyVk.cpp
new file mode 100644
index 0000000..9fffbd1
--- /dev/null
+++ b/src/dawn/native/vulkan/CacheKeyVk.cpp
@@ -0,0 +1,97 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/vulkan/CacheKeyVk.h"
+
+#include <cstring>
+
+namespace dawn::native {
+
+    template <>
+    void CacheKeySerializer<VkDescriptorSetLayoutBinding>::Serialize(
+        CacheKey* key,
+        const VkDescriptorSetLayoutBinding& t) {
+        key->Record(t.binding, t.descriptorType, t.descriptorCount, t.stageFlags);
+    }
+
+    template <>
+    void CacheKeySerializer<VkDescriptorSetLayoutCreateInfo>::Serialize(
+        CacheKey* key,
+        const VkDescriptorSetLayoutCreateInfo& t) {
+        key->Record(t.flags).RecordIterable(t.pBindings, t.bindingCount);
+        vulkan::SerializePnext<>(key, reinterpret_cast<const VkBaseOutStructure*>(&t));
+    }
+
+    template <>
+    void CacheKeySerializer<VkPushConstantRange>::Serialize(CacheKey* key,
+                                                            const VkPushConstantRange& t) {
+        key->Record(t.stageFlags, t.offset, t.size);
+    }
+
+    template <>
+    void CacheKeySerializer<VkPipelineLayoutCreateInfo>::Serialize(
+        CacheKey* key,
+        const VkPipelineLayoutCreateInfo& t) {
+        // The set layouts are not serialized here because they are pointers to backend objects.
+        // They need to be cross-referenced with the frontend objects and serialized from there.
+        key->Record(t.flags).RecordIterable(t.pPushConstantRanges, t.pushConstantRangeCount);
+        vulkan::SerializePnext<>(key, reinterpret_cast<const VkBaseOutStructure*>(&t));
+    }
+
+    template <>
+    void CacheKeySerializer<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>::Serialize(
+        CacheKey* key,
+        const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT& t) {
+        key->Record(t.requiredSubgroupSize);
+    }
+
+    template <>
+    void CacheKeySerializer<VkSpecializationMapEntry>::Serialize(
+        CacheKey* key,
+        const VkSpecializationMapEntry& t) {
+        key->Record(t.constantID, t.offset, t.size);
+    }
+
+    template <>
+    void CacheKeySerializer<VkSpecializationInfo>::Serialize(CacheKey* key,
+                                                             const VkSpecializationInfo& t) {
+        key->RecordIterable(t.pMapEntries, t.mapEntryCount)
+            .RecordIterable(static_cast<const uint8_t*>(t.pData), t.dataSize);
+    }
+
+    template <>
+    void CacheKeySerializer<VkPipelineShaderStageCreateInfo>::Serialize(
+        CacheKey* key,
+        const VkPipelineShaderStageCreateInfo& t) {
+        // The shader module is not serialized here because it is a pointer to a backend object.
+        key->Record(t.flags, t.stage)
+            .RecordIterable(t.pName, strlen(t.pName))
+            .Record(t.pSpecializationInfo);
+        vulkan::SerializePnext<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(
+            key, reinterpret_cast<const VkBaseOutStructure*>(&t));
+    }
+
+    template <>
+    void CacheKeySerializer<VkComputePipelineCreateInfo>::Serialize(
+        CacheKey* key,
+        const VkComputePipelineCreateInfo& t) {
+        // The pipeline layout is not serialized here because it is a pointer to a backend object.
+        // It needs to be cross-referenced with the frontend objects and serialized from there. The
+        // base pipeline information is also currently not recorded since we do not use them in our
+        // backend implementation. If we decide to use them later on, they also need to be
+        // cross-referenced from the frontend.
+        key->Record(t.flags, t.stage);
+    }
+
+}  // namespace dawn::native
diff --git a/src/dawn/native/vulkan/CacheKeyVk.h b/src/dawn/native/vulkan/CacheKeyVk.h
new file mode 100644
index 0000000..ab8e02d
--- /dev/null
+++ b/src/dawn/native/vulkan/CacheKeyVk.h
@@ -0,0 +1,85 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/common/Assert.h"
+#include "dawn/common/vulkan_platform.h"
+#include "dawn/native/CacheKey.h"
+
+#include "vulkan-tools/src/icd/generated/vk_typemap_helper.h"
+
+#include <map>
+
+namespace dawn::native::vulkan {
+
+    namespace detail {
+
+        template <typename... VK_STRUCT_TYPES>
+        void ValidatePnextImpl(const VkBaseOutStructure* root) {
+            const VkBaseOutStructure* next =
+                reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
+            while (next != nullptr) {
+                // Assert that the type of each pNext struct is exactly one of the specified
+                // templates.
+                ASSERT(((LvlTypeMap<VK_STRUCT_TYPES>::kSType == next->sType ? 1 : 0) + ... + 0) ==
+                       1);
+                next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
+            }
+        }
+
+        template <typename VK_STRUCT_TYPE>
+        void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
+            const VkBaseOutStructure* next =
+                reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
+            const VK_STRUCT_TYPE* found = nullptr;
+            while (next != nullptr) {
+                if (LvlTypeMap<VK_STRUCT_TYPE>::kSType == next->sType) {
+                    if (found == nullptr) {
+                        found = reinterpret_cast<const VK_STRUCT_TYPE*>(next);
+                    } else {
+                        // Fail an assert here since that means that the chain had more than one of
+                        // the same typed chained object.
+                        ASSERT(false);
+                    }
+                }
+                next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
+            }
+            if (found != nullptr) {
+                key->Record(found);
+            }
+        }
+
+        template <typename VK_STRUCT_TYPE,
+                  typename... VK_STRUCT_TYPES,
+                  typename = std::enable_if_t<(sizeof...(VK_STRUCT_TYPES) > 0)>>
+        void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
+            SerializePnextImpl<VK_STRUCT_TYPE>(key, root);
+            SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
+        }
+
+    }  // namespace detail
+
+    template <typename... VK_STRUCT_TYPES>
+    void SerializePnext(CacheKey* key, const VkBaseOutStructure* root) {
+        detail::ValidatePnextImpl<VK_STRUCT_TYPES...>(root);
+        detail::SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
+    }
+
+    // Empty template specialization so that we can put this in to ensure failures occur if new
+    // extensions are added without updating serialization.
+    template <>
+    void SerializePnext(CacheKey* key, const VkBaseOutStructure* root) {
+        detail::ValidatePnextImpl<>(root);
+    }
+
+}  // namespace dawn::native::vulkan
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp
index fa13e26..68ac7d9 100644
--- a/src/dawn/native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -22,6 +22,8 @@
 #include "dawn/native/vulkan/UtilsVulkan.h"
 #include "dawn/native/vulkan/VulkanError.h"
 
+#include <utility>
+
 namespace dawn::native::vulkan {
 
     // static
@@ -46,10 +48,11 @@
         createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
         // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
         const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
-        DAWN_TRY_ASSIGN(createInfo.stage.module,
-                        ToBackend(computeStage.module.Get())
-                            ->GetTransformedModuleHandle(computeStage.entryPoint.c_str(),
-                                                         ToBackend(GetLayout())));
+        ShaderModule* module = ToBackend(computeStage.module.Get());
+        PipelineLayout* layout = ToBackend(GetLayout());
+        const ShaderModule::Spirv* spirv;
+        DAWN_TRY_ASSIGN((std::tie(createInfo.stage.module, spirv)),
+                        module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
 
         createInfo.stage.pName = computeStage.entryPoint.c_str();
 
@@ -74,6 +77,11 @@
                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
         }
 
+        // Record cache key information now since the createInfo is not stored.
+        GetCacheKey()
+            ->Record(createInfo, static_cast<const ComputePipeline*>(this)->GetLayout())
+            .RecordIterable(*spirv);
+
         DAWN_TRY(CheckVkSuccess(
             device->fn.CreateComputePipelines(device->GetVkDevice(), ::VK_NULL_HANDLE, 1,
                                               &createInfo, nullptr, &*mHandle),
diff --git a/src/dawn/native/vulkan/PipelineLayoutVk.cpp b/src/dawn/native/vulkan/PipelineLayoutVk.cpp
index 245f2c9..e6653bd 100644
--- a/src/dawn/native/vulkan/PipelineLayoutVk.cpp
+++ b/src/dawn/native/vulkan/PipelineLayoutVk.cpp
@@ -38,8 +38,11 @@
         // this constraints at the Dawn level?
         uint32_t numSetLayouts = 0;
         std::array<VkDescriptorSetLayout, kMaxBindGroups> setLayouts;
+        std::array<const CachedObject*, kMaxBindGroups> cachedObjects;
         for (BindGroupIndex setIndex : IterateBitSet(GetBindGroupLayoutsMask())) {
-            setLayouts[numSetLayouts] = ToBackend(GetBindGroupLayout(setIndex))->GetHandle();
+            const BindGroupLayoutBase* bindGroupLayout = GetBindGroupLayout(setIndex);
+            setLayouts[numSetLayouts] = ToBackend(bindGroupLayout)->GetHandle();
+            cachedObjects[numSetLayouts] = bindGroupLayout;
             numSetLayouts++;
         }
 
@@ -52,6 +55,9 @@
         createInfo.pushConstantRangeCount = 0;
         createInfo.pPushConstantRanges = nullptr;
 
+        // Record cache key information now since the createInfo is not stored.
+        GetCacheKey()->RecordIterable(cachedObjects.data(), numSetLayouts).Record(createInfo);
+
         Device* device = ToBackend(GetDevice());
         DAWN_TRY(CheckVkSuccess(
             device->fn.CreatePipelineLayout(device->GetVkDevice(), &createInfo, nullptr, &*mHandle),
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp
index 4f30496..405da49 100644
--- a/src/dawn/native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -332,6 +332,7 @@
 
     MaybeError RenderPipeline::Initialize() {
         Device* device = ToBackend(GetDevice());
+        PipelineLayout* layout = ToBackend(GetLayout());
 
         // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
         std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
@@ -344,10 +345,11 @@
             VkPipelineShaderStageCreateInfo shaderStage;
 
             const ProgrammableStage& programmableStage = GetStage(stage);
-            DAWN_TRY_ASSIGN(shaderStage.module,
-                            ToBackend(programmableStage.module)
-                                ->GetTransformedModuleHandle(programmableStage.entryPoint.c_str(),
-                                                             ToBackend(GetLayout())));
+            ShaderModule* module = ToBackend(programmableStage.module.Get());
+            const ShaderModule::Spirv* spirv;
+            DAWN_TRY_ASSIGN(
+                std::tie(shaderStage.module, spirv),
+                module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
 
             shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
             shaderStage.pNext = nullptr;
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 9b8c291..d56e599 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -38,35 +38,38 @@
     ShaderModule::ConcurrentTransformedShaderModuleCache::
         ~ConcurrentTransformedShaderModuleCache() {
         std::lock_guard<std::mutex> lock(mMutex);
-        for (const auto& [_, module] : mTransformedShaderModuleCache) {
-            mDevice->GetFencedDeleter()->DeleteWhenUnused(module);
+        for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) {
+            mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.first);
         }
     }
 
-    VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::FindShaderModule(
+    std::optional<ShaderModule::ModuleAndSpirv>
+    ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
         const PipelineLayoutEntryPointPair& key) {
         std::lock_guard<std::mutex> lock(mMutex);
         auto iter = mTransformedShaderModuleCache.find(key);
         if (iter != mTransformedShaderModuleCache.end()) {
-            auto cached = iter->second;
-            return cached;
+            return std::make_pair(iter->second.first, iter->second.second.get());
         }
-        return VK_NULL_HANDLE;
+        return {};
     }
 
-    VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGetCachedShaderModule(
+    ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
         const PipelineLayoutEntryPointPair& key,
-        VkShaderModule value) {
-        ASSERT(value != VK_NULL_HANDLE);
+        VkShaderModule module,
+        std::vector<uint32_t>&& spirv) {
+        ASSERT(module != VK_NULL_HANDLE);
         std::lock_guard<std::mutex> lock(mMutex);
         auto iter = mTransformedShaderModuleCache.find(key);
         if (iter == mTransformedShaderModuleCache.end()) {
-            mTransformedShaderModuleCache.emplace(key, value);
-            return value;
+            mTransformedShaderModuleCache.emplace(
+                key, std::make_pair(module, std::unique_ptr<Spirv>(new Spirv(spirv))));
         } else {
-            mDevice->GetFencedDeleter()->DeleteWhenUnused(value);
-            return iter->second;
+            mDevice->GetFencedDeleter()->DeleteWhenUnused(module);
         }
+        // Now the key should exist in the map, so find it again and return it.
+        iter = mTransformedShaderModuleCache.find(key);
+        return std::make_pair(iter->second.first, iter->second.second.get());
     }
 
     // static
@@ -109,25 +112,24 @@
 
     ShaderModule::~ShaderModule() = default;
 
-    ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle(
+    ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
         const char* entryPointName,
         PipelineLayout* layout) {
-        TRACE_EVENT0(GetDevice()->GetPlatform(), General,
-                     "ShaderModuleVk::GetTransformedModuleHandle");
+        TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
 
         // If the shader was destroyed, we should never call this function.
         ASSERT(IsAlive());
 
         ScopedTintICEHandler scopedICEHandler(GetDevice());
 
+        // Check to see if we have the handle and spirv cached already.
         auto cacheKey = std::make_pair(layout, entryPointName);
-        VkShaderModule cachedShaderModule =
-            mTransformedShaderModuleCache->FindShaderModule(cacheKey);
-        if (cachedShaderModule != VK_NULL_HANDLE) {
-            return cachedShaderModule;
+        auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
+        if (handleAndSpirv.has_value()) {
+            return std::move(*handleAndSpirv);
         }
 
-        // Creation of VkShaderModule is deferred to this point when using tint generator
+        // Creation of module and spirv is deferred to this point when using tint generator
 
         // Remap BindingNumber to BindingIndex in WGSL shader
         using BindingRemapper = tint::transform::BindingRemapper;
@@ -207,7 +209,7 @@
         options.use_zero_initialize_workgroup_memory_extension =
             GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
 
-        std::vector<uint32_t> spirv;
+        Spirv spirv;
         {
             TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::spirv::Generate()");
             auto result = tint::writer::spirv::Generate(&program, options);
@@ -236,15 +238,17 @@
                                         device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
                                     "CreateShaderModule"));
         }
+        ModuleAndSpirv moduleAndSpirv;
         if (newHandle != VK_NULL_HANDLE) {
-            newHandle =
-                mTransformedShaderModuleCache->AddOrGetCachedShaderModule(cacheKey, newHandle);
+            moduleAndSpirv =
+                mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, std::move(spirv));
         }
 
         SetDebugName(ToBackend(GetDevice()), VK_OBJECT_TYPE_SHADER_MODULE,
-                     reinterpret_cast<uint64_t&>(newHandle), "Dawn_ShaderModule", GetLabel());
+                     reinterpret_cast<uint64_t&>(moduleAndSpirv.first), "Dawn_ShaderModule",
+                     GetLabel());
 
-        return newHandle;
+        return std::move(moduleAndSpirv);
     }
 
 }  // namespace dawn::native::vulkan
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index 7040b74..3b69b75 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -20,7 +20,11 @@
 #include "dawn/common/vulkan_platform.h"
 #include "dawn/native/Error.h"
 
+#include <memory>
 #include <mutex>
+#include <optional>
+#include <utility>
+#include <vector>
 
 namespace dawn::native::vulkan {
 
@@ -29,12 +33,15 @@
 
     class ShaderModule final : public ShaderModuleBase {
       public:
+        using Spirv = std::vector<uint32_t>;
+        using ModuleAndSpirv = std::pair<VkShaderModule, const Spirv*>;
+
         static ResultOrError<Ref<ShaderModule>> Create(Device* device,
                                                        const ShaderModuleDescriptor* descriptor,
                                                        ShaderModuleParseResult* parseResult);
 
-        ResultOrError<VkShaderModule> GetTransformedModuleHandle(const char* entryPointName,
-                                                                 PipelineLayout* layout);
+        ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
+                                                        PipelineLayout* layout);
 
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
@@ -42,20 +49,24 @@
         MaybeError Initialize(ShaderModuleParseResult* parseResult);
         void DestroyImpl() override;
 
-        // New handles created by GetTransformedModuleHandle at pipeline creation time
+        // New handles created by GetHandleAndSpirv at pipeline creation time.
         class ConcurrentTransformedShaderModuleCache {
           public:
             explicit ConcurrentTransformedShaderModuleCache(Device* device);
             ~ConcurrentTransformedShaderModuleCache();
-            VkShaderModule FindShaderModule(const PipelineLayoutEntryPointPair& key);
-            VkShaderModule AddOrGetCachedShaderModule(const PipelineLayoutEntryPointPair& key,
-                                                      VkShaderModule value);
+
+            std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
+            ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
+                                    VkShaderModule module,
+                                    std::vector<uint32_t>&& spirv);
 
           private:
+            using Entry = std::pair<VkShaderModule, std::unique_ptr<Spirv>>;
+
             Device* mDevice;
             std::mutex mMutex;
             std::unordered_map<PipelineLayoutEntryPointPair,
-                               VkShaderModule,
+                               Entry,
                                PipelineLayoutEntryPointPairHashFunc>
                 mTransformedShaderModuleCache;
         };
diff --git a/src/dawn/tests/unittests/native/CacheKeyTests.cpp b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
index 45fd360..009b7b6 100644
--- a/src/dawn/tests/unittests/native/CacheKeyTests.cpp
+++ b/src/dawn/tests/unittests/native/CacheKeyTests.cpp
@@ -51,7 +51,7 @@
 
         // Matcher to compare CacheKeys for easier testing.
         MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
-            return memcmp(arg.data(), key.data(), arg.size()) == 0;
+            return arg.size() == key.size() && memcmp(arg.data(), key.data(), key.size()) == 0;
         }
 
         TEST(CacheKeyTests, RecordSingleMember) {