[d3d] Move shader translation and compile code to d3d
So it can be shared between d3d11 and d3d12 backends
Bug: dawn:1705
Change-Id: Iffabe8d77a0ac3713da985c0cac5839299dc2a47
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124883
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Peng Huang <penghuang@chromium.org>
diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn
index 1ca341d..ff21df0 100644
--- a/src/dawn/native/BUILD.gn
+++ b/src/dawn/native/BUILD.gn
@@ -409,6 +409,8 @@
"d3d/Forward.h",
"d3d/PlatformFunctions.cpp",
"d3d/PlatformFunctions.h",
+ "d3d/ShaderUtils.cpp",
+ "d3d/ShaderUtils.h",
"d3d/UtilsD3D.cpp",
"d3d/UtilsD3D.h",
"d3d/d3d_platform.h",
diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt
index d1625a1..3ad2292 100644
--- a/src/dawn/native/CMakeLists.txt
+++ b/src/dawn/native/CMakeLists.txt
@@ -267,6 +267,8 @@
"d3d/Forward.h"
"d3d/PlatformFunctions.cpp"
"d3d/PlatformFunctions.h"
+ "d3d/ShaderUtils.cpp"
+ "d3d/ShaderUtils.h"
"d3d/UtilsD3D.cpp"
"d3d/UtilsD3D.h"
"d3d/d3d_platform.h"
diff --git a/src/dawn/native/d3d/ShaderUtils.cpp b/src/dawn/native/d3d/ShaderUtils.cpp
new file mode 100644
index 0000000..ca39e92
--- /dev/null
+++ b/src/dawn/native/d3d/ShaderUtils.cpp
@@ -0,0 +1,280 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/native/d3d/ShaderUtils.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "dawn/native/d3d/BlobD3D.h"
+#include "dawn/native/d3d/D3DCompilationRequest.h"
+#include "dawn/native/d3d/D3DError.h"
+#include "dawn/native/d3d/UtilsD3D.h"
+#include "dawn/platform/DawnPlatform.h"
+#include "dawn/platform/tracing/TraceEvent.h"
+
+#include "tint/tint.h"
+
+namespace dawn::native::d3d {
+
+namespace {
+
+std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
+ std::vector<const wchar_t*> arguments;
+ if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
+ arguments.push_back(L"/Gec");
+ }
+ if (compileFlags & D3DCOMPILE_IEEE_STRICTNESS) {
+ arguments.push_back(L"/Gis");
+ }
+ constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2;
+ if (compileFlags & d3dCompileFlagsBits) {
+ switch (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) {
+ case D3DCOMPILE_OPTIMIZATION_LEVEL0:
+ arguments.push_back(L"/O0");
+ break;
+ case D3DCOMPILE_OPTIMIZATION_LEVEL2:
+ arguments.push_back(L"/O2");
+ break;
+ case D3DCOMPILE_OPTIMIZATION_LEVEL3:
+ arguments.push_back(L"/O3");
+ break;
+ }
+ }
+ if (compileFlags & D3DCOMPILE_DEBUG) {
+ arguments.push_back(L"/Zi");
+ }
+ if (compileFlags & D3DCOMPILE_PACK_MATRIX_ROW_MAJOR) {
+ arguments.push_back(L"/Zpr");
+ }
+ if (compileFlags & D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR) {
+ arguments.push_back(L"/Zpc");
+ }
+ if (compileFlags & D3DCOMPILE_AVOID_FLOW_CONTROL) {
+ arguments.push_back(L"/Gfa");
+ }
+ if (compileFlags & D3DCOMPILE_PREFER_FLOW_CONTROL) {
+ arguments.push_back(L"/Gfp");
+ }
+ if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) {
+ arguments.push_back(L"/res_may_alias");
+ }
+
+ if (enable16BitTypes) {
+ // enable-16bit-types are only allowed in -HV 2018 (default)
+ arguments.push_back(L"/enable-16bit-types");
+ }
+
+ arguments.push_back(L"-HV");
+ arguments.push_back(L"2018");
+
+ return arguments;
+}
+
+ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const d3d::D3DBytecodeCompilationRequest& r,
+ const std::string& entryPointName,
+ const std::string& hlslSource) {
+ ComPtr<IDxcBlobEncoding> sourceBlob;
+ DAWN_TRY(CheckHRESULT(r.dxcLibrary->CreateBlobWithEncodingFromPinned(
+ hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
+ "DXC create blob"));
+
+ std::wstring entryPointW;
+ DAWN_TRY_ASSIGN(entryPointW, d3d::ConvertStringToWstring(entryPointName));
+
+ std::vector<const wchar_t*> arguments = GetDXCArguments(r.compileFlags, r.hasShaderF16Feature);
+
+ ComPtr<IDxcOperationResult> result;
+ DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
+ r.dxcShaderProfile.data(), arguments.data(),
+ arguments.size(), nullptr, 0, nullptr, &result),
+ "DXC compile"));
+
+ HRESULT hr;
+ DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
+
+ if (FAILED(hr)) {
+ ComPtr<IDxcBlobEncoding> errors;
+ DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
+
+ return DAWN_VALIDATION_ERROR("DXC compile failed with: %s",
+ static_cast<char*>(errors->GetBufferPointer()));
+ }
+
+ ComPtr<IDxcBlob> compiledShader;
+ DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
+ return std::move(compiledShader);
+}
+
+ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const d3d::D3DBytecodeCompilationRequest& r,
+ const std::string& entryPointName,
+ const std::string& hlslSource) {
+ ComPtr<ID3DBlob> compiledShader;
+ ComPtr<ID3DBlob> errors;
+
+ DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
+ nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
+ r.compileFlags, 0, &compiledShader, &errors)),
+ "D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
+
+ return std::move(compiledShader);
+}
+
+ResultOrError<std::string> TranslateToHLSL(
+ d3d::HlslCompilationRequest r,
+ CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*> tracePlatform,
+ std::string* remappedEntryPointName,
+ bool* usesVertexOrInstanceIndex) {
+ std::ostringstream errorStream;
+ errorStream << "Tint HLSL failure:" << std::endl;
+
+ tint::transform::Manager transformManager;
+ tint::transform::DataMap transformInputs;
+
+ // Run before the renamer so that the entry point name matches `entryPointName` still.
+ transformManager.Add<tint::transform::SingleEntryPoint>();
+ transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName.data());
+
+ // Needs to run before all other transforms so that they can use builtin names safely.
+ transformManager.Add<tint::transform::Renamer>();
+ if (r.disableSymbolRenaming) {
+ // We still need to rename HLSL reserved keywords
+ transformInputs.Add<tint::transform::Renamer::Config>(
+ tint::transform::Renamer::Target::kHlslKeywords);
+ }
+
+ if (r.stage == SingleShaderStage::Vertex) {
+ transformManager.Add<tint::transform::FirstIndexOffset>();
+ transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
+ r.firstIndexOffsetShaderRegister, r.firstIndexOffsetRegisterSpace);
+ }
+
+ if (r.substituteOverrideConfig) {
+ // This needs to run after SingleEntryPoint transform which removes unused overrides for
+ // current entry point.
+ transformManager.Add<tint::transform::SubstituteOverride>();
+ transformInputs.Add<tint::transform::SubstituteOverride::Config>(
+ std::move(r.substituteOverrideConfig).value());
+ }
+
+ tint::Program transformedProgram;
+ tint::transform::DataMap transformOutputs;
+ {
+ TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms");
+ DAWN_TRY_ASSIGN(transformedProgram,
+ RunTransforms(&transformManager, r.inputProgram, transformInputs,
+ &transformOutputs, nullptr));
+ }
+
+ if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
+ auto it = data->remappings.find(r.entryPointName.data());
+ if (it != data->remappings.end()) {
+ *remappedEntryPointName = it->second;
+ } else {
+ DAWN_INVALID_IF(!r.disableSymbolRenaming,
+ "Could not find remapped name for entry point.");
+
+ *remappedEntryPointName = r.entryPointName;
+ }
+ } else {
+ return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
+ }
+
+ if (r.stage == SingleShaderStage::Compute) {
+ // Validate workgroup size after program runs transforms.
+ Extent3D _;
+ DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
+ transformedProgram, remappedEntryPointName->data(), r.limits));
+ }
+
+ if (r.stage == SingleShaderStage::Vertex) {
+ if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
+ *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
+ } else {
+ return DAWN_VALIDATION_ERROR("Transform output missing first index offset data.");
+ }
+ }
+
+ tint::writer::hlsl::Options options;
+ options.disable_robustness = !r.isRobustnessEnabled;
+ options.disable_workgroup_init = r.disableWorkgroupInit;
+ options.binding_remapper_options = r.bindingRemapper;
+ options.external_texture_options = r.externalTextureOptions;
+
+ if (r.usesNumWorkgroups) {
+ options.root_constant_binding_point =
+ tint::writer::BindingPoint{r.numWorkgroupsRegisterSpace, r.numWorkgroupsShaderRegister};
+ }
+ // TODO(dawn:549): HLSL generation outputs the indices into the
+ // array_length_from_uniform buffer that were actually used. When the blob cache can
+ // store more than compiled shaders, we should reflect these used indices and store
+ // them as well. This would allow us to only upload root constants that are actually
+ // read by the shader.
+ options.array_length_from_uniform = r.arrayLengthFromUniform;
+
+ if (r.stage == SingleShaderStage::Vertex) {
+ // Now that only vertex shader can have interstage outputs.
+ // Pass in the actually used interstage locations for tint to potentially truncate unused
+ // outputs.
+ options.interstage_locations = r.interstageLocations;
+ }
+
+ options.polyfill_reflect_vec2_f32 = r.polyfillReflectVec2F32;
+
+ TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "tint::writer::hlsl::Generate");
+ auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
+ DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", result.error);
+
+ return std::move(result.hlsl);
+}
+
+} // anonymous namespace
+
+ResultOrError<CompiledShader> CompileShader(d3d::D3DCompilationRequest r) {
+ CompiledShader compiledShader;
+ // Compile the source shader to HLSL.
+ std::string remappedEntryPoint;
+ DAWN_TRY_ASSIGN(compiledShader.hlslSource,
+ TranslateToHLSL(std::move(r.hlsl), r.tracePlatform, &remappedEntryPoint,
+ &compiledShader.usesVertexOrInstanceIndex));
+
+ switch (r.bytecode.compiler) {
+ case d3d::Compiler::DXC: {
+ TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderDXC");
+ ComPtr<IDxcBlob> compiledDXCShader;
+ DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(r.bytecode, remappedEntryPoint,
+ compiledShader.hlslSource));
+ compiledShader.shaderBlob = CreateBlob(std::move(compiledDXCShader));
+ break;
+ }
+ case d3d::Compiler::FXC: {
+ TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderFXC");
+ ComPtr<ID3DBlob> compiledFXCShader;
+ DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(r.bytecode, remappedEntryPoint,
+ compiledShader.hlslSource));
+ compiledShader.shaderBlob = CreateBlob(std::move(compiledFXCShader));
+ break;
+ }
+ }
+
+ // If dumpShaders is false, we don't need the HLSL for logging. Clear the contents so it
+ // isn't stored into the cache.
+ if (!r.hlsl.dumpShaders) {
+ compiledShader.hlslSource = "";
+ }
+ return compiledShader;
+}
+
+} // namespace dawn::native::d3d
diff --git a/src/dawn/native/d3d/ShaderUtils.h b/src/dawn/native/d3d/ShaderUtils.h
new file mode 100644
index 0000000..bbfc2b1
--- /dev/null
+++ b/src/dawn/native/d3d/ShaderUtils.h
@@ -0,0 +1,43 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_DAWN_NATIVE_D3D_SHADEUTILS_H_
+#define SRC_DAWN_NATIVE_D3D_SHADEUTILS_H_
+
+#include <string>
+
+#include "dawn/native/Blob.h"
+#include "dawn/native/Serializable.h"
+#include "dawn/native/d3d/D3DCompilationRequest.h"
+#include "dawn/native/d3d/d3d_platform.h"
+
+namespace dawn::native::d3d {
+
+#define COMPILED_SHADER_MEMBERS(X) \
+ X(Blob, shaderBlob) \
+ X(std::string, hlslSource) \
+ X(bool, usesVertexOrInstanceIndex)
+
+// `CompiledShader` holds a ref to one of the various representations of shader blobs and
+// information used to emulate vertex/instance index starts. It also holds the `hlslSource` for the
+// shader compilation, which is only transiently available during Compile, and cleared before it
+// returns. It is not written to or loaded from the cache unless Toggle dump_shaders is true.
+DAWN_SERIALIZABLE(struct, CompiledShader, COMPILED_SHADER_MEMBERS){};
+#undef COMPILED_SHADER_MEMBERS
+
+ResultOrError<CompiledShader> CompileShader(d3d::D3DCompilationRequest r);
+
+} // namespace dawn::native::d3d
+
+#endif // SRC_DAWN_NATIVE_D3D_SHADEUTILS_H_
diff --git a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
index e781f0f..8585940 100644
--- a/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn/native/d3d12/ComputePipelineD3D12.cpp
@@ -57,10 +57,10 @@
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
// TODO(dawn:549): Compile shader everytime before we implement compiled shader cache
- CompiledShader compiledShader;
+ d3d::CompiledShader compiledShader;
DAWN_TRY_ASSIGN(compiledShader, module->Compile(computeStage, SingleShaderStage::Compute,
ToBackend(GetLayout()), compileFlags));
- d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
+ d3dDesc.CS = {compiledShader.shaderBlob.Data(), compiledShader.shaderBlob.Size()};
StreamIn(&mCacheKey, d3dDesc, ToBackend(GetLayout())->GetRootSignatureBlob());
diff --git a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
index 6cfa55d..e6988e4 100644
--- a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp
@@ -380,7 +380,7 @@
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
- PerStage<CompiledShader> compiledShader;
+ PerStage<d3d::CompiledShader> compiledShader;
std::bitset<kMaxInterStageShaderVariables>* usedInterstageVariables = nullptr;
dawn::native::EntryPointMetadata fragmentEntryPoint;
@@ -397,7 +397,8 @@
ToBackend(programmableStage.module)
->Compile(programmableStage, stage, ToBackend(GetLayout()),
compileFlags, usedInterstageVariables));
- *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
+ *shaders[stage] = {compiledShader[stage].shaderBlob.Data(),
+ compiledShader[stage].shaderBlob.Size()};
}
mUsesVertexOrInstanceIndex =
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
index 745b27a..0c3d819 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp
@@ -14,22 +14,14 @@
#include "dawn/native/d3d12/ShaderModuleD3D12.h"
-#include <map>
-#include <sstream>
#include <string>
-#include <unordered_map>
-#include <unordered_set>
#include <utility>
-#include <vector>
#include "dawn/common/Assert.h"
#include "dawn/common/BitSetIterator.h"
#include "dawn/common/Log.h"
-#include "dawn/common/WindowsUtils.h"
-#include "dawn/native/CacheKey.h"
#include "dawn/native/Pipeline.h"
#include "dawn/native/TintUtils.h"
-#include "dawn/native/d3d/BlobD3D.h"
#include "dawn/native/d3d/D3DCompilationRequest.h"
#include "dawn/native/d3d/D3DError.h"
#include "dawn/native/d3d12/AdapterD3D12.h"
@@ -39,8 +31,6 @@
#include "dawn/native/d3d12/PipelineLayoutD3D12.h"
#include "dawn/native/d3d12/PlatformFunctionsD3D12.h"
#include "dawn/native/d3d12/UtilsD3D12.h"
-#include "dawn/native/stream/BlobSource.h"
-#include "dawn/native/stream/ByteVectorSink.h"
#include "dawn/platform/DawnPlatform.h"
#include "dawn/platform/tracing/TraceEvent.h"
@@ -49,94 +39,6 @@
namespace dawn::native::d3d12 {
namespace {
-
-std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
- std::vector<const wchar_t*> arguments;
- if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
- arguments.push_back(L"/Gec");
- }
- if (compileFlags & D3DCOMPILE_IEEE_STRICTNESS) {
- arguments.push_back(L"/Gis");
- }
- constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2;
- if (compileFlags & d3dCompileFlagsBits) {
- switch (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) {
- case D3DCOMPILE_OPTIMIZATION_LEVEL0:
- arguments.push_back(L"/O0");
- break;
- case D3DCOMPILE_OPTIMIZATION_LEVEL2:
- arguments.push_back(L"/O2");
- break;
- case D3DCOMPILE_OPTIMIZATION_LEVEL3:
- arguments.push_back(L"/O3");
- break;
- }
- }
- if (compileFlags & D3DCOMPILE_DEBUG) {
- arguments.push_back(L"/Zi");
- }
- if (compileFlags & D3DCOMPILE_PACK_MATRIX_ROW_MAJOR) {
- arguments.push_back(L"/Zpr");
- }
- if (compileFlags & D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR) {
- arguments.push_back(L"/Zpc");
- }
- if (compileFlags & D3DCOMPILE_AVOID_FLOW_CONTROL) {
- arguments.push_back(L"/Gfa");
- }
- if (compileFlags & D3DCOMPILE_PREFER_FLOW_CONTROL) {
- arguments.push_back(L"/Gfp");
- }
- if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) {
- arguments.push_back(L"/res_may_alias");
- }
-
- if (enable16BitTypes) {
- // enable-16bit-types are only allowed in -HV 2018 (default)
- arguments.push_back(L"/enable-16bit-types");
- }
-
- arguments.push_back(L"-HV");
- arguments.push_back(L"2018");
-
- return arguments;
-}
-
-ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(const d3d::D3DBytecodeCompilationRequest& r,
- const std::string& entryPointName,
- const std::string& hlslSource) {
- ComPtr<IDxcBlobEncoding> sourceBlob;
- DAWN_TRY(CheckHRESULT(r.dxcLibrary->CreateBlobWithEncodingFromPinned(
- hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
- "DXC create blob"));
-
- std::wstring entryPointW;
- DAWN_TRY_ASSIGN(entryPointW, d3d::ConvertStringToWstring(entryPointName));
-
- std::vector<const wchar_t*> arguments = GetDXCArguments(r.compileFlags, r.hasShaderF16Feature);
-
- ComPtr<IDxcOperationResult> result;
- DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
- r.dxcShaderProfile.data(), arguments.data(),
- arguments.size(), nullptr, 0, nullptr, &result),
- "DXC compile"));
-
- HRESULT hr;
- DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
-
- if (FAILED(hr)) {
- ComPtr<IDxcBlobEncoding> errors;
- DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
-
- return DAWN_VALIDATION_ERROR("DXC compile failed with: %s",
- static_cast<char*>(errors->GetBufferPointer()));
- }
-
- ComPtr<IDxcBlob> compiledShader;
- DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
- return std::move(compiledShader);
-}
-
std::string CompileFlagsToStringFXC(uint32_t compileFlags) {
struct Flag {
uint32_t value;
@@ -199,163 +101,6 @@
return result;
}
-ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const d3d::D3DBytecodeCompilationRequest& r,
- const std::string& entryPointName,
- const std::string& hlslSource) {
- ComPtr<ID3DBlob> compiledShader;
- ComPtr<ID3DBlob> errors;
-
- DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
- nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(),
- r.compileFlags, 0, &compiledShader, &errors)),
- "D3D compile failed with: %s", static_cast<char*>(errors->GetBufferPointer()));
-
- return std::move(compiledShader);
-}
-
-ResultOrError<std::string> TranslateToHLSL(
- d3d::HlslCompilationRequest r,
- CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*> tracePlatform,
- std::string* remappedEntryPointName,
- bool* usesVertexOrInstanceIndex) {
- std::ostringstream errorStream;
- errorStream << "Tint HLSL failure:" << std::endl;
-
- tint::transform::Manager transformManager;
- tint::transform::DataMap transformInputs;
-
- // Run before the renamer so that the entry point name matches `entryPointName` still.
- transformManager.Add<tint::transform::SingleEntryPoint>();
- transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName.data());
-
- // Needs to run before all other transforms so that they can use builtin names safely.
- transformManager.Add<tint::transform::Renamer>();
- if (r.disableSymbolRenaming) {
- // We still need to rename HLSL reserved keywords
- transformInputs.Add<tint::transform::Renamer::Config>(
- tint::transform::Renamer::Target::kHlslKeywords);
- }
-
- if (r.stage == SingleShaderStage::Vertex) {
- transformManager.Add<tint::transform::FirstIndexOffset>();
- transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
- r.firstIndexOffsetShaderRegister, r.firstIndexOffsetRegisterSpace);
- }
-
- if (r.substituteOverrideConfig) {
- // This needs to run after SingleEntryPoint transform which removes unused overrides for
- // current entry point.
- transformManager.Add<tint::transform::SubstituteOverride>();
- transformInputs.Add<tint::transform::SubstituteOverride::Config>(
- std::move(r.substituteOverrideConfig).value());
- }
-
- tint::Program transformedProgram;
- tint::transform::DataMap transformOutputs;
- {
- TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms");
- DAWN_TRY_ASSIGN(transformedProgram,
- RunTransforms(&transformManager, r.inputProgram, transformInputs,
- &transformOutputs, nullptr));
- }
-
- if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
- auto it = data->remappings.find(r.entryPointName.data());
- if (it != data->remappings.end()) {
- *remappedEntryPointName = it->second;
- } else {
- DAWN_INVALID_IF(!r.disableSymbolRenaming,
- "Could not find remapped name for entry point.");
-
- *remappedEntryPointName = r.entryPointName;
- }
- } else {
- return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
- }
-
- if (r.stage == SingleShaderStage::Compute) {
- // Validate workgroup size after program runs transforms.
- Extent3D _;
- DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
- transformedProgram, remappedEntryPointName->data(), r.limits));
- }
-
- if (r.stage == SingleShaderStage::Vertex) {
- if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
- *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index;
- } else {
- return DAWN_VALIDATION_ERROR("Transform output missing first index offset data.");
- }
- }
-
- tint::writer::hlsl::Options options;
- options.disable_robustness = !r.isRobustnessEnabled;
- options.disable_workgroup_init = r.disableWorkgroupInit;
- options.binding_remapper_options = r.bindingRemapper;
- options.external_texture_options = r.externalTextureOptions;
-
- if (r.usesNumWorkgroups) {
- options.root_constant_binding_point =
- tint::writer::BindingPoint{r.numWorkgroupsRegisterSpace, r.numWorkgroupsShaderRegister};
- }
- // TODO(dawn:549): HLSL generation outputs the indices into the
- // array_length_from_uniform buffer that were actually used. When the blob cache can
- // store more than compiled shaders, we should reflect these used indices and store
- // them as well. This would allow us to only upload root constants that are actually
- // read by the shader.
- options.array_length_from_uniform = r.arrayLengthFromUniform;
-
- if (r.stage == SingleShaderStage::Vertex) {
- // Now that only vertex shader can have interstage outputs.
- // Pass in the actually used interstage locations for tint to potentially truncate unused
- // outputs.
- options.interstage_locations = r.interstageLocations;
- }
-
- options.polyfill_reflect_vec2_f32 = r.polyfillReflectVec2F32;
-
- TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "tint::writer::hlsl::Generate");
- auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
- DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", result.error);
-
- return std::move(result.hlsl);
-}
-
-ResultOrError<CompiledShader> CompileShader(d3d::D3DCompilationRequest r) {
- CompiledShader compiledShader;
- // Compile the source shader to HLSL.
- std::string remappedEntryPoint;
- DAWN_TRY_ASSIGN(compiledShader.hlslSource,
- TranslateToHLSL(std::move(r.hlsl), r.tracePlatform, &remappedEntryPoint,
- &compiledShader.usesVertexOrInstanceIndex));
-
- switch (r.bytecode.compiler) {
- case d3d::Compiler::DXC: {
- TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderDXC");
- ComPtr<IDxcBlob> compiledDXCShader;
- DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(r.bytecode, remappedEntryPoint,
- compiledShader.hlslSource));
- compiledShader.shaderBlob = CreateBlob(std::move(compiledDXCShader));
- break;
- }
- case d3d::Compiler::FXC: {
- TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderFXC");
- ComPtr<ID3DBlob> compiledFXCShader;
- DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(r.bytecode, remappedEntryPoint,
- compiledShader.hlslSource));
- compiledShader.shaderBlob = CreateBlob(std::move(compiledFXCShader));
- break;
- }
- }
-
- // If dumpShaders is false, we don't need the HLSL for logging. Clear the contents so it
- // isn't stored into the cache.
- if (!r.hlsl.dumpShaders) {
- compiledShader.hlslSource = "";
- }
- return compiledShader;
-}
-
} // anonymous namespace
// static
@@ -378,7 +123,7 @@
return InitializeBase(parseResult, compilationMessages);
}
-ResultOrError<CompiledShader> ShaderModule::Compile(
+ResultOrError<d3d::CompiledShader> ShaderModule::Compile(
const ProgrammableStage& programmableStage,
SingleShaderStage stage,
const PipelineLayout* layout,
@@ -526,9 +271,9 @@
const CombinedLimits& limits = device->GetLimits();
req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1);
- CacheResult<CompiledShader> compiledShader;
- DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob,
- CompileShader);
+ CacheResult<d3d::CompiledShader> compiledShader;
+ DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), d3d::CompiledShader::FromBlob,
+ d3d::CompileShader);
if (device->IsToggleEnabled(Toggle::DumpShaders)) {
std::ostringstream dumpedMsg;
@@ -537,11 +282,11 @@
if (device->IsToggleEnabled(Toggle::UseDXC)) {
dumpedMsg << "/* Dumped disassembled DXIL */" << std::endl;
- D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode();
+ const Blob& shaderBlob = compiledShader->shaderBlob;
ComPtr<IDxcBlobEncoding> dxcBlob;
ComPtr<IDxcBlobEncoding> disassembly;
if (FAILED(device->GetDxcLibrary()->CreateBlobWithEncodingFromPinned(
- code.pShaderBytecode, code.BytecodeLength, 0, &dxcBlob)) ||
+ shaderBlob.Data(), shaderBlob.Size(), 0, &dxcBlob)) ||
FAILED(device->GetDxcCompiler()->Disassemble(dxcBlob.Get(), &disassembly))) {
dumpedMsg << "DXC disassemble failed" << std::endl;
} else {
@@ -554,13 +299,13 @@
<< CompileFlagsToStringFXC(compileFlags) << std::endl;
dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;
ComPtr<ID3DBlob> disassembly;
- D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode();
+ const Blob& shaderBlob = compiledShader->shaderBlob;
UINT flags =
// Some literals are printed as floats with precision(6) which is not enough
// precision for values very close to 0, so always print literals as hex values.
D3D_DISASM_PRINT_HEX_LITERALS;
- if (FAILED(device->GetFunctions()->d3dDisassemble(
- code.pShaderBytecode, code.BytecodeLength, flags, nullptr, &disassembly))) {
+ if (FAILED(device->GetFunctions()->d3dDisassemble(shaderBlob.Data(), shaderBlob.Size(),
+ flags, nullptr, &disassembly))) {
dumpedMsg << "D3D disassemble failed" << std::endl;
} else {
dumpedMsg << std::string_view(
@@ -575,13 +320,9 @@
// Clear the hlslSource. It is only used for logging and should not be used
// outside of the compilation.
- CompiledShader result = compiledShader.Acquire();
+ d3d::CompiledShader result = compiledShader.Acquire();
result.hlslSource = "";
return result;
}
-D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
- return {shaderBlob.Data(), shaderBlob.Size()};
-}
-
} // namespace dawn::native::d3d12
diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h
index f646117..9cf1c0d 100644
--- a/src/dawn/native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h
@@ -20,6 +20,7 @@
#include "dawn/native/Blob.h"
#include "dawn/native/Serializable.h"
#include "dawn/native/ShaderModule.h"
+#include "dawn/native/d3d/ShaderUtils.h"
#include "dawn/native/d3d12/d3d12_platform.h"
namespace dawn::native {
@@ -31,20 +32,6 @@
class Device;
class PipelineLayout;
-#define COMPILED_SHADER_MEMBERS(X) \
- X(Blob, shaderBlob) \
- X(std::string, hlslSource) \
- X(bool, usesVertexOrInstanceIndex)
-
-// `CompiledShader` holds a ref to one of the various representations of shader blobs and
-// information used to emulate vertex/instance index starts. It also holds the `hlslSource` for the
-// shader compilation, which is only transiently available during Compile, and cleared before it
-// returns. It is not written to or loaded from the cache unless Toggle dump_shaders is true.
-DAWN_SERIALIZABLE(struct, CompiledShader, COMPILED_SHADER_MEMBERS) {
- D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
-};
-#undef COMPILED_SHADER_MEMBERS
-
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
@@ -52,7 +39,7 @@
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
- ResultOrError<CompiledShader> Compile(
+ ResultOrError<d3d::CompiledShader> Compile(
const ProgrammableStage& programmableStage,
SingleShaderStage stage,
const PipelineLayout* layout,