// Copyright 2017 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dawn_native/d3d12/ShaderModuleD3D12.h"

#include "common/Assert.h"
#include "common/BitSetIterator.h"
#include "common/Log.h"
#include "common/WindowsUtils.h"
#include "dawn_native/Pipeline.h"
#include "dawn_native/TintUtils.h"
#include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
#include "dawn_native/d3d12/D3D12Error.h"
#include "dawn_native/d3d12/DeviceD3D12.h"
#include "dawn_native/d3d12/PipelineLayoutD3D12.h"
#include "dawn_native/d3d12/PlatformFunctions.h"
#include "dawn_native/d3d12/UtilsD3D12.h"

#include <d3dcompiler.h>

#include <tint/tint.h>
#include <map>
#include <sstream>
#include <unordered_map>

namespace dawn_native::d3d12 {

    namespace {
        ResultOrError<uint64_t> GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator) {
            ComPtr<IDxcVersionInfo> versionInfo;
            DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),
                                  "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo"));

            uint32_t compilerMajor, compilerMinor;
            DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor),
                                  "IDxcVersionInfo::GetVersion"));

            // Pack both into a single version number.
            return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor;
        }

        uint64_t GetD3DCompilerVersion() {
            return D3D_COMPILER_VERSION;
        }

        struct CompareBindingPoint {
            constexpr bool operator()(const tint::transform::BindingPoint& lhs,
                                      const tint::transform::BindingPoint& rhs) const {
                if (lhs.group != rhs.group) {
                    return lhs.group < rhs.group;
                } else {
                    return lhs.binding < rhs.binding;
                }
            }
        };

        void Serialize(std::stringstream& output, const tint::ast::Access& access) {
            output << access;
        }

        void Serialize(std::stringstream& output,
                       const tint::transform::BindingPoint& binding_point) {
            output << "(BindingPoint";
            output << " group=" << binding_point.group;
            output << " binding=" << binding_point.binding;
            output << ")";
        }

        template <typename T,
                  typename = typename std::enable_if<std::is_fundamental<T>::value>::type>
        void Serialize(std::stringstream& output, const T& val) {
            output << val;
        }

        template <typename T>
        void Serialize(std::stringstream& output,
                       const std::unordered_map<tint::transform::BindingPoint, T>& map) {
            output << "(map";

            std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(),
                                                                                   map.end());
            for (auto& [bindingPoint, value] : sorted) {
                output << " ";
                Serialize(output, bindingPoint);
                output << "=";
                Serialize(output, value);
            }
            output << ")";
        }

        void Serialize(std::stringstream& output,
                       const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
            output << "(ArrayLengthFromUniformOptions";
            output << " ubo_binding=";
            Serialize(output, arrayLengthFromUniform.ubo_binding);
            output << " bindpoint_to_size_index=";
            Serialize(output, arrayLengthFromUniform.bindpoint_to_size_index);
            output << ")";
        }

        // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
        std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
            std::ostringstream out;
            out.precision(n);
            out << std::fixed << v;
            return out.str();
        }

        std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
                                       const OverridableConstantScalar* entry,
                                       double value = 0) {
            switch (dawnType) {
                case EntryPointMetadata::OverridableConstant::Type::Boolean:
                    return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
                case EntryPointMetadata::OverridableConstant::Type::Float32:
                    return FloatToStringWithPrecision(entry ? entry->f32
                                                            : static_cast<float>(value));
                case EntryPointMetadata::OverridableConstant::Type::Int32:
                    return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
                case EntryPointMetadata::OverridableConstant::Type::Uint32:
                    return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
                default:
                    UNREACHABLE();
            }
        }

        constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";

        void GetOverridableConstantsDefines(
            std::vector<std::pair<std::string, std::string>>* defineStrings,
            const PipelineConstantEntries* pipelineConstantEntries,
            const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
            std::unordered_set<std::string> overriddenConstants;

            // Set pipeline overridden values
            for (const auto& [name, value] : *pipelineConstantEntries) {
                overriddenConstants.insert(name);

                // This is already validated so `name` must exist
                const auto& moduleConstant = shaderEntryPointConstants->at(name);

                defineStrings->emplace_back(
                    kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
                    GetHLSLValueString(moduleConstant.type, nullptr, value));
            }

            // Set shader initialized default values
            for (const auto& iter : *shaderEntryPointConstants) {
                const std::string& name = iter.first;
                if (overriddenConstants.count(name) != 0) {
                    // This constant already has overridden value
                    continue;
                }

                const auto& moduleConstant = shaderEntryPointConstants->at(name);

                // Uninitialized default values are okay since they ar only defined to pass
                // compilation but not used
                defineStrings->emplace_back(
                    kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
                    GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
            }
        }

        // The inputs to a shader compilation. These have been intentionally isolated from the
        // device to help ensure that the pipeline cache key contains all inputs for compilation.
        struct ShaderCompilationRequest {
            enum Compiler { FXC, DXC };

            // Common inputs
            Compiler compiler;
            const tint::Program* program;
            const char* entryPointName;
            SingleShaderStage stage;
            uint32_t compileFlags;
            bool disableSymbolRenaming;
            tint::transform::BindingRemapper::BindingPoints remappedBindingPoints;
            tint::transform::BindingRemapper::AccessControls remappedAccessControls;
            bool isRobustnessEnabled;
            bool usesNumWorkgroups;
            uint32_t numWorkgroupsRegisterSpace;
            uint32_t numWorkgroupsShaderRegister;
            tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
            std::vector<std::pair<std::string, std::string>> defineStrings;

            // FXC/DXC common inputs
            bool disableWorkgroupInit;

            // FXC inputs
            uint64_t fxcVersion;

            // DXC inputs
            uint64_t dxcVersion;
            const D3D12DeviceInfo* deviceInfo;
            bool hasShaderFloat16Feature;

            static ResultOrError<ShaderCompilationRequest> Create(
                const char* entryPointName,
                SingleShaderStage stage,
                const PipelineLayout* layout,
                uint32_t compileFlags,
                const Device* device,
                const tint::Program* program,
                const EntryPointMetadata& entryPoint,
                const ProgrammableStage& programmableStage) {
                Compiler compiler;
                uint64_t dxcVersion = 0;
                if (device->IsToggleEnabled(Toggle::UseDXC)) {
                    compiler = Compiler::DXC;
                    DAWN_TRY_ASSIGN(dxcVersion, GetDXCompilerVersion(device->GetDxcValidator()));
                } else {
                    compiler = Compiler::FXC;
                }

                using tint::transform::BindingPoint;
                using tint::transform::BindingRemapper;

                BindingRemapper::BindingPoints remappedBindingPoints;
                BindingRemapper::AccessControls remappedAccessControls;

                tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
                arrayLengthFromUniform.ubo_binding = {
                    layout->GetDynamicStorageBufferLengthsRegisterSpace(),
                    layout->GetDynamicStorageBufferLengthsShaderRegister()};

                const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
                for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
                    const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
                    const auto& groupBindingInfo = moduleBindingInfo[group];

                    // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
                    // the Tint AST to make the "bindings" decoration match the offset chosen by
                    // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
                    // assigned to each interface variable.
                    for (const auto& [binding, bindingInfo] : groupBindingInfo) {
                        BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
                        BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
                                                     static_cast<uint32_t>(binding)};
                        BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
                                                     bgl->GetShaderRegister(bindingIndex)};
                        if (srcBindingPoint != dstBindingPoint) {
                            remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
                        }

                        // Declaring a read-only storage buffer in HLSL but specifying a storage
                        // buffer in the BGL produces the wrong output. Force read-only storage
                        // buffer bindings to be treated as UAV instead of SRV. Internal storage
                        // buffer is a storage buffer used in the internal pipeline.
                        const bool forceStorageBufferAsUAV =
                            (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
                             (bgl->GetBindingInfo(bindingIndex).buffer.type ==
                                  wgpu::BufferBindingType::Storage ||
                              bgl->GetBindingInfo(bindingIndex).buffer.type ==
                                  kInternalStorageBufferBinding));
                        if (forceStorageBufferAsUAV) {
                            remappedAccessControls.emplace(srcBindingPoint,
                                                           tint::ast::Access::kReadWrite);
                        }
                    }

                    // Add arrayLengthFromUniform options
                    {
                        for (const auto& bindingAndRegisterOffset :
                             layout->GetDynamicStorageBufferLengthInfo()[group]
                                 .bindingAndRegisterOffsets) {
                            BindingNumber binding = bindingAndRegisterOffset.binding;
                            uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;

                            BindingPoint bindingPoint{static_cast<uint32_t>(group),
                                                      static_cast<uint32_t>(binding)};
                            // Get the renamed binding point if it was remapped.
                            auto it = remappedBindingPoints.find(bindingPoint);
                            if (it != remappedBindingPoints.end()) {
                                bindingPoint = it->second;
                            }

                            arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
                                                                                   registerOffset);
                        }
                    }
                }

                ShaderCompilationRequest request;
                request.compiler = compiler;
                request.program = program;
                request.entryPointName = entryPointName;
                request.stage = stage;
                request.compileFlags = compileFlags;
                request.disableSymbolRenaming =
                    device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
                request.remappedBindingPoints = std::move(remappedBindingPoints);
                request.remappedAccessControls = std::move(remappedAccessControls);
                request.isRobustnessEnabled = device->IsRobustnessEnabled();
                request.disableWorkgroupInit =
                    device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
                request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
                request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
                request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
                request.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
                request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
                request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
                request.deviceInfo = &device->GetDeviceInfo();
                request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);

                GetOverridableConstantsDefines(
                    &request.defineStrings, &programmableStage.constants,
                    &programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
                         .overridableConstants);

                return std::move(request);
            }

            ResultOrError<PersistentCacheKey> CreateCacheKey() const {
                // Generate the WGSL from the Tint program so it's normalized.
                // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more
                // compact representation.
                auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{});
                if (!result.success) {
                    std::ostringstream errorStream;
                    errorStream << "Tint WGSL failure:" << std::endl;
                    errorStream << "Generator: " << result.error << std::endl;
                    return DAWN_INTERNAL_ERROR(errorStream.str().c_str());
                }

                std::stringstream stream;

                // Prefix the key with the type to avoid collisions from another type that could
                // have the same key.
                stream << static_cast<uint32_t>(PersistentKeyType::Shader);
                stream << "\n";

                stream << result.wgsl.length();
                stream << "\n";

                stream << result.wgsl;
                stream << "\n";

                stream << "(ShaderCompilationRequest";
                stream << " compiler=" << compiler;
                stream << " entryPointName=" << entryPointName;
                stream << " stage=" << uint32_t(stage);
                stream << " compileFlags=" << compileFlags;
                stream << " disableSymbolRenaming=" << disableSymbolRenaming;

                stream << " remappedBindingPoints=";
                Serialize(stream, remappedBindingPoints);

                stream << " remappedAccessControls=";
                Serialize(stream, remappedAccessControls);

                stream << " useNumWorkgroups=" << usesNumWorkgroups;
                stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
                stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;

                stream << " arrayLengthFromUniform=";
                Serialize(stream, arrayLengthFromUniform);

                stream << " shaderModel=" << deviceInfo->shaderModel;
                stream << " disableWorkgroupInit=" << disableWorkgroupInit;
                stream << " isRobustnessEnabled=" << isRobustnessEnabled;
                stream << " fxcVersion=" << fxcVersion;
                stream << " dxcVersion=" << dxcVersion;
                stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;

                stream << " defines={";
                for (const auto& [name, value] : defineStrings) {
                    stream << " <" << name << "," << value << ">";
                }
                stream << " }";

                stream << ")";
                stream << "\n";

                return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
                                          std::istreambuf_iterator<char>{});
            }
        };

        std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
            std::vector<const wchar_t*> arguments;
            if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
                arguments.push_back(L"/Gec");
            }
            if (compileFlags & D3DCOMPILE_IEEE_STRICTNESS) {
                arguments.push_back(L"/Gis");
            }
            constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2;
            if (compileFlags & d3dCompileFlagsBits) {
                switch (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) {
                    case D3DCOMPILE_OPTIMIZATION_LEVEL0:
                        arguments.push_back(L"/O0");
                        break;
                    case D3DCOMPILE_OPTIMIZATION_LEVEL2:
                        arguments.push_back(L"/O2");
                        break;
                    case D3DCOMPILE_OPTIMIZATION_LEVEL3:
                        arguments.push_back(L"/O3");
                        break;
                }
            }
            if (compileFlags & D3DCOMPILE_DEBUG) {
                arguments.push_back(L"/Zi");
            }
            if (compileFlags & D3DCOMPILE_PACK_MATRIX_ROW_MAJOR) {
                arguments.push_back(L"/Zpr");
            }
            if (compileFlags & D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR) {
                arguments.push_back(L"/Zpc");
            }
            if (compileFlags & D3DCOMPILE_AVOID_FLOW_CONTROL) {
                arguments.push_back(L"/Gfa");
            }
            if (compileFlags & D3DCOMPILE_PREFER_FLOW_CONTROL) {
                arguments.push_back(L"/Gfp");
            }
            if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) {
                arguments.push_back(L"/res_may_alias");
            }

            if (enable16BitTypes) {
                // enable-16bit-types are only allowed in -HV 2018 (default)
                arguments.push_back(L"/enable-16bit-types");
            }

            arguments.push_back(L"-HV");
            arguments.push_back(L"2018");

            return arguments;
        }

        ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
                                                         IDxcCompiler* dxcCompiler,
                                                         const ShaderCompilationRequest& request,
                                                         const std::string& hlslSource) {
            ComPtr<IDxcBlobEncoding> sourceBlob;
            DAWN_TRY(
                CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
                                 hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
                             "DXC create blob"));

            std::wstring entryPointW;
            DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName));

            std::vector<const wchar_t*> arguments =
                GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);

            // Build defines for overridable constants
            std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
            defineStrings.reserve(request.defineStrings.size());
            for (const auto& [name, value] : request.defineStrings) {
                defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str()));
            }

            std::vector<DxcDefine> dxcDefines;
            dxcDefines.reserve(defineStrings.size());
            for (const auto& [name, value] : defineStrings) {
                dxcDefines.push_back({name.c_str(), value.c_str()});
            }

            ComPtr<IDxcOperationResult> result;
            DAWN_TRY(CheckHRESULT(
                dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
                                     request.deviceInfo->shaderProfiles[request.stage].c_str(),
                                     arguments.data(), arguments.size(), dxcDefines.data(),
                                     dxcDefines.size(), nullptr, &result),
                "DXC compile"));

            HRESULT hr;
            DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));

            if (FAILED(hr)) {
                ComPtr<IDxcBlobEncoding> errors;
                DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));

                return DAWN_FORMAT_VALIDATION_ERROR("DXC compile failed with: %s",
                                                    static_cast<char*>(errors->GetBufferPointer()));
            }

            ComPtr<IDxcBlob> compiledShader;
            DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
            return std::move(compiledShader);
        }

        std::string CompileFlagsToStringFXC(uint32_t compileFlags) {
            struct Flag {
                uint32_t value;
                const char* name;
            };
            constexpr Flag flags[] = {
            // Populated from d3dcompiler.h
#define F(f) Flag{f, #f}
                F(D3DCOMPILE_DEBUG),
                F(D3DCOMPILE_SKIP_VALIDATION),
                F(D3DCOMPILE_SKIP_OPTIMIZATION),
                F(D3DCOMPILE_PACK_MATRIX_ROW_MAJOR),
                F(D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR),
                F(D3DCOMPILE_PARTIAL_PRECISION),
                F(D3DCOMPILE_FORCE_VS_SOFTWARE_NO_OPT),
                F(D3DCOMPILE_FORCE_PS_SOFTWARE_NO_OPT),
                F(D3DCOMPILE_NO_PRESHADER),
                F(D3DCOMPILE_AVOID_FLOW_CONTROL),
                F(D3DCOMPILE_PREFER_FLOW_CONTROL),
                F(D3DCOMPILE_ENABLE_STRICTNESS),
                F(D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY),
                F(D3DCOMPILE_IEEE_STRICTNESS),
                F(D3DCOMPILE_RESERVED16),
                F(D3DCOMPILE_RESERVED17),
                F(D3DCOMPILE_WARNINGS_ARE_ERRORS),
                F(D3DCOMPILE_RESOURCES_MAY_ALIAS),
                F(D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES),
                F(D3DCOMPILE_ALL_RESOURCES_BOUND),
                F(D3DCOMPILE_DEBUG_NAME_FOR_SOURCE),
                F(D3DCOMPILE_DEBUG_NAME_FOR_BINARY),
#undef F
            };

            std::string result;
            for (const Flag& f : flags) {
                if ((compileFlags & f.value) != 0) {
                    result += f.name + std::string("\n");
                }
            }

            // Optimization level must be handled separately as two bits are used, and the values
            // don't map neatly to 0-3.
            constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2;
            switch (compileFlags & d3dCompileFlagsBits) {
                case D3DCOMPILE_OPTIMIZATION_LEVEL0:
                    result += "D3DCOMPILE_OPTIMIZATION_LEVEL0";
                    break;
                case D3DCOMPILE_OPTIMIZATION_LEVEL1:
                    result += "D3DCOMPILE_OPTIMIZATION_LEVEL1";
                    break;
                case D3DCOMPILE_OPTIMIZATION_LEVEL2:
                    result += "D3DCOMPILE_OPTIMIZATION_LEVEL2";
                    break;
                case D3DCOMPILE_OPTIMIZATION_LEVEL3:
                    result += "D3DCOMPILE_OPTIMIZATION_LEVEL3";
                    break;
            }
            result += std::string("\n");

            return result;
        }

        ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions,
                                                         const ShaderCompilationRequest& request,
                                                         const std::string& hlslSource) {
            const char* targetProfile = nullptr;
            switch (request.stage) {
                case SingleShaderStage::Vertex:
                    targetProfile = "vs_5_1";
                    break;
                case SingleShaderStage::Fragment:
                    targetProfile = "ps_5_1";
                    break;
                case SingleShaderStage::Compute:
                    targetProfile = "cs_5_1";
                    break;
            }

            ComPtr<ID3DBlob> compiledShader;
            ComPtr<ID3DBlob> errors;

            // Build defines for overridable constants
            const D3D_SHADER_MACRO* pDefines = nullptr;
            std::vector<D3D_SHADER_MACRO> fxcDefines;
            if (request.defineStrings.size() > 0) {
                fxcDefines.reserve(request.defineStrings.size() + 1);
                for (const auto& [name, value] : request.defineStrings) {
                    fxcDefines.push_back({name.c_str(), value.c_str()});
                }
                // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
                fxcDefines.push_back({nullptr, nullptr});
                pDefines = fxcDefines.data();
            }

            DAWN_INVALID_IF(FAILED(functions->d3dCompile(
                                hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr,
                                request.entryPointName, targetProfile, request.compileFlags, 0,
                                &compiledShader, &errors)),
                            "D3D compile failed with: %s",
                            static_cast<char*>(errors->GetBufferPointer()));

            return std::move(compiledShader);
        }

        ResultOrError<std::string> TranslateToHLSL(const ShaderCompilationRequest& request,
                                                   std::string* remappedEntryPointName) {
            std::ostringstream errorStream;
            errorStream << "Tint HLSL failure:" << std::endl;

            tint::transform::Manager transformManager;
            tint::transform::DataMap transformInputs;

            if (request.isRobustnessEnabled) {
                transformManager.Add<tint::transform::Robustness>();
            }

            transformManager.Add<tint::transform::BindingRemapper>();

            transformManager.Add<tint::transform::SingleEntryPoint>();
            transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName);

            transformManager.Add<tint::transform::Renamer>();

            if (request.disableSymbolRenaming) {
                // We still need to rename HLSL reserved keywords
                transformInputs.Add<tint::transform::Renamer::Config>(
                    tint::transform::Renamer::Target::kHlslKeywords);
            }

            // D3D12 registers like `t3` and `c3` have the same bindingOffset number in
            // the remapping but should not be considered a collision because they have
            // different types.
            const bool mayCollide = true;
            transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
                std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls),
                mayCollide);

            tint::Program transformedProgram;
            tint::transform::DataMap transformOutputs;
            DAWN_TRY_ASSIGN(transformedProgram,
                            RunTransforms(&transformManager, request.program, transformInputs,
                                          &transformOutputs, nullptr));

            if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
                auto it = data->remappings.find(request.entryPointName);
                if (it != data->remappings.end()) {
                    *remappedEntryPointName = it->second;
                } else {
                    DAWN_INVALID_IF(!request.disableSymbolRenaming,
                                    "Could not find remapped name for entry point.");

                    *remappedEntryPointName = request.entryPointName;
                }
            } else {
                return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data.");
            }

            tint::writer::hlsl::Options options;
            options.disable_workgroup_init = request.disableWorkgroupInit;
            if (request.usesNumWorkgroups) {
                options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
                options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
            }
            // TODO(dawn:549): HLSL generation outputs the indices into the
            // array_length_from_uniform buffer that were actually used. When the blob cache can
            // store more than compiled shaders, we should reflect these used indices and store
            // them as well. This would allow us to only upload root constants that are actually
            // read by the shader.
            options.array_length_from_uniform = request.arrayLengthFromUniform;
            auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
            DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s",
                            result.error);

            return std::move(result.hlsl);
        }

        template <typename F>
        MaybeError CompileShader(const PlatformFunctions* functions,
                                 IDxcLibrary* dxcLibrary,
                                 IDxcCompiler* dxcCompiler,
                                 ShaderCompilationRequest&& request,
                                 bool dumpShaders,
                                 F&& DumpShadersEmitLog,
                                 CompiledShader* compiledShader) {
            // Compile the source shader to HLSL.
            std::string hlslSource;
            std::string remappedEntryPoint;
            DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(request, &remappedEntryPoint));
            if (dumpShaders) {
                std::ostringstream dumpedMsg;
                dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
                DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
            }
            request.entryPointName = remappedEntryPoint.c_str();
            switch (request.compiler) {
                case ShaderCompilationRequest::Compiler::DXC:
                    DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader,
                                    CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource));
                    break;
                case ShaderCompilationRequest::Compiler::FXC:
                    DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader,
                                    CompileShaderFXC(functions, request, hlslSource));
                    break;
            }

            if (dumpShaders && request.compiler == ShaderCompilationRequest::Compiler::FXC) {
                std::ostringstream dumpedMsg;
                dumpedMsg << "/* FXC compile flags */ " << std::endl
                          << CompileFlagsToStringFXC(request.compileFlags) << std::endl;

                dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;

                ComPtr<ID3DBlob> disassembly;
                if (FAILED(functions->d3dDisassemble(
                        compiledShader->compiledFXCShader->GetBufferPointer(),
                        compiledShader->compiledFXCShader->GetBufferSize(), 0, nullptr,
                        &disassembly))) {
                    dumpedMsg << "D3D disassemble failed" << std::endl;
                } else {
                    dumpedMsg << reinterpret_cast<const char*>(disassembly->GetBufferPointer());
                }
                DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
            }

            return {};
        }

    }  // anonymous namespace

    // static
    ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
                                                          const ShaderModuleDescriptor* descriptor,
                                                          ShaderModuleParseResult* parseResult) {
        Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
        DAWN_TRY(module->Initialize(parseResult));
        return module;
    }

    ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
        : ShaderModuleBase(device, descriptor) {
    }

    MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
        ScopedTintICEHandler scopedICEHandler(GetDevice());
        return InitializeBase(parseResult);
    }

    ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,
                                                        SingleShaderStage stage,
                                                        PipelineLayout* layout,
                                                        uint32_t compileFlags) {
        ASSERT(!IsError());
        ScopedTintICEHandler scopedICEHandler(GetDevice());

        Device* device = ToBackend(GetDevice());

        CompiledShader compiledShader = {};

        tint::transform::Manager transformManager;
        tint::transform::DataMap transformInputs;

        const tint::Program* program;
        tint::Program programAsValue;
        if (stage == SingleShaderStage::Vertex) {
            transformManager.Add<tint::transform::FirstIndexOffset>();
            transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
                layout->GetFirstIndexOffsetShaderRegister(),
                layout->GetFirstIndexOffsetRegisterSpace());

            tint::transform::DataMap transformOutputs;
            DAWN_TRY_ASSIGN(programAsValue,
                            RunTransforms(&transformManager, GetTintProgram(), transformInputs,
                                          &transformOutputs, nullptr));

            if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
                // TODO(dawn:549): Consider adding this information to the pipeline cache once we
                // can store more than the shader blob in it.
                compiledShader.firstOffsetInfo.usesVertexIndex = data->has_vertex_index;
                if (compiledShader.firstOffsetInfo.usesVertexIndex) {
                    compiledShader.firstOffsetInfo.vertexIndexOffset = data->first_vertex_offset;
                }
                compiledShader.firstOffsetInfo.usesInstanceIndex = data->has_instance_index;
                if (compiledShader.firstOffsetInfo.usesInstanceIndex) {
                    compiledShader.firstOffsetInfo.instanceIndexOffset =
                        data->first_instance_offset;
                }
            }

            program = &programAsValue;
        } else {
            program = GetTintProgram();
        }

        ShaderCompilationRequest request;
        DAWN_TRY_ASSIGN(
            request, ShaderCompilationRequest::Create(
                         programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
                         program, GetEntryPoint(programmableStage.entryPoint), programmableStage));

        PersistentCacheKey shaderCacheKey;
        DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());

        DAWN_TRY_ASSIGN(
            compiledShader.cachedShader,
            device->GetPersistentCache()->GetOrCreate(
                shaderCacheKey, [&](auto doCache) -> MaybeError {
                    DAWN_TRY(CompileShader(
                        device->GetFunctions(),
                        device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get()
                                                                : nullptr,
                        device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get()
                                                                : nullptr,
                        std::move(request), device->IsToggleEnabled(Toggle::DumpShaders),
                        [&](WGPULoggingType loggingType, const char* message) {
                            GetDevice()->EmitLog(loggingType, message);
                        },
                        &compiledShader));
                    const D3D12_SHADER_BYTECODE shader = compiledShader.GetD3D12ShaderBytecode();
                    doCache(shader.pShaderBytecode, shader.BytecodeLength);
                    return {};
                }));

        return std::move(compiledShader);
    }

    D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
        if (cachedShader.buffer != nullptr) {
            return {cachedShader.buffer.get(), cachedShader.bufferSize};
        } else if (compiledFXCShader != nullptr) {
            return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
        } else if (compiledDXCShader != nullptr) {
            return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
        }
        UNREACHABLE();
        return {};
    }
}  // namespace dawn_native::d3d12
