Check FP16 support on D3D backend

True FP16 is only supported in DXC through Shader Model 6.2, also
check the value of the Native16BitShaderOpsSupported member of
D3D12_FEATURE_DATA_D3D12_OPTIONS4 to view whether hardware actually
supports FP16 operations.

BUG=dawn:426
TEST=dawn_end2end_tests

Change-Id: If675f7ba650cb1bd8c792928b70619b9ccda048a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/23243
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Xinghua Cao <xinghua.cao@intel.com>
diff --git a/src/dawn_native/d3d12/AdapterD3D12.cpp b/src/dawn_native/d3d12/AdapterD3D12.cpp
index 68fa27c..c6078e5 100644
--- a/src/dawn_native/d3d12/AdapterD3D12.cpp
+++ b/src/dawn_native/d3d12/AdapterD3D12.cpp
@@ -102,6 +102,9 @@
         mSupportedExtensions.EnableExtension(Extension::TextureCompressionBC);
         mSupportedExtensions.EnableExtension(Extension::PipelineStatisticsQuery);
         mSupportedExtensions.EnableExtension(Extension::TimestampQuery);
+        if (mDeviceInfo.supportsShaderFloat16 && GetBackend()->GetFunctions()->IsDXCAvailable()) {
+            mSupportedExtensions.EnableExtension(Extension::ShaderFloat16);
+        }
     }
 
     MaybeError Adapter::InitializeDebugLayerFilters() {
diff --git a/src/dawn_native/d3d12/D3D12Info.cpp b/src/dawn_native/d3d12/D3D12Info.cpp
index 6505e44..2ca6429 100644
--- a/src/dawn_native/d3d12/D3D12Info.cpp
+++ b/src/dawn_native/d3d12/D3D12Info.cpp
@@ -59,6 +59,63 @@
             }
         }
 
-        return info;
+        D3D12_FEATURE_DATA_SHADER_MODEL knownShaderModels[] = {{D3D_SHADER_MODEL_6_2},
+                                                               {D3D_SHADER_MODEL_6_1},
+                                                               {D3D_SHADER_MODEL_6_0},
+                                                               {D3D_SHADER_MODEL_5_1}};
+        for (D3D12_FEATURE_DATA_SHADER_MODEL shaderModel : knownShaderModels) {
+            if (SUCCEEDED(adapter.GetDevice()->CheckFeatureSupport(
+                    D3D12_FEATURE_SHADER_MODEL, &shaderModel, sizeof(shaderModel)))) {
+                if (shaderModel.HighestShaderModel < D3D_SHADER_MODEL_5_1) {
+                    return DAWN_INTERNAL_ERROR(
+                        "Driver could not support Shader Model 5.1 or higher");
+                }
+
+                switch (shaderModel.HighestShaderModel) {
+                    case D3D_SHADER_MODEL_6_2: {
+                        info.shaderModel = 62;
+                        info.shaderProfiles[SingleShaderStage::Vertex] = L"vs_6_2";
+                        info.shaderProfiles[SingleShaderStage::Fragment] = L"ps_6_2";
+                        info.shaderProfiles[SingleShaderStage::Compute] = L"cs_6_2";
+
+                        D3D12_FEATURE_DATA_D3D12_OPTIONS4 featureData4 = {};
+                        if (SUCCEEDED(adapter.GetDevice()->CheckFeatureSupport(
+                                D3D12_FEATURE_D3D12_OPTIONS4, &featureData4,
+                                sizeof(featureData4)))) {
+                            info.supportsShaderFloat16 =
+                                shaderModel.HighestShaderModel >= D3D_SHADER_MODEL_6_2 &&
+                                featureData4.Native16BitShaderOpsSupported;
+                        }
+                        break;
+                    }
+                    case D3D_SHADER_MODEL_6_1: {
+                        info.shaderModel = 61;
+                        info.shaderProfiles[SingleShaderStage::Vertex] = L"vs_6_1";
+                        info.shaderProfiles[SingleShaderStage::Fragment] = L"ps_6_1";
+                        info.shaderProfiles[SingleShaderStage::Compute] = L"cs_6_1";
+                        break;
+                    }
+                    case D3D_SHADER_MODEL_6_0: {
+                        info.shaderModel = 60;
+                        info.shaderProfiles[SingleShaderStage::Vertex] = L"vs_6_0";
+                        info.shaderProfiles[SingleShaderStage::Fragment] = L"ps_6_0";
+                        info.shaderProfiles[SingleShaderStage::Compute] = L"cs_6_0";
+                        break;
+                    }
+                    default: {
+                        info.shaderModel = 51;
+                        info.shaderProfiles[SingleShaderStage::Vertex] = L"vs_5_1";
+                        info.shaderProfiles[SingleShaderStage::Fragment] = L"ps_5_1";
+                        info.shaderProfiles[SingleShaderStage::Compute] = L"cs_5_1";
+                        break;
+                    }
+                }
+
+                // Successfully find the maximum supported shader model.
+                break;
+            }
+        }
+
+        return std::move(info);
     }
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/D3D12Info.h b/src/dawn_native/d3d12/D3D12Info.h
index 78d3820..cce74b1 100644
--- a/src/dawn_native/d3d12/D3D12Info.h
+++ b/src/dawn_native/d3d12/D3D12Info.h
@@ -16,6 +16,7 @@
 #define DAWNNATIVE_D3D12_D3D12INFO_H_
 
 #include "dawn_native/Error.h"
+#include "dawn_native/PerStage.h"
 #include "dawn_native/d3d12/d3d12_platform.h"
 
 namespace dawn_native { namespace d3d12 {
@@ -26,6 +27,11 @@
         bool isUMA;
         uint32_t resourceHeapTier;
         bool supportsRenderPass;
+        bool supportsShaderFloat16;
+        // shaderModel indicates the maximum supported shader model, for example, the value 62
+        // indicates that current driver supports the maximum shader model is D3D_SHADER_MODEL_6_2.
+        uint32_t shaderModel;
+        PerStage<std::wstring> shaderProfiles;
     };
 
     ResultOrError<D3D12DeviceInfo> GatherDeviceInfo(const Adapter& adapter);
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 770a006..0bdbafa 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -31,7 +31,7 @@
 namespace dawn_native { namespace d3d12 {
 
     namespace {
-        std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags) {
+        std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
             std::vector<const wchar_t*> arguments;
             if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
                 arguments.push_back(L"/Gec");
@@ -70,9 +70,15 @@
             if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) {
                 arguments.push_back(L"/res_may_alias");
             }
-            // Enable FXC backward compatibility by setting the language version to 2016
-            arguments.push_back(L"-HV");
-            arguments.push_back(L"2016");
+
+            if (enable16BitTypes) {
+                // enable-16bit-types are only allowed in -HV 2018 (default)
+                arguments.push_back(L"/enable-16bit-types");
+            } else {
+                // Enable FXC backward compatibility by setting the language version to 2016
+                arguments.push_back(L"-HV");
+                arguments.push_back(L"2016");
+            }
             return arguments;
         }
 
@@ -98,7 +104,12 @@
             shaderc_spvc::CompileOptions options = GetCompileOptions();
 
             options.SetForceZeroInitializedVariables(true);
-            options.SetHLSLShaderModel(51);
+            if (GetDevice()->IsExtensionEnabled(Extension::ShaderFloat16)) {
+                options.SetHLSLShaderModel(ToBackend(GetDevice())->GetDeviceInfo().shaderModel);
+                options.SetHLSLEnable16BitTypes(true);
+            } else {
+                options.SetHLSLShaderModel(51);
+            }
             // PointCoord and PointSize are not supported in HLSL
             // TODO (hao.x.li@intel.com): The point_coord_compat and point_size_compat are
             // required temporarily for https://bugs.chromium.org/p/dawn/issues/detail?id=146,
@@ -138,7 +149,12 @@
             options_glsl.force_zero_initialized_variables = true;
 
             spirv_cross::CompilerHLSL::Options options_hlsl;
-            options_hlsl.shader_model = 51;
+            if (GetDevice()->IsExtensionEnabled(Extension::ShaderFloat16)) {
+                options_hlsl.shader_model = ToBackend(GetDevice())->GetDeviceInfo().shaderModel;
+                options_hlsl.enable_16bit_types = true;
+            } else {
+                options_hlsl.shader_model = 51;
+            }
             // PointCoord and PointSize are not supported in HLSL
             // TODO (hao.x.li@intel.com): The point_coord_compat and point_size_compat are
             // required temporarily for https://bugs.chromium.org/p/dawn/issues/detail?id=146,
@@ -210,19 +226,6 @@
                                                                    const std::string& hlslSource,
                                                                    const char* entryPoint,
                                                                    uint32_t compileFlags) {
-        const wchar_t* targetProfile = nullptr;
-        switch (stage) {
-            case SingleShaderStage::Vertex:
-                targetProfile = L"vs_6_0";
-                break;
-            case SingleShaderStage::Fragment:
-                targetProfile = L"ps_6_0";
-                break;
-            case SingleShaderStage::Compute:
-                targetProfile = L"cs_6_0";
-                break;
-        }
-
         IDxcLibrary* dxcLibrary;
         DAWN_TRY_ASSIGN(dxcLibrary, ToBackend(GetDevice())->GetOrCreateDxcLibrary());
 
@@ -237,13 +240,16 @@
         std::wstring entryPointW;
         DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
 
-        std::vector<const wchar_t*> arguments = GetDXCArguments(compileFlags);
+        std::vector<const wchar_t*> arguments = GetDXCArguments(
+            compileFlags, GetDevice()->IsExtensionEnabled(Extension::ShaderFloat16));
 
         ComPtr<IDxcOperationResult> result;
-        DAWN_TRY(CheckHRESULT(
-            dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), targetProfile,
-                                 arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
-            "DXC compile"));
+        DAWN_TRY(
+            CheckHRESULT(dxcCompiler->Compile(
+                             sourceBlob.Get(), nullptr, entryPointW.c_str(),
+                             ToBackend(GetDevice())->GetDeviceInfo().shaderProfiles[stage].c_str(),
+                             arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
+                         "DXC compile"));
 
         HRESULT hr;
         DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));