D3D12: Support caching DX shaders.

This change is a prerequisite to D3D pipeline caching.

This change introduces:
- Caching interface which enables the cache.
- Helper for backends to load/store blobs to be cached.
- Ability to cache HLSL shaders.

Bug:dawn:549
Change-Id: I2af759882d18b3f45dc63e49dcb6a3caa1be3485
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32305
Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 513fcd4..9c9c8ab 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -223,6 +223,8 @@
     "PassResourceUsageTracker.h",
     "PerStage.cpp",
     "PerStage.h",
+    "PersistentCache.cpp",
+    "PersistentCache.h",
     "Pipeline.cpp",
     "Pipeline.h",
     "PipelineLayout.cpp",
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index e290aaa..4321a44 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -108,6 +108,8 @@
     "PassResourceUsage.h"
     "PassResourceUsageTracker.cpp"
     "PassResourceUsageTracker.h"
+    "PersistentCache.cpp"
+    "PersistentCache.h"
     "PerStage.cpp"
     "PerStage.h"
     "Pipeline.cpp"
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 0c493b6..2c8468f 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -31,6 +31,7 @@
 #include "dawn_native/Fence.h"
 #include "dawn_native/Instance.h"
 #include "dawn_native/InternalPipelineStore.h"
+#include "dawn_native/PersistentCache.h"
 #include "dawn_native/PipelineLayout.h"
 #include "dawn_native/QuerySet.h"
 #include "dawn_native/Queue.h"
@@ -132,6 +133,7 @@
         mCreateReadyPipelineTracker = std::make_unique<CreateReadyPipelineTracker>(this);
         mDeprecationWarnings = std::make_unique<DeprecationWarnings>();
         mInternalPipelineStore = std::make_unique<InternalPipelineStore>();
+        mPersistentCache = std::make_unique<PersistentCache>(this);
 
         // Starting from now the backend can start doing reentrant calls so the device is marked as
         // alive.
@@ -196,6 +198,7 @@
         mErrorScopeTracker = nullptr;
         mDynamicUploader = nullptr;
         mCreateReadyPipelineTracker = nullptr;
+        mPersistentCache = nullptr;
 
         mEmptyBindGroupLayout = nullptr;
 
@@ -299,6 +302,11 @@
         return mCurrentErrorScope.Get();
     }
 
+    PersistentCache* DeviceBase::GetPersistentCache() {
+        ASSERT(mPersistentCache.get() != nullptr);
+        return mPersistentCache.get();
+    }
+
     MaybeError DeviceBase::ValidateObject(const ObjectBase* object) const {
         ASSERT(object != nullptr);
         if (DAWN_UNLIKELY(object->GetDevice() != this)) {
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 6408a90..08a4c80 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -37,6 +37,7 @@
     class DynamicUploader;
     class ErrorScope;
     class ErrorScopeTracker;
+    class PersistentCache;
     class StagingBufferBase;
     struct InternalPipelineStore;
 
@@ -180,6 +181,8 @@
 
         ErrorScope* GetCurrentErrorScope();
 
+        PersistentCache* GetPersistentCache();
+
         void Reference();
         void Release();
 
@@ -388,6 +391,8 @@
         ExtensionsSet mEnabledExtensions;
 
         std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
+
+        std::unique_ptr<PersistentCache> mPersistentCache;
     };
 
 }  // namespace dawn_native
diff --git a/src/dawn_native/PersistentCache.cpp b/src/dawn_native/PersistentCache.cpp
new file mode 100644
index 0000000..fbb1ece
--- /dev/null
+++ b/src/dawn_native/PersistentCache.cpp
@@ -0,0 +1,65 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/PersistentCache.h"
+
+#include "common/Assert.h"
+#include "dawn_native/Device.h"
+#include "dawn_platform/DawnPlatform.h"
+
+namespace dawn_native {
+
+    PersistentCache::PersistentCache(DeviceBase* device)
+        : mDevice(device), mCache(GetPlatformCache()) {
+    }
+
+    ScopedCachedBlob PersistentCache::LoadData(const PersistentCacheKey& key) {
+        ScopedCachedBlob blob = {};
+        if (mCache == nullptr) {
+            return blob;
+        }
+        blob.bufferSize = mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(),
+                                           key.size(), nullptr, 0);
+        if (blob.bufferSize > 0) {
+            blob.buffer.reset(new uint8_t[blob.bufferSize]);
+            const size_t bufferSize =
+                mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(),
+                                 blob.buffer.get(), blob.bufferSize);
+            ASSERT(bufferSize == blob.bufferSize);
+            return blob;
+        }
+        return blob;
+    }
+
+    void PersistentCache::StoreData(const PersistentCacheKey& key, const void* value, size_t size) {
+        if (mCache == nullptr) {
+            return;
+        }
+        ASSERT(value != nullptr);
+        ASSERT(size > 0);
+        mCache->StoreData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(), value,
+                          size);
+    }
+
+    dawn_platform::CachingInterface* PersistentCache::GetPlatformCache() {
+        // TODO(dawn:549): Create a fingerprint of concatenated version strings (ex. Tint commit
+        // hash, Dawn commit hash). This will be used by the client so it may know when to discard
+        // previously cached Dawn objects should this fingerprint change.
+        dawn_platform::Platform* platform = mDevice->GetPlatform();
+        if (platform != nullptr) {
+            return platform->GetCachingInterface(/*fingerprint*/ nullptr, /*fingerprintSize*/ 0);
+        }
+        return nullptr;
+    }
+}  // namespace dawn_native
\ No newline at end of file
diff --git a/src/dawn_native/PersistentCache.h b/src/dawn_native/PersistentCache.h
new file mode 100644
index 0000000..5e9dbc0
--- /dev/null
+++ b/src/dawn_native/PersistentCache.h
@@ -0,0 +1,86 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_PERSISTENTCACHE_H_
+#define DAWNNATIVE_PERSISTENTCACHE_H_
+
+#include "dawn_native/Error.h"
+
+#include <vector>
+
+namespace dawn_platform {
+    class CachingInterface;
+}
+
+namespace dawn_native {
+
+    using PersistentCacheKey = std::vector<uint8_t>;
+
+    struct ScopedCachedBlob {
+        std::unique_ptr<uint8_t[]> buffer;
+        size_t bufferSize = 0;
+    };
+
+    class DeviceBase;
+
+    enum class PersistentKeyType { Shader };
+
+    class PersistentCache {
+      public:
+        PersistentCache(DeviceBase* device);
+
+        // Combines load/store operations into a single call.
+        // If the load was successful, a non-empty blob is returned to the caller.
+        // Else, the creation callback |createFn| gets invoked with a callback
+        // |doCache| to store the newly created blob back in the cache.
+        //
+        // Example usage:
+        //
+        // ScopedCachedBlob cachedBlob = {};
+        // DAWN_TRY_ASSIGN(cachedBlob, GetOrCreate(key, [&](auto doCache)) {
+        //      // Create a new blob to be stored
+        //      doCache(newBlobPtr, newBlobSize); // store
+        // }));
+        //
+        template <typename CreateFn>
+        ResultOrError<ScopedCachedBlob> GetOrCreate(const PersistentCacheKey& key,
+                                                    CreateFn&& createFn) {
+            // Attempt to load an existing blob from the cache.
+            ScopedCachedBlob blob = LoadData(key);
+            if (blob.bufferSize > 0) {
+                return std::move(blob);
+            }
+
+            // Allow the caller to create a new blob to be stored for the given key.
+            DAWN_TRY(createFn([this, key](const void* value, size_t size) {
+                this->StoreData(key, value, size);
+            }));
+
+            return std::move(blob);
+        }
+
+      private:
+        // PersistentCache impl
+        ScopedCachedBlob LoadData(const PersistentCacheKey& key);
+        void StoreData(const PersistentCacheKey& key, const void* value, size_t size);
+
+        dawn_platform::CachingInterface* GetPlatformCache();
+
+        DeviceBase* mDevice = nullptr;
+
+        dawn_platform::CachingInterface* mCache = nullptr;
+    };
+}  // namespace dawn_native
+
+#endif  // DAWNNATIVE_PERSISTENTCACHE_H_
\ No newline at end of file
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 68acce9..88aa240 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -43,44 +43,14 @@
 
         ShaderModule* module = ToBackend(descriptor->computeStage.module);
 
-        const char* entryPoint = descriptor->computeStage.entryPoint;
-        std::string remappedEntryPoint;
-        std::string hlslSource;
-        if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
-            DAWN_TRY_ASSIGN(hlslSource, module->TranslateToHLSLWithTint(
-                                            entryPoint, SingleShaderStage::Compute,
-                                            ToBackend(GetLayout()), &remappedEntryPoint));
-            entryPoint = remappedEntryPoint.c_str();
-
-        } else {
-            DAWN_TRY_ASSIGN(hlslSource,
-                            module->TranslateToHLSLWithSPIRVCross(
-                                entryPoint, SingleShaderStage::Compute, ToBackend(GetLayout())));
-
-            // Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
-            entryPoint = "main";
-        }
-
         D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
         d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
 
-        ComPtr<IDxcBlob> compiledDXCShader;
-        ComPtr<ID3DBlob> compiledFXCShader;
-        if (device->IsToggleEnabled(Toggle::UseDXC)) {
-            DAWN_TRY_ASSIGN(compiledDXCShader,
-                            CompileShaderDXC(device, SingleShaderStage::Compute, hlslSource,
-                                             entryPoint, compileFlags));
-
-            d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
-            d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
-        } else {
-            DAWN_TRY_ASSIGN(compiledFXCShader,
-                            CompileShaderFXC(device, SingleShaderStage::Compute, hlslSource,
-                                             entryPoint, compileFlags));
-            d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
-            d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
-        }
-
+        CompiledShader compiledShader;
+        DAWN_TRY_ASSIGN(compiledShader, module->Compile(descriptor->computeStage.entryPoint,
+                                                        SingleShaderStage::Compute,
+                                                        ToBackend(GetLayout()), compileFlags));
+        d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
         device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
                                                              IID_PPV_ARGS(&mPipelineState));
         return {};
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index 5501308..f9a2b3f 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -306,44 +306,13 @@
         shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
         shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
 
-        PerStage<ComPtr<ID3DBlob>> compiledFXCShader;
-        PerStage<ComPtr<IDxcBlob>> compiledDXCShader;
-
+        PerStage<CompiledShader> compiledShader;
         wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
         for (auto stage : IterateStages(renderStages)) {
-            std::string hlslSource;
-            const char* entryPoint = GetStage(stage).entryPoint.c_str();
-            std::string remappedEntryPoint;
-
-            if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
-                DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithTint(
-                                                entryPoint, stage, ToBackend(GetLayout()),
-                                                &remappedEntryPoint));
-                entryPoint = remappedEntryPoint.c_str();
-
-            } else {
-                DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithSPIRVCross(
-                                                entryPoint, stage, ToBackend(GetLayout())));
-
-                // Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
-                entryPoint = "main";
-            }
-
-            if (device->IsToggleEnabled(Toggle::UseDXC)) {
-                DAWN_TRY_ASSIGN(
-                    compiledDXCShader[stage],
-                    CompileShaderDXC(device, stage, hlslSource, entryPoint, compileFlags));
-
-                shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
-                shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
-            } else {
-                DAWN_TRY_ASSIGN(
-                    compiledFXCShader[stage],
-                    CompileShaderFXC(device, stage, hlslSource, entryPoint, compileFlags));
-
-                shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
-                shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();
-            }
+            DAWN_TRY_ASSIGN(compiledShader[stage],
+                            modules[stage]->Compile(entryPoints[stage], stage,
+                                                    ToBackend(GetLayout()), compileFlags));
+            *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
         }
 
         PipelineLayout* layout = ToBackend(GetLayout());
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index b0db7c3..7b10fc0 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -309,4 +309,102 @@
         return compiler.compile();
     }
 
+    ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
+                                                        SingleShaderStage stage,
+                                                        PipelineLayout* layout,
+                                                        uint32_t compileFlags) {
+        Device* device = ToBackend(GetDevice());
+
+        // Compile the source shader to HLSL.
+        std::string hlslSource;
+        std::string remappedEntryPoint;
+        if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+            DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSLWithTint(entryPointName, stage, layout,
+                                                                &remappedEntryPoint));
+            entryPointName = remappedEntryPoint.c_str();
+        } else {
+            DAWN_TRY_ASSIGN(hlslSource,
+                            TranslateToHLSLWithSPIRVCross(entryPointName, stage, layout));
+
+            // Note that the HLSL will always use entryPoint "main" under
+            // SPIRV-cross.
+            entryPointName = "main";
+        }
+
+        // Use HLSL source as the input for the key since it does need to know about the pipeline
+        // layout. The pipeline layout is only required if we key from WGSL: two different pipeline
+        // layouts could be used to produce different shader blobs and the wrong shader blob could
+        // be loaded since the pipeline layout was missing from the key.
+        // TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used.
+        const PersistentCacheKey& shaderCacheKey =
+            CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags);
+
+        CompiledShader compiledShader = {};
+        DAWN_TRY_ASSIGN(compiledShader.cachedShader,
+                        device->GetPersistentCache()->GetOrCreate(
+                            shaderCacheKey, [&](auto doCache) -> MaybeError {
+                                if (device->IsToggleEnabled(Toggle::UseDXC)) {
+                                    DAWN_TRY_ASSIGN(compiledShader.compiledDXCShader,
+                                                    CompileShaderDXC(device, stage, hlslSource,
+                                                                     entryPointName, compileFlags));
+                                } else {
+                                    DAWN_TRY_ASSIGN(compiledShader.compiledFXCShader,
+                                                    CompileShaderFXC(device, stage, hlslSource,
+                                                                     entryPointName, compileFlags));
+                                }
+                                const D3D12_SHADER_BYTECODE shader =
+                                    compiledShader.GetD3D12ShaderBytecode();
+                                doCache(shader.pShaderBytecode, shader.BytecodeLength);
+                                return {};
+                            }));
+
+        return std::move(compiledShader);
+    }
+
+    D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
+        if (cachedShader.buffer != nullptr) {
+            return {cachedShader.buffer.get(), cachedShader.bufferSize};
+        } else if (compiledFXCShader != nullptr) {
+            return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
+        } else if (compiledDXCShader != nullptr) {
+            return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
+        }
+        UNREACHABLE();
+        return {};
+    }
+
+    PersistentCacheKey ShaderModule::CreateHLSLKey(const char* entryPointName,
+                                                   SingleShaderStage stage,
+                                                   const std::string& hlslSource,
+                                                   uint32_t compileFlags) const {
+        std::stringstream stream;
+
+        // Prefix the key with the type to avoid collisions from another type that could have the
+        // same key.
+        stream << static_cast<uint32_t>(PersistentKeyType::Shader);
+
+        // Provide "guard" strings that the user cannot provide to help ensure the generated HLSL
+        // used to create this key is not being manufactured by the user to load the wrong shader
+        // blob.
+        // These strings can be HLSL comments because Tint does not emit HLSL comments.
+        // TODO(dawn:549): Replace guards strings with something more secure.
+        ASSERT(hlslSource.find("//") == std::string::npos);
+
+        stream << "// Start shader autogenerated by Dawn.";
+        stream << hlslSource;
+        stream << "// End of shader autogenerated by Dawn.";
+
+        stream << compileFlags;
+
+        // TODO(dawn:549): add the HLSL compiler version for good measure.
+
+        // If the source contains multiple entry points, ensure they are cached seperately
+        // per stage since DX shader code can only be compiled per stage using the same
+        // entry point.
+        stream << static_cast<uint32_t>(stage);
+        stream << entryPointName;
+
+        return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
+                                  std::istreambuf_iterator<char>{});
+    }
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 4fc2532..63cb2e9 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -15,6 +15,7 @@
 #ifndef DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
 #define DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
 
+#include "dawn_native/PersistentCache.h"
 #include "dawn_native/ShaderModule.h"
 
 #include "dawn_native/d3d12/d3d12_platform.h"
@@ -24,22 +25,28 @@
     class Device;
     class PipelineLayout;
 
-    ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
-                                                     SingleShaderStage stage,
-                                                     const std::string& hlslSource,
-                                                     const char* entryPoint,
-                                                     uint32_t compileFlags);
-    ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
-                                                     SingleShaderStage stage,
-                                                     const std::string& hlslSource,
-                                                     const char* entryPoint,
-                                                     uint32_t compileFlags);
+    // Manages a ref to one of the various representations of shader blobs.
+    struct CompiledShader {
+        ScopedCachedBlob cachedShader;
+        ComPtr<ID3DBlob> compiledFXCShader;
+        ComPtr<IDxcBlob> compiledDXCShader;
+        D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
+    };
 
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<ShaderModule*> Create(Device* device,
                                                    const ShaderModuleDescriptor* descriptor);
 
+        ResultOrError<CompiledShader> Compile(const char* entryPointName,
+                                              SingleShaderStage stage,
+                                              PipelineLayout* layout,
+                                              uint32_t compileFlags);
+
+      private:
+        ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
+        ~ShaderModule() override = default;
+
         ResultOrError<std::string> TranslateToHLSLWithTint(
             const char* entryPointName,
             SingleShaderStage stage,
@@ -50,9 +57,10 @@
                                                                  SingleShaderStage stage,
                                                                  PipelineLayout* layout) const;
 
-      private:
-        ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
-        ~ShaderModule() override = default;
+        PersistentCacheKey CreateHLSLKey(const char* entryPointName,
+                                         SingleShaderStage stage,
+                                         const std::string& hlslSource,
+                                         uint32_t compileFlags) const;
     };
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_platform/BUILD.gn b/src/dawn_platform/BUILD.gn
index d26a859..91c9e75 100644
--- a/src/dawn_platform/BUILD.gn
+++ b/src/dawn_platform/BUILD.gn
@@ -31,4 +31,9 @@
   ]
 
   deps = [ "${dawn_root}/src/common" ]
+
+  public_deps = [
+    # DawnPlatform.h has #include <dawn/webgpu.h>
+    "${dawn_root}/src/dawn:dawn_headers",
+  ]
 }
diff --git a/src/dawn_platform/CMakeLists.txt b/src/dawn_platform/CMakeLists.txt
index 67b1b92..b8075e2 100644
--- a/src/dawn_platform/CMakeLists.txt
+++ b/src/dawn_platform/CMakeLists.txt
@@ -27,4 +27,4 @@
     "tracing/EventTracer.h"
     "tracing/TraceEvent.h"
 )
-target_link_libraries(dawn_platform PRIVATE dawn_internal_config dawn_common)
+target_link_libraries(dawn_platform PUBLIC dawn_headers PRIVATE dawn_internal_config dawn_common)
diff --git a/src/dawn_platform/DawnPlatform.cpp b/src/dawn_platform/DawnPlatform.cpp
index 6fe61d5..6b71708 100644
--- a/src/dawn_platform/DawnPlatform.cpp
+++ b/src/dawn_platform/DawnPlatform.cpp
@@ -14,10 +14,45 @@
 
 #include "dawn_platform/DawnPlatform.h"
 
+#include "common/Assert.h"
+
 namespace dawn_platform {
 
+    CachingInterface::CachingInterface() = default;
+
+    CachingInterface::~CachingInterface() = default;
+
     Platform::Platform() = default;
 
     Platform::~Platform() = default;
 
+    const unsigned char* Platform::GetTraceCategoryEnabledFlag(TraceCategory category) {
+        static unsigned char disabled = 0;
+        return &disabled;
+    }
+
+    double Platform::MonotonicallyIncreasingTime() {
+        return 0;
+    }
+
+    uint64_t Platform::AddTraceEvent(char phase,
+                                     const unsigned char* categoryGroupEnabled,
+                                     const char* name,
+                                     uint64_t id,
+                                     double timestamp,
+                                     int numArgs,
+                                     const char** argNames,
+                                     const unsigned char* argTypes,
+                                     const uint64_t* argValues,
+                                     unsigned char flags) {
+        // AddTraceEvent cannot be called if events are disabled.
+        ASSERT(false);
+        return 0;
+    }
+
+    dawn_platform::CachingInterface* Platform::GetCachingInterface(const void* fingerprint,
+                                                                   size_t fingerprintSize) {
+        return nullptr;
+    }
+
 }  // namespace dawn_platform
\ No newline at end of file
diff --git a/src/include/dawn_platform/DawnPlatform.h b/src/include/dawn_platform/DawnPlatform.h
index 107b91e..4a00f53 100644
--- a/src/include/dawn_platform/DawnPlatform.h
+++ b/src/include/dawn_platform/DawnPlatform.h
@@ -17,8 +17,11 @@
 
 #include "dawn_platform/dawn_platform_export.h"
 
+#include <cstddef>
 #include <cstdint>
 
+#include <dawn/webgpu.h>
+
 namespace dawn_platform {
 
     enum class TraceCategory {
@@ -28,14 +31,43 @@
         GPUWork,     // Actual GPU work
     };
 
+    class DAWN_PLATFORM_EXPORT CachingInterface {
+      public:
+        CachingInterface();
+        virtual ~CachingInterface();
+
+        // LoadData has two modes. The first mode is used to get a value which
+        // corresponds to the |key|. The |valueOut| is a caller provided buffer
+        // allocated to the size |valueSize| which is loaded with data of the
+        // size returned. The second mode is used to query for the existence of
+        // the |key| where |valueOut| is nullptr and |valueSize| must be 0.
+        // The return size is non-zero if the |key| exists.
+        virtual size_t LoadData(const WGPUDevice device,
+                                const void* key,
+                                size_t keySize,
+                                void* valueOut,
+                                size_t valueSize) = 0;
+
+        // StoreData puts a |value| in the cache which corresponds to the |key|.
+        virtual void StoreData(const WGPUDevice device,
+                               const void* key,
+                               size_t keySize,
+                               const void* value,
+                               size_t valueSize) = 0;
+
+      private:
+        CachingInterface(const CachingInterface&) = delete;
+        CachingInterface& operator=(const CachingInterface&) = delete;
+    };
+
     class DAWN_PLATFORM_EXPORT Platform {
       public:
         Platform();
         virtual ~Platform();
 
-        virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category) = 0;
+        virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category);
 
-        virtual double MonotonicallyIncreasingTime() = 0;
+        virtual double MonotonicallyIncreasingTime();
 
         virtual uint64_t AddTraceEvent(char phase,
                                        const unsigned char* categoryGroupEnabled,
@@ -46,7 +78,13 @@
                                        const char** argNames,
                                        const unsigned char* argTypes,
                                        const uint64_t* argValues,
-                                       unsigned char flags) = 0;
+                                       unsigned char flags);
+
+        // The |fingerprint| is provided by Dawn to inform the client to discard the Dawn caches
+        // when the fingerprint changes. The returned CachingInterface is expected to outlive the
+        // device which uses it to persistently cache objects.
+        virtual CachingInterface* GetCachingInterface(const void* fingerprint,
+                                                      size_t fingerprintSize);
 
       private:
         Platform(const Platform&) = delete;
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index cfd4805..3fd7253 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -333,7 +333,10 @@
   libs = []
 
   if (dawn_enable_d3d12) {
-    sources += [ "end2end/D3D12ResourceWrappingTests.cpp" ]
+    sources += [
+      "end2end/D3D12CachingTests.cpp",
+      "end2end/D3D12ResourceWrappingTests.cpp",
+    ]
     libs += [
       "d3d11.lib",
       "dxgi.lib",
diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp
index 7743875..ded6059 100644
--- a/src/tests/DawnTest.cpp
+++ b/src/tests/DawnTest.cpp
@@ -744,6 +744,10 @@
         mBackendAdapter = *it;
     }
 
+    // Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform.
+    mTestPlatform = CreateTestPlatform();
+    gTestEnv->GetInstance()->SetPlatform(mTestPlatform.get());
+
     // Create the device from the adapter
     for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) {
         ASSERT(gTestEnv->GetInstance()->GetToggleInfo(forceEnabledWorkaround) != nullptr);
@@ -1080,6 +1084,10 @@
     }
 }
 
+std::unique_ptr<dawn_platform::Platform> DawnTestBase::CreateTestPlatform() {
+    return nullptr;
+}
+
 bool RGBA8::operator==(const RGBA8& other) const {
     return r == other.r && g == other.g && b == other.b && a == other.a;
 }
diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h
index f2f3106..264f6f4 100644
--- a/src/tests/DawnTest.h
+++ b/src/tests/DawnTest.h
@@ -20,6 +20,7 @@
 #include "dawn/webgpu_cpp.h"
 #include "dawn_native/DawnNative.h"
 
+#include <dawn_platform/DawnPlatform.h>
 #include <gtest/gtest.h>
 
 #include <memory>
@@ -268,6 +269,8 @@
     wgpu::Instance GetInstance() const;
     dawn_native::Adapter GetAdapter() const;
 
+    virtual std::unique_ptr<dawn_platform::Platform> CreateTestPlatform();
+
   protected:
     wgpu::Device device;
     wgpu::Queue queue;
@@ -403,6 +406,8 @@
     void ResolveExpectations();
 
     dawn_native::Adapter mBackendAdapter;
+
+    std::unique_ptr<dawn_platform::Platform> mTestPlatform;
 };
 
 // Skip a test when the given condition is satisfied.
diff --git a/src/tests/end2end/D3D12CachingTests.cpp b/src/tests/end2end/D3D12CachingTests.cpp
new file mode 100644
index 0000000..27690b7
--- /dev/null
+++ b/src/tests/end2end/D3D12CachingTests.cpp
@@ -0,0 +1,287 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnTest.h"
+
+#include "utils/ComboRenderPipelineDescriptor.h"
+#include "utils/WGPUHelpers.h"
+
+#define EXPECT_CACHE_HIT(N, statement)              \
+    do {                                            \
+        size_t before = mPersistentCache.mHitCount; \
+        statement;                                  \
+        FlushWire();                                \
+        size_t after = mPersistentCache.mHitCount;  \
+        EXPECT_EQ(N, after - before);               \
+    } while (0)
+
+// FakePersistentCache implements a in-memory persistent cache.
+class FakePersistentCache : public dawn_platform::CachingInterface {
+  public:
+    // PersistentCache API
+    void StoreData(const WGPUDevice device,
+                   const void* key,
+                   size_t keySize,
+                   const void* value,
+                   size_t valueSize) override {
+        if (mIsDisabled)
+            return;
+        const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
+
+        const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
+        std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
+
+        EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
+    }
+
+    size_t LoadData(const WGPUDevice device,
+                    const void* key,
+                    size_t keySize,
+                    void* value,
+                    size_t valueSize) override {
+        const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
+        auto entry = mCache.find(keyStr);
+        if (entry == mCache.end()) {
+            return 0;
+        }
+        if (valueSize >= entry->second.size()) {
+            memcpy(value, entry->second.data(), entry->second.size());
+        }
+        mHitCount++;
+        return entry->second.size();
+    }
+
+    using Blob = std::vector<uint8_t>;
+    using FakeCache = std::unordered_map<std::string, Blob>;
+
+    FakeCache mCache;
+
+    size_t mHitCount = 0;
+    bool mIsDisabled = false;
+};
+
+// Test platform that only supports caching.
+class DawnTestPlatform : public dawn_platform::Platform {
+  public:
+    DawnTestPlatform(dawn_platform::CachingInterface* cachingInterface)
+        : mCachingInterface(cachingInterface) {
+    }
+    ~DawnTestPlatform() override = default;
+
+    dawn_platform::CachingInterface* GetCachingInterface(const void* fingerprint,
+                                                         size_t fingerprintSize) override {
+        return mCachingInterface;
+    }
+
+    dawn_platform::CachingInterface* mCachingInterface = nullptr;
+};
+
+class D3D12CachingTests : public DawnTest {
+  protected:
+    std::unique_ptr<dawn_platform::Platform> CreateTestPlatform() override {
+        return std::make_unique<DawnTestPlatform>(&mPersistentCache);
+    }
+
+    FakePersistentCache mPersistentCache;
+};
+
+// Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled.
+TEST_P(D3D12CachingTests, SameShaderNoCache) {
+    mPersistentCache.mIsDisabled = true;
+
+    wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
+        [[builtin(position)]] var<out> Position : vec4<f32>;
+
+        [[stage(vertex)]]
+        fn vertex_main() -> void {
+            Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            return;
+        }
+
+        [[location(0)]] var<out> outColor : vec4<f32>;
+
+        [[stage(fragment)]]
+        fn fragment_main() -> void {
+          outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+          return;
+        }
+    )");
+
+    // Store the WGSL shader into the cache.
+    {
+        utils::ComboRenderPipelineDescriptor desc(device);
+        desc.vertexStage.module = module;
+        desc.vertexStage.entryPoint = "vertex_main";
+        desc.cFragmentStage.module = module;
+        desc.cFragmentStage.entryPoint = "fragment_main";
+
+        EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+    }
+
+    EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
+
+    // Load the same WGSL shader from the cache.
+    {
+        utils::ComboRenderPipelineDescriptor desc(device);
+        desc.vertexStage.module = module;
+        desc.vertexStage.entryPoint = "vertex_main";
+        desc.cFragmentStage.module = module;
+        desc.cFragmentStage.entryPoint = "fragment_main";
+
+        EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+    }
+
+    EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
+}
+
+// Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
+// of HLSL shaders. WGSL shader should result into caching 2 HLSL shaders (stage x
+// entrypoints)
+TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
+    wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
+        [[builtin(position)]] var<out> Position : vec4<f32>;
+
+        [[stage(vertex)]]
+        fn vertex_main() -> void {
+            Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+            return;
+        }
+
+        [[location(0)]] var<out> outColor : vec4<f32>;
+
+        [[stage(fragment)]]
+        fn fragment_main() -> void {
+          outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+          return;
+        }
+    )");
+
+    // Store the WGSL shader into the cache.
+    {
+        utils::ComboRenderPipelineDescriptor desc(device);
+        desc.vertexStage.module = module;
+        desc.vertexStage.entryPoint = "vertex_main";
+        desc.cFragmentStage.module = module;
+        desc.cFragmentStage.entryPoint = "fragment_main";
+
+        EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+    }
+
+    EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+
+    // Load the same WGSL shader from the cache.
+    {
+        utils::ComboRenderPipelineDescriptor desc(device);
+        desc.vertexStage.module = module;
+        desc.vertexStage.entryPoint = "vertex_main";
+        desc.cFragmentStage.module = module;
+        desc.cFragmentStage.entryPoint = "fragment_main";
+
+        // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
+        // kNumOfShaders hits.
+        EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
+    }
+
+    EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+
+    // Modify the WGSL shader functions and make sure it doesn't hit.
+    wgpu::ShaderModule newModule = utils::CreateShaderModuleFromWGSL(device, R"(
+      [[builtin(position)]] var<out> Position : vec4<f32>;
+
+      [[stage(vertex)]]
+      fn vertex_main() -> void {
+          Position = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+          return;
+      }
+
+      [[location(0)]] var<out> outColor : vec4<f32>;
+
+      [[stage(fragment)]]
+      fn fragment_main() -> void {
+        outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+        return;
+      }
+  )");
+
+    {
+        utils::ComboRenderPipelineDescriptor desc(device);
+        desc.vertexStage.module = newModule;
+        desc.vertexStage.entryPoint = "vertex_main";
+        desc.cFragmentStage.module = newModule;
+        desc.cFragmentStage.entryPoint = "fragment_main";
+        EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+    }
+
+    // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
+    // kNumOfShaders hits.
+    EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
+}
+
+// Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
+// of HLSL shaders. WGSL shader should result into caching 1 HLSL shader (stage x entrypoints)
+TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
+    wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
+        [[block]] struct Data {
+            [[offset(0)]] data : u32;
+        };
+        [[binding(0), set(0)]] var<storage_buffer> data : Data;
+
+        [[stage(compute)]]
+        fn write1() -> void {
+            data.data = 1u;
+            return;
+        }
+
+        [[stage(compute)]]
+        fn write42() -> void {
+            data.data = 42u;
+            return;
+        }
+    )");
+
+    // Store the WGSL shader into the cache.
+    {
+        wgpu::ComputePipelineDescriptor desc;
+        desc.computeStage.module = module;
+        desc.computeStage.entryPoint = "write1";
+        EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
+
+        desc.computeStage.module = module;
+        desc.computeStage.entryPoint = "write42";
+        EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
+    }
+
+    EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+
+    // Load the same WGSL shader from the cache.
+    {
+        wgpu::ComputePipelineDescriptor desc;
+        desc.computeStage.module = module;
+        desc.computeStage.entryPoint = "write1";
+
+        // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
+        // kNumOfShaders hits.
+        EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
+
+        desc.computeStage.module = module;
+        desc.computeStage.entryPoint = "write42";
+
+        // Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
+        EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
+    }
+
+    EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+}
+
+DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());
\ No newline at end of file