D3D12: implement WGSL pipeline cache key generation
Since spirv_to_dxil does not generate HLSL, to support pipeline caching,
we need to generate a cache key from the WGSL instead.
A new type, ShaderCompilationRequest, is added to isolate the
compilation inputs to help ensure that the cache key contains all
relevant information.
Bug: dawn:1103
Change-Id: Ic2f09326dc3ac254cecf35098dcfe95aa396796f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/61160
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Michael Tang <tangm@microsoft.com>
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 718e0c1..d601325 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -28,10 +28,227 @@
#include <d3dcompiler.h>
#include <tint/tint.h>
+#include <map>
+#include <sstream>
+#include <unordered_map>
namespace dawn_native { namespace d3d12 {
namespace {
+ ResultOrError<uint64_t> GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator) {
+ ComPtr<IDxcVersionInfo> versionInfo;
+ DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),
+ "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo"));
+
+ uint32_t compilerMajor, compilerMinor;
+ DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor),
+ "IDxcVersionInfo::GetVersion"));
+
+ // Pack both into a single version number.
+ return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor;
+ }
+
+ uint64_t GetD3DCompilerVersion() {
+ return D3D_COMPILER_VERSION;
+ }
+
+ struct CompareBindingPoint {
+ constexpr bool operator()(const tint::transform::BindingPoint& lhs,
+ const tint::transform::BindingPoint& rhs) const {
+ if (lhs.group != rhs.group) {
+ return lhs.group < rhs.group;
+ } else {
+ return lhs.binding < rhs.binding;
+ }
+ }
+ };
+
+ void Serialize(std::stringstream& output, const tint::ast::Access& access) {
+ output << access;
+ }
+
+ void Serialize(std::stringstream& output,
+ const tint::transform::BindingPoint& binding_point) {
+ output << "(BindingPoint";
+ output << " group=" << binding_point.group;
+ output << " binding=" << binding_point.binding;
+ output << ")";
+ }
+
+ template <typename T>
+ void Serialize(std::stringstream& output,
+ const std::unordered_map<tint::transform::BindingPoint, T>& map) {
+ output << "(map";
+
+ std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(),
+ map.end());
+ for (auto& entry : sorted) {
+ output << " ";
+ Serialize(output, entry.first);
+ output << "=";
+ Serialize(output, entry.second);
+ }
+ output << ")";
+ }
+
+ // The inputs to a shader compilation. These have been intentionally isolated from the
+ // device to help ensure that the pipeline cache key contains all inputs for compilation.
+ struct ShaderCompilationRequest {
+ enum Compiler { FXC, DXC };
+
+ // Common inputs
+ Compiler compiler;
+ const tint::Program* program;
+ const char* entryPointName;
+ SingleShaderStage stage;
+ uint32_t compileFlags;
+ bool disableSymbolRenaming;
+ tint::transform::BindingRemapper::BindingPoints bindingPoints;
+ tint::transform::BindingRemapper::AccessControls accessControls;
+ bool isRobustnessEnabled;
+
+ // FXC/DXC common inputs
+ bool disableWorkgroupInit;
+
+ // FXC inputs
+ uint64_t fxcVersion;
+
+ // DXC inputs
+ uint64_t dxcVersion;
+ const D3D12DeviceInfo* deviceInfo;
+ bool hasShaderFloat16Extension;
+
+ static ResultOrError<ShaderCompilationRequest> Create(
+ const char* entryPointName,
+ SingleShaderStage stage,
+ const PipelineLayout* layout,
+ uint32_t compileFlags,
+ const Device* device,
+ const tint::Program* program,
+ const BindingInfoArray& moduleBindingInfo) {
+ Compiler compiler;
+ uint64_t dxcVersion = 0;
+ if (device->IsToggleEnabled(Toggle::UseDXC)) {
+ compiler = Compiler::DXC;
+ DAWN_TRY_ASSIGN(dxcVersion, GetDXCompilerVersion(device->GetDxcValidator()));
+ } else {
+ compiler = Compiler::FXC;
+ }
+
+ using tint::transform::BindingPoint;
+ using tint::transform::BindingRemapper;
+
+ BindingRemapper::BindingPoints bindingPoints;
+ BindingRemapper::AccessControls accessControls;
+
+ // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify the
+ // Tint AST to make the "bindings" decoration match the offset chosen by
+ // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
+ // assigned to each interface variable.
+ for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+ const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
+ const auto& groupBindingInfo = moduleBindingInfo[group];
+ for (const auto& it : groupBindingInfo) {
+ BindingNumber binding = it.first;
+ auto const& bindingInfo = it.second;
+ BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
+ BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
+ static_cast<uint32_t>(binding)};
+ BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
+ bgl->GetShaderRegister(bindingIndex)};
+ if (srcBindingPoint != dstBindingPoint) {
+ bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
+ }
+
+ // Declaring a read-only storage buffer in HLSL but specifying a storage
+ // buffer in the BGL produces the wrong output. Force read-only storage
+ // buffer bindings to be treated as UAV instead of SRV. Internal storage
+ // buffer is a storage buffer used in the internal pipeline.
+ const bool forceStorageBufferAsUAV =
+ (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
+ (bgl->GetBindingInfo(bindingIndex).buffer.type ==
+ wgpu::BufferBindingType::Storage ||
+ bgl->GetBindingInfo(bindingIndex).buffer.type ==
+ kInternalStorageBufferBinding));
+ if (forceStorageBufferAsUAV) {
+ accessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
+ }
+ }
+ }
+
+ ShaderCompilationRequest request;
+ request.compiler = compiler;
+ request.program = program;
+ request.entryPointName = entryPointName;
+ request.stage = stage;
+ request.compileFlags = compileFlags;
+ request.disableSymbolRenaming =
+ device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
+ request.bindingPoints = std::move(bindingPoints);
+ request.accessControls = std::move(accessControls);
+ request.isRobustnessEnabled = device->IsRobustnessEnabled();
+ request.disableWorkgroupInit =
+ device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
+ request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
+ request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
+ request.deviceInfo = &device->GetDeviceInfo();
+ request.hasShaderFloat16Extension =
+ device->IsExtensionEnabled(Extension::ShaderFloat16);
+ return std::move(request);
+ }
+
+ ResultOrError<PersistentCacheKey> CreateCacheKey() const {
+ // Generate the WGSL from the Tint program so it's normalized.
+ // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more
+ // compact representation.
+ auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{});
+ if (!result.success) {
+ std::ostringstream errorStream;
+ errorStream << "Tint WGSL failure:" << std::endl;
+ errorStream << "Generator: " << result.error << std::endl;
+ return DAWN_INTERNAL_ERROR(errorStream.str().c_str());
+ }
+
+ std::stringstream stream;
+
+ // Prefix the key with the type to avoid collisions from another type that could
+ // have the same key.
+ stream << static_cast<uint32_t>(PersistentKeyType::Shader);
+ stream << "\n";
+
+ stream << result.wgsl.length();
+ stream << "\n";
+
+ stream << result.wgsl;
+ stream << "\n";
+
+ stream << "(ShaderCompilationRequest";
+ stream << " compiler=" << compiler;
+ stream << " entryPointName=" << entryPointName;
+ stream << " stage=" << uint32_t(stage);
+ stream << " compileFlags=" << compileFlags;
+ stream << " disableSymbolRenaming=" << disableSymbolRenaming;
+
+ stream << " bindingPoints=";
+ Serialize(stream, bindingPoints);
+
+ stream << " accessControls=";
+ Serialize(stream, accessControls);
+
+ stream << " shaderModel=" << deviceInfo->shaderModel;
+ stream << " disableWorkgroupInit=" << disableWorkgroupInit;
+ stream << " isRobustnessEnabled=" << isRobustnessEnabled;
+ stream << " fxcVersion=" << fxcVersion;
+ stream << " dxcVersion=" << dxcVersion;
+ stream << " hasShaderFloat16Extension=" << hasShaderFloat16Extension;
+ stream << ")";
+ stream << "\n";
+
+ return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
+ std::istreambuf_iterator<char>{});
+ }
+ };
+
std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
std::vector<const wchar_t*> arguments;
if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
@@ -83,85 +300,173 @@
return arguments;
}
+ ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
+ IDxcCompiler* dxcCompiler,
+ const ShaderCompilationRequest& request,
+ const std::string& hlslSource) {
+ ComPtr<IDxcBlobEncoding> sourceBlob;
+ DAWN_TRY(
+ CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
+ hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
+ "DXC create blob"));
+
+ std::wstring entryPointW;
+ DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName));
+
+ std::vector<const wchar_t*> arguments =
+ GetDXCArguments(request.compileFlags, request.hasShaderFloat16Extension);
+
+ ComPtr<IDxcOperationResult> result;
+ DAWN_TRY(CheckHRESULT(
+ dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
+ request.deviceInfo->shaderProfiles[request.stage].c_str(),
+ arguments.data(), arguments.size(), nullptr, 0, nullptr,
+ &result),
+ "DXC compile"));
+
+ HRESULT hr;
+ DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
+
+ if (FAILED(hr)) {
+ ComPtr<IDxcBlobEncoding> errors;
+ DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
+
+ std::string message = std::string("DXC compile failed with ") +
+ static_cast<char*>(errors->GetBufferPointer());
+ return DAWN_VALIDATION_ERROR(message);
+ }
+
+ ComPtr<IDxcBlob> compiledShader;
+ DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
+ return std::move(compiledShader);
+ }
+
+ ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions,
+ const ShaderCompilationRequest& request,
+ const std::string& hlslSource) {
+ const char* targetProfile = nullptr;
+ switch (request.stage) {
+ case SingleShaderStage::Vertex:
+ targetProfile = "vs_5_1";
+ break;
+ case SingleShaderStage::Fragment:
+ targetProfile = "ps_5_1";
+ break;
+ case SingleShaderStage::Compute:
+ targetProfile = "cs_5_1";
+ break;
+ }
+
+ ComPtr<ID3DBlob> compiledShader;
+ ComPtr<ID3DBlob> errors;
+
+ if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr,
+ nullptr, nullptr, request.entryPointName,
+ targetProfile, request.compileFlags, 0,
+ &compiledShader, &errors))) {
+ std::string message = std::string("D3D compile failed with ") +
+ static_cast<char*>(errors->GetBufferPointer());
+ return DAWN_VALIDATION_ERROR(message);
+ }
+
+ return std::move(compiledShader);
+ }
+
+ ResultOrError<std::string> TranslateToHLSL(const ShaderCompilationRequest& request,
+ std::string* remappedEntryPointName) {
+ std::ostringstream errorStream;
+ errorStream << "Tint HLSL failure:" << std::endl;
+
+ tint::transform::Manager transformManager;
+ tint::transform::DataMap transformInputs;
+
+ if (request.isRobustnessEnabled) {
+ transformManager.Add<tint::transform::BoundArrayAccessors>();
+ }
+ transformManager.Add<tint::transform::BindingRemapper>();
+
+ transformManager.Add<tint::transform::Renamer>();
+
+ if (request.disableSymbolRenaming) {
+ // We still need to rename HLSL reserved keywords
+ transformInputs.Add<tint::transform::Renamer::Config>(
+ tint::transform::Renamer::Target::kHlslKeywords);
+ }
+
+ // D3D12 registers like `t3` and `c3` have the same bindingOffset number in
+ // the remapping but should not be considered a collision because they have
+ // different types.
+ const bool mayCollide = true;
+ transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
+ std::move(request.bindingPoints), std::move(request.accessControls), mayCollide);
+
+ tint::Program transformedProgram;
+ tint::transform::DataMap transformOutputs;
+ DAWN_TRY_ASSIGN(transformedProgram,
+ RunTransforms(&transformManager, request.program, transformInputs,
+ &transformOutputs, nullptr));
+
+ if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
+ auto it = data->remappings.find(request.entryPointName);
+ if (it != data->remappings.end()) {
+ *remappedEntryPointName = it->second;
+ } else {
+ if (request.disableSymbolRenaming) {
+ *remappedEntryPointName = request.entryPointName;
+ } else {
+ return DAWN_VALIDATION_ERROR(
+ "Could not find remapped name for entry point.");
+ }
+ }
+ } else {
+ return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
+ }
+
+ tint::writer::hlsl::Options options;
+ options.disable_workgroup_init = request.disableWorkgroupInit;
+ auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
+ if (!result.success) {
+ errorStream << "Generator: " << result.error << std::endl;
+ return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+ }
+
+ return std::move(result.hlsl);
+ }
+
+ template <typename F>
+ MaybeError CompileShader(const PlatformFunctions* functions,
+ IDxcLibrary* dxcLibrary,
+ IDxcCompiler* dxcCompiler,
+ ShaderCompilationRequest&& request,
+ bool dumpShaders,
+ F&& DumpShadersEmitLog,
+ CompiledShader* compiledShader) {
+ // Compile the source shader to HLSL.
+ std::string hlslSource;
+ std::string remappedEntryPoint;
+ DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(request, &remappedEntryPoint));
+ if (dumpShaders) {
+ std::ostringstream dumpedMsg;
+ dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
+ DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
+ }
+ request.entryPointName = remappedEntryPoint.c_str();
+ switch (request.compiler) {
+ case ShaderCompilationRequest::Compiler::DXC:
+ DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader,
+ CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource));
+ break;
+ case ShaderCompilationRequest::Compiler::FXC:
+ DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader,
+ CompileShaderFXC(functions, request, hlslSource));
+ break;
+ }
+
+ return {};
+ }
+
} // anonymous namespace
- ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
- SingleShaderStage stage,
- const std::string& hlslSource,
- const char* entryPoint,
- uint32_t compileFlags) {
- ComPtr<IDxcLibrary> dxcLibrary = device->GetDxcLibrary();
-
- ComPtr<IDxcBlobEncoding> sourceBlob;
- DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
- hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
- "DXC create blob"));
-
- ComPtr<IDxcCompiler> dxcCompiler = device->GetDxcCompiler();
-
- std::wstring entryPointW;
- DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
-
- std::vector<const wchar_t*> arguments =
- GetDXCArguments(compileFlags, device->IsExtensionEnabled(Extension::ShaderFloat16));
-
- ComPtr<IDxcOperationResult> result;
- DAWN_TRY(CheckHRESULT(
- dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
- device->GetDeviceInfo().shaderProfiles[stage].c_str(),
- arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
- "DXC compile"));
-
- HRESULT hr;
- DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
-
- if (FAILED(hr)) {
- ComPtr<IDxcBlobEncoding> errors;
- DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
-
- std::string message = std::string("DXC compile failed with ") +
- static_cast<char*>(errors->GetBufferPointer());
- return DAWN_VALIDATION_ERROR(message);
- }
-
- ComPtr<IDxcBlob> compiledShader;
- DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
- return std::move(compiledShader);
- }
-
- ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
- SingleShaderStage stage,
- const std::string& hlslSource,
- const char* entryPoint,
- uint32_t compileFlags) {
- const char* targetProfile = nullptr;
- switch (stage) {
- case SingleShaderStage::Vertex:
- targetProfile = "vs_5_1";
- break;
- case SingleShaderStage::Fragment:
- targetProfile = "ps_5_1";
- break;
- case SingleShaderStage::Compute:
- targetProfile = "cs_5_1";
- break;
- }
-
- ComPtr<ID3DBlob> compiledShader;
- ComPtr<ID3DBlob> errors;
-
- const PlatformFunctions* functions = device->GetFunctions();
- if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
- nullptr, entryPoint, targetProfile, compileFlags, 0,
- &compiledShader, &errors))) {
- std::string message = std::string("D3D compile failed with ") +
- static_cast<char*>(errors->GetBufferPointer());
- return DAWN_VALIDATION_ERROR(message);
- }
-
- return std::move(compiledShader);
- }
-
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor,
@@ -180,120 +485,13 @@
return InitializeBase(parseResult);
}
- ResultOrError<std::string> ShaderModule::TranslateToHLSL(
- const tint::Program* program,
- const char* entryPointName,
- SingleShaderStage stage,
- PipelineLayout* layout,
- std::string* remappedEntryPointName) const {
- ASSERT(!IsError());
-
- ScopedTintICEHandler scopedICEHandler(GetDevice());
-
- using BindingRemapper = tint::transform::BindingRemapper;
- using BindingPoint = tint::transform::BindingPoint;
- BindingRemapper::BindingPoints bindingPoints;
- BindingRemapper::AccessControls accessControls;
-
- const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
-
- // d3d12::BindGroupLayout packs the bindings per HLSL register-space.
- // We modify the Tint AST to make the "bindings" decoration match the
- // offset chosen by d3d12::BindGroupLayout so that Tint produces HLSL
- // with the correct registers assigned to each interface variable.
- for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
- const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
- const auto& groupBindingInfo = moduleBindingInfo[group];
- for (const auto& it : groupBindingInfo) {
- BindingNumber binding = it.first;
- auto const& bindingInfo = it.second;
- BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
- BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
- static_cast<uint32_t>(binding)};
- BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
- bgl->GetShaderRegister(bindingIndex)};
- if (srcBindingPoint != dstBindingPoint) {
- bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
- }
-
- // Declaring a read-only storage buffer in HLSL but specifying a
- // storage buffer in the BGL produces the wrong output.
- // Force read-only storage buffer bindings to be treated as UAV
- // instead of SRV.
- // Internal storage buffer is a storage buffer used in the internal pipeline.
- const bool forceStorageBufferAsUAV =
- (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
- (bgl->GetBindingInfo(bindingIndex).buffer.type ==
- wgpu::BufferBindingType::Storage ||
- bgl->GetBindingInfo(bindingIndex).buffer.type ==
- kInternalStorageBufferBinding));
- if (forceStorageBufferAsUAV) {
- accessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
- }
- }
- }
-
- std::ostringstream errorStream;
- errorStream << "Tint HLSL failure:" << std::endl;
-
- tint::transform::Manager transformManager;
- tint::transform::DataMap transformInputs;
-
- if (GetDevice()->IsRobustnessEnabled()) {
- transformManager.Add<tint::transform::BoundArrayAccessors>();
- }
- transformManager.Add<tint::transform::BindingRemapper>();
-
- transformManager.Add<tint::transform::Renamer>();
-
- if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
- // We still need to rename HLSL reserved keywords
- transformInputs.Add<tint::transform::Renamer::Config>(
- tint::transform::Renamer::Target::kHlslKeywords);
- }
-
- // D3D12 registers like `t3` and `c3` have the same bindingOffset number in the
- // remapping but should not be considered a collision because they have different types.
- const bool mayCollide = true;
- transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
- std::move(accessControls), mayCollide);
-
- tint::Program transformedProgram;
- tint::transform::DataMap transformOutputs;
- DAWN_TRY_ASSIGN(
- transformedProgram,
- RunTransforms(&transformManager, program, transformInputs, &transformOutputs, nullptr));
-
- if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
- auto it = data->remappings.find(entryPointName);
- if (it != data->remappings.end()) {
- *remappedEntryPointName = it->second;
- } else {
- if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
- *remappedEntryPointName = entryPointName;
- } else {
- return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");
- }
- }
- } else {
- return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
- }
-
- tint::writer::hlsl::Options options;
- options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
- auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
- if (!result.success) {
- errorStream << "Generator: " << result.error << std::endl;
- return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
- }
-
- return std::move(result.hlsl);
- }
-
ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
SingleShaderStage stage,
PipelineLayout* layout,
uint32_t compileFlags) {
+ ASSERT(!IsError());
+ ScopedTintICEHandler scopedICEHandler(GetDevice());
+
Device* device = ToBackend(GetDevice());
CompiledShader compiledShader = {};
@@ -333,47 +531,33 @@
program = GetTintProgram();
}
- // Compile the source shader to HLSL.
- std::string hlslSource;
- std::string remappedEntryPoint;
- DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(program, entryPointName, stage, layout,
- &remappedEntryPoint));
- entryPointName = remappedEntryPoint.c_str();
+ ShaderCompilationRequest request;
+ DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
+ entryPointName, stage, layout, compileFlags, device, program,
+ GetEntryPoint(entryPointName).bindings));
- if (device->IsToggleEnabled(Toggle::DumpShaders)) {
- std::ostringstream dumpedMsg;
- dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
- GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
- }
-
- // Use HLSL source as the input for the key since it does need to know about the pipeline
- // layout. The pipeline layout is only required if we key from WGSL: two different pipeline
- // layouts could be used to produce different shader blobs and the wrong shader blob could
- // be loaded since the pipeline layout was missing from the key.
- // The compiler flags or version used could also produce different HLSL source. HLSL key
- // needs both to ensure the shader cache key is unique to the HLSL source.
- // TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used.
PersistentCacheKey shaderCacheKey;
- DAWN_TRY_ASSIGN(shaderCacheKey,
- CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags));
+ DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
- DAWN_TRY_ASSIGN(compiledShader.cachedShader,
- device->GetPersistentCache()->GetOrCreate(
- shaderCacheKey, [&](auto doCache) -> MaybeError {
- if (device->IsToggleEnabled(Toggle::UseDXC)) {
- DAWN_TRY_ASSIGN(compiledShader.compiledDXCShader,
- CompileShaderDXC(device, stage, hlslSource,
- entryPointName, compileFlags));
- } else {
- DAWN_TRY_ASSIGN(compiledShader.compiledFXCShader,
- CompileShaderFXC(device, stage, hlslSource,
- entryPointName, compileFlags));
- }
- const D3D12_SHADER_BYTECODE shader =
- compiledShader.GetD3D12ShaderBytecode();
- doCache(shader.pShaderBytecode, shader.BytecodeLength);
- return {};
- }));
+ DAWN_TRY_ASSIGN(
+ compiledShader.cachedShader,
+ device->GetPersistentCache()->GetOrCreate(
+ shaderCacheKey, [&](auto doCache) -> MaybeError {
+ DAWN_TRY(CompileShader(
+ device->GetFunctions(),
+ device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get()
+ : nullptr,
+ device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get()
+ : nullptr,
+ std::move(request), device->IsToggleEnabled(Toggle::DumpShaders),
+ [&](WGPULoggingType loggingType, const char* message) {
+ GetDevice()->EmitLog(loggingType, message);
+ },
+ &compiledShader));
+ const D3D12_SHADER_BYTECODE shader = compiledShader.GetD3D12ShaderBytecode();
+ doCache(shader.pShaderBytecode, shader.BytecodeLength);
+ return {};
+ }));
return std::move(compiledShader);
}
@@ -389,69 +573,4 @@
UNREACHABLE();
return {};
}
-
- ResultOrError<PersistentCacheKey> ShaderModule::CreateHLSLKey(const char* entryPointName,
- SingleShaderStage stage,
- const std::string& hlslSource,
- uint32_t compileFlags) const {
- std::stringstream stream;
-
- // Prefix the key with the type to avoid collisions from another type that could have the
- // same key.
- stream << static_cast<uint32_t>(PersistentKeyType::Shader);
-
- // Provide "guard" strings that the user cannot provide to help ensure the generated HLSL
- // used to create this key is not being manufactured by the user to load the wrong shader
- // blob.
- // These strings can be HLSL comments because Tint does not emit HLSL comments.
- // TODO(dawn:549): Replace guards strings with something more secure.
- constexpr char kStartGuard[] = "// Start shader autogenerated by Dawn.";
- constexpr char kEndGuard[] = "// End shader autogenerated by Dawn.";
- ASSERT(hlslSource.find(kStartGuard) == std::string::npos);
- ASSERT(hlslSource.find(kEndGuard) == std::string::npos);
-
- stream << kStartGuard << "\n";
- stream << hlslSource;
- stream << "\n" << kEndGuard;
-
- stream << compileFlags;
-
- // Add the HLSL compiler version for good measure.
- // Prepend the compiler name to ensure the version is always unique.
- if (GetDevice()->IsToggleEnabled(Toggle::UseDXC)) {
- uint64_t dxCompilerVersion;
- DAWN_TRY_ASSIGN(dxCompilerVersion, GetDXCompilerVersion());
- stream << "DXC" << dxCompilerVersion;
- } else {
- stream << "FXC" << GetD3DCompilerVersion();
- }
-
- // If the source contains multiple entry points, ensure they are cached seperately
- // per stage since DX shader code can only be compiled per stage using the same
- // entry point.
- stream << static_cast<uint32_t>(stage);
- stream << entryPointName;
-
- return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
- std::istreambuf_iterator<char>{});
- }
-
- ResultOrError<uint64_t> ShaderModule::GetDXCompilerVersion() const {
- ComPtr<IDxcValidator> dxcValidator = ToBackend(GetDevice())->GetDxcValidator();
-
- ComPtr<IDxcVersionInfo> versionInfo;
- DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),
- "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo"));
-
- uint32_t compilerMajor, compilerMinor;
- DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor),
- "IDxcVersionInfo::GetVersion"));
-
- // Pack both into a single version number.
- return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor;
- }
-
- uint64_t ShaderModule::GetD3DCompilerVersion() const {
- return D3D_COMPILER_VERSION;
- }
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 436d7bd..880a35c 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -58,20 +58,6 @@
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
-
- ResultOrError<std::string> TranslateToHLSL(const tint::Program* program,
- const char* entryPointName,
- SingleShaderStage stage,
- PipelineLayout* layout,
- std::string* remappedEntryPointName) const;
-
- ResultOrError<PersistentCacheKey> CreateHLSLKey(const char* entryPointName,
- SingleShaderStage stage,
- const std::string& hlslSource,
- uint32_t compileFlags) const;
-
- ResultOrError<uint64_t> GetDXCompilerVersion() const;
- uint64_t GetD3DCompilerVersion() const;
};
}} // namespace dawn_native::d3d12