D3D12: implement WGSL pipeline cache key generation

Since spirv_to_dxil does not generate HLSL, to support pipeline caching,
we need to generate a cache key from the WGSL instead.

A new type, ShaderCompilationRequest, is added to isolate the
compilation inputs to help ensure that the cache key contains all
relevant information.

Bug: dawn:1103
Change-Id: Ic2f09326dc3ac254cecf35098dcfe95aa396796f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/61160
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Michael Tang <tangm@microsoft.com>
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 718e0c1..d601325 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -28,10 +28,227 @@
 #include <d3dcompiler.h>
 
 #include <tint/tint.h>
+#include <map>
+#include <sstream>
+#include <unordered_map>
 
 namespace dawn_native { namespace d3d12 {
 
     namespace {
+        ResultOrError<uint64_t> GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator) {
+            ComPtr<IDxcVersionInfo> versionInfo;
+            DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),
+                                  "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo"));
+
+            uint32_t compilerMajor, compilerMinor;
+            DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor),
+                                  "IDxcVersionInfo::GetVersion"));
+
+            // Pack both into a single version number.
+            return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor;
+        }
+
+        uint64_t GetD3DCompilerVersion() {
+            return D3D_COMPILER_VERSION;
+        }
+
+        struct CompareBindingPoint {
+            constexpr bool operator()(const tint::transform::BindingPoint& lhs,
+                                      const tint::transform::BindingPoint& rhs) const {
+                if (lhs.group != rhs.group) {
+                    return lhs.group < rhs.group;
+                } else {
+                    return lhs.binding < rhs.binding;
+                }
+            }
+        };
+
+        void Serialize(std::stringstream& output, const tint::ast::Access& access) {
+            output << access;
+        }
+
+        void Serialize(std::stringstream& output,
+                       const tint::transform::BindingPoint& binding_point) {
+            output << "(BindingPoint";
+            output << " group=" << binding_point.group;
+            output << " binding=" << binding_point.binding;
+            output << ")";
+        }
+
+        template <typename T>
+        void Serialize(std::stringstream& output,
+                       const std::unordered_map<tint::transform::BindingPoint, T>& map) {
+            output << "(map";
+
+            std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(),
+                                                                                   map.end());
+            for (auto& entry : sorted) {
+                output << " ";
+                Serialize(output, entry.first);
+                output << "=";
+                Serialize(output, entry.second);
+            }
+            output << ")";
+        }
+
+        // The inputs to a shader compilation. These have been intentionally isolated from the
+        // device to help ensure that the pipeline cache key contains all inputs for compilation.
+        struct ShaderCompilationRequest {
+            enum Compiler { FXC, DXC };
+
+            // Common inputs
+            Compiler compiler;
+            const tint::Program* program;
+            const char* entryPointName;
+            SingleShaderStage stage;
+            uint32_t compileFlags;
+            bool disableSymbolRenaming;
+            tint::transform::BindingRemapper::BindingPoints bindingPoints;
+            tint::transform::BindingRemapper::AccessControls accessControls;
+            bool isRobustnessEnabled;
+
+            // FXC/DXC common inputs
+            bool disableWorkgroupInit;
+
+            // FXC inputs
+            uint64_t fxcVersion;
+
+            // DXC inputs
+            uint64_t dxcVersion;
+            const D3D12DeviceInfo* deviceInfo;
+            bool hasShaderFloat16Extension;
+
+            static ResultOrError<ShaderCompilationRequest> Create(
+                const char* entryPointName,
+                SingleShaderStage stage,
+                const PipelineLayout* layout,
+                uint32_t compileFlags,
+                const Device* device,
+                const tint::Program* program,
+                const BindingInfoArray& moduleBindingInfo) {
+                Compiler compiler;
+                uint64_t dxcVersion = 0;
+                if (device->IsToggleEnabled(Toggle::UseDXC)) {
+                    compiler = Compiler::DXC;
+                    DAWN_TRY_ASSIGN(dxcVersion, GetDXCompilerVersion(device->GetDxcValidator()));
+                } else {
+                    compiler = Compiler::FXC;
+                }
+
+                using tint::transform::BindingPoint;
+                using tint::transform::BindingRemapper;
+
+                BindingRemapper::BindingPoints bindingPoints;
+                BindingRemapper::AccessControls accessControls;
+
+                // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify the
+                // Tint AST to make the "bindings" decoration match the offset chosen by
+                // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
+                // assigned to each interface variable.
+                for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
+                    const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
+                    const auto& groupBindingInfo = moduleBindingInfo[group];
+                    for (const auto& it : groupBindingInfo) {
+                        BindingNumber binding = it.first;
+                        auto const& bindingInfo = it.second;
+                        BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
+                        BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
+                                                     static_cast<uint32_t>(binding)};
+                        BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
+                                                     bgl->GetShaderRegister(bindingIndex)};
+                        if (srcBindingPoint != dstBindingPoint) {
+                            bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
+                        }
+
+                        // Declaring a read-only storage buffer in HLSL but specifying a storage
+                        // buffer in the BGL produces the wrong output. Force read-only storage
+                        // buffer bindings to be treated as UAV instead of SRV. Internal storage
+                        // buffer is a storage buffer used in the internal pipeline.
+                        const bool forceStorageBufferAsUAV =
+                            (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
+                             (bgl->GetBindingInfo(bindingIndex).buffer.type ==
+                                  wgpu::BufferBindingType::Storage ||
+                              bgl->GetBindingInfo(bindingIndex).buffer.type ==
+                                  kInternalStorageBufferBinding));
+                        if (forceStorageBufferAsUAV) {
+                            accessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
+                        }
+                    }
+                }
+
+                ShaderCompilationRequest request;
+                request.compiler = compiler;
+                request.program = program;
+                request.entryPointName = entryPointName;
+                request.stage = stage;
+                request.compileFlags = compileFlags;
+                request.disableSymbolRenaming =
+                    device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
+                request.bindingPoints = std::move(bindingPoints);
+                request.accessControls = std::move(accessControls);
+                request.isRobustnessEnabled = device->IsRobustnessEnabled();
+                request.disableWorkgroupInit =
+                    device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
+                request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
+                request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
+                request.deviceInfo = &device->GetDeviceInfo();
+                request.hasShaderFloat16Extension =
+                    device->IsExtensionEnabled(Extension::ShaderFloat16);
+                return std::move(request);
+            }
+
+            ResultOrError<PersistentCacheKey> CreateCacheKey() const {
+                // Generate the WGSL from the Tint program so it's normalized.
+                // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more
+                // compact representation.
+                auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{});
+                if (!result.success) {
+                    std::ostringstream errorStream;
+                    errorStream << "Tint WGSL failure:" << std::endl;
+                    errorStream << "Generator: " << result.error << std::endl;
+                    return DAWN_INTERNAL_ERROR(errorStream.str().c_str());
+                }
+
+                std::stringstream stream;
+
+                // Prefix the key with the type to avoid collisions from another type that could
+                // have the same key.
+                stream << static_cast<uint32_t>(PersistentKeyType::Shader);
+                stream << "\n";
+
+                stream << result.wgsl.length();
+                stream << "\n";
+
+                stream << result.wgsl;
+                stream << "\n";
+
+                stream << "(ShaderCompilationRequest";
+                stream << " compiler=" << compiler;
+                stream << " entryPointName=" << entryPointName;
+                stream << " stage=" << uint32_t(stage);
+                stream << " compileFlags=" << compileFlags;
+                stream << " disableSymbolRenaming=" << disableSymbolRenaming;
+
+                stream << " bindingPoints=";
+                Serialize(stream, bindingPoints);
+
+                stream << " accessControls=";
+                Serialize(stream, accessControls);
+
+                stream << " shaderModel=" << deviceInfo->shaderModel;
+                stream << " disableWorkgroupInit=" << disableWorkgroupInit;
+                stream << " isRobustnessEnabled=" << isRobustnessEnabled;
+                stream << " fxcVersion=" << fxcVersion;
+                stream << " dxcVersion=" << dxcVersion;
+                stream << " hasShaderFloat16Extension=" << hasShaderFloat16Extension;
+                stream << ")";
+                stream << "\n";
+
+                return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
+                                          std::istreambuf_iterator<char>{});
+            }
+        };
+
         std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
             std::vector<const wchar_t*> arguments;
             if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
@@ -83,85 +300,173 @@
             return arguments;
         }
 
+        ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
+                                                         IDxcCompiler* dxcCompiler,
+                                                         const ShaderCompilationRequest& request,
+                                                         const std::string& hlslSource) {
+            ComPtr<IDxcBlobEncoding> sourceBlob;
+            DAWN_TRY(
+                CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
+                                 hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
+                             "DXC create blob"));
+
+            std::wstring entryPointW;
+            DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName));
+
+            std::vector<const wchar_t*> arguments =
+                GetDXCArguments(request.compileFlags, request.hasShaderFloat16Extension);
+
+            ComPtr<IDxcOperationResult> result;
+            DAWN_TRY(CheckHRESULT(
+                dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
+                                     request.deviceInfo->shaderProfiles[request.stage].c_str(),
+                                     arguments.data(), arguments.size(), nullptr, 0, nullptr,
+                                     &result),
+                "DXC compile"));
+
+            HRESULT hr;
+            DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
+
+            if (FAILED(hr)) {
+                ComPtr<IDxcBlobEncoding> errors;
+                DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
+
+                std::string message = std::string("DXC compile failed with ") +
+                                      static_cast<char*>(errors->GetBufferPointer());
+                return DAWN_VALIDATION_ERROR(message);
+            }
+
+            ComPtr<IDxcBlob> compiledShader;
+            DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
+            return std::move(compiledShader);
+        }
+
+        ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions,
+                                                         const ShaderCompilationRequest& request,
+                                                         const std::string& hlslSource) {
+            const char* targetProfile = nullptr;
+            switch (request.stage) {
+                case SingleShaderStage::Vertex:
+                    targetProfile = "vs_5_1";
+                    break;
+                case SingleShaderStage::Fragment:
+                    targetProfile = "ps_5_1";
+                    break;
+                case SingleShaderStage::Compute:
+                    targetProfile = "cs_5_1";
+                    break;
+            }
+
+            ComPtr<ID3DBlob> compiledShader;
+            ComPtr<ID3DBlob> errors;
+
+            if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr,
+                                             nullptr, nullptr, request.entryPointName,
+                                             targetProfile, request.compileFlags, 0,
+                                             &compiledShader, &errors))) {
+                std::string message = std::string("D3D compile failed with ") +
+                                      static_cast<char*>(errors->GetBufferPointer());
+                return DAWN_VALIDATION_ERROR(message);
+            }
+
+            return std::move(compiledShader);
+        }
+
+        ResultOrError<std::string> TranslateToHLSL(const ShaderCompilationRequest& request,
+                                                   std::string* remappedEntryPointName) {
+            std::ostringstream errorStream;
+            errorStream << "Tint HLSL failure:" << std::endl;
+
+            tint::transform::Manager transformManager;
+            tint::transform::DataMap transformInputs;
+
+            if (request.isRobustnessEnabled) {
+                transformManager.Add<tint::transform::BoundArrayAccessors>();
+            }
+            transformManager.Add<tint::transform::BindingRemapper>();
+
+            transformManager.Add<tint::transform::Renamer>();
+
+            if (request.disableSymbolRenaming) {
+                // We still need to rename HLSL reserved keywords
+                transformInputs.Add<tint::transform::Renamer::Config>(
+                    tint::transform::Renamer::Target::kHlslKeywords);
+            }
+
+            // D3D12 registers like `t3` and `c3` have the same bindingOffset number in
+            // the remapping but should not be considered a collision because they have
+            // different types.
+            const bool mayCollide = true;
+            transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
+                std::move(request.bindingPoints), std::move(request.accessControls), mayCollide);
+
+            tint::Program transformedProgram;
+            tint::transform::DataMap transformOutputs;
+            DAWN_TRY_ASSIGN(transformedProgram,
+                            RunTransforms(&transformManager, request.program, transformInputs,
+                                          &transformOutputs, nullptr));
+
+            if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
+                auto it = data->remappings.find(request.entryPointName);
+                if (it != data->remappings.end()) {
+                    *remappedEntryPointName = it->second;
+                } else {
+                    if (request.disableSymbolRenaming) {
+                        *remappedEntryPointName = request.entryPointName;
+                    } else {
+                        return DAWN_VALIDATION_ERROR(
+                            "Could not find remapped name for entry point.");
+                    }
+                }
+            } else {
+                return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
+            }
+
+            tint::writer::hlsl::Options options;
+            options.disable_workgroup_init = request.disableWorkgroupInit;
+            auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
+            if (!result.success) {
+                errorStream << "Generator: " << result.error << std::endl;
+                return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
+            }
+
+            return std::move(result.hlsl);
+        }
+
+        template <typename F>
+        MaybeError CompileShader(const PlatformFunctions* functions,
+                                 IDxcLibrary* dxcLibrary,
+                                 IDxcCompiler* dxcCompiler,
+                                 ShaderCompilationRequest&& request,
+                                 bool dumpShaders,
+                                 F&& DumpShadersEmitLog,
+                                 CompiledShader* compiledShader) {
+            // Compile the source shader to HLSL.
+            std::string hlslSource;
+            std::string remappedEntryPoint;
+            DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(request, &remappedEntryPoint));
+            if (dumpShaders) {
+                std::ostringstream dumpedMsg;
+                dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
+                DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
+            }
+            request.entryPointName = remappedEntryPoint.c_str();
+            switch (request.compiler) {
+                case ShaderCompilationRequest::Compiler::DXC:
+                    DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader,
+                                    CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource));
+                    break;
+                case ShaderCompilationRequest::Compiler::FXC:
+                    DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader,
+                                    CompileShaderFXC(functions, request, hlslSource));
+                    break;
+            }
+
+            return {};
+        }
+
     }  // anonymous namespace
 
-    ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
-                                                     SingleShaderStage stage,
-                                                     const std::string& hlslSource,
-                                                     const char* entryPoint,
-                                                     uint32_t compileFlags) {
-        ComPtr<IDxcLibrary> dxcLibrary = device->GetDxcLibrary();
-
-        ComPtr<IDxcBlobEncoding> sourceBlob;
-        DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
-                                  hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
-                              "DXC create blob"));
-
-        ComPtr<IDxcCompiler> dxcCompiler = device->GetDxcCompiler();
-
-        std::wstring entryPointW;
-        DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
-
-        std::vector<const wchar_t*> arguments =
-            GetDXCArguments(compileFlags, device->IsExtensionEnabled(Extension::ShaderFloat16));
-
-        ComPtr<IDxcOperationResult> result;
-        DAWN_TRY(CheckHRESULT(
-            dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
-                                 device->GetDeviceInfo().shaderProfiles[stage].c_str(),
-                                 arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
-            "DXC compile"));
-
-        HRESULT hr;
-        DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
-
-        if (FAILED(hr)) {
-            ComPtr<IDxcBlobEncoding> errors;
-            DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
-
-            std::string message = std::string("DXC compile failed with ") +
-                                  static_cast<char*>(errors->GetBufferPointer());
-            return DAWN_VALIDATION_ERROR(message);
-        }
-
-        ComPtr<IDxcBlob> compiledShader;
-        DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
-        return std::move(compiledShader);
-    }
-
-    ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
-                                                     SingleShaderStage stage,
-                                                     const std::string& hlslSource,
-                                                     const char* entryPoint,
-                                                     uint32_t compileFlags) {
-        const char* targetProfile = nullptr;
-        switch (stage) {
-            case SingleShaderStage::Vertex:
-                targetProfile = "vs_5_1";
-                break;
-            case SingleShaderStage::Fragment:
-                targetProfile = "ps_5_1";
-                break;
-            case SingleShaderStage::Compute:
-                targetProfile = "cs_5_1";
-                break;
-        }
-
-        ComPtr<ID3DBlob> compiledShader;
-        ComPtr<ID3DBlob> errors;
-
-        const PlatformFunctions* functions = device->GetFunctions();
-        if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
-                                         nullptr, entryPoint, targetProfile, compileFlags, 0,
-                                         &compiledShader, &errors))) {
-            std::string message = std::string("D3D compile failed with ") +
-                                  static_cast<char*>(errors->GetBufferPointer());
-            return DAWN_VALIDATION_ERROR(message);
-        }
-
-        return std::move(compiledShader);
-    }
-
     // static
     ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
                                                           const ShaderModuleDescriptor* descriptor,
@@ -180,120 +485,13 @@
         return InitializeBase(parseResult);
     }
 
-    ResultOrError<std::string> ShaderModule::TranslateToHLSL(
-        const tint::Program* program,
-        const char* entryPointName,
-        SingleShaderStage stage,
-        PipelineLayout* layout,
-        std::string* remappedEntryPointName) const {
-        ASSERT(!IsError());
-
-        ScopedTintICEHandler scopedICEHandler(GetDevice());
-
-        using BindingRemapper = tint::transform::BindingRemapper;
-        using BindingPoint = tint::transform::BindingPoint;
-        BindingRemapper::BindingPoints bindingPoints;
-        BindingRemapper::AccessControls accessControls;
-
-        const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
-
-        // d3d12::BindGroupLayout packs the bindings per HLSL register-space.
-        // We modify the Tint AST to make the "bindings" decoration match the
-        // offset chosen by d3d12::BindGroupLayout so that Tint produces HLSL
-        // with the correct registers assigned to each interface variable.
-        for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
-            const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
-            const auto& groupBindingInfo = moduleBindingInfo[group];
-            for (const auto& it : groupBindingInfo) {
-                BindingNumber binding = it.first;
-                auto const& bindingInfo = it.second;
-                BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
-                BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
-                                             static_cast<uint32_t>(binding)};
-                BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
-                                             bgl->GetShaderRegister(bindingIndex)};
-                if (srcBindingPoint != dstBindingPoint) {
-                    bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
-                }
-
-                // Declaring a read-only storage buffer in HLSL but specifying a
-                // storage buffer in the BGL produces the wrong output.
-                // Force read-only storage buffer bindings to be treated as UAV
-                // instead of SRV.
-                // Internal storage buffer is a storage buffer used in the internal pipeline.
-                const bool forceStorageBufferAsUAV =
-                    (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
-                     (bgl->GetBindingInfo(bindingIndex).buffer.type ==
-                          wgpu::BufferBindingType::Storage ||
-                      bgl->GetBindingInfo(bindingIndex).buffer.type ==
-                          kInternalStorageBufferBinding));
-                if (forceStorageBufferAsUAV) {
-                    accessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite);
-                }
-            }
-        }
-
-        std::ostringstream errorStream;
-        errorStream << "Tint HLSL failure:" << std::endl;
-
-        tint::transform::Manager transformManager;
-        tint::transform::DataMap transformInputs;
-
-        if (GetDevice()->IsRobustnessEnabled()) {
-            transformManager.Add<tint::transform::BoundArrayAccessors>();
-        }
-        transformManager.Add<tint::transform::BindingRemapper>();
-
-        transformManager.Add<tint::transform::Renamer>();
-
-        if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
-            // We still need to rename HLSL reserved keywords
-            transformInputs.Add<tint::transform::Renamer::Config>(
-                tint::transform::Renamer::Target::kHlslKeywords);
-        }
-
-        // D3D12 registers like `t3` and `c3` have the same bindingOffset number in the
-        // remapping but should not be considered a collision because they have different types.
-        const bool mayCollide = true;
-        transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
-                                                         std::move(accessControls), mayCollide);
-
-        tint::Program transformedProgram;
-        tint::transform::DataMap transformOutputs;
-        DAWN_TRY_ASSIGN(
-            transformedProgram,
-            RunTransforms(&transformManager, program, transformInputs, &transformOutputs, nullptr));
-
-        if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
-            auto it = data->remappings.find(entryPointName);
-            if (it != data->remappings.end()) {
-                *remappedEntryPointName = it->second;
-            } else {
-                if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
-                    *remappedEntryPointName = entryPointName;
-                } else {
-                    return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point.");
-                }
-            }
-        } else {
-            return DAWN_VALIDATION_ERROR("Transform output missing renamer data.");
-        }
-
-        tint::writer::hlsl::Options options;
-        options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
-        auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
-        if (!result.success) {
-            errorStream << "Generator: " << result.error << std::endl;
-            return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
-        }
-
-        return std::move(result.hlsl);
-    }
-
     ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
                                                         SingleShaderStage stage,
                                                         PipelineLayout* layout,
                                                         uint32_t compileFlags) {
+        ASSERT(!IsError());
+        ScopedTintICEHandler scopedICEHandler(GetDevice());
+
         Device* device = ToBackend(GetDevice());
 
         CompiledShader compiledShader = {};
@@ -333,47 +531,33 @@
             program = GetTintProgram();
         }
 
-        // Compile the source shader to HLSL.
-        std::string hlslSource;
-        std::string remappedEntryPoint;
-        DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(program, entryPointName, stage, layout,
-                                                    &remappedEntryPoint));
-        entryPointName = remappedEntryPoint.c_str();
+        ShaderCompilationRequest request;
+        DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
+                                     entryPointName, stage, layout, compileFlags, device, program,
+                                     GetEntryPoint(entryPointName).bindings));
 
-        if (device->IsToggleEnabled(Toggle::DumpShaders)) {
-            std::ostringstream dumpedMsg;
-            dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
-            GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
-        }
-
-        // Use HLSL source as the input for the key since it does need to know about the pipeline
-        // layout. The pipeline layout is only required if we key from WGSL: two different pipeline
-        // layouts could be used to produce different shader blobs and the wrong shader blob could
-        // be loaded since the pipeline layout was missing from the key.
-        // The compiler flags or version used could also produce different HLSL source. HLSL key
-        // needs both to ensure the shader cache key is unique to the HLSL source.
-        // TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used.
         PersistentCacheKey shaderCacheKey;
-        DAWN_TRY_ASSIGN(shaderCacheKey,
-                        CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags));
+        DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
 
-        DAWN_TRY_ASSIGN(compiledShader.cachedShader,
-                        device->GetPersistentCache()->GetOrCreate(
-                            shaderCacheKey, [&](auto doCache) -> MaybeError {
-                                if (device->IsToggleEnabled(Toggle::UseDXC)) {
-                                    DAWN_TRY_ASSIGN(compiledShader.compiledDXCShader,
-                                                    CompileShaderDXC(device, stage, hlslSource,
-                                                                     entryPointName, compileFlags));
-                                } else {
-                                    DAWN_TRY_ASSIGN(compiledShader.compiledFXCShader,
-                                                    CompileShaderFXC(device, stage, hlslSource,
-                                                                     entryPointName, compileFlags));
-                                }
-                                const D3D12_SHADER_BYTECODE shader =
-                                    compiledShader.GetD3D12ShaderBytecode();
-                                doCache(shader.pShaderBytecode, shader.BytecodeLength);
-                                return {};
-                            }));
+        DAWN_TRY_ASSIGN(
+            compiledShader.cachedShader,
+            device->GetPersistentCache()->GetOrCreate(
+                shaderCacheKey, [&](auto doCache) -> MaybeError {
+                    DAWN_TRY(CompileShader(
+                        device->GetFunctions(),
+                        device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get()
+                                                                : nullptr,
+                        device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get()
+                                                                : nullptr,
+                        std::move(request), device->IsToggleEnabled(Toggle::DumpShaders),
+                        [&](WGPULoggingType loggingType, const char* message) {
+                            GetDevice()->EmitLog(loggingType, message);
+                        },
+                        &compiledShader));
+                    const D3D12_SHADER_BYTECODE shader = compiledShader.GetD3D12ShaderBytecode();
+                    doCache(shader.pShaderBytecode, shader.BytecodeLength);
+                    return {};
+                }));
 
         return std::move(compiledShader);
     }
@@ -389,69 +573,4 @@
         UNREACHABLE();
         return {};
     }
-
-    ResultOrError<PersistentCacheKey> ShaderModule::CreateHLSLKey(const char* entryPointName,
-                                                                  SingleShaderStage stage,
-                                                                  const std::string& hlslSource,
-                                                                  uint32_t compileFlags) const {
-        std::stringstream stream;
-
-        // Prefix the key with the type to avoid collisions from another type that could have the
-        // same key.
-        stream << static_cast<uint32_t>(PersistentKeyType::Shader);
-
-        // Provide "guard" strings that the user cannot provide to help ensure the generated HLSL
-        // used to create this key is not being manufactured by the user to load the wrong shader
-        // blob.
-        // These strings can be HLSL comments because Tint does not emit HLSL comments.
-        // TODO(dawn:549): Replace guards strings with something more secure.
-        constexpr char kStartGuard[] = "// Start shader autogenerated by Dawn.";
-        constexpr char kEndGuard[] = "// End shader autogenerated by Dawn.";
-        ASSERT(hlslSource.find(kStartGuard) == std::string::npos);
-        ASSERT(hlslSource.find(kEndGuard) == std::string::npos);
-
-        stream << kStartGuard << "\n";
-        stream << hlslSource;
-        stream << "\n" << kEndGuard;
-
-        stream << compileFlags;
-
-        // Add the HLSL compiler version for good measure.
-        // Prepend the compiler name to ensure the version is always unique.
-        if (GetDevice()->IsToggleEnabled(Toggle::UseDXC)) {
-            uint64_t dxCompilerVersion;
-            DAWN_TRY_ASSIGN(dxCompilerVersion, GetDXCompilerVersion());
-            stream << "DXC" << dxCompilerVersion;
-        } else {
-            stream << "FXC" << GetD3DCompilerVersion();
-        }
-
-        // If the source contains multiple entry points, ensure they are cached seperately
-        // per stage since DX shader code can only be compiled per stage using the same
-        // entry point.
-        stream << static_cast<uint32_t>(stage);
-        stream << entryPointName;
-
-        return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
-                                  std::istreambuf_iterator<char>{});
-    }
-
-    ResultOrError<uint64_t> ShaderModule::GetDXCompilerVersion() const {
-        ComPtr<IDxcValidator> dxcValidator = ToBackend(GetDevice())->GetDxcValidator();
-
-        ComPtr<IDxcVersionInfo> versionInfo;
-        DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),
-                              "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo"));
-
-        uint32_t compilerMajor, compilerMinor;
-        DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor),
-                              "IDxcVersionInfo::GetVersion"));
-
-        // Pack both into a single version number.
-        return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor;
-    }
-
-    uint64_t ShaderModule::GetD3DCompilerVersion() const {
-        return D3D_COMPILER_VERSION;
-    }
 }}  // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h
index 436d7bd..880a35c 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.h
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h
@@ -58,20 +58,6 @@
         ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
         ~ShaderModule() override = default;
         MaybeError Initialize(ShaderModuleParseResult* parseResult);
-
-        ResultOrError<std::string> TranslateToHLSL(const tint::Program* program,
-                                                   const char* entryPointName,
-                                                   SingleShaderStage stage,
-                                                   PipelineLayout* layout,
-                                                   std::string* remappedEntryPointName) const;
-
-        ResultOrError<PersistentCacheKey> CreateHLSLKey(const char* entryPointName,
-                                                        SingleShaderStage stage,
-                                                        const std::string& hlslSource,
-                                                        uint32_t compileFlags) const;
-
-        ResultOrError<uint64_t> GetDXCompilerVersion() const;
-        uint64_t GetD3DCompilerVersion() const;
     };
 
 }}  // namespace dawn_native::d3d12