D3D12: Support caching DX shaders.
This change is a prerequisite to D3D pipeline caching.
This change introduces:
- Caching interface which enables the cache.
- Helper for backends to load/store blobs to be cached.
- Ability to cache HLSL shaders.
Bug:dawn:549
Change-Id: I2af759882d18b3f45dc63e49dcb6a3caa1be3485
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32305
Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 513fcd4..9c9c8ab 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -223,6 +223,8 @@
"PassResourceUsageTracker.h",
"PerStage.cpp",
"PerStage.h",
+ "PersistentCache.cpp",
+ "PersistentCache.h",
"Pipeline.cpp",
"Pipeline.h",
"PipelineLayout.cpp",
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt
index e290aaa..4321a44 100644
--- a/src/dawn_native/CMakeLists.txt
+++ b/src/dawn_native/CMakeLists.txt
@@ -108,6 +108,8 @@
"PassResourceUsage.h"
"PassResourceUsageTracker.cpp"
"PassResourceUsageTracker.h"
+ "PersistentCache.cpp"
+ "PersistentCache.h"
"PerStage.cpp"
"PerStage.h"
"Pipeline.cpp"
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp
index 0c493b6..2c8468f 100644
--- a/src/dawn_native/Device.cpp
+++ b/src/dawn_native/Device.cpp
@@ -31,6 +31,7 @@
#include "dawn_native/Fence.h"
#include "dawn_native/Instance.h"
#include "dawn_native/InternalPipelineStore.h"
+#include "dawn_native/PersistentCache.h"
#include "dawn_native/PipelineLayout.h"
#include "dawn_native/QuerySet.h"
#include "dawn_native/Queue.h"
@@ -132,6 +133,7 @@
mCreateReadyPipelineTracker = std::make_unique<CreateReadyPipelineTracker>(this);
mDeprecationWarnings = std::make_unique<DeprecationWarnings>();
mInternalPipelineStore = std::make_unique<InternalPipelineStore>();
+ mPersistentCache = std::make_unique<PersistentCache>(this);
// Starting from now the backend can start doing reentrant calls so the device is marked as
// alive.
@@ -196,6 +198,7 @@
mErrorScopeTracker = nullptr;
mDynamicUploader = nullptr;
mCreateReadyPipelineTracker = nullptr;
+ mPersistentCache = nullptr;
mEmptyBindGroupLayout = nullptr;
@@ -299,6 +302,11 @@
return mCurrentErrorScope.Get();
}
+ PersistentCache* DeviceBase::GetPersistentCache() {
+ ASSERT(mPersistentCache.get() != nullptr);
+ return mPersistentCache.get();
+ }
+
MaybeError DeviceBase::ValidateObject(const ObjectBase* object) const {
ASSERT(object != nullptr);
if (DAWN_UNLIKELY(object->GetDevice() != this)) {
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h
index 6408a90..08a4c80 100644
--- a/src/dawn_native/Device.h
+++ b/src/dawn_native/Device.h
@@ -37,6 +37,7 @@
class DynamicUploader;
class ErrorScope;
class ErrorScopeTracker;
+ class PersistentCache;
class StagingBufferBase;
struct InternalPipelineStore;
@@ -180,6 +181,8 @@
ErrorScope* GetCurrentErrorScope();
+ PersistentCache* GetPersistentCache();
+
void Reference();
void Release();
@@ -388,6 +391,8 @@
ExtensionsSet mEnabledExtensions;
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
+
+ std::unique_ptr<PersistentCache> mPersistentCache;
};
} // namespace dawn_native
diff --git a/src/dawn_native/PersistentCache.cpp b/src/dawn_native/PersistentCache.cpp
new file mode 100644
index 0000000..fbb1ece
--- /dev/null
+++ b/src/dawn_native/PersistentCache.cpp
@@ -0,0 +1,65 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/PersistentCache.h"
+
+#include "common/Assert.h"
+#include "dawn_native/Device.h"
+#include "dawn_platform/DawnPlatform.h"
+
+namespace dawn_native {
+
+ PersistentCache::PersistentCache(DeviceBase* device)
+ : mDevice(device), mCache(GetPlatformCache()) {
+ }
+
+ ScopedCachedBlob PersistentCache::LoadData(const PersistentCacheKey& key) {
+ ScopedCachedBlob blob = {};
+ if (mCache == nullptr) {
+ return blob;
+ }
+ blob.bufferSize = mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(),
+ key.size(), nullptr, 0);
+ if (blob.bufferSize > 0) {
+ blob.buffer.reset(new uint8_t[blob.bufferSize]);
+ const size_t bufferSize =
+ mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(),
+ blob.buffer.get(), blob.bufferSize);
+ ASSERT(bufferSize == blob.bufferSize);
+ return blob;
+ }
+ return blob;
+ }
+
+ void PersistentCache::StoreData(const PersistentCacheKey& key, const void* value, size_t size) {
+ if (mCache == nullptr) {
+ return;
+ }
+ ASSERT(value != nullptr);
+ ASSERT(size > 0);
+ mCache->StoreData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(), value,
+ size);
+ }
+
+ dawn_platform::CachingInterface* PersistentCache::GetPlatformCache() {
+ // TODO(dawn:549): Create a fingerprint of concatenated version strings (ex. Tint commit
+ // hash, Dawn commit hash). This will be used by the client so it may know when to discard
+ // previously cached Dawn objects should this fingerprint change.
+ dawn_platform::Platform* platform = mDevice->GetPlatform();
+ if (platform != nullptr) {
+ return platform->GetCachingInterface(/*fingerprint*/ nullptr, /*fingerprintSize*/ 0);
+ }
+ return nullptr;
+ }
+} // namespace dawn_native
\ No newline at end of file
diff --git a/src/dawn_native/PersistentCache.h b/src/dawn_native/PersistentCache.h
new file mode 100644
index 0000000..5e9dbc0
--- /dev/null
+++ b/src/dawn_native/PersistentCache.h
@@ -0,0 +1,86 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_PERSISTENTCACHE_H_
+#define DAWNNATIVE_PERSISTENTCACHE_H_
+
+#include "dawn_native/Error.h"
+
+#include <vector>
+
+namespace dawn_platform {
+ class CachingInterface;
+}
+
+namespace dawn_native {
+
+ using PersistentCacheKey = std::vector<uint8_t>;
+
+ struct ScopedCachedBlob {
+ std::unique_ptr<uint8_t[]> buffer;
+ size_t bufferSize = 0;
+ };
+
+ class DeviceBase;
+
+ enum class PersistentKeyType { Shader };
+
+ class PersistentCache {
+ public:
+ PersistentCache(DeviceBase* device);
+
+ // Combines load/store operations into a single call.
+ // If the load was successful, a non-empty blob is returned to the caller.
+ // Else, the creation callback |createFn| gets invoked with a callback
+ // |doCache| to store the newly created blob back in the cache.
+ //
+ // Example usage:
+ //
+ // ScopedCachedBlob cachedBlob = {};
+ // DAWN_TRY_ASSIGN(cachedBlob, GetOrCreate(key, [&](auto doCache)) {
+ // // Create a new blob to be stored
+ // doCache(newBlobPtr, newBlobSize); // store
+ // }));
+ //
+ template <typename CreateFn>
+ ResultOrError<ScopedCachedBlob> GetOrCreate(const PersistentCacheKey& key,
+ CreateFn&& createFn) {
+ // Attempt to load an existing blob from the cache.
+ ScopedCachedBlob blob = LoadData(key);
+ if (blob.bufferSize > 0) {
+ return std::move(blob);
+ }
+
+ // Allow the caller to create a new blob to be stored for the given key.
+ DAWN_TRY(createFn([this, key](const void* value, size_t size) {
+ this->StoreData(key, value, size);
+ }));
+
+ return std::move(blob);
+ }
+
+ private:
+ // PersistentCache impl
+ ScopedCachedBlob LoadData(const PersistentCacheKey& key);
+ void StoreData(const PersistentCacheKey& key, const void* value, size_t size);
+
+ dawn_platform::CachingInterface* GetPlatformCache();
+
+ DeviceBase* mDevice = nullptr;
+
+ dawn_platform::CachingInterface* mCache = nullptr;
+ };
+} // namespace dawn_native
+
+#endif // DAWNNATIVE_PERSISTENTCACHE_H_
\ No newline at end of file
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 68acce9..88aa240 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -43,44 +43,14 @@
ShaderModule* module = ToBackend(descriptor->computeStage.module);
- const char* entryPoint = descriptor->computeStage.entryPoint;
- std::string remappedEntryPoint;
- std::string hlslSource;
- if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
- DAWN_TRY_ASSIGN(hlslSource, module->TranslateToHLSLWithTint(
- entryPoint, SingleShaderStage::Compute,
- ToBackend(GetLayout()), &remappedEntryPoint));
- entryPoint = remappedEntryPoint.c_str();
-
- } else {
- DAWN_TRY_ASSIGN(hlslSource,
- module->TranslateToHLSLWithSPIRVCross(
- entryPoint, SingleShaderStage::Compute, ToBackend(GetLayout())));
-
- // Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
- entryPoint = "main";
- }
-
D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
- ComPtr<IDxcBlob> compiledDXCShader;
- ComPtr<ID3DBlob> compiledFXCShader;
- if (device->IsToggleEnabled(Toggle::UseDXC)) {
- DAWN_TRY_ASSIGN(compiledDXCShader,
- CompileShaderDXC(device, SingleShaderStage::Compute, hlslSource,
- entryPoint, compileFlags));
-
- d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
- d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
- } else {
- DAWN_TRY_ASSIGN(compiledFXCShader,
- CompileShaderFXC(device, SingleShaderStage::Compute, hlslSource,
- entryPoint, compileFlags));
- d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
- d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
- }
-
+ CompiledShader compiledShader;
+ DAWN_TRY_ASSIGN(compiledShader, module->Compile(descriptor->computeStage.entryPoint,
+ SingleShaderStage::Compute,
+ ToBackend(GetLayout()), compileFlags));
+ d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
IID_PPV_ARGS(&mPipelineState));
return {};
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index 5501308..f9a2b3f 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -306,44 +306,13 @@
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
- PerStage<ComPtr<ID3DBlob>> compiledFXCShader;
- PerStage<ComPtr<IDxcBlob>> compiledDXCShader;
-
+ PerStage<CompiledShader> compiledShader;
wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
for (auto stage : IterateStages(renderStages)) {
- std::string hlslSource;
- const char* entryPoint = GetStage(stage).entryPoint.c_str();
- std::string remappedEntryPoint;
-
- if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
- DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithTint(
- entryPoint, stage, ToBackend(GetLayout()),
- &remappedEntryPoint));
- entryPoint = remappedEntryPoint.c_str();
-
- } else {
- DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithSPIRVCross(
- entryPoint, stage, ToBackend(GetLayout())));
-
- // Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
- entryPoint = "main";
- }
-
- if (device->IsToggleEnabled(Toggle::UseDXC)) {
- DAWN_TRY_ASSIGN(
- compiledDXCShader[stage],
- CompileShaderDXC(device, stage, hlslSource, entryPoint, compileFlags));
-
- shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
- shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
- } else {
- DAWN_TRY_ASSIGN(
- compiledFXCShader[stage],
- CompileShaderFXC(device, stage, hlslSource, entryPoint, compileFlags));
-
- shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
- shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();
- }
+ DAWN_TRY_ASSIGN(compiledShader[stage],
+ modules[stage]->Compile(entryPoints[stage], stage,
+ ToBackend(GetLayout()), compileFlags));
+ *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
}
PipelineLayout* layout = ToBackend(GetLayout());
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index b0db7c3..7b10fc0 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -309,4 +309,102 @@
return compiler.compile();
}
+ ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
+ SingleShaderStage stage,
+ PipelineLayout* layout,
+ uint32_t compileFlags) {
+ Device* device = ToBackend(GetDevice());
+
+ // Compile the source shader to HLSL.
+ std::string hlslSource;
+ std::string remappedEntryPoint;
+ if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSLWithTint(entryPointName, stage, layout,
+ &remappedEntryPoint));
+ entryPointName = remappedEntryPoint.c_str();
+ } else {
+ DAWN_TRY_ASSIGN(hlslSource,
+ TranslateToHLSLWithSPIRVCross(entryPointName, stage, layout));
+
+ // Note that the HLSL will always use entryPoint "main" under
+ // SPIRV-cross.
+ entryPointName = "main";
+ }
+
+ // Use HLSL source as the input for the key since it does need to know about the pipeline
+ // layout. The pipeline layout is only required if we key from WGSL: two different pipeline
+ // layouts could be used to produce different shader blobs and the wrong shader blob could
+ // be loaded since the pipeline layout was missing from the key.
+ // TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used.
+ const PersistentCacheKey& shaderCacheKey =
+ CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags);
+
+ CompiledShader compiledShader = {};
+ DAWN_TRY_ASSIGN(compiledShader.cachedShader,
+ device->GetPersistentCache()->GetOrCreate(
+ shaderCacheKey, [&](auto doCache) -> MaybeError {
+ if (device->IsToggleEnabled(Toggle::UseDXC)) {
+ DAWN_TRY_ASSIGN(compiledShader.compiledDXCShader,
+ CompileShaderDXC(device, stage, hlslSource,
+ entryPointName, compileFlags));
+ } else {
+ DAWN_TRY_ASSIGN(compiledShader.compiledFXCShader,
+ CompileShaderFXC(device, stage, hlslSource,
+ entryPointName, compileFlags));
+ }
+ const D3D12_SHADER_BYTECODE shader =
+ compiledShader.GetD3D12ShaderBytecode();
+ doCache(shader.pShaderBytecode, shader.BytecodeLength);
+ return {};
+ }));
+
+ return std::move(compiledShader);
+ }
+
+ D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
+ if (cachedShader.buffer != nullptr) {
+ return {cachedShader.buffer.get(), cachedShader.bufferSize};
+ } else if (compiledFXCShader != nullptr) {
+ return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
+ } else if (compiledDXCShader != nullptr) {
+ return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
+ }
+ UNREACHABLE();
+ return {};
+ }
+
+ PersistentCacheKey ShaderModule::CreateHLSLKey(const char* entryPointName,
+ SingleShaderStage stage,
+ const std::string& hlslSource,
+ uint32_t compileFlags) const {
+ std::stringstream stream;
+
+ // Prefix the key with the type to avoid collisions from another type that could have the
+ // same key.
+ stream << static_cast<uint32_t>(PersistentKeyType::Shader);
+
+ // Provide "guard" strings that the user cannot provide to help ensure the generated HLSL
+ // used to create this key is not being manufactured by the user to load the wrong shader
+ // blob.
+ // These strings can be HLSL comments because Tint does not emit HLSL comments.
+ // TODO(dawn:549): Replace guards strings with something more secure.
+ ASSERT(hlslSource.find("//") == std::string::npos);
+
+ stream << "// Start shader autogenerated by Dawn.";
+ stream << hlslSource;
+ stream << "// End of shader autogenerated by Dawn.";
+
+ stream << compileFlags;
+
+ // TODO(dawn:549): add the HLSL compiler version for good measure.
+
+ // If the source contains multiple entry points, ensure they are cached seperately
+ // per stage since DX shader code can only be compiled per stage using the same
+ // entry point.
+ stream << static_cast<uint32_t>(stage);
+ stream << entryPointName;
+
+ return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
+ std::istreambuf_iterator<char>{});
+ }
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 4fc2532..63cb2e9 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -15,6 +15,7 @@
#ifndef DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
#define DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
+#include "dawn_native/PersistentCache.h"
#include "dawn_native/ShaderModule.h"
#include "dawn_native/d3d12/d3d12_platform.h"
@@ -24,22 +25,28 @@
class Device;
class PipelineLayout;
- ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
- SingleShaderStage stage,
- const std::string& hlslSource,
- const char* entryPoint,
- uint32_t compileFlags);
- ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
- SingleShaderStage stage,
- const std::string& hlslSource,
- const char* entryPoint,
- uint32_t compileFlags);
+ // Manages a ref to one of the various representations of shader blobs.
+ struct CompiledShader {
+ ScopedCachedBlob cachedShader;
+ ComPtr<ID3DBlob> compiledFXCShader;
+ ComPtr<IDxcBlob> compiledDXCShader;
+ D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
+ };
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor);
+ ResultOrError<CompiledShader> Compile(const char* entryPointName,
+ SingleShaderStage stage,
+ PipelineLayout* layout,
+ uint32_t compileFlags);
+
+ private:
+ ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
+ ~ShaderModule() override = default;
+
ResultOrError<std::string> TranslateToHLSLWithTint(
const char* entryPointName,
SingleShaderStage stage,
@@ -50,9 +57,10 @@
SingleShaderStage stage,
PipelineLayout* layout) const;
- private:
- ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
- ~ShaderModule() override = default;
+ PersistentCacheKey CreateHLSLKey(const char* entryPointName,
+ SingleShaderStage stage,
+ const std::string& hlslSource,
+ uint32_t compileFlags) const;
};
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_platform/BUILD.gn b/src/dawn_platform/BUILD.gn
index d26a859..91c9e75 100644
--- a/src/dawn_platform/BUILD.gn
+++ b/src/dawn_platform/BUILD.gn
@@ -31,4 +31,9 @@
]
deps = [ "${dawn_root}/src/common" ]
+
+ public_deps = [
+ # DawnPlatform.h has #include <dawn/webgpu.h>
+ "${dawn_root}/src/dawn:dawn_headers",
+ ]
}
diff --git a/src/dawn_platform/CMakeLists.txt b/src/dawn_platform/CMakeLists.txt
index 67b1b92..b8075e2 100644
--- a/src/dawn_platform/CMakeLists.txt
+++ b/src/dawn_platform/CMakeLists.txt
@@ -27,4 +27,4 @@
"tracing/EventTracer.h"
"tracing/TraceEvent.h"
)
-target_link_libraries(dawn_platform PRIVATE dawn_internal_config dawn_common)
+target_link_libraries(dawn_platform PUBLIC dawn_headers PRIVATE dawn_internal_config dawn_common)
diff --git a/src/dawn_platform/DawnPlatform.cpp b/src/dawn_platform/DawnPlatform.cpp
index 6fe61d5..6b71708 100644
--- a/src/dawn_platform/DawnPlatform.cpp
+++ b/src/dawn_platform/DawnPlatform.cpp
@@ -14,10 +14,45 @@
#include "dawn_platform/DawnPlatform.h"
+#include "common/Assert.h"
+
namespace dawn_platform {
+ CachingInterface::CachingInterface() = default;
+
+ CachingInterface::~CachingInterface() = default;
+
Platform::Platform() = default;
Platform::~Platform() = default;
+ const unsigned char* Platform::GetTraceCategoryEnabledFlag(TraceCategory category) {
+ static unsigned char disabled = 0;
+ return &disabled;
+ }
+
+ double Platform::MonotonicallyIncreasingTime() {
+ return 0;
+ }
+
+ uint64_t Platform::AddTraceEvent(char phase,
+ const unsigned char* categoryGroupEnabled,
+ const char* name,
+ uint64_t id,
+ double timestamp,
+ int numArgs,
+ const char** argNames,
+ const unsigned char* argTypes,
+ const uint64_t* argValues,
+ unsigned char flags) {
+ // AddTraceEvent cannot be called if events are disabled.
+ ASSERT(false);
+ return 0;
+ }
+
+ dawn_platform::CachingInterface* Platform::GetCachingInterface(const void* fingerprint,
+ size_t fingerprintSize) {
+ return nullptr;
+ }
+
} // namespace dawn_platform
\ No newline at end of file
diff --git a/src/include/dawn_platform/DawnPlatform.h b/src/include/dawn_platform/DawnPlatform.h
index 107b91e..4a00f53 100644
--- a/src/include/dawn_platform/DawnPlatform.h
+++ b/src/include/dawn_platform/DawnPlatform.h
@@ -17,8 +17,11 @@
#include "dawn_platform/dawn_platform_export.h"
+#include <cstddef>
#include <cstdint>
+#include <dawn/webgpu.h>
+
namespace dawn_platform {
enum class TraceCategory {
@@ -28,14 +31,43 @@
GPUWork, // Actual GPU work
};
+ class DAWN_PLATFORM_EXPORT CachingInterface {
+ public:
+ CachingInterface();
+ virtual ~CachingInterface();
+
+ // LoadData has two modes. The first mode is used to get a value which
+ // corresponds to the |key|. The |valueOut| is a caller provided buffer
+ // allocated to the size |valueSize| which is loaded with data of the
+ // size returned. The second mode is used to query for the existence of
+ // the |key| where |valueOut| is nullptr and |valueSize| must be 0.
+ // The return size is non-zero if the |key| exists.
+ virtual size_t LoadData(const WGPUDevice device,
+ const void* key,
+ size_t keySize,
+ void* valueOut,
+ size_t valueSize) = 0;
+
+ // StoreData puts a |value| in the cache which corresponds to the |key|.
+ virtual void StoreData(const WGPUDevice device,
+ const void* key,
+ size_t keySize,
+ const void* value,
+ size_t valueSize) = 0;
+
+ private:
+ CachingInterface(const CachingInterface&) = delete;
+ CachingInterface& operator=(const CachingInterface&) = delete;
+ };
+
class DAWN_PLATFORM_EXPORT Platform {
public:
Platform();
virtual ~Platform();
- virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category) = 0;
+ virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category);
- virtual double MonotonicallyIncreasingTime() = 0;
+ virtual double MonotonicallyIncreasingTime();
virtual uint64_t AddTraceEvent(char phase,
const unsigned char* categoryGroupEnabled,
@@ -46,7 +78,13 @@
const char** argNames,
const unsigned char* argTypes,
const uint64_t* argValues,
- unsigned char flags) = 0;
+ unsigned char flags);
+
+ // The |fingerprint| is provided by Dawn to inform the client to discard the Dawn caches
+ // when the fingerprint changes. The returned CachingInterface is expected to outlive the
+ // device which uses it to persistently cache objects.
+ virtual CachingInterface* GetCachingInterface(const void* fingerprint,
+ size_t fingerprintSize);
private:
Platform(const Platform&) = delete;
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index cfd4805..3fd7253 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -333,7 +333,10 @@
libs = []
if (dawn_enable_d3d12) {
- sources += [ "end2end/D3D12ResourceWrappingTests.cpp" ]
+ sources += [
+ "end2end/D3D12CachingTests.cpp",
+ "end2end/D3D12ResourceWrappingTests.cpp",
+ ]
libs += [
"d3d11.lib",
"dxgi.lib",
diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp
index 7743875..ded6059 100644
--- a/src/tests/DawnTest.cpp
+++ b/src/tests/DawnTest.cpp
@@ -744,6 +744,10 @@
mBackendAdapter = *it;
}
+ // Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform.
+ mTestPlatform = CreateTestPlatform();
+ gTestEnv->GetInstance()->SetPlatform(mTestPlatform.get());
+
// Create the device from the adapter
for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) {
ASSERT(gTestEnv->GetInstance()->GetToggleInfo(forceEnabledWorkaround) != nullptr);
@@ -1080,6 +1084,10 @@
}
}
+std::unique_ptr<dawn_platform::Platform> DawnTestBase::CreateTestPlatform() {
+ return nullptr;
+}
+
bool RGBA8::operator==(const RGBA8& other) const {
return r == other.r && g == other.g && b == other.b && a == other.a;
}
diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h
index f2f3106..264f6f4 100644
--- a/src/tests/DawnTest.h
+++ b/src/tests/DawnTest.h
@@ -20,6 +20,7 @@
#include "dawn/webgpu_cpp.h"
#include "dawn_native/DawnNative.h"
+#include <dawn_platform/DawnPlatform.h>
#include <gtest/gtest.h>
#include <memory>
@@ -268,6 +269,8 @@
wgpu::Instance GetInstance() const;
dawn_native::Adapter GetAdapter() const;
+ virtual std::unique_ptr<dawn_platform::Platform> CreateTestPlatform();
+
protected:
wgpu::Device device;
wgpu::Queue queue;
@@ -403,6 +406,8 @@
void ResolveExpectations();
dawn_native::Adapter mBackendAdapter;
+
+ std::unique_ptr<dawn_platform::Platform> mTestPlatform;
};
// Skip a test when the given condition is satisfied.
diff --git a/src/tests/end2end/D3D12CachingTests.cpp b/src/tests/end2end/D3D12CachingTests.cpp
new file mode 100644
index 0000000..27690b7
--- /dev/null
+++ b/src/tests/end2end/D3D12CachingTests.cpp
@@ -0,0 +1,287 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnTest.h"
+
+#include "utils/ComboRenderPipelineDescriptor.h"
+#include "utils/WGPUHelpers.h"
+
+#define EXPECT_CACHE_HIT(N, statement) \
+ do { \
+ size_t before = mPersistentCache.mHitCount; \
+ statement; \
+ FlushWire(); \
+ size_t after = mPersistentCache.mHitCount; \
+ EXPECT_EQ(N, after - before); \
+ } while (0)
+
+// FakePersistentCache implements a in-memory persistent cache.
+class FakePersistentCache : public dawn_platform::CachingInterface {
+ public:
+ // PersistentCache API
+ void StoreData(const WGPUDevice device,
+ const void* key,
+ size_t keySize,
+ const void* value,
+ size_t valueSize) override {
+ if (mIsDisabled)
+ return;
+ const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
+
+ const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
+ std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
+
+ EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
+ }
+
+ size_t LoadData(const WGPUDevice device,
+ const void* key,
+ size_t keySize,
+ void* value,
+ size_t valueSize) override {
+ const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
+ auto entry = mCache.find(keyStr);
+ if (entry == mCache.end()) {
+ return 0;
+ }
+ if (valueSize >= entry->second.size()) {
+ memcpy(value, entry->second.data(), entry->second.size());
+ }
+ mHitCount++;
+ return entry->second.size();
+ }
+
+ using Blob = std::vector<uint8_t>;
+ using FakeCache = std::unordered_map<std::string, Blob>;
+
+ FakeCache mCache;
+
+ size_t mHitCount = 0;
+ bool mIsDisabled = false;
+};
+
+// Test platform that only supports caching.
+class DawnTestPlatform : public dawn_platform::Platform {
+ public:
+ DawnTestPlatform(dawn_platform::CachingInterface* cachingInterface)
+ : mCachingInterface(cachingInterface) {
+ }
+ ~DawnTestPlatform() override = default;
+
+ dawn_platform::CachingInterface* GetCachingInterface(const void* fingerprint,
+ size_t fingerprintSize) override {
+ return mCachingInterface;
+ }
+
+ dawn_platform::CachingInterface* mCachingInterface = nullptr;
+};
+
+class D3D12CachingTests : public DawnTest {
+ protected:
+ std::unique_ptr<dawn_platform::Platform> CreateTestPlatform() override {
+ return std::make_unique<DawnTestPlatform>(&mPersistentCache);
+ }
+
+ FakePersistentCache mPersistentCache;
+};
+
+// Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled.
+TEST_P(D3D12CachingTests, SameShaderNoCache) {
+ mPersistentCache.mIsDisabled = true;
+
+ wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
+ [[builtin(position)]] var<out> Position : vec4<f32>;
+
+ [[stage(vertex)]]
+ fn vertex_main() -> void {
+ Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+ return;
+ }
+
+ [[location(0)]] var<out> outColor : vec4<f32>;
+
+ [[stage(fragment)]]
+ fn fragment_main() -> void {
+ outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+ return;
+ }
+ )");
+
+ // Store the WGSL shader into the cache.
+ {
+ utils::ComboRenderPipelineDescriptor desc(device);
+ desc.vertexStage.module = module;
+ desc.vertexStage.entryPoint = "vertex_main";
+ desc.cFragmentStage.module = module;
+ desc.cFragmentStage.entryPoint = "fragment_main";
+
+ EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+ }
+
+ EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
+
+ // Load the same WGSL shader from the cache.
+ {
+ utils::ComboRenderPipelineDescriptor desc(device);
+ desc.vertexStage.module = module;
+ desc.vertexStage.entryPoint = "vertex_main";
+ desc.cFragmentStage.module = module;
+ desc.cFragmentStage.entryPoint = "fragment_main";
+
+ EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+ }
+
+ EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
+}
+
+// Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
+// of HLSL shaders. WGSL shader should result into caching 2 HLSL shaders (stage x
+// entrypoints)
+TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
+ wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
+ [[builtin(position)]] var<out> Position : vec4<f32>;
+
+ [[stage(vertex)]]
+ fn vertex_main() -> void {
+ Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
+ return;
+ }
+
+ [[location(0)]] var<out> outColor : vec4<f32>;
+
+ [[stage(fragment)]]
+ fn fragment_main() -> void {
+ outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
+ return;
+ }
+ )");
+
+ // Store the WGSL shader into the cache.
+ {
+ utils::ComboRenderPipelineDescriptor desc(device);
+ desc.vertexStage.module = module;
+ desc.vertexStage.entryPoint = "vertex_main";
+ desc.cFragmentStage.module = module;
+ desc.cFragmentStage.entryPoint = "fragment_main";
+
+ EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+ }
+
+ EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+
+ // Load the same WGSL shader from the cache.
+ {
+ utils::ComboRenderPipelineDescriptor desc(device);
+ desc.vertexStage.module = module;
+ desc.vertexStage.entryPoint = "vertex_main";
+ desc.cFragmentStage.module = module;
+ desc.cFragmentStage.entryPoint = "fragment_main";
+
+ // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
+ // kNumOfShaders hits.
+ EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
+ }
+
+ EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+
+ // Modify the WGSL shader functions and make sure it doesn't hit.
+ wgpu::ShaderModule newModule = utils::CreateShaderModuleFromWGSL(device, R"(
+ [[builtin(position)]] var<out> Position : vec4<f32>;
+
+ [[stage(vertex)]]
+ fn vertex_main() -> void {
+ Position = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ return;
+ }
+
+ [[location(0)]] var<out> outColor : vec4<f32>;
+
+ [[stage(fragment)]]
+ fn fragment_main() -> void {
+ outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
+ return;
+ }
+ )");
+
+ {
+ utils::ComboRenderPipelineDescriptor desc(device);
+ desc.vertexStage.module = newModule;
+ desc.vertexStage.entryPoint = "vertex_main";
+ desc.cFragmentStage.module = newModule;
+ desc.cFragmentStage.entryPoint = "fragment_main";
+ EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
+ }
+
+ // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
+ // kNumOfShaders hits.
+ EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
+}
+
+// Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
+// of HLSL shaders. WGSL shader should result into caching 1 HLSL shader (stage x entrypoints)
+TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
+ wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
+ [[block]] struct Data {
+ [[offset(0)]] data : u32;
+ };
+ [[binding(0), set(0)]] var<storage_buffer> data : Data;
+
+ [[stage(compute)]]
+ fn write1() -> void {
+ data.data = 1u;
+ return;
+ }
+
+ [[stage(compute)]]
+ fn write42() -> void {
+ data.data = 42u;
+ return;
+ }
+ )");
+
+ // Store the WGSL shader into the cache.
+ {
+ wgpu::ComputePipelineDescriptor desc;
+ desc.computeStage.module = module;
+ desc.computeStage.entryPoint = "write1";
+ EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
+
+ desc.computeStage.module = module;
+ desc.computeStage.entryPoint = "write42";
+ EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
+ }
+
+ EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+
+ // Load the same WGSL shader from the cache.
+ {
+ wgpu::ComputePipelineDescriptor desc;
+ desc.computeStage.module = module;
+ desc.computeStage.entryPoint = "write1";
+
+ // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
+ // kNumOfShaders hits.
+ EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
+
+ desc.computeStage.module = module;
+ desc.computeStage.entryPoint = "write42";
+
+ // Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
+ EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
+ }
+
+ EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
+}
+
+DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());
\ No newline at end of file