D3D12: Implement UseDxc toggle to use DXC for HLSL compilation

Factor out common shader compilation logic to ShaderModuleD3D12
used by both RenderPipeline and ComputePipeline, and implement
a new compilation path using DXC when UseDXC toggle is enabled

Bug: dawn:402

Change-Id: I67d3ae0aecee11634af917735456ddbe10b3d86a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/21840
Commit-Queue: Hugo Amiard <hugo.amiard@laposte.net>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/d3d12/BackendD3D12.cpp b/src/dawn_native/d3d12/BackendD3D12.cpp
index 44478bb..46dc1b6 100644
--- a/src/dawn_native/d3d12/BackendD3D12.cpp
+++ b/src/dawn_native/d3d12/BackendD3D12.cpp
@@ -17,6 +17,7 @@
 #include "dawn_native/D3D12Backend.h"
 #include "dawn_native/Instance.h"
 #include "dawn_native/d3d12/AdapterD3D12.h"
+#include "dawn_native/d3d12/D3D12Error.h"
 #include "dawn_native/d3d12/PlatformFunctions.h"
 
 namespace dawn_native { namespace d3d12 {
@@ -92,6 +93,26 @@
         return mFactory;
     }
 
+    ResultOrError<IDxcLibrary*> Backend::GetOrCreateDxcLibrary() {
+        if (mDxcLibrary == nullptr) {
+            DAWN_TRY(CheckHRESULT(
+                mFunctions->dxcCreateInstance(CLSID_DxcLibrary, IID_PPV_ARGS(&mDxcLibrary)),
+                "DXC create library"));
+            ASSERT(mDxcLibrary != nullptr);
+        }
+        return mDxcLibrary.Get();
+    }
+
+    ResultOrError<IDxcCompiler*> Backend::GetOrCreateDxcCompiler() {
+        if (mDxcCompiler == nullptr) {
+            DAWN_TRY(CheckHRESULT(
+                mFunctions->dxcCreateInstance(CLSID_DxcCompiler, IID_PPV_ARGS(&mDxcCompiler)),
+                "DXC create compiler"));
+            ASSERT(mDxcCompiler != nullptr);
+        }
+        return mDxcCompiler.Get();
+    }
+
     const PlatformFunctions* Backend::GetFunctions() const {
         return mFunctions.get();
     }
diff --git a/src/dawn_native/d3d12/BackendD3D12.h b/src/dawn_native/d3d12/BackendD3D12.h
index 3161048..27ef1d1 100644
--- a/src/dawn_native/d3d12/BackendD3D12.h
+++ b/src/dawn_native/d3d12/BackendD3D12.h
@@ -30,6 +30,8 @@
         MaybeError Initialize();
 
         ComPtr<IDXGIFactory4> GetFactory() const;
+        ResultOrError<IDxcLibrary*> GetOrCreateDxcLibrary();
+        ResultOrError<IDxcCompiler*> GetOrCreateDxcCompiler();
         const PlatformFunctions* GetFunctions() const;
 
         std::vector<std::unique_ptr<AdapterBase>> DiscoverDefaultAdapters() override;
@@ -39,6 +41,8 @@
         // the D3D12 DLLs are unloaded before we are done using them.
         std::unique_ptr<PlatformFunctions> mFunctions;
         ComPtr<IDXGIFactory4> mFactory;
+        ComPtr<IDxcLibrary> mDxcLibrary;
+        ComPtr<IDxcCompiler> mDxcCompiler;
     };
 
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 940a3e9..0c9fc6f 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -19,6 +19,7 @@
 #include "dawn_native/d3d12/PipelineLayoutD3D12.h"
 #include "dawn_native/d3d12/PlatformFunctions.h"
 #include "dawn_native/d3d12/ShaderModuleD3D12.h"
+#include "dawn_native/d3d12/UtilsD3D12.h"
 
 namespace dawn_native { namespace d3d12 {
 
@@ -44,21 +45,28 @@
         std::string hlslSource;
         DAWN_TRY_ASSIGN(hlslSource, module->GetHLSLSource(ToBackend(GetLayout())));
 
-        ComPtr<ID3DBlob> compiledShader;
-        ComPtr<ID3DBlob> errors;
-
-        const PlatformFunctions* functions = device->GetFunctions();
-        if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
-                                         nullptr, descriptor->computeStage.entryPoint, "cs_5_1",
-                                         compileFlags, 0, &compiledShader, &errors))) {
-            printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer()));
-            ASSERT(false);
-        }
-
         D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
         d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
-        d3dDesc.CS.pShaderBytecode = compiledShader->GetBufferPointer();
-        d3dDesc.CS.BytecodeLength = compiledShader->GetBufferSize();
+
+        ComPtr<IDxcBlob> compiledDXCShader;
+        ComPtr<ID3DBlob> compiledFXCShader;
+
+        if (device->IsToggleEnabled(Toggle::UseDXC)) {
+            DAWN_TRY_ASSIGN(
+                compiledDXCShader,
+                module->CompileShaderDXC(SingleShaderStage::Compute, hlslSource,
+                                         descriptor->computeStage.entryPoint, compileFlags));
+
+            d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
+            d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
+        } else {
+            DAWN_TRY_ASSIGN(
+                compiledFXCShader,
+                module->CompileShaderFXC(SingleShaderStage::Compute, hlslSource,
+                                         descriptor->computeStage.entryPoint, compileFlags));
+            d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
+            d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
+        }
 
         device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
                                                              IID_PPV_ARGS(&mPipelineState));
diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp
index 21f4aaf..9f4d887 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn_native/d3d12/DeviceD3D12.cpp
@@ -177,6 +177,14 @@
         return ToBackend(GetAdapter())->GetBackend()->GetFactory();
     }
 
+    ResultOrError<IDxcLibrary*> Device::GetOrCreateDxcLibrary() const {
+        return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcLibrary();
+    }
+
+    ResultOrError<IDxcCompiler*> Device::GetOrCreateDxcCompiler() const {
+        return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcCompiler();
+    }
+
     const PlatformFunctions* Device::GetFunctions() const {
         return ToBackend(GetAdapter())->GetBackend()->GetFunctions();
     }
diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h
index 88011b9..c55f8f7 100644
--- a/src/dawn_native/d3d12/DeviceD3D12.h
+++ b/src/dawn_native/d3d12/DeviceD3D12.h
@@ -72,6 +72,8 @@
 
         const PlatformFunctions* GetFunctions() const;
         ComPtr<IDXGIFactory4> GetFactory() const;
+        ResultOrError<IDxcLibrary*> GetOrCreateDxcLibrary() const;
+        ResultOrError<IDxcCompiler*> GetOrCreateDxcCompiler() const;
 
         ResultOrError<CommandRecordingContext*> GetPendingCommandContext();
 
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index 82fdedf..f629e76 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -310,50 +310,41 @@
 
         D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {};
 
-        PerStage<ComPtr<ID3DBlob>> compiledShader;
-        ComPtr<ID3DBlob> errors;
+        PerStage<const char*> entryPoints;
+        entryPoints[SingleShaderStage::Vertex] = descriptor->vertexStage.entryPoint;
+        entryPoints[SingleShaderStage::Fragment] = descriptor->fragmentStage->entryPoint;
+
+        PerStage<ShaderModule*> modules;
+        modules[SingleShaderStage::Vertex] = ToBackend(descriptor->vertexStage.module);
+        modules[SingleShaderStage::Fragment] = ToBackend(descriptor->fragmentStage->module);
+
+        PerStage<D3D12_SHADER_BYTECODE*> shaders;
+        shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
+        shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
+
+        PerStage<ComPtr<ID3DBlob>> compiledFXCShader;
+        PerStage<ComPtr<IDxcBlob>> compiledDXCShader;
 
         wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
         for (auto stage : IterateStages(renderStages)) {
-            ShaderModule* module = nullptr;
-            const char* entryPoint = nullptr;
-            const char* compileTarget = nullptr;
-            D3D12_SHADER_BYTECODE* shader = nullptr;
-            switch (stage) {
-                case SingleShaderStage::Vertex:
-                    module = ToBackend(descriptor->vertexStage.module);
-                    entryPoint = descriptor->vertexStage.entryPoint;
-                    shader = &descriptorD3D12.VS;
-                    compileTarget = "vs_5_1";
-                    break;
-                case SingleShaderStage::Fragment:
-                    module = ToBackend(descriptor->fragmentStage->module);
-                    entryPoint = descriptor->fragmentStage->entryPoint;
-                    shader = &descriptorD3D12.PS;
-                    compileTarget = "ps_5_1";
-                    break;
-                default:
-                    UNREACHABLE();
-                    break;
-            }
 
             std::string hlslSource;
-            DAWN_TRY_ASSIGN(hlslSource, module->GetHLSLSource(ToBackend(GetLayout())));
+            DAWN_TRY_ASSIGN(hlslSource, modules[stage]->GetHLSLSource(ToBackend(GetLayout())));
 
-            const PlatformFunctions* functions = device->GetFunctions();
-            MaybeError error = CheckHRESULT(
-                functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
-                                      nullptr, entryPoint, compileTarget, compileFlags, 0,
-                                      &compiledShader[stage], &errors),
-                "D3DCompile");
-            if (error.IsError()) {
-                dawn::WarningLog() << reinterpret_cast<char*>(errors->GetBufferPointer());
-                DAWN_TRY(std::move(error));
-            }
+            if (device->IsToggleEnabled(Toggle::UseDXC)) {
+                DAWN_TRY_ASSIGN(compiledDXCShader[stage],
+                                modules[stage]->CompileShaderDXC(stage, hlslSource,
+                                                                 entryPoints[stage], compileFlags));
 
-            if (shader != nullptr) {
-                shader->pShaderBytecode = compiledShader[stage]->GetBufferPointer();
-                shader->BytecodeLength = compiledShader[stage]->GetBufferSize();
+                shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
+                shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
+            } else {
+                DAWN_TRY_ASSIGN(compiledFXCShader[stage],
+                                modules[stage]->CompileShaderFXC(stage, hlslSource,
+                                                                 entryPoints[stage], compileFlags));
+
+                shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
+                shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();
             }
         }
 
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index d6410aa..896547c 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -16,14 +16,68 @@
 
 #include "common/Assert.h"
 #include "common/BitSetIterator.h"
+#include "common/Log.h"
 #include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
+#include "dawn_native/d3d12/D3D12Error.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
 #include "dawn_native/d3d12/PipelineLayoutD3D12.h"
+#include "dawn_native/d3d12/PlatformFunctions.h"
+#include "dawn_native/d3d12/UtilsD3D12.h"
+
+#include <d3dcompiler.h>
 
 #include <spirv_hlsl.hpp>
 
 namespace dawn_native { namespace d3d12 {
 
+    namespace {
+        std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags) {
+            std::vector<const wchar_t*> arguments;
+            if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
+                arguments.push_back(L"/Gec");
+            }
+            if (compileFlags & D3DCOMPILE_IEEE_STRICTNESS) {
+                arguments.push_back(L"/Gis");
+            }
+            if (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) {
+                switch (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) {
+                    case D3DCOMPILE_OPTIMIZATION_LEVEL0:
+                        arguments.push_back(L"/O0");
+                        break;
+                    case D3DCOMPILE_OPTIMIZATION_LEVEL2:
+                        arguments.push_back(L"/O2");
+                        break;
+                    case D3DCOMPILE_OPTIMIZATION_LEVEL3:
+                        arguments.push_back(L"/O3");
+                        break;
+                }
+            }
+            if (compileFlags & D3DCOMPILE_DEBUG) {
+                arguments.push_back(L"/Zi");
+            }
+            if (compileFlags & D3DCOMPILE_PACK_MATRIX_ROW_MAJOR) {
+                arguments.push_back(L"/Zpr");
+            }
+            if (compileFlags & D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR) {
+                arguments.push_back(L"/Zpc");
+            }
+            if (compileFlags & D3DCOMPILE_AVOID_FLOW_CONTROL) {
+                arguments.push_back(L"/Gfa");
+            }
+            if (compileFlags & D3DCOMPILE_PREFER_FLOW_CONTROL) {
+                arguments.push_back(L"/Gfp");
+            }
+            if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) {
+                arguments.push_back(L"/res_may_alias");
+            }
+            // Enable FXC backward compatibility by setting the language version to 2016
+            arguments.push_back(L"-HV");
+            arguments.push_back(L"2016");
+            return arguments;
+        }
+
+    }  // anonymous namespace
+
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
                                                       const ShaderModuleDescriptor* descriptor) {
@@ -134,4 +188,92 @@
         }
     }
 
+    ResultOrError<ComPtr<IDxcBlob>> ShaderModule::CompileShaderDXC(SingleShaderStage stage,
+                                                                   const std::string& hlslSource,
+                                                                   const char* entryPoint,
+                                                                   uint32_t compileFlags) {
+        const wchar_t* targetProfile = nullptr;
+        switch (stage) {
+            case SingleShaderStage::Vertex:
+                targetProfile = L"vs_6_0";
+                break;
+            case SingleShaderStage::Fragment:
+                targetProfile = L"ps_6_0";
+                break;
+            case SingleShaderStage::Compute:
+                targetProfile = L"cs_6_0";
+                break;
+        }
+
+        IDxcLibrary* dxcLibrary;
+        DAWN_TRY_ASSIGN(dxcLibrary, ToBackend(GetDevice())->GetOrCreateDxcLibrary());
+
+        ComPtr<IDxcBlobEncoding> sourceBlob;
+        DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
+                                  hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
+                              "DXC create blob"));
+
+        IDxcCompiler* dxcCompiler;
+        DAWN_TRY_ASSIGN(dxcCompiler, ToBackend(GetDevice())->GetOrCreateDxcCompiler());
+
+        std::wstring entryPointW;
+        DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
+
+        std::vector<const wchar_t*> arguments = GetDXCArguments(compileFlags);
+
+        ComPtr<IDxcOperationResult> result;
+        DAWN_TRY(CheckHRESULT(
+            dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), targetProfile,
+                                 arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
+            "DXC compile"));
+
+        HRESULT hr;
+        DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
+
+        if (FAILED(hr)) {
+            ComPtr<IDxcBlobEncoding> errors;
+            DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
+
+            std::string message = std::string("DXC compile failed with ") +
+                                  static_cast<char*>(errors->GetBufferPointer());
+            return DAWN_INTERNAL_ERROR(message);
+        }
+
+        ComPtr<IDxcBlob> compiledShader;
+        DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
+        return std::move(compiledShader);
+    }
+
+    ResultOrError<ComPtr<ID3DBlob>> ShaderModule::CompileShaderFXC(SingleShaderStage stage,
+                                                                   const std::string& hlslSource,
+                                                                   const char* entryPoint,
+                                                                   uint32_t compileFlags) {
+        const char* targetProfile = nullptr;
+        switch (stage) {
+            case SingleShaderStage::Vertex:
+                targetProfile = "vs_5_1";
+                break;
+            case SingleShaderStage::Fragment:
+                targetProfile = "ps_5_1";
+                break;
+            case SingleShaderStage::Compute:
+                targetProfile = "cs_5_1";
+                break;
+        }
+
+        ComPtr<ID3DBlob> compiledShader;
+        ComPtr<ID3DBlob> errors;
+
+        const PlatformFunctions* functions = ToBackend(GetDevice())->GetFunctions();
+        if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
+                                         nullptr, entryPoint, targetProfile, compileFlags, 0,
+                                         &compiledShader, &errors))) {
+            std::string message = std::string("D3D compile failed with ") +
+                                  static_cast<char*>(errors->GetBufferPointer());
+            return DAWN_INTERNAL_ERROR(message);
+        }
+
+        return std::move(compiledShader);
+    }
+
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index e34d881..c64e8ce 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -17,6 +17,8 @@
 
 #include "dawn_native/ShaderModule.h"
 
+#include "dawn_native/d3d12/d3d12_platform.h"
+
 namespace dawn_native { namespace d3d12 {
 
     class Device;
@@ -29,6 +31,15 @@
 
         ResultOrError<std::string> GetHLSLSource(PipelineLayout* layout);
 
+        ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(SingleShaderStage stage,
+                                                         const std::string& hlslSource,
+                                                         const char* entryPoint,
+                                                         uint32_t compileFlags);
+        ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(SingleShaderStage stage,
+                                                         const std::string& hlslSource,
+                                                         const char* entryPoint,
+                                                         uint32_t compileFlags);
+
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
diff --git a/src/dawn_native/d3d12/UtilsD3D12.cpp b/src/dawn_native/d3d12/UtilsD3D12.cpp
index 04e6669..d8c20ef 100644
--- a/src/dawn_native/d3d12/UtilsD3D12.cpp
+++ b/src/dawn_native/d3d12/UtilsD3D12.cpp
@@ -16,8 +16,29 @@
 
 #include "common/Assert.h"
 
+#include <stringapiset.h>
+
 namespace dawn_native { namespace d3d12 {
 
+    ResultOrError<std::wstring> ConvertStringToWstring(const char* str) {
+        size_t len = strlen(str);
+        if (len == 0) {
+            return std::wstring();
+        }
+        int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, nullptr, 0);
+        if (numChars == 0) {
+            return DAWN_INTERNAL_ERROR("Failed to convert string to wide string");
+        }
+        std::wstring result;
+        result.resize(numChars);
+        int numConvertedChars =
+            MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, &result[0], numChars);
+        if (numConvertedChars != numChars) {
+            return DAWN_INTERNAL_ERROR("Failed to convert string to wide string");
+        }
+        return std::move(result);
+    }
+
     D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func) {
         switch (func) {
             case wgpu::CompareFunction::Never:
diff --git a/src/dawn_native/d3d12/UtilsD3D12.h b/src/dawn_native/d3d12/UtilsD3D12.h
index 36a5abe..d1559e7 100644
--- a/src/dawn_native/d3d12/UtilsD3D12.h
+++ b/src/dawn_native/d3d12/UtilsD3D12.h
@@ -23,6 +23,8 @@
 
 namespace dawn_native { namespace d3d12 {
 
+    ResultOrError<std::wstring> ConvertStringToWstring(const char* str);
+
     D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func);
 
     D3D12_TEXTURE_COPY_LOCATION ComputeTextureCopyLocationForTexture(const Texture* texture,
diff --git a/src/dawn_native/d3d12/d3d12_platform.h b/src/dawn_native/d3d12/d3d12_platform.h
index a64486c..1962468 100644
--- a/src/dawn_native/d3d12/d3d12_platform.h
+++ b/src/dawn_native/d3d12/d3d12_platform.h
@@ -18,6 +18,7 @@
 #include <d3d11_2.h>
 #include <d3d11on12.h>
 #include <d3d12.h>
+#include <dxcapi.h>
 #include <dxgi1_4.h>
 #include <wrl.h>