Refactor D3D12 shader define strings code

Instead of using template functions, generate define string pairs
in std::string. For DXC path, convert them to std::wstring. It is
okay since shader generation is already expensive.

By the way generalize it to all kinds of defines in
ShaderCompilationRequest.

Bug: dawn:1137
Change-Id: I5518e992b56497e28c8ac7e818bf19b4853dee4a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/70120
Commit-Queue: Shrek Shao <shrekshao@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index ec5b3b3..9fb8b19 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -34,94 +34,6 @@
 #include <sstream>
 #include <unordered_map>
 
-namespace dawn_native {
-    template <typename StringType, typename T = int32_t>
-    struct NumberToString {
-        static StringType ToStringAsValue(T v);
-        static StringType ToStringAsId(T v);
-    };
-
-    template <typename T>
-    struct NumberToString<std::string, T> {
-        static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
-        static std::string ToStringAsValue(T v) {
-            return std::to_string(v);
-        }
-        static std::string ToStringAsId(T v) {
-            return std::to_string(v);
-        }
-    };
-
-    template <typename T>
-    struct NumberToString<std::wstring, T> {
-        static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
-        static std::wstring ToStringAsValue(T v) {
-            return std::to_wstring(v);
-        }
-        static std::wstring ToStringAsId(T v) {
-            return std::to_wstring(v);
-        }
-    };
-
-    template <>
-    struct NumberToString<std::string, float> {
-        static std::string ToStringAsValue(float v) {
-            std::ostringstream out;
-            // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
-            out.precision(8);
-            out << std::fixed << v;
-            return out.str();
-        }
-    };
-
-    template <>
-    struct NumberToString<std::wstring, float> {
-        static std::wstring ToStringAsValue(float v) {
-            std::basic_ostringstream<WCHAR> out;
-            // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
-            out.precision(8);
-            out << std::fixed << v;
-            return out.str();
-        }
-    };
-
-    template <>
-    struct NumberToString<std::string, uint32_t> {
-        static std::string ToStringAsValue(uint32_t v) {
-            return std::to_string(v) + "u";
-        }
-    };
-
-    template <>
-    struct NumberToString<std::wstring, uint32_t> {
-        static std::wstring ToStringAsValue(uint32_t v) {
-            return std::to_wstring(v) + L"u";
-        }
-    };
-
-    template <typename StringType>
-    StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
-                                  const OverridableConstantScalar* entry,
-                                  double value = 0) {
-        switch (dawnType) {
-            case EntryPointMetadata::OverridableConstant::Type::Boolean:
-                return NumberToString<StringType, int32_t>::ToStringAsValue(
-                    entry ? entry->b : static_cast<int32_t>(value));
-            case EntryPointMetadata::OverridableConstant::Type::Float32:
-                return NumberToString<StringType, float>::ToStringAsValue(
-                    entry ? entry->f32 : static_cast<float>(value));
-            case EntryPointMetadata::OverridableConstant::Type::Int32:
-                return NumberToString<StringType, int32_t>::ToStringAsValue(
-                    entry ? entry->i32 : static_cast<int32_t>(value));
-            case EntryPointMetadata::OverridableConstant::Type::Uint32:
-                return NumberToString<StringType, uint32_t>::ToStringAsValue(
-                    entry ? entry->u32 : static_cast<uint32_t>(value));
-            default:
-                UNREACHABLE();
-        }
-    }
-}  // namespace dawn_native
-
 namespace dawn_native { namespace d3d12 {
 
     namespace {
@@ -181,6 +93,73 @@
             output << ")";
         }
 
+        // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
+        std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
+            std::ostringstream out;
+            out.precision(n);
+            out << std::fixed << v;
+            return out.str();
+        }
+
+        std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
+                                       const OverridableConstantScalar* entry,
+                                       double value = 0) {
+            switch (dawnType) {
+                case EntryPointMetadata::OverridableConstant::Type::Boolean:
+                    return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
+                case EntryPointMetadata::OverridableConstant::Type::Float32:
+                    return FloatToStringWithPrecision(entry ? entry->f32
+                                                            : static_cast<float>(value));
+                case EntryPointMetadata::OverridableConstant::Type::Int32:
+                    return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
+                case EntryPointMetadata::OverridableConstant::Type::Uint32:
+                    return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
+                default:
+                    UNREACHABLE();
+            }
+        }
+
+        constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
+
+        void GetOverridableConstantsDefines(
+            std::vector<std::pair<std::string, std::string>>* defineStrings,
+            const PipelineConstantEntries* pipelineConstantEntries,
+            const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
+            std::unordered_set<std::string> overriddenConstants;
+
+            // Set pipeline overridden values
+            for (const auto& pipelineConstant : *pipelineConstantEntries) {
+                const std::string& name = pipelineConstant.first;
+                double value = pipelineConstant.second;
+
+                overriddenConstants.insert(name);
+
+                // This is already validated so `name` must exist
+                const auto& moduleConstant = shaderEntryPointConstants->at(name);
+
+                defineStrings->emplace_back(
+                    kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
+                    GetHLSLValueString(moduleConstant.type, nullptr, value));
+            }
+
+            // Set shader initialized default values
+            for (const auto& iter : *shaderEntryPointConstants) {
+                const std::string& name = iter.first;
+                if (overriddenConstants.count(name) != 0) {
+                    // This constant already has overridden value
+                    continue;
+                }
+
+                const auto& moduleConstant = shaderEntryPointConstants->at(name);
+
+                // Uninitialized default values are okay since they ar only defined to pass
+                // compilation but not used
+                defineStrings->emplace_back(
+                    kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
+                    GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
+            }
+        }
+
         // The inputs to a shader compilation. These have been intentionally isolated from the
         // device to help ensure that the pipeline cache key contains all inputs for compilation.
         struct ShaderCompilationRequest {
@@ -199,8 +178,7 @@
             bool usesNumWorkgroups;
             uint32_t numWorkgroupsRegisterSpace;
             uint32_t numWorkgroupsShaderRegister;
-            const PipelineConstantEntries* pipelineConstantEntries;
-            const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
+            std::vector<std::pair<std::string, std::string>> defineStrings;
 
             // FXC/DXC common inputs
             bool disableWorkgroupInit;
@@ -293,10 +271,12 @@
                 request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
                 request.deviceInfo = &device->GetDeviceInfo();
                 request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
-                request.pipelineConstantEntries = &programmableStage.constants;
-                request.shaderEntryPointConstants =
+
+                GetOverridableConstantsDefines(
+                    &request.defineStrings, &programmableStage.constants,
                     &programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
-                         .overridableConstants;
+                         .overridableConstants);
+
                 return std::move(request);
             }
 
@@ -349,17 +329,9 @@
                 stream << " dxcVersion=" << dxcVersion;
                 stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
 
-                stream << " overridableConstants={";
-                for (const auto& pipelineConstant : *pipelineConstantEntries) {
-                    const std::string& name = pipelineConstant.first;
-                    double value = pipelineConstant.second;
-
-                    // This is already validated so `name` must exist
-                    const auto& moduleConstant = shaderEntryPointConstants->at(name);
-
-                    stream << " <" << name << ","
-                           << GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
-                           << ">";
+                stream << " defines={";
+                for (const auto& it : defineStrings) {
+                    stream << " <" << it.first << "," << it.second << ">";
                 }
                 stream << " }";
 
@@ -422,53 +394,6 @@
             return arguments;
         }
 
-        template <typename StringType>
-        const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
-            const PipelineConstantEntries* pipelineConstantEntries,
-            const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
-            std::vector<std::pair<StringType, StringType>> defineStrings;
-
-            std::unordered_set<std::string> overriddenConstants;
-
-            // Set pipeline overridden values
-            for (const auto& pipelineConstant : *pipelineConstantEntries) {
-                const std::string& name = pipelineConstant.first;
-                double value = pipelineConstant.second;
-
-                overriddenConstants.insert(name);
-
-                // This is already validated so `name` must exist
-                const auto& moduleConstant = shaderEntryPointConstants->at(name);
-
-                defineStrings.emplace_back(
-                    NumberToString<StringType>::kSpecConstantPrefix +
-                        NumberToString<StringType>::ToStringAsId(
-                            static_cast<int32_t>(moduleConstant.id)),
-                    GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
-            }
-
-            // Set shader initialized default values
-            for (const auto& iter : *shaderEntryPointConstants) {
-                const std::string& name = iter.first;
-                if (overriddenConstants.count(name) != 0) {
-                    // This constant already has overridden value
-                    continue;
-                }
-
-                const auto& moduleConstant = shaderEntryPointConstants->at(name);
-
-                // Uninitialized default values are okay since they are only defined to pass
-                // compilation but not used
-                defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
-                                               NumberToString<StringType>::ToStringAsId(
-                                                   static_cast<int32_t>(moduleConstant.id)),
-                                           GetHLSLValueString<StringType>(
-                                               moduleConstant.type, &moduleConstant.defaultValue));
-            }
-
-            return defineStrings;
-        }
-
         ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
                                                          IDxcCompiler* dxcCompiler,
                                                          const ShaderCompilationRequest& request,
@@ -486,8 +411,13 @@
                 GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
 
             // Build defines for overridable constants
-            const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
-                request.pipelineConstantEntries, request.shaderEntryPointConstants);
+            std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
+            defineStrings.reserve(request.defineStrings.size());
+            for (const auto& it : request.defineStrings) {
+                defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()),
+                                           UTF8ToWStr(it.second.c_str()));
+            }
+
             std::vector<DxcDefine> dxcDefines;
             dxcDefines.reserve(defineStrings.size());
             for (const auto& d : defineStrings) {
@@ -538,14 +468,11 @@
             ComPtr<ID3DBlob> errors;
 
             // Build defines for overridable constants
-            const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
-                request.pipelineConstantEntries, request.shaderEntryPointConstants);
-
             const D3D_SHADER_MACRO* pDefines = nullptr;
             std::vector<D3D_SHADER_MACRO> fxcDefines;
-            if (defineStrings.size() > 0) {
-                fxcDefines.reserve(defineStrings.size() + 1);
-                for (const auto& d : defineStrings) {
+            if (request.defineStrings.size() > 0) {
+                fxcDefines.reserve(request.defineStrings.size() + 1);
+                for (const auto& d : request.defineStrings) {
                     fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
                 }
                 // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array