Refactor D3D12 shader define strings code
Instead of using template functions, generate define string pairs
in std::string. For DXC path, convert them to std::wstring. It is
okay since shader generation is already expensive.
By the way generalize it to all kinds of defines in
ShaderCompilationRequest.
Bug: dawn:1137
Change-Id: I5518e992b56497e28c8ac7e818bf19b4853dee4a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/70120
Commit-Queue: Shrek Shao <shrekshao@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index ec5b3b3..9fb8b19 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -34,94 +34,6 @@
#include <sstream>
#include <unordered_map>
-namespace dawn_native {
- template <typename StringType, typename T = int32_t>
- struct NumberToString {
- static StringType ToStringAsValue(T v);
- static StringType ToStringAsId(T v);
- };
-
- template <typename T>
- struct NumberToString<std::string, T> {
- static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
- static std::string ToStringAsValue(T v) {
- return std::to_string(v);
- }
- static std::string ToStringAsId(T v) {
- return std::to_string(v);
- }
- };
-
- template <typename T>
- struct NumberToString<std::wstring, T> {
- static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
- static std::wstring ToStringAsValue(T v) {
- return std::to_wstring(v);
- }
- static std::wstring ToStringAsId(T v) {
- return std::to_wstring(v);
- }
- };
-
- template <>
- struct NumberToString<std::string, float> {
- static std::string ToStringAsValue(float v) {
- std::ostringstream out;
- // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
- out.precision(8);
- out << std::fixed << v;
- return out.str();
- }
- };
-
- template <>
- struct NumberToString<std::wstring, float> {
- static std::wstring ToStringAsValue(float v) {
- std::basic_ostringstream<WCHAR> out;
- // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
- out.precision(8);
- out << std::fixed << v;
- return out.str();
- }
- };
-
- template <>
- struct NumberToString<std::string, uint32_t> {
- static std::string ToStringAsValue(uint32_t v) {
- return std::to_string(v) + "u";
- }
- };
-
- template <>
- struct NumberToString<std::wstring, uint32_t> {
- static std::wstring ToStringAsValue(uint32_t v) {
- return std::to_wstring(v) + L"u";
- }
- };
-
- template <typename StringType>
- StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
- const OverridableConstantScalar* entry,
- double value = 0) {
- switch (dawnType) {
- case EntryPointMetadata::OverridableConstant::Type::Boolean:
- return NumberToString<StringType, int32_t>::ToStringAsValue(
- entry ? entry->b : static_cast<int32_t>(value));
- case EntryPointMetadata::OverridableConstant::Type::Float32:
- return NumberToString<StringType, float>::ToStringAsValue(
- entry ? entry->f32 : static_cast<float>(value));
- case EntryPointMetadata::OverridableConstant::Type::Int32:
- return NumberToString<StringType, int32_t>::ToStringAsValue(
- entry ? entry->i32 : static_cast<int32_t>(value));
- case EntryPointMetadata::OverridableConstant::Type::Uint32:
- return NumberToString<StringType, uint32_t>::ToStringAsValue(
- entry ? entry->u32 : static_cast<uint32_t>(value));
- default:
- UNREACHABLE();
- }
- }
-} // namespace dawn_native
-
namespace dawn_native { namespace d3d12 {
namespace {
@@ -181,6 +93,73 @@
output << ")";
}
+ // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
+ std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
+ std::ostringstream out;
+ out.precision(n);
+ out << std::fixed << v;
+ return out.str();
+ }
+
+ std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
+ const OverridableConstantScalar* entry,
+ double value = 0) {
+ switch (dawnType) {
+ case EntryPointMetadata::OverridableConstant::Type::Boolean:
+ return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
+ case EntryPointMetadata::OverridableConstant::Type::Float32:
+ return FloatToStringWithPrecision(entry ? entry->f32
+ : static_cast<float>(value));
+ case EntryPointMetadata::OverridableConstant::Type::Int32:
+ return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
+ case EntryPointMetadata::OverridableConstant::Type::Uint32:
+ return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
+ default:
+ UNREACHABLE();
+ }
+ }
+
+ constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
+
+ void GetOverridableConstantsDefines(
+ std::vector<std::pair<std::string, std::string>>* defineStrings,
+ const PipelineConstantEntries* pipelineConstantEntries,
+ const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
+ std::unordered_set<std::string> overriddenConstants;
+
+ // Set pipeline overridden values
+ for (const auto& pipelineConstant : *pipelineConstantEntries) {
+ const std::string& name = pipelineConstant.first;
+ double value = pipelineConstant.second;
+
+ overriddenConstants.insert(name);
+
+ // This is already validated so `name` must exist
+ const auto& moduleConstant = shaderEntryPointConstants->at(name);
+
+ defineStrings->emplace_back(
+ kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
+ GetHLSLValueString(moduleConstant.type, nullptr, value));
+ }
+
+ // Set shader initialized default values
+ for (const auto& iter : *shaderEntryPointConstants) {
+ const std::string& name = iter.first;
+ if (overriddenConstants.count(name) != 0) {
+ // This constant already has overridden value
+ continue;
+ }
+
+ const auto& moduleConstant = shaderEntryPointConstants->at(name);
+
+ // Uninitialized default values are okay since they ar only defined to pass
+ // compilation but not used
+ defineStrings->emplace_back(
+ kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
+ GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
+ }
+ }
+
// The inputs to a shader compilation. These have been intentionally isolated from the
// device to help ensure that the pipeline cache key contains all inputs for compilation.
struct ShaderCompilationRequest {
@@ -199,8 +178,7 @@
bool usesNumWorkgroups;
uint32_t numWorkgroupsRegisterSpace;
uint32_t numWorkgroupsShaderRegister;
- const PipelineConstantEntries* pipelineConstantEntries;
- const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
+ std::vector<std::pair<std::string, std::string>> defineStrings;
// FXC/DXC common inputs
bool disableWorkgroupInit;
@@ -293,10 +271,12 @@
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
request.deviceInfo = &device->GetDeviceInfo();
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
- request.pipelineConstantEntries = &programmableStage.constants;
- request.shaderEntryPointConstants =
+
+ GetOverridableConstantsDefines(
+ &request.defineStrings, &programmableStage.constants,
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
- .overridableConstants;
+ .overridableConstants);
+
return std::move(request);
}
@@ -349,17 +329,9 @@
stream << " dxcVersion=" << dxcVersion;
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
- stream << " overridableConstants={";
- for (const auto& pipelineConstant : *pipelineConstantEntries) {
- const std::string& name = pipelineConstant.first;
- double value = pipelineConstant.second;
-
- // This is already validated so `name` must exist
- const auto& moduleConstant = shaderEntryPointConstants->at(name);
-
- stream << " <" << name << ","
- << GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
- << ">";
+ stream << " defines={";
+ for (const auto& it : defineStrings) {
+ stream << " <" << it.first << "," << it.second << ">";
}
stream << " }";
@@ -422,53 +394,6 @@
return arguments;
}
- template <typename StringType>
- const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
- const PipelineConstantEntries* pipelineConstantEntries,
- const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
- std::vector<std::pair<StringType, StringType>> defineStrings;
-
- std::unordered_set<std::string> overriddenConstants;
-
- // Set pipeline overridden values
- for (const auto& pipelineConstant : *pipelineConstantEntries) {
- const std::string& name = pipelineConstant.first;
- double value = pipelineConstant.second;
-
- overriddenConstants.insert(name);
-
- // This is already validated so `name` must exist
- const auto& moduleConstant = shaderEntryPointConstants->at(name);
-
- defineStrings.emplace_back(
- NumberToString<StringType>::kSpecConstantPrefix +
- NumberToString<StringType>::ToStringAsId(
- static_cast<int32_t>(moduleConstant.id)),
- GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
- }
-
- // Set shader initialized default values
- for (const auto& iter : *shaderEntryPointConstants) {
- const std::string& name = iter.first;
- if (overriddenConstants.count(name) != 0) {
- // This constant already has overridden value
- continue;
- }
-
- const auto& moduleConstant = shaderEntryPointConstants->at(name);
-
- // Uninitialized default values are okay since they are only defined to pass
- // compilation but not used
- defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
- NumberToString<StringType>::ToStringAsId(
- static_cast<int32_t>(moduleConstant.id)),
- GetHLSLValueString<StringType>(
- moduleConstant.type, &moduleConstant.defaultValue));
- }
-
- return defineStrings;
- }
-
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
IDxcCompiler* dxcCompiler,
const ShaderCompilationRequest& request,
@@ -486,8 +411,13 @@
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
// Build defines for overridable constants
- const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
- request.pipelineConstantEntries, request.shaderEntryPointConstants);
+ std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
+ defineStrings.reserve(request.defineStrings.size());
+ for (const auto& it : request.defineStrings) {
+ defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()),
+ UTF8ToWStr(it.second.c_str()));
+ }
+
std::vector<DxcDefine> dxcDefines;
dxcDefines.reserve(defineStrings.size());
for (const auto& d : defineStrings) {
@@ -538,14 +468,11 @@
ComPtr<ID3DBlob> errors;
// Build defines for overridable constants
- const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
- request.pipelineConstantEntries, request.shaderEntryPointConstants);
-
const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines;
- if (defineStrings.size() > 0) {
- fxcDefines.reserve(defineStrings.size() + 1);
- for (const auto& d : defineStrings) {
+ if (request.defineStrings.size() > 0) {
+ fxcDefines.reserve(request.defineStrings.size() + 1);
+ for (const auto& d : request.defineStrings) {
fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
}
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array