D3D12: Make HLSL generation per-entrypoint.

Also make the CompileShaderDXC/FXC standalone functions because
they don't use ShaderModule except to GetDevice().

Bug: dawn:216
Change-Id: Iaec9abe52ad4422891474086c3b973baf07046a5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28243
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 0c9fc6f..e742784 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -42,8 +42,12 @@
         compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
 
         ShaderModule* module = ToBackend(descriptor->computeStage.module);
+
+        // Note that the HLSL will always use entryPoint "main".
         std::string hlslSource;
-        DAWN_TRY_ASSIGN(hlslSource, module->GetHLSLSource(ToBackend(GetLayout())));
+        DAWN_TRY_ASSIGN(hlslSource, module->TranslateToHLSL(descriptor->computeStage.entryPoint,
+                                                            SingleShaderStage::Compute,
+                                                            ToBackend(GetLayout())));
 
         D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
         d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
@@ -52,18 +56,14 @@
         ComPtr<ID3DBlob> compiledFXCShader;
 
         if (device->IsToggleEnabled(Toggle::UseDXC)) {
-            DAWN_TRY_ASSIGN(
-                compiledDXCShader,
-                module->CompileShaderDXC(SingleShaderStage::Compute, hlslSource,
-                                         descriptor->computeStage.entryPoint, compileFlags));
+            DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(device, SingleShaderStage::Compute,
+                                                                hlslSource, "main", compileFlags));
 
             d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
             d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
         } else {
-            DAWN_TRY_ASSIGN(
-                compiledFXCShader,
-                module->CompileShaderFXC(SingleShaderStage::Compute, hlslSource,
-                                         descriptor->computeStage.entryPoint, compileFlags));
+            DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(device, SingleShaderStage::Compute,
+                                                                hlslSource, "main", compileFlags));
             d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
             d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
         }
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index ada9fc6..52ec371 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -327,20 +327,21 @@
 
         wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
         for (auto stage : IterateStages(renderStages)) {
+            // Note that the HLSL entryPoint will always be "main".
             std::string hlslSource;
-            DAWN_TRY_ASSIGN(hlslSource, modules[stage]->GetHLSLSource(ToBackend(GetLayout())));
+            DAWN_TRY_ASSIGN(hlslSource,
+                            modules[stage]->TranslateToHLSL(GetStage(stage).entryPoint.c_str(),
+                                                            stage, ToBackend(GetLayout())));
 
             if (device->IsToggleEnabled(Toggle::UseDXC)) {
                 DAWN_TRY_ASSIGN(compiledDXCShader[stage],
-                                modules[stage]->CompileShaderDXC(stage, hlslSource,
-                                                                 entryPoints[stage], compileFlags));
+                                CompileShaderDXC(device, stage, hlslSource, "main", compileFlags));
 
                 shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
                 shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
             } else {
                 DAWN_TRY_ASSIGN(compiledFXCShader[stage],
-                                modules[stage]->CompileShaderFXC(stage, hlslSource,
-                                                                 entryPoints[stage], compileFlags));
+                                CompileShaderFXC(device, stage, hlslSource, "main", compileFlags));
 
                 shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
                 shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index bb9dff1..1c9fbb5 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -17,6 +17,7 @@
 #include "common/Assert.h"
 #include "common/BitSetIterator.h"
 #include "common/Log.h"
+#include "dawn_native/SpirvUtils.h"
 #include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
 #include "dawn_native/d3d12/D3D12Error.h"
 #include "dawn_native/d3d12/DeviceD3D12.h"
@@ -84,11 +85,90 @@
 
     }  // anonymous namespace
 
+    ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
+                                                     SingleShaderStage stage,
+                                                     const std::string& hlslSource,
+                                                     const char* entryPoint,
+                                                     uint32_t compileFlags) {
+        IDxcLibrary* dxcLibrary;
+        DAWN_TRY_ASSIGN(dxcLibrary, device->GetOrCreateDxcLibrary());
+
+        ComPtr<IDxcBlobEncoding> sourceBlob;
+        DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
+                                  hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
+                              "DXC create blob"));
+
+        IDxcCompiler* dxcCompiler;
+        DAWN_TRY_ASSIGN(dxcCompiler, device->GetOrCreateDxcCompiler());
+
+        std::wstring entryPointW;
+        DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
+
+        std::vector<const wchar_t*> arguments =
+            GetDXCArguments(compileFlags, device->IsExtensionEnabled(Extension::ShaderFloat16));
+
+        ComPtr<IDxcOperationResult> result;
+        DAWN_TRY(CheckHRESULT(
+            dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
+                                 device->GetDeviceInfo().shaderProfiles[stage].c_str(),
+                                 arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
+            "DXC compile"));
+
+        HRESULT hr;
+        DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
+
+        if (FAILED(hr)) {
+            ComPtr<IDxcBlobEncoding> errors;
+            DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
+
+            std::string message = std::string("DXC compile failed with ") +
+                                  static_cast<char*>(errors->GetBufferPointer());
+            return DAWN_INTERNAL_ERROR(message);
+        }
+
+        ComPtr<IDxcBlob> compiledShader;
+        DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
+        return std::move(compiledShader);
+    }
+
+    ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
+                                                     SingleShaderStage stage,
+                                                     const std::string& hlslSource,
+                                                     const char* entryPoint,
+                                                     uint32_t compileFlags) {
+        const char* targetProfile = nullptr;
+        switch (stage) {
+            case SingleShaderStage::Vertex:
+                targetProfile = "vs_5_1";
+                break;
+            case SingleShaderStage::Fragment:
+                targetProfile = "ps_5_1";
+                break;
+            case SingleShaderStage::Compute:
+                targetProfile = "cs_5_1";
+                break;
+        }
+
+        ComPtr<ID3DBlob> compiledShader;
+        ComPtr<ID3DBlob> errors;
+
+        const PlatformFunctions* functions = device->GetFunctions();
+        if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
+                                         nullptr, entryPoint, targetProfile, compileFlags, 0,
+                                         &compiledShader, &errors))) {
+            std::string message = std::string("D3D compile failed with ") +
+                                  static_cast<char*>(errors->GetBufferPointer());
+            return DAWN_INTERNAL_ERROR(message);
+        }
+
+        return std::move(compiledShader);
+    }
+
     // static
     ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
                                                       const ShaderModuleDescriptor* descriptor) {
         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
-        DAWN_TRY(module->Initialize());
+        DAWN_TRY(module->InitializeBase());
         return module.Detach();
     }
 
@@ -96,14 +176,10 @@
         : ShaderModuleBase(device, descriptor) {
     }
 
-    MaybeError ShaderModule::Initialize() {
-        return InitializeBase();
-    }
-
-    ResultOrError<std::string> ShaderModule::GetHLSLSource(PipelineLayout* layout) {
+    ResultOrError<std::string> ShaderModule::TranslateToHLSL(const char* entryPointName,
+                                                             SingleShaderStage stage,
+                                                             PipelineLayout* layout) const {
         ASSERT(!IsError());
-        const std::vector<uint32_t>& spirv = GetSpirv();
-
         // If these options are changed, the values in DawnSPIRVCrossHLSLFastFuzzer.cpp need to
         // be updated.
         spirv_cross::CompilerGLSL::Options options_glsl;
@@ -127,12 +203,13 @@
         options_hlsl.point_size_compat = true;
         options_hlsl.nonwritable_uav_texture_as_srv = true;
 
-        spirv_cross::CompilerHLSL compiler(spirv);
+        spirv_cross::CompilerHLSL compiler(GetSpirv());
         compiler.set_common_options(options_glsl);
         compiler.set_hlsl_options(options_hlsl);
+        compiler.set_entry_point(entryPointName, ShaderStageToExecutionModel(stage));
 
         const EntryPointMetadata::BindingInfo& moduleBindingInfo =
-            GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings;
+            GetEntryPoint(entryPointName, stage).bindings;
 
         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
             const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
@@ -158,85 +235,8 @@
                 }
             }
         }
+
         return compiler.compile();
     }
 
-    ResultOrError<ComPtr<IDxcBlob>> ShaderModule::CompileShaderDXC(SingleShaderStage stage,
-                                                                   const std::string& hlslSource,
-                                                                   const char* entryPoint,
-                                                                   uint32_t compileFlags) {
-        IDxcLibrary* dxcLibrary;
-        DAWN_TRY_ASSIGN(dxcLibrary, ToBackend(GetDevice())->GetOrCreateDxcLibrary());
-
-        ComPtr<IDxcBlobEncoding> sourceBlob;
-        DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
-                                  hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
-                              "DXC create blob"));
-
-        IDxcCompiler* dxcCompiler;
-        DAWN_TRY_ASSIGN(dxcCompiler, ToBackend(GetDevice())->GetOrCreateDxcCompiler());
-
-        std::wstring entryPointW;
-        DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
-
-        std::vector<const wchar_t*> arguments = GetDXCArguments(
-            compileFlags, GetDevice()->IsExtensionEnabled(Extension::ShaderFloat16));
-
-        ComPtr<IDxcOperationResult> result;
-        DAWN_TRY(
-            CheckHRESULT(dxcCompiler->Compile(
-                             sourceBlob.Get(), nullptr, entryPointW.c_str(),
-                             ToBackend(GetDevice())->GetDeviceInfo().shaderProfiles[stage].c_str(),
-                             arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
-                         "DXC compile"));
-
-        HRESULT hr;
-        DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
-
-        if (FAILED(hr)) {
-            ComPtr<IDxcBlobEncoding> errors;
-            DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
-
-            std::string message = std::string("DXC compile failed with ") +
-                                  static_cast<char*>(errors->GetBufferPointer());
-            return DAWN_INTERNAL_ERROR(message);
-        }
-
-        ComPtr<IDxcBlob> compiledShader;
-        DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
-        return std::move(compiledShader);
-    }
-
-    ResultOrError<ComPtr<ID3DBlob>> ShaderModule::CompileShaderFXC(SingleShaderStage stage,
-                                                                   const std::string& hlslSource,
-                                                                   const char* entryPoint,
-                                                                   uint32_t compileFlags) {
-        const char* targetProfile = nullptr;
-        switch (stage) {
-            case SingleShaderStage::Vertex:
-                targetProfile = "vs_5_1";
-                break;
-            case SingleShaderStage::Fragment:
-                targetProfile = "ps_5_1";
-                break;
-            case SingleShaderStage::Compute:
-                targetProfile = "cs_5_1";
-                break;
-        }
-
-        ComPtr<ID3DBlob> compiledShader;
-        ComPtr<ID3DBlob> errors;
-
-        const PlatformFunctions* functions = ToBackend(GetDevice())->GetFunctions();
-        if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
-                                         nullptr, entryPoint, targetProfile, compileFlags, 0,
-                                         &compiledShader, &errors))) {
-            std::string message = std::string("D3D compile failed with ") +
-                                  static_cast<char*>(errors->GetBufferPointer());
-            return DAWN_INTERNAL_ERROR(message);
-        }
-
-        return std::move(compiledShader);
-    }
-
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index c64e8ce..554c365 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -24,26 +24,29 @@
     class Device;
     class PipelineLayout;
 
+    ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
+                                                     SingleShaderStage stage,
+                                                     const std::string& hlslSource,
+                                                     const char* entryPoint,
+                                                     uint32_t compileFlags);
+    ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
+                                                     SingleShaderStage stage,
+                                                     const std::string& hlslSource,
+                                                     const char* entryPoint,
+                                                     uint32_t compileFlags);
+
     class ShaderModule final : public ShaderModuleBase {
       public:
         static ResultOrError<ShaderModule*> Create(Device* device,
                                                    const ShaderModuleDescriptor* descriptor);
 
-        ResultOrError<std::string> GetHLSLSource(PipelineLayout* layout);
-
-        ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(SingleShaderStage stage,
-                                                         const std::string& hlslSource,
-                                                         const char* entryPoint,
-                                                         uint32_t compileFlags);
-        ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(SingleShaderStage stage,
-                                                         const std::string& hlslSource,
-                                                         const char* entryPoint,
-                                                         uint32_t compileFlags);
+        ResultOrError<std::string> TranslateToHLSL(const char* entryPointName,
+                                                   SingleShaderStage stage,
+                                                   PipelineLayout* layout) const;
 
       private:
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
-        MaybeError Initialize();
     };
 
 }}  // namespace dawn_native::d3d12