Cache WGSL -> DXBC/DXIL compilation
Bug: dawn:1480
Change-Id: I858111f62be457c2e7cd5017bbf4c10e76395e83
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95340
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/native/Blob.cpp b/src/dawn/native/Blob.cpp
index cecac30..78b18e7 100644
--- a/src/dawn/native/Blob.cpp
+++ b/src/dawn/native/Blob.cpp
@@ -17,6 +17,7 @@
#include "dawn/common/Assert.h"
#include "dawn/common/Math.h"
#include "dawn/native/Blob.h"
+#include "dawn/native/stream/Stream.h"
namespace dawn::native {
@@ -99,4 +100,29 @@
*this = std::move(blob);
}
+template <>
+void stream::Stream<Blob>::Write(stream::Sink* s, const Blob& b) {
+ size_t size = b.Size();
+ StreamIn(s, size);
+ if (size > 0) {
+ void* ptr = s->GetSpace(size);
+ memcpy(ptr, b.Data(), size);
+ }
+}
+
+template <>
+MaybeError stream::Stream<Blob>::Read(stream::Source* s, Blob* b) {
+ size_t size;
+ DAWN_TRY(StreamOut(s, &size));
+ if (size > 0) {
+ const void* ptr;
+ DAWN_TRY(s->Read(&ptr, size));
+ *b = CreateBlob(size);
+ memcpy(b->Data(), ptr, size);
+ } else {
+ *b = Blob();
+ }
+ return {};
+}
+
} // namespace dawn::native
diff --git a/src/dawn/native/StreamImplTint.cpp b/src/dawn/native/StreamImplTint.cpp
index 1c0c04d..13a70ca 100644
--- a/src/dawn/native/StreamImplTint.cpp
+++ b/src/dawn/native/StreamImplTint.cpp
@@ -14,7 +14,8 @@
#include "dawn/native/stream/Stream.h"
-#include "tint/tint.h"
+#include "dawn/native/TintUtils.h"
+#include "tint/writer/array_length_from_uniform_options.h"
namespace dawn::native {
@@ -96,4 +97,21 @@
StreamIn(sink, attrib.format, attrib.offset, attrib.shader_location);
}
+// static
+template <>
+void stream::Stream<tint::writer::ArrayLengthFromUniformOptions>::Write(
+ stream::Sink* sink,
+ const tint::writer::ArrayLengthFromUniformOptions& o) {
+ static_assert(offsetof(tint::writer::ArrayLengthFromUniformOptions, ubo_binding) == 0,
+ "Please update serialization for tint::writer::ArrayLengthFromUniformOptions");
+ static_assert(
+ offsetof(tint::writer::ArrayLengthFromUniformOptions, bindpoint_to_size_index) == 8,
+ "Please update serialization for tint::writer::ArrayLengthFromUniformOptions");
+ static_assert(
+ sizeof(tint::writer::ArrayLengthFromUniformOptions) ==
+ 8 + sizeof(tint::writer::ArrayLengthFromUniformOptions::bindpoint_to_size_index),
+ "Please update serialization for tint::writer::ArrayLengthFromUniformOptions");
+ StreamIn(sink, o.ubo_binding, o.bindpoint_to_size_index);
+}
+
} // namespace dawn::native
diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp
index c1585a5..f59e6e5 100644
--- a/src/dawn/native/TintUtils.cpp
+++ b/src/dawn/native/TintUtils.cpp
@@ -185,7 +185,10 @@
} // namespace dawn::native
-bool std::less<tint::sem::BindingPoint>::operator()(const tint::sem::BindingPoint& a,
- const tint::sem::BindingPoint& b) const {
+namespace tint::sem {
+
+bool operator<(const BindingPoint& a, const BindingPoint& b) {
return std::tie(a.group, a.binding) < std::tie(b.group, b.binding);
}
+
+} // namespace tint::sem
diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h
index e17fc69..7c03881 100644
--- a/src/dawn/native/TintUtils.h
+++ b/src/dawn/native/TintUtils.h
@@ -49,10 +49,11 @@
} // namespace dawn::native
-// std::less operator for std::map containing BindingPoint
-template <>
-struct std::less<tint::sem::BindingPoint> {
- bool operator()(const tint::sem::BindingPoint& a, const tint::sem::BindingPoint& b) const;
-};
+namespace tint::sem {
+
+// Defin operator< for std::map containing BindingPoint
+bool operator<(const BindingPoint& a, const BindingPoint& b);
+
+} // namespace tint::sem
#endif // SRC_DAWN_NATIVE_TINTUTILS_H_
diff --git a/src/dawn/native/d3d12/BlobD3D12.cpp b/src/dawn/native/d3d12/BlobD3D12.cpp
index ef9bbb9..3b596575 100644
--- a/src/dawn/native/d3d12/BlobD3D12.cpp
+++ b/src/dawn/native/d3d12/BlobD3D12.cpp
@@ -28,4 +28,16 @@
});
}
+Blob CreateBlob(ComPtr<IDxcBlob> blob) {
+ // Detach so the deleter callback can "own" the reference
+ IDxcBlob* ptr = blob.Detach();
+ return Blob::UnsafeCreateWithDeleter(reinterpret_cast<uint8_t*>(ptr->GetBufferPointer()),
+ ptr->GetBufferSize(), [=]() {
+ // Reattach and drop to delete it.
+ ComPtr<IDxcBlob> b;
+ b.Attach(ptr);
+ b = nullptr;
+ });
+}
+
} // namespace dawn::native
diff --git a/src/dawn/native/d3d12/BlobD3D12.h b/src/dawn/native/d3d12/BlobD3D12.h
index 563ac73..cc8c99c 100644
--- a/src/dawn/native/d3d12/BlobD3D12.h
+++ b/src/dawn/native/d3d12/BlobD3D12.h
@@ -18,5 +18,6 @@
namespace dawn::native {
Blob CreateBlob(ComPtr<ID3DBlob> blob);
+Blob CreateBlob(ComPtr<IDxcBlob> blob);
} // namespace dawn::native
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index 3e09cc0..2fd04bd 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -29,79 +29,41 @@
#include "dawn/common/Log.h"
#include "dawn/common/WindowsUtils.h"
#include "dawn/native/CacheKey.h"
+#include "dawn/native/CacheRequest.h"
#include "dawn/native/Pipeline.h"
#include "dawn/native/TintUtils.h"
#include "dawn/native/d3d12/AdapterD3D12.h"
#include "dawn/native/d3d12/BackendD3D12.h"
#include "dawn/native/d3d12/BindGroupLayoutD3D12.h"
+#include "dawn/native/d3d12/BlobD3D12.h"
#include "dawn/native/d3d12/D3D12Error.h"
#include "dawn/native/d3d12/DeviceD3D12.h"
#include "dawn/native/d3d12/PipelineLayoutD3D12.h"
#include "dawn/native/d3d12/PlatformFunctions.h"
#include "dawn/native/d3d12/UtilsD3D12.h"
+#include "dawn/native/stream/BlobSource.h"
+#include "dawn/native/stream/ByteVectorSink.h"
#include "dawn/platform/DawnPlatform.h"
#include "dawn/platform/tracing/TraceEvent.h"
#include "tint/tint.h"
+namespace dawn::native::stream {
+
+// Define no-op serializations for pD3DCompile, IDxcLibrary, and IDxcCompiler.
+// These are output-only interfaces used to generate bytecode.
+template <>
+void Stream<IDxcLibrary*>::Write(Sink*, IDxcLibrary* const&) {}
+template <>
+void Stream<IDxcCompiler*>::Write(Sink*, IDxcCompiler* const&) {}
+template <>
+void Stream<pD3DCompile>::Write(Sink*, pD3DCompile const&) {}
+
+} // namespace dawn::native::stream
+
namespace dawn::native::d3d12 {
namespace {
-uint64_t GetD3DCompilerVersion() {
- return D3D_COMPILER_VERSION;
-}
-
-struct CompareBindingPoint {
- constexpr bool operator()(const tint::transform::BindingPoint& lhs,
- const tint::transform::BindingPoint& rhs) const {
- if (lhs.group != rhs.group) {
- return lhs.group < rhs.group;
- } else {
- return lhs.binding < rhs.binding;
- }
- }
-};
-
-void StreamIn(std::stringstream& output, const tint::ast::Access& access) {
- output << access;
-}
-
-void StreamIn(std::stringstream& output, const tint::transform::BindingPoint& binding_point) {
- output << "(BindingPoint";
- output << " group=" << binding_point.group;
- output << " binding=" << binding_point.binding;
- output << ")";
-}
-
-template <typename T, typename = typename std::enable_if<std::is_fundamental<T>::value>::type>
-void StreamIn(std::stringstream& output, const T& val) {
- output << val;
-}
-
-template <typename T>
-void StreamIn(std::stringstream& output,
- const std::unordered_map<tint::transform::BindingPoint, T>& map) {
- output << "(map";
-
- std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(), map.end());
- for (auto& [bindingPoint, value] : sorted) {
- output << " ";
- StreamIn(output, bindingPoint);
- output << "=";
- StreamIn(output, value);
- }
- output << ")";
-}
-
-void StreamIn(std::stringstream& output,
- const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
- output << "(ArrayLengthFromUniformOptions";
- output << " ubo_binding=";
- StreamIn(output, arrayLengthFromUniform.ubo_binding);
- output << " bindpoint_to_size_index=";
- StreamIn(output, arrayLengthFromUniform.bindpoint_to_size_index);
- output << ")";
-}
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
@@ -130,253 +92,113 @@
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
-void GetOverridableConstantsDefines(
- std::vector<std::pair<std::string, std::string>>* defineStrings,
- const PipelineConstantEntries* pipelineConstantEntries,
- const EntryPointMetadata::OverridesMap* shaderEntryPointConstants) {
+using DefineStrings = std::vector<std::pair<std::string, std::string>>;
+
+DefineStrings GetOverridableConstantsDefines(
+ const PipelineConstantEntries& pipelineConstantEntries,
+ const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) {
+ DefineStrings defineStrings;
std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values
- for (const auto& [name, value] : *pipelineConstantEntries) {
+ for (const auto& [name, value] : pipelineConstantEntries) {
overriddenConstants.insert(name);
// This is already validated so `name` must exist
- const auto& moduleConstant = shaderEntryPointConstants->at(name);
+ const auto& moduleConstant = shaderEntryPointConstants.at(name);
- defineStrings->emplace_back(
+ defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, nullptr, value));
}
// Set shader initialized default values
- for (const auto& iter : *shaderEntryPointConstants) {
+ for (const auto& iter : shaderEntryPointConstants) {
const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
- const auto& moduleConstant = shaderEntryPointConstants->at(name);
+ const auto& moduleConstant = shaderEntryPointConstants.at(name);
// Uninitialized default values are okay since they ar only defined to pass
// compilation but not used
- defineStrings->emplace_back(
+ defineStrings.emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
}
+ return defineStrings;
}
-// The inputs to a shader compilation. These have been intentionally isolated from the
-// device to help ensure that the pipeline cache key contains all inputs for compilation.
-struct ShaderCompilationRequest {
- enum Compiler { FXC, DXC };
+enum class Compiler { FXC, DXC };
- // Common inputs
- Compiler compiler;
- const tint::Program* program;
- const char* entryPointName;
- SingleShaderStage stage;
- uint32_t compileFlags;
- bool disableSymbolRenaming;
- tint::transform::BindingRemapper::BindingPoints remappedBindingPoints;
- tint::transform::BindingRemapper::AccessControls remappedAccessControls;
- bool isRobustnessEnabled;
- bool usesNumWorkgroups;
- uint32_t numWorkgroupsRegisterSpace;
- uint32_t numWorkgroupsShaderRegister;
- tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
- std::vector<std::pair<std::string, std::string>> defineStrings;
+#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \
+ X(const tint::Program*, inputProgram) \
+ X(std::string_view, entryPointName) \
+ X(SingleShaderStage, stage) \
+ X(uint32_t, shaderModel) \
+ X(uint32_t, compileFlags) \
+ X(Compiler, compiler) \
+ X(uint64_t, compilerVersion) \
+ X(std::wstring_view, dxcShaderProfile) \
+ X(std::string_view, fxcShaderProfile) \
+ X(pD3DCompile, d3dCompile) \
+ X(IDxcLibrary*, dxcLibrary) \
+ X(IDxcCompiler*, dxcCompiler) \
+ X(uint32_t, firstIndexOffsetShaderRegister) \
+ X(uint32_t, firstIndexOffsetRegisterSpace) \
+ X(bool, usesNumWorkgroups) \
+ X(uint32_t, numWorkgroupsShaderRegister) \
+ X(uint32_t, numWorkgroupsRegisterSpace) \
+ X(DefineStrings, defineStrings) \
+ X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
+ X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \
+ X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \
+ X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \
+ X(bool, disableSymbolRenaming) \
+ X(bool, isRobustnessEnabled) \
+ X(bool, disableWorkgroupInit) \
+ X(bool, dumpShaders)
- // FXC/DXC common inputs
- bool disableWorkgroupInit;
+#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \
+ X(bool, hasShaderFloat16Feature) \
+ X(uint32_t, compileFlags) \
+ X(Compiler, compiler) \
+ X(uint64_t, compilerVersion) \
+ X(std::wstring_view, dxcShaderProfile) \
+ X(std::string_view, fxcShaderProfile) \
+ X(pD3DCompile, d3dCompile) \
+ X(IDxcLibrary*, dxcLibrary) \
+ X(IDxcCompiler*, dxcCompiler) \
+ X(DefineStrings, defineStrings)
- // FXC inputs
- uint64_t fxcVersion;
+struct HlslCompilationRequest {
+ DAWN_VISITABLE_MEMBERS(HLSL_COMPILATION_REQUEST_MEMBERS)
- // DXC inputs
- uint64_t dxcVersion;
- const D3D12DeviceInfo* deviceInfo;
- bool hasShaderFloat16Feature;
-
- static ResultOrError<ShaderCompilationRequest> Create(
- const char* entryPointName,
- SingleShaderStage stage,
- const PipelineLayout* layout,
- uint32_t compileFlags,
- const Device* device,
- const tint::Program* program,
- const EntryPointMetadata& entryPoint,
- const ProgrammableStage& programmableStage) {
- Compiler compiler;
- uint64_t dxcVersion = 0;
- if (device->IsToggleEnabled(Toggle::UseDXC)) {
- compiler = Compiler::DXC;
- DAWN_TRY_ASSIGN(dxcVersion,
- ToBackend(device->GetAdapter())->GetBackend()->GetDXCompilerVersion());
- } else {
- compiler = Compiler::FXC;
- }
-
- using tint::transform::BindingPoint;
- using tint::transform::BindingRemapper;
-
- BindingRemapper::BindingPoints remappedBindingPoints;
- BindingRemapper::AccessControls remappedAccessControls;
-
- tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
- arrayLengthFromUniform.ubo_binding = {
- layout->GetDynamicStorageBufferLengthsRegisterSpace(),
- layout->GetDynamicStorageBufferLengthsShaderRegister()};
-
- const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
- for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
- const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
- const auto& groupBindingInfo = moduleBindingInfo[group];
-
- // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
- // the Tint AST to make the "bindings" decoration match the offset chosen by
- // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
- // assigned to each interface variable.
- for (const auto& [binding, bindingInfo] : groupBindingInfo) {
- BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
- BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
- static_cast<uint32_t>(binding)};
- BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
- bgl->GetShaderRegister(bindingIndex)};
- if (srcBindingPoint != dstBindingPoint) {
- remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
- }
-
- // Declaring a read-only storage buffer in HLSL but specifying a storage
- // buffer in the BGL produces the wrong output. Force read-only storage
- // buffer bindings to be treated as UAV instead of SRV. Internal storage
- // buffer is a storage buffer used in the internal pipeline.
- const bool forceStorageBufferAsUAV =
- (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
- (bgl->GetBindingInfo(bindingIndex).buffer.type ==
- wgpu::BufferBindingType::Storage ||
- bgl->GetBindingInfo(bindingIndex).buffer.type ==
- kInternalStorageBufferBinding));
- if (forceStorageBufferAsUAV) {
- remappedAccessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
- }
- }
-
- // Add arrayLengthFromUniform options
- {
- for (const auto& bindingAndRegisterOffset :
- layout->GetDynamicStorageBufferLengthInfo()[group].bindingAndRegisterOffsets) {
- BindingNumber binding = bindingAndRegisterOffset.binding;
- uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;
-
- BindingPoint bindingPoint{static_cast<uint32_t>(group),
- static_cast<uint32_t>(binding)};
- // Get the renamed binding point if it was remapped.
- auto it = remappedBindingPoints.find(bindingPoint);
- if (it != remappedBindingPoints.end()) {
- bindingPoint = it->second;
- }
-
- arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
- registerOffset);
- }
- }
- }
-
- ShaderCompilationRequest request;
- request.compiler = compiler;
- request.program = program;
- request.entryPointName = entryPointName;
- request.stage = stage;
- request.compileFlags = compileFlags;
- request.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
- request.remappedBindingPoints = std::move(remappedBindingPoints);
- request.remappedAccessControls = std::move(remappedAccessControls);
- request.isRobustnessEnabled = device->IsRobustnessEnabled();
- request.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
- request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
- request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
- request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
- request.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
- request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
- request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
- request.deviceInfo = &device->GetDeviceInfo();
- request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
-
- GetOverridableConstantsDefines(
- &request.defineStrings, &programmableStage.constants,
- &programmableStage.module->GetEntryPoint(programmableStage.entryPoint).overrides);
-
- return std::move(request);
- }
-
- // TODO(dawn:1341): Move to use CacheKey instead of the vector.
- ResultOrError<std::vector<uint8_t>> CreateCacheKey() const {
- // Generate the WGSL from the Tint program so it's normalized.
- // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more
- // compact representation.
- auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{});
- if (!result.success) {
- std::ostringstream errorStream;
- errorStream << "Tint WGSL failure:" << std::endl;
- errorStream << "Generator: " << result.error << std::endl;
- return DAWN_INTERNAL_ERROR(errorStream.str().c_str());
- }
-
- std::stringstream stream;
-
- // Prefix the key with the type to avoid collisions from another type that could
- // have the same key.
- stream << static_cast<uint32_t>(CacheKey::Type::Shader);
- stream << "\n";
-
- stream << result.wgsl.length();
- stream << "\n";
-
- stream << result.wgsl;
- stream << "\n";
-
- stream << "(ShaderCompilationRequest";
- stream << " compiler=" << compiler;
- stream << " entryPointName=" << entryPointName;
- stream << " stage=" << uint32_t(stage);
- stream << " compileFlags=" << compileFlags;
- stream << " disableSymbolRenaming=" << disableSymbolRenaming;
-
- stream << " remappedBindingPoints=";
- StreamIn(stream, remappedBindingPoints);
-
- stream << " remappedAccessControls=";
- StreamIn(stream, remappedAccessControls);
-
- stream << " useNumWorkgroups=" << usesNumWorkgroups;
- stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
- stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
-
- stream << " arrayLengthFromUniform=";
- StreamIn(stream, arrayLengthFromUniform);
-
- stream << " shaderModel=" << deviceInfo->shaderModel;
- stream << " disableWorkgroupInit=" << disableWorkgroupInit;
- stream << " isRobustnessEnabled=" << isRobustnessEnabled;
- stream << " fxcVersion=" << fxcVersion;
- stream << " dxcVersion=" << dxcVersion;
- stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
-
- stream << " defines={";
- for (const auto& [name, value] : defineStrings) {
- stream << " <" << name << "," << value << ">";
- }
- stream << " }";
-
- stream << ")";
- stream << "\n";
-
- return std::vector<uint8_t>(std::istreambuf_iterator<char>{stream},
- std::istreambuf_iterator<char>{});
+ friend void StreamIn(stream::Sink* sink, const HlslCompilationRequest& r) {
+ r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); });
}
};
+struct D3DBytecodeCompilationRequest {
+ DAWN_VISITABLE_MEMBERS(D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS)
+
+ friend void StreamIn(stream::Sink* sink, const D3DBytecodeCompilationRequest& r) {
+ r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); });
+ }
+};
+
+#define D3D_COMPILATION_REQUEST_MEMBERS(X) \
+ X(HlslCompilationRequest, hlsl) \
+ X(D3DBytecodeCompilationRequest, bytecode) \
+ X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
+
+DAWN_MAKE_CACHE_REQUEST(D3DCompilationRequest, D3D_COMPILATION_REQUEST_MEMBERS);
+#undef HLSL_COMPILATION_REQUEST_MEMBERS
+#undef D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS
+#undef D3D_COMPILATION_REQUEST_MEMBERS
+
std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
std::vector<const wchar_t*> arguments;
if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
@@ -429,25 +251,24 @@
return arguments;
}
-ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
- IDxcCompiler* dxcCompiler,
- const ShaderCompilationRequest& request,
+ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const D3DBytecodeCompilationRequest& r,
+ const std::string& entryPointName,
const std::string& hlslSource) {
ComPtr<IDxcBlobEncoding> sourceBlob;
- DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
+ DAWN_TRY(CheckHRESULT(r.dxcLibrary->CreateBlobWithEncodingFromPinned(
hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
"DXC create blob"));
std::wstring entryPointW;
- DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName));
+ DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPointName));
std::vector<const wchar_t*> arguments =
- GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
+ GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature);
// Build defines for overridable constants
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
- defineStrings.reserve(request.defineStrings.size());
- for (const auto& [name, value] : request.defineStrings) {
+ defineStrings.reserve(r.defineStrings.size());
+ for (const auto& [name, value] : r.defineStrings) {
defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
}
@@ -458,12 +279,11 @@
}
ComPtr<IDxcOperationResult> result;
- DAWN_TRY(
- CheckHRESULT(dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
- request.deviceInfo->shaderProfiles[request.stage].c_str(),
- arguments.data(), arguments.size(), dxcDefines.data(),
- dxcDefines.size(), nullptr, &result),
- "DXC compile"));
+ DAWN_TRY(CheckHRESULT(
+ r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
+ r.dxcShaderProfile.data(), arguments.data(), arguments.size(),
+ dxcDefines.data(), dxcDefines.size(), nullptr, &result),
+ "DXC compile"));
HRESULT hr;
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
@@ -543,31 +363,18 @@
return result;
}
-ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions,
- const ShaderCompilationRequest& request,
+ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const D3DBytecodeCompilationRequest& r,
+ const std::string& entryPointName,
const std::string& hlslSource) {
- const char* targetProfile = nullptr;
- switch (request.stage) {
- case SingleShaderStage::Vertex:
- targetProfile = "vs_5_1";
- break;
- case SingleShaderStage::Fragment:
- targetProfile = "ps_5_1";
- break;
- case SingleShaderStage::Compute:
- targetProfile = "cs_5_1";
- break;
- }
-
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
// Build defines for overridable constants
const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines;
- if (request.defineStrings.size() > 0) {
- fxcDefines.reserve(request.defineStrings.size() + 1);
- for (const auto& [name, value] : request.defineStrings) {
+ if (r.defineStrings.size() > 0) {
+ fxcDefines.reserve(r.defineStrings.size() + 1);
+ for (const auto& [name, value] : r.defineStrings) {
fxcDefines.push_back({name.c_str(), value.c_str()});
}
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
@@ -575,36 +382,49 @@
pDefines = fxcDefines.data();
}
- DAWN_INVALID_IF(
- FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
- nullptr, request.entryPointName, targetProfile,
- request.compileFlags, 0, &compiledShader, &errors)),
- "D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
+ DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines,
+ nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
+ r.compileFlags, 0, &compiledShader, &errors)),
+ "D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
return std::move(compiledShader);
}
-ResultOrError<std::string> TranslateToHLSL(dawn::platform::Platform* platform,
- const ShaderCompilationRequest& request,
- std::string* remappedEntryPointName) {
+ResultOrError<std::string> TranslateToHLSL(
+ HlslCompilationRequest r,
+ CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*> tracePlatform,
+ std::string* remappedEntryPointName,
+ bool* usesVertexOrInstanceIndex) {
std::ostringstream errorStream;
errorStream << "Tint HLSL failure:" << std::endl;
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
- if (request.isRobustnessEnabled) {
+ if (!r.newBindingsMap.empty()) {
+ transformManager.Add<tint::transform::MultiplanarExternalTexture>();
+ transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
+ std::move(r.newBindingsMap));
+ }
+
+ if (r.stage == SingleShaderStage::Vertex) {
+ transformManager.Add<tint::transform::FirstIndexOffset>();
+ transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
+ r.firstIndexOffsetShaderRegister, r.firstIndexOffsetRegisterSpace);
+ }
+
+ if (r.isRobustnessEnabled) {
transformManager.Add<tint::transform::Robustness>();
}
transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<tint::transform::SingleEntryPoint>();
- transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName);
+ transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName.data());
transformManager.Add<tint::transform::Renamer>();
- if (request.disableSymbolRenaming) {
+ if (r.disableSymbolRenaming) {
// We still need to rename HLSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kHlslKeywords);
@@ -615,104 +435,92 @@
// different types.
const bool mayCollide = true;
transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
- std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls),
- mayCollide);
+ std::move(r.remappedBindingPoints), std::move(r.remappedAccessControls), mayCollide);
tint::Program transformedProgram;
tint::transform::DataMap transformOutputs;
{
- TRACE_EVENT0(platform, General, "RunTransforms");
+ TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(transformedProgram,
- RunTransforms(&transformManager, request.program, transformInputs,
+ RunTransforms(&transformManager, r.inputProgram, transformInputs,
&transformOutputs, nullptr));
}
if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
- auto it = data->remappings.find(request.entryPointName);
+ auto it = data->remappings.find(r.entryPointName.data());
if (it != data->remappings.end()) {
*remappedEntryPointName = it->second;
} else {
- DAWN_INVALID_IF(!request.disableSymbolRenaming,
+ DAWN_INVALID_IF(!r.disableSymbolRenaming,
"Could not find remapped name for entry point.");
- *remappedEntryPointName = request.entryPointName;
+ *remappedEntryPointName = r.entryPointName;
}
} else {
return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data.");
}
+ if (r.stage == SingleShaderStage::Vertex) {
+ if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
+ *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
+ } else {
+ return DAWN_FORMAT_VALIDATION_ERROR(
+ "Transform output missing first index offset data.");
+ }
+ }
+
tint::writer::hlsl::Options options;
- options.disable_workgroup_init = request.disableWorkgroupInit;
- if (request.usesNumWorkgroups) {
- options.root_constant_binding_point = tint::sem::BindingPoint{
- request.numWorkgroupsRegisterSpace, request.numWorkgroupsShaderRegister};
+ options.disable_workgroup_init = r.disableWorkgroupInit;
+ if (r.usesNumWorkgroups) {
+ options.root_constant_binding_point =
+ tint::sem::BindingPoint{r.numWorkgroupsRegisterSpace, r.numWorkgroupsShaderRegister};
}
// TODO(dawn:549): HLSL generation outputs the indices into the
// array_length_from_uniform buffer that were actually used. When the blob cache can
// store more than compiled shaders, we should reflect these used indices and store
// them as well. This would allow us to only upload root constants that are actually
// read by the shader.
- options.array_length_from_uniform = request.arrayLengthFromUniform;
- TRACE_EVENT0(platform, General, "tint::writer::hlsl::Generate");
+ options.array_length_from_uniform = r.arrayLengthFromUniform;
+ TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "tint::writer::hlsl::Generate");
auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", result.error);
return std::move(result.hlsl);
}
-template <typename F>
-MaybeError CompileShader(dawn::platform::Platform* platform,
- const PlatformFunctions* functions,
- IDxcLibrary* dxcLibrary,
- IDxcCompiler* dxcCompiler,
- ShaderCompilationRequest&& request,
- bool dumpShaders,
- F&& DumpShadersEmitLog,
- CompiledShader* compiledShader) {
+ResultOrError<CompiledShader> CompileShader(D3DCompilationRequest r) {
+ CompiledShader compiledShader;
// Compile the source shader to HLSL.
- std::string hlslSource;
std::string remappedEntryPoint;
- DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(platform, request, &remappedEntryPoint));
- if (dumpShaders) {
- std::ostringstream dumpedMsg;
- dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
- DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
- }
- request.entryPointName = remappedEntryPoint.c_str();
- switch (request.compiler) {
- case ShaderCompilationRequest::Compiler::DXC: {
- TRACE_EVENT0(platform, General, "CompileShaderDXC");
- DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader,
- CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource));
+ DAWN_TRY_ASSIGN(compiledShader.hlslSource,
+ TranslateToHLSL(std::move(r.hlsl), r.tracePlatform, &remappedEntryPoint,
+ &compiledShader.usesVertexOrInstanceIndex));
+
+ switch (r.bytecode.compiler) {
+ case Compiler::DXC: {
+ TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderDXC");
+ ComPtr<IDxcBlob> compiledDXCShader;
+ DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(r.bytecode, remappedEntryPoint,
+ compiledShader.hlslSource));
+ compiledShader.shaderBlob = CreateBlob(std::move(compiledDXCShader));
break;
}
- case ShaderCompilationRequest::Compiler::FXC: {
- TRACE_EVENT0(platform, General, "CompileShaderFXC");
- DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader,
- CompileShaderFXC(functions, request, hlslSource));
+ case Compiler::FXC: {
+ TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderFXC");
+ ComPtr<ID3DBlob> compiledFXCShader;
+ DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(r.bytecode, remappedEntryPoint,
+ compiledShader.hlslSource));
+ compiledShader.shaderBlob = CreateBlob(std::move(compiledFXCShader));
break;
}
}
- if (dumpShaders && request.compiler == ShaderCompilationRequest::Compiler::FXC) {
- std::ostringstream dumpedMsg;
- dumpedMsg << "/* FXC compile flags */ " << std::endl
- << CompileFlagsToStringFXC(request.compileFlags) << std::endl;
-
- dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;
-
- ComPtr<ID3DBlob> disassembly;
- if (FAILED(functions->d3dDisassemble(compiledShader->compiledFXCShader->GetBufferPointer(),
- compiledShader->compiledFXCShader->GetBufferSize(), 0,
- nullptr, &disassembly))) {
- dumpedMsg << "D3D disassemble failed" << std::endl;
- } else {
- dumpedMsg << reinterpret_cast<const char*>(disassembly->GetBufferPointer());
- }
- DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
+ // If dumpShaders is false, we don't need the HLSL for logging. Clear the contents so it
+ // isn't stored into the cache.
+ if (!r.hlsl.dumpShaders) {
+ compiledShader.hlslSource = "";
}
-
- return {};
+ return compiledShader;
}
} // anonymous namespace
@@ -741,74 +549,202 @@
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t compileFlags) {
- TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleD3D12::Compile");
+ Device* device = ToBackend(GetDevice());
+ TRACE_EVENT0(device->GetPlatform(), General, "ShaderModuleD3D12::Compile");
ASSERT(!IsError());
- ScopedTintICEHandler scopedICEHandler(GetDevice());
+ ScopedTintICEHandler scopedICEHandler(device);
+ const EntryPointMetadata& entryPoint = GetEntryPoint(programmableStage.entryPoint);
- Device* device = ToBackend(GetDevice());
+ D3DCompilationRequest req = {};
+ req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform());
+ req.hlsl.shaderModel = device->GetDeviceInfo().shaderModel;
+ req.hlsl.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
+ req.hlsl.isRobustnessEnabled = device->IsRobustnessEnabled();
+ req.hlsl.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
+ req.hlsl.dumpShaders = device->IsToggleEnabled(Toggle::DumpShaders);
- CompiledShader compiledShader = {};
-
- tint::transform::Manager transformManager;
- tint::transform::DataMap transformInputs;
-
- const tint::Program* program = GetTintProgram();
- tint::Program programAsValue;
-
- auto externalTextureBindings = BuildExternalTextureTransformBindings(layout);
- if (!externalTextureBindings.empty()) {
- transformManager.Add<tint::transform::MultiplanarExternalTexture>();
- transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
- std::move(externalTextureBindings));
- }
-
- if (stage == SingleShaderStage::Vertex) {
- transformManager.Add<tint::transform::FirstIndexOffset>();
- transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
- layout->GetFirstIndexOffsetShaderRegister(),
- layout->GetFirstIndexOffsetRegisterSpace());
- }
-
- tint::transform::DataMap transformOutputs;
- DAWN_TRY_ASSIGN(programAsValue, RunTransforms(&transformManager, program, transformInputs,
- &transformOutputs, nullptr));
- program = &programAsValue;
-
- if (stage == SingleShaderStage::Vertex) {
- if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
- // TODO(dawn:549): Consider adding this information to the pipeline cache once we
- // can store more than the shader blob in it.
- compiledShader.usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
+ req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
+ req.bytecode.compileFlags = compileFlags;
+ req.bytecode.defineStrings =
+ GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides);
+ if (device->IsToggleEnabled(Toggle::UseDXC)) {
+ req.bytecode.compiler = Compiler::DXC;
+ req.bytecode.dxcLibrary = device->GetDxcLibrary().Get();
+ req.bytecode.dxcCompiler = device->GetDxcCompiler().Get();
+ DAWN_TRY_ASSIGN(req.bytecode.compilerVersion,
+ ToBackend(device->GetAdapter())->GetBackend()->GetDXCompilerVersion());
+ req.bytecode.dxcShaderProfile = device->GetDeviceInfo().shaderProfiles[stage];
+ } else {
+ req.bytecode.compiler = Compiler::FXC;
+ req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile;
+ req.bytecode.compilerVersion = D3D_COMPILER_VERSION;
+ switch (stage) {
+ case SingleShaderStage::Vertex:
+ req.bytecode.fxcShaderProfile = "vs_5_1";
+ break;
+ case SingleShaderStage::Fragment:
+ req.bytecode.fxcShaderProfile = "ps_5_1";
+ break;
+ case SingleShaderStage::Compute:
+ req.bytecode.fxcShaderProfile = "cs_5_1";
+ break;
}
}
- ShaderCompilationRequest request;
- DAWN_TRY_ASSIGN(request,
- ShaderCompilationRequest::Create(
- programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
- program, GetEntryPoint(programmableStage.entryPoint), programmableStage));
+ using tint::transform::BindingPoint;
+ using tint::transform::BindingRemapper;
- // TODO(dawn:1341): Add shader cache key generation and caching for the compiled shader.
- DAWN_TRY(CompileShader(
- device->GetPlatform(), device->GetFunctions(),
- device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get() : nullptr,
- device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get() : nullptr,
- std::move(request), device->IsToggleEnabled(Toggle::DumpShaders),
- [&](WGPULoggingType loggingType, const char* message) {
- GetDevice()->EmitLog(loggingType, message);
- },
- &compiledShader));
- return std::move(compiledShader);
+ BindingRemapper::BindingPoints remappedBindingPoints;
+ BindingRemapper::AccessControls remappedAccessControls;
+
+ tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
+ arrayLengthFromUniform.ubo_binding = {layout->GetDynamicStorageBufferLengthsRegisterSpace(),
+ layout->GetDynamicStorageBufferLengthsShaderRegister()};
+
+ const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
+ for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+ const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
+ const auto& groupBindingInfo = moduleBindingInfo[group];
+
+ // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
+ // the Tint AST to make the "bindings" decoration match the offset chosen by
+ // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
+ // assigned to each interface variable.
+ for (const auto& [binding, bindingInfo] : groupBindingInfo) {
+ BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
+ BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
+ static_cast<uint32_t>(binding)};
+ BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
+ bgl->GetShaderRegister(bindingIndex)};
+ if (srcBindingPoint != dstBindingPoint) {
+ remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
+ }
+
+ // Declaring a read-only storage buffer in HLSL but specifying a storage
+ // buffer in the BGL produces the wrong output. Force read-only storage
+ // buffer bindings to be treated as UAV instead of SRV. Internal storage
+ // buffer is a storage buffer used in the internal pipeline.
+ const bool forceStorageBufferAsUAV =
+ (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
+ (bgl->GetBindingInfo(bindingIndex).buffer.type ==
+ wgpu::BufferBindingType::Storage ||
+ bgl->GetBindingInfo(bindingIndex).buffer.type == kInternalStorageBufferBinding));
+ if (forceStorageBufferAsUAV) {
+ remappedAccessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
+ }
+ }
+
+ // Add arrayLengthFromUniform options
+ {
+ for (const auto& bindingAndRegisterOffset :
+ layout->GetDynamicStorageBufferLengthInfo()[group].bindingAndRegisterOffsets) {
+ BindingNumber binding = bindingAndRegisterOffset.binding;
+ uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;
+
+ BindingPoint bindingPoint{static_cast<uint32_t>(group),
+ static_cast<uint32_t>(binding)};
+ // Get the renamed binding point if it was remapped.
+ auto it = remappedBindingPoints.find(bindingPoint);
+ if (it != remappedBindingPoints.end()) {
+ bindingPoint = it->second;
+ }
+
+ arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
+ registerOffset);
+ }
+ }
+ }
+
+ req.hlsl.inputProgram = GetTintProgram();
+ req.hlsl.entryPointName = programmableStage.entryPoint.c_str();
+ req.hlsl.stage = stage;
+ req.hlsl.firstIndexOffsetShaderRegister = layout->GetFirstIndexOffsetShaderRegister();
+ req.hlsl.firstIndexOffsetRegisterSpace = layout->GetFirstIndexOffsetRegisterSpace();
+ req.hlsl.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
+ req.hlsl.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
+ req.hlsl.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
+ req.hlsl.remappedBindingPoints = std::move(remappedBindingPoints);
+ req.hlsl.remappedAccessControls = std::move(remappedAccessControls);
+ req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout);
+ req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
+
+ CacheResult<CompiledShader> compiledShader;
+ DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,
+ CompileShader);
+
+ if (device->IsToggleEnabled(Toggle::DumpShaders)) {
+ std::ostringstream dumpedMsg;
+ dumpedMsg << "/* Dumped generated HLSL */" << std::endl
+ << compiledShader->hlslSource << std::endl;
+ device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
+
+ if (device->IsToggleEnabled(Toggle::UseDXC)) {
+ dumpedMsg << "/* Dumped disassembled DXIL */" << std::endl;
+ D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode();
+ ComPtr<IDxcBlobEncoding> dxcBlob;
+ ComPtr<IDxcBlobEncoding> disassembly;
+ if (FAILED(device->GetDxcLibrary()->CreateBlobWithEncodingFromPinned(
+ code.pShaderBytecode, code.BytecodeLength, 0, &dxcBlob)) ||
+ FAILED(device->GetDxcCompiler()->Disassemble(dxcBlob.Get(), &disassembly))) {
+ dumpedMsg << "DXC disassemble failed" << std::endl;
+ } else {
+ dumpedMsg << std::string_view(
+ static_cast<const char*>(disassembly->GetBufferPointer()),
+ disassembly->GetBufferSize());
+ }
+ } else {
+ dumpedMsg << "/* FXC compile flags */ " << std::endl
+ << CompileFlagsToStringFXC(compileFlags) << std::endl;
+ dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;
+ ComPtr<ID3DBlob> disassembly;
+ D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode();
+ if (FAILED(device->GetFunctions()->d3dDisassemble(
+ code.pShaderBytecode, code.BytecodeLength, 0, nullptr, &disassembly))) {
+ dumpedMsg << "D3D disassemble failed" << std::endl;
+ } else {
+ dumpedMsg << std::string_view(
+ static_cast<const char*>(disassembly->GetBufferPointer()),
+ disassembly->GetBufferSize());
+ }
+ }
+ device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
+ }
+
+ if (BlobCache* cache = device->GetBlobCache()) {
+ cache->EnsureStored(compiledShader);
+ }
+
+ // Clear the hlslSource. It is only used for logging and should not be used
+ // outside of the compilation.
+ CompiledShader result = compiledShader.Acquire();
+ result.hlslSource = "";
+ return result;
}
D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
- if (compiledFXCShader != nullptr) {
- return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
- } else if (compiledDXCShader != nullptr) {
- return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
- }
- UNREACHABLE();
- return {};
+ return {shaderBlob.Data(), shaderBlob.Size()};
}
+
} // namespace dawn::native::d3d12
+
+namespace dawn::native {
+
+// Define the implementation to store d3d12::CompiledShader into the BlobCache.
+template <>
+void BlobCache::Store<d3d12::CompiledShader>(const CacheKey& key, const d3d12::CompiledShader& c) {
+ stream::ByteVectorSink sink;
+ c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); });
+ Store(key, CreateBlob(std::move(sink)));
+}
+
+// Define the implementation to load d3d12::CompiledShader from a Blob.
+// static
+ResultOrError<d3d12::CompiledShader> d3d12::CompiledShader::FromBlob(Blob blob) {
+ stream::BlobSource source(std::move(blob));
+ d3d12::CompiledShader c;
+ DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); }));
+ return c;
+}
+
+} // namespace dawn::native
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h
index 7f68b10..528e1e4 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h
@@ -15,8 +15,11 @@
#ifndef SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_
#define SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_
-#include "dawn/native/ShaderModule.h"
+#include <string>
+#include "dawn/native/Blob.h"
+#include "dawn/native/ShaderModule.h"
+#include "dawn/native/VisitableMembers.h"
#include "dawn/native/d3d12/d3d12_platform.h"
namespace dawn::native {
@@ -28,14 +31,22 @@
class Device;
class PipelineLayout;
-// Manages a ref to one of the various representations of shader blobs and information used to
-// emulate vertex/instance index starts
+#define COMPILED_SHADER_MEMBERS(X) \
+ X(Blob, shaderBlob) \
+ X(std::string, hlslSource) \
+ X(bool, usesVertexOrInstanceIndex)
+
+// `CompiledShader` holds a ref to one of the various representations of shader blobs and
+// information used to emulate vertex/instance index starts. It also holds the `hlslSource` for the
+// shader compilation, which is only transiently available during Compile, and cleared before it
+// returns. It is not written to or loaded from the cache unless Toggle dump_shaders is true.
struct CompiledShader {
- ComPtr<ID3DBlob> compiledFXCShader;
- ComPtr<IDxcBlob> compiledDXCShader;
+ static ResultOrError<CompiledShader> FromBlob(Blob blob);
+
D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
- bool usesVertexOrInstanceIndex;
+ DAWN_VISITABLE_MEMBERS(COMPILED_SHADER_MEMBERS)
+#undef COMPILED_SHADER_MEMBERS
};
class ShaderModule final : public ShaderModuleBase {
diff --git a/src/dawn/native/d3d12/UtilsD3D12.cpp b/src/dawn/native/d3d12/UtilsD3D12.cpp
index 0e761f8..e706d8f 100644
--- a/src/dawn/native/d3d12/UtilsD3D12.cpp
+++ b/src/dawn/native/d3d12/UtilsD3D12.cpp
@@ -81,19 +81,19 @@
} // anonymous namespace
-ResultOrError<std::wstring> ConvertStringToWstring(const char* str) {
- size_t len = strlen(str);
+ResultOrError<std::wstring> ConvertStringToWstring(std::string_view s) {
+ size_t len = s.length();
if (len == 0) {
return std::wstring();
}
- int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, nullptr, 0);
+ int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), len, nullptr, 0);
if (numChars == 0) {
return DAWN_INTERNAL_ERROR("Failed to convert string to wide string");
}
std::wstring result;
result.resize(numChars);
int numConvertedChars =
- MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, &result[0], numChars);
+ MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), len, &result[0], numChars);
if (numConvertedChars != numChars) {
return DAWN_INTERNAL_ERROR("Failed to convert string to wide string");
}
diff --git a/src/dawn/native/d3d12/UtilsD3D12.h b/src/dawn/native/d3d12/UtilsD3D12.h
index 1418f54..dcbe782 100644
--- a/src/dawn/native/d3d12/UtilsD3D12.h
+++ b/src/dawn/native/d3d12/UtilsD3D12.h
@@ -26,7 +26,7 @@
namespace dawn::native::d3d12 {
-ResultOrError<std::wstring> ConvertStringToWstring(const char* str);
+ResultOrError<std::wstring> ConvertStringToWstring(std::string_view s);
D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func);
diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp
index 1ca241c..beb3823 100644
--- a/src/dawn/native/stream/Stream.cpp
+++ b/src/dawn/native/stream/Stream.cpp
@@ -48,4 +48,14 @@
}
}
+template <>
+void Stream<std::wstring_view>::Write(Sink* s, const std::wstring_view& t) {
+ StreamIn(s, t.length());
+ size_t size = t.length() * sizeof(wchar_t);
+ if (size > 0) {
+ void* ptr = s->GetSpace(size);
+ memcpy(ptr, t.data(), size);
+ }
+}
+
} // namespace dawn::native::stream
diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h
index 3433317..d077cccc 100644
--- a/src/dawn/native/stream/Stream.h
+++ b/src/dawn/native/stream/Stream.h
@@ -297,10 +297,9 @@
public:
static void Write(stream::Sink* sink, const std::unordered_map<K, V>& m) {
std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
- std::sort(ordered.begin(), ordered.end(),
- [](const std::pair<K, V>& a, const std::pair<K, V>& b) {
- return std::less<K>{}(a.first, b.first);
- });
+ std::sort(
+ ordered.begin(), ordered.end(),
+ [](const std::pair<K, V>& a, const std::pair<K, V>& b) { return a.first < b.first; });
StreamIn(sink, ordered);
}
};
diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp
index 5dcf918..cbf5a92 100644
--- a/src/dawn/tests/end2end/PipelineCachingTests.cpp
+++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp
@@ -108,8 +108,8 @@
const EntryCounts counts = {
// pipeline caching is only implemented on D3D12/Vulkan
IsD3D12() || IsVulkan() ? 1u : 0u,
- // shader module caching is only implemented on Vulkan/Metal
- IsVulkan() || IsMetal() ? 1u : 0u,
+ // shader module caching is only implemented on Vulkan/D3D12/Metal
+ IsVulkan() || IsMetal() || IsD3D12() ? 1u : 0u,
};
NiceMock<CachingInterfaceMock> mMockCache;
};
@@ -646,6 +646,7 @@
DAWN_INSTANTIATE_TEST(SinglePipelineCachingTests,
D3D12Backend({"enable_blob_cache"}),
+ D3D12Backend({"enable_blob_cache", "use_dxc"}),
MetalBackend({"enable_blob_cache"}),
OpenGLBackend({"enable_blob_cache"}),
OpenGLESBackend({"enable_blob_cache"}),
diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp
index 6f25700..a1196f3 100644
--- a/src/dawn/tests/unittests/native/StreamTests.cpp
+++ b/src/dawn/tests/unittests/native/StreamTests.cpp
@@ -174,6 +174,30 @@
EXPECT_CACHE_KEY_EQ(str, expected);
}
+// Test that ByteVectorSink serializes std::wstring_views as expected.
+TEST(SerializeTests, StdWStringViews) {
+ static constexpr std::wstring_view str(L"Hello world!");
+
+ ByteVectorSink expected;
+ StreamIn(&expected, size_t(str.length()));
+ size_t bytes = str.length() * sizeof(wchar_t);
+ memcpy(expected.GetSpace(bytes), str.data(), bytes);
+
+ EXPECT_CACHE_KEY_EQ(str, expected);
+}
+
+// Test that ByteVectorSink serializes Blobs as expected.
+TEST(SerializeTests, Blob) {
+ uint8_t data[] = "dawn native Blob";
+ Blob blob = Blob::UnsafeCreateWithDeleter(data, sizeof(data), []() {});
+
+ ByteVectorSink expected;
+ StreamIn(&expected, sizeof(data));
+ expected.insert(expected.end(), data, data + sizeof(data));
+
+ EXPECT_CACHE_KEY_EQ(blob, expected);
+}
+
// Test that ByteVectorSink serializes other ByteVectorSinks as expected.
TEST(SerializeTests, ByteVectorSinks) {
ByteVectorSink data = {'d', 'a', 't', 'a'};
@@ -309,6 +333,41 @@
}
}
+// Test that serializing then deserializing a Blob yields the same data.
+// Tested here instead of in the type-parameterized tests since Blobs are not copyable.
+TEST(StreamTests, SerializeDeserializeBlobs) {
+ // Test an empty blob
+ {
+ Blob blob;
+ EXPECT_EQ(blob.Size(), 0u);
+
+ ByteVectorSink sink;
+ StreamIn(&sink, blob);
+
+ BlobSource src(CreateBlob(sink));
+ Blob out;
+ auto err = StreamOut(&src, &out);
+ EXPECT_FALSE(err.IsError());
+ EXPECT_EQ(blob.Size(), out.Size());
+ EXPECT_EQ(memcmp(blob.Data(), out.Data(), blob.Size()), 0);
+ }
+
+ // Test a blob with some data
+ {
+ Blob blob = CreateBlob(std::vector<double>{6.24, 3.12222});
+
+ ByteVectorSink sink;
+ StreamIn(&sink, blob);
+
+ BlobSource src(CreateBlob(sink));
+ Blob out;
+ auto err = StreamOut(&src, &out);
+ EXPECT_FALSE(err.IsError());
+ EXPECT_EQ(blob.Size(), out.Size());
+ EXPECT_EQ(memcmp(blob.Data(), out.Data(), blob.Size()), 0);
+ }
+}
+
template <size_t N>
std::bitset<N - 1> BitsetFromBitString(const char (&str)[N]) {
// N - 1 because the last character is the null terminator.