// Copyright 2017 The Dawn & Tint Authors
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
//    list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
//    this list of conditions and the following disclaimer in the documentation
//    and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
//    contributors may be used to endorse or promote products derived from
//    this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifndef SRC_DAWN_NATIVE_SHADERMODULE_H_
#define SRC_DAWN_NATIVE_SHADERMODULE_H_

#include <bitset>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "dawn/common/Constants.h"
#include "dawn/common/ContentLessObjectCacheable.h"
#include "dawn/common/MutexProtected.h"
#include "dawn/common/RefCountedWithExternalCount.h"
#include "dawn/common/Sha3.h"
#include "dawn/common/ityp_array.h"
#include "dawn/native/BindingInfo.h"
#include "dawn/native/CachedObject.h"
#include "dawn/native/CompilationMessages.h"
#include "dawn/native/Error.h"
#include "dawn/native/ErrorData.h"
#include "dawn/native/Format.h"
#include "dawn/native/Forward.h"
#include "dawn/native/IntegerTypes.h"
#include "dawn/native/Limits.h"
#include "dawn/native/LogEmitter.h"
#include "dawn/native/ObjectBase.h"
#include "dawn/native/PerStage.h"
#include "dawn/native/Serializable.h"
#include "dawn/native/dawn_platform.h"
#include "tint/tint.h"

namespace tint {

class Program;

}  // namespace tint

namespace dawn::native {

struct EntryPointMetadata;
class ShaderModuleParseRequest;

// Base component type of an inter-stage variable
enum class InterStageComponentType {
    I32,
    U32,
    F32,
    F16,
};

enum class InterpolationType {
    Perspective,
    Linear,
    Flat,
};

enum class InterpolationSampling {
    None,
    Center,
    Centroid,
    Sample,
    First,
    Either,
};

enum class PixelLocalMemberType {
    I32,
    U32,
    F32,
};

// Use map to make sure constant keys are sorted for creating shader cache keys
using PipelineConstantEntries = std::map<std::string, double>;

// A map from name to EntryPointMetadata.
using EntryPointMetadataTable =
    absl::flat_hash_map<std::string, std::unique_ptr<EntryPointMetadata>>;

struct TintProgram : public RefCounted {
    TintProgram(tint::Program program, std::unique_ptr<tint::Source::File> file)
        : program(std::move(program)), file(std::move(file)) {}
    const tint::Program program;
    const std::unique_ptr<tint::Source::File> file;  // Keep the tint::Source::File alive
};

#define CACHED_VALIDATION_ERROR_MEMBER(X) \
    X(std::string, message)               \
    X(std::vector<std::string>, contexts)
// clang-format off
DAWN_SERIALIZABLE(struct, CachedValidationError, CACHED_VALIDATION_ERROR_MEMBER){
    CachedValidationError() = default;
    explicit CachedValidationError(std::unique_ptr<ErrorData>&& errorData);
    std::unique_ptr<ErrorData> ToErrorData() const;
};
// clang-format on
#undef CACHED_VALIDATION_ERROR_MEMBER

// ShaderModuleParseResult is used for shader module creation and can be generated by
// ParseShaderModule or loaded from blob cache.
#define SHADER_MODULE_PARSE_RESULT_MEMBER(X)                                                  \
    X(UnsafeUnserializedValue<std::optional<Ref<TintProgram>>>, tintProgram)                  \
    /* EntryPointMetadataTable might be unnecessary in cases like Tint Program recreation. */ \
    X(std::optional<EntryPointMetadataTable>, metadataTable)                                  \
    X(ParsedCompilationMessages, compilationMessages)                                         \
    /* Nullopt if no validation error occurs. */                                              \
    X(std::optional<CachedValidationError>, cachedValidationError)
DAWN_SERIALIZABLE(struct, ShaderModuleParseResult, SHADER_MODULE_PARSE_RESULT_MEMBER) {
    // Check if ShaderModuleParseResult holds a valid tintProgram. A ShaderModuleParseResult loaded
    // from blob cache holds no tintProgram.
    bool HasTintProgram() const;
    // Check if ShaderModuleParseResult holds validation error.
    bool HasError() const;
    std::unique_ptr<ErrorData> ToErrorData() const;

    void SetValidationError(std::unique_ptr<ErrorData> && errorData);
};
#undef SHADER_MODULE_PARSE_RESULT_MEMBER

struct ShaderModuleEntryPoint {
    bool defaulted;
    std::string name;
};

void DumpShaderFromDescriptor(LogEmitter* logEmitter,
                              const UnpackedPtr<ShaderModuleDescriptor>& shaderModuleDesc);

// Parse a shader module from a validated ShaderModuleDescriptor, and generate reflection
// information if required. Validation errors generated during parsing are also made cacheable and
// returned within ShaderModuleParseResult together with compilation messages, rather than as an
// error (i.e. ResultOrError::IsSuccess() is true in this case). Other types of errors still get
// returned as ErrorData in ResultOrError (i.e. ResultOrError::IsError() is true).
ResultOrError<ShaderModuleParseResult> ParseShaderModule(ShaderModuleParseRequest req);

MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
                                                   const EntryPointMetadata& entryPoint,
                                                   const PipelineLayoutBase* layout);

// Return extent3D with workgroup size dimension info if it is valid.
// width = x, height = y, depthOrArrayLength = z.
ResultOrError<Extent3D> ValidateComputeStageWorkgroupSize(
    uint32_t x,
    uint32_t y,
    uint32_t z,
    size_t workgroupStorageSize,
    bool usesSubgroupMatrix,
    uint32_t maxSubgroupSize,
    const LimitsForCompilationRequest& limits,
    const LimitsForCompilationRequest& adaterSupportedlimits);

RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
                                                        const PipelineLayoutBase* layout);

// Shader metadata for a binding, very similar to information contained in a pipeline layout.
using ShaderBindingInfoVariant = std::variant<BufferBindingInfo,
                                              SamplerBindingInfo,
                                              TextureBindingInfo,
                                              StorageTextureBindingInfo,
                                              TexelBufferBindingInfo,
                                              ExternalTextureBindingInfo,
                                              InputAttachmentBindingInfo>;
#define SHADER_BINDING_INFO_MEMBER(X)              \
    X(BindingNumber, binding)                      \
    X(BindingIndex, arraySize)                     \
    /*The variable name of the binding resource.*/ \
    X(std::string, name)                           \
    X(ShaderBindingInfoVariant, bindingInfo)
DAWN_SERIALIZABLE(struct, ShaderBindingInfo, SHADER_BINDING_INFO_MEMBER){};
#undef SHADER_BINDING_INFO_MEMBER

using BindingGroupInfoMap = absl::flat_hash_map<BindingNumber, ShaderBindingInfo>;
using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;

// Shader metadata that's the equivalent for the dynamic binding arrays in the BGLs.
#define GROUP_DYNAMIC_BINDING_ARRAY_INFO_MEMBERS(X) \
    X(BindingNumber, start)                         \
    X(wgpu::DynamicBindingKind, kind)
DAWN_SERIALIZABLE(struct, GroupDynamicBindingArrayInfo, GROUP_DYNAMIC_BINDING_ARRAY_INFO_MEMBERS){};
#undef GROUP_DYNAMIC_BINDING_ARRAY_INFO_MEMBERS

using DynamicBindingArrayInfo = absl::flat_hash_map<BindGroupIndex, GroupDynamicBindingArrayInfo>;

// Define types for the shader reflection data structures in detail namespaces to prevent messing
// up dawn::native namespace. These types can be exposed within EntryPointMetadata if needed.
namespace detail {
#define SAMPLER_TEXTURE_PAIR_MEMBER(X) \
    X(BindingSlot, sampler)            \
    X(BindingSlot, texture)
DAWN_SERIALIZABLE(struct, SamplerTexturePair, SAMPLER_TEXTURE_PAIR_MEMBER){};
#undef SAMPLER_TEXTURE_PAIR_MEMBER

/// Match tint::inspector::Inspector::LevelSampleInfo
enum class TextureQueryType : uint8_t { TextureNumLevels, TextureNumSamples };

#define TEXTURE_METADATE_QUERY_MEMBER(X) \
    X(TextureQueryType, type)            \
    X(uint32_t, group)                   \
    X(uint32_t, binding)
DAWN_SERIALIZABLE(struct, TextureMetadataQuery, TEXTURE_METADATE_QUERY_MEMBER) {
    using TextureQueryType = detail::TextureQueryType;
};
#undef TEXTURE_METADATE_QUERY_MEMBER

// Structure to record the basic types (float, int and uint) of the fragment shader framebuffer
// input/outputs (inputs being "framebuffer fetch").
#define FRAGMENT_RENDER_ATTACHMENT_INFO_MEMBER(X) \
    X(TextureComponentType, baseType)             \
    X(uint8_t, componentCount)                    \
    X(uint8_t, blendSrc)
DAWN_SERIALIZABLE(struct, FragmentRenderAttachmentInfo, FRAGMENT_RENDER_ATTACHMENT_INFO_MEMBER){};
#undef FRAGMENT_RENDER_ATTACHMENT_INFO_MEMBER

#define INTER_STAGE_VARIABLE_INFO_MEMBER(X) \
    X(std::string, name)                    \
    X(InterStageComponentType, baseType)    \
    X(uint32_t, componentCount)             \
    X(InterpolationType, interpolationType) \
    X(InterpolationSampling, interpolationSampling)
DAWN_SERIALIZABLE(struct, InterStageVariableInfo, INTER_STAGE_VARIABLE_INFO_MEMBER){};
#undef INTER_STAGE_VARIABLE_INFO_MEMBER

// Match tint::OverrideId
#define OVERRIDE_ID_MEMBER(X) X(uint16_t, value)
DAWN_SERIALIZABLE(struct, OverrideId, OVERRIDE_ID_MEMBER){};
#undef OVERRIDE_ID_MEMBER

enum class OverrideType { Boolean, Float32, Uint32, Int32, Float16 };

#define OVERRIDE_MEMBER(X) \
    X(OverrideId, id)      \
    X(OverrideType, type)  \
    X(bool, isInitialized) \
    X(bool, isUsed)
DAWN_SERIALIZABLE(struct, Override, OVERRIDE_MEMBER) {
    using Type = OverrideType;
};
#undef OVERRIDE_MEMBER

using OverridesMap = absl::flat_hash_map<std::string, Override>;
}  // namespace detail

// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). This structure is
// serializable and doesn't depend on the shader program, thus it can outlive the shader program and
// get cached on disk. They are stored in the ShaderModuleBase so pointers to EntryPointMetadata are
// safe to store as long as you also keep a Ref to the ShaderModuleBase.
#define ENTRY_POINT_METADATA_MEMBER(X)                                                            \
    /* It is valid for a shader to contain entry points that go over limits. To keep this      */ \
    /* structure with packed arrays and bitsets, we still validate against limits when doing   */ \
    /* reflection, but store the errors in this vector, for later use if the application tries */ \
    /* to use the entry point.                                                                 */ \
    X(std::vector<std::string>, infringedLimitErrors)                                             \
    /* bindings[G][B] is the reflection data for the binding defined with @group(G) @binding(B)*/ \
    X(BindingInfoArray, bindings)                                                                 \
    /* dynamicBindingArray[G] is the reflection data for the dynamic binding array of @group(G)*/ \
    /* if one is present in the shader module                                                  */ \
    X(DynamicBindingArrayInfo, dynamicBindingArrays)                                              \
    /* Contains the reflection information of all sampler and non-sampler texture (storage     */ \
    /* texture not included) usage in the entry point. For non-sampler usage,                  */ \
    /* nonSamplerBindingPoint is used for sampler slot.                                        */ \
    X(std::vector<detail::SamplerTexturePair>, samplerAndNonSamplerTexturePairs)                  \
    X(std::vector<detail::TextureMetadataQuery>, textureQueries)                                  \
    /* The set of vertex attributes this entryPoint uses.*/                                       \
    X(PerVertexAttribute<VertexFormatBaseType>, vertexInputBaseTypes)                             \
    X(VertexAttributeMask, usedVertexInputs)                                                      \
    /* An array to record the basic types of the fragment shader framebuffer input/outputs.*/     \
    X(PerColorAttachment<detail::FragmentRenderAttachmentInfo>, fragmentOutputVariables)          \
    X(ColorAttachmentMask, fragmentOutputMask)                                                    \
    X(PerColorAttachment<detail::FragmentRenderAttachmentInfo>, fragmentInputVariables)           \
    X(ColorAttachmentMask, fragmentInputMask)                                                     \
    /* Now that we only support vertex and fragment stages, there can't be both inter-stage    */ \
    /* inputs and outputs in one shader stage.                                                 */ \
    X(std::vector<bool>, usedInterStageVariables)                                                 \
    X(std::vector<detail::InterStageVariableInfo>, interStageVariables)                           \
    X(uint32_t, totalInterStageShaderVariables)                                                   \
    /* The shader stage for this entry point.*/                                                   \
    X(SingleShaderStage, stage)                                                                   \
    /* Map identifier to override variable. */                                                    \
    /* Identifier is unique: either the variable name or the numeric ID if specified */           \
    X(detail::OverridesMap, overrides)                                                            \
    /* Override variables that are not initialized in shaders. They need value initialization  */ \
    /* from pipeline stage or it is a validation error                                         */ \
    X(absl::flat_hash_set<std::string>, uninitializedOverrides)                                   \
    /* Store constants with shader initialized values as well.                                 */ \
    /* This is used by metal backend to set values with default initializers that are not      */ \
    /* overridden.                                                                             */ \
    X(absl::flat_hash_set<std::string>, initializedOverrides)                                     \
    /* Reflection information about potential `pixel_local` variable use. */                      \
    X(bool, usesPixelLocal)                                                                       \
    X(size_t, pixelLocalBlockSize)                                                                \
    X(std::vector<PixelLocalMemberType>, pixelLocalMembers)                                       \
    X(bool, usesFragDepth)                                                                        \
    X(bool, usesInstanceIndex)                                                                    \
    X(bool, usesNumWorkgroups)                                                                    \
    X(bool, usesSampleMaskOutput)                                                                 \
    X(bool, usesSampleIndex)                                                                      \
    X(bool, usesVertexIndex)                                                                      \
    X(bool, usesTextureLoadWithDepthTexture)                                                      \
    X(bool, usesDepthTextureWithNonComparisonSampler)                                             \
    X(bool, usesSubgroupMatrix)                                                                   \
    /* Immediate Data block byte size */                                                          \
    X(uint32_t, immediateDataRangeByteSize)
DAWN_SERIALIZABLE(struct, EntryPointMetadata, ENTRY_POINT_METADATA_MEMBER) {
    using SamplerTexturePair = detail::SamplerTexturePair;
    // TODO(crbug.com/409438000): Remove the hack of sampler placeholders for non-sampler texture.
    static constexpr const BindingSlot nonSamplerBindingPoint{
        {BindGroupIndex{std::numeric_limits<uint32_t>::max()},
         BindingNumber{std::numeric_limits<uint32_t>::max()}}};

    using TextureMetadataQuery = detail::TextureMetadataQuery;
    using FragmentRenderAttachmentInfo = detail::FragmentRenderAttachmentInfo;
    using InterStageVariableInfo = detail::InterStageVariableInfo;
    using OverrideId = detail::OverrideId;
    using Override = detail::Override;
};
#undef ENTRY_POINT_METADATA_MEMBER

// The WebGPU override variables only support these scalar types
union OverrideScalar {
    // Use int32_t for boolean to initialize the full 32bit
    int32_t b;
    float f32;
    int32_t i32;
    uint32_t u32;
};

class ShaderModuleBase : public RefCountedWithExternalCount<ApiObjectBase>,
                         public CachedObject,
                         public ContentLessObjectCacheable<ShaderModuleBase> {
  public:
    using Base = RefCountedWithExternalCount<ApiObjectBase>;
    ShaderModuleBase(DeviceBase* device,
                     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                     std::vector<tint::wgsl::Extension> internalExtensions,
                     ApiObjectBase::UntrackedByDeviceTag tag);
    ShaderModuleBase(DeviceBase* device,
                     const UnpackedPtr<ShaderModuleDescriptor>& descriptor,
                     std::vector<tint::wgsl::Extension> internalExtensions);
    ~ShaderModuleBase() override;

    static Ref<ShaderModuleBase> MakeError(DeviceBase* device,
                                           StringView label,
                                           ParsedCompilationMessages&& compilationMessages);

    ObjectType GetType() const override;

    // Return true iff the program has an entrypoint called `entryPoint`.
    bool HasEntryPoint(absl::string_view entryPoint) const;

    // Return the number of entry points for a stage.
    size_t GetEntryPointCount(SingleShaderStage stage) const { return mEntryPointCounts[stage]; }

    // Return the entry point for a stage. If no entry point name, returns the default one.
    ShaderModuleEntryPoint ReifyEntryPointName(StringView entryPointName,
                                               SingleShaderStage stage) const;

    // Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument
    // must be true.
    const EntryPointMetadata& GetEntryPoint(absl::string_view entryPoint) const;

    // Functions necessary for the unordered_set<ShaderModuleBase*>-based cache.
    size_t ComputeContentHash() override;

    struct EqualityFunc {
        bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
    };

    std::optional<bool> GetStrictMath() const;

    using ShaderModuleHasher = Sha3_512;
    using ShaderModuleHash = ShaderModuleHasher::Output;
    const ShaderModuleHash& GetHash() const;

    using ScopedUseTintProgram = APIRef<ShaderModuleBase>;
    ScopedUseTintProgram UseTintProgram();

    // Get tintProgram, (re)create it if necessary.
    Ref<TintProgram> GetTintProgram();

    Future APIGetCompilationInfo(const WGPUCompilationInfoCallbackInfo& callbackInfo);

    const OwnedCompilationMessages* GetCompilationMessages() const;
    std::string GetCompilationLog() const;
    void SetCompilationMessagesForTesting(
        std::unique_ptr<OwnedCompilationMessages>* compilationMessages);

    // Return nullable tintProgram directly without any recreation, can be used for testing the
    // releasing/recreation behaviors.
    Ref<TintProgram> GetNullableTintProgramForTesting() const;
    int GetTintProgramRecreateCountForTesting() const;

  protected:
    void DestroyImpl() override;

    MaybeError InitializeBase(ShaderModuleParseResult* parseResult);

  private:
    ShaderModuleBase(DeviceBase* device,
                     ObjectBase::ErrorTag tag,
                     StringView label,
                     ParsedCompilationMessages&& compilationMessages);

    void WillDropLastExternalRef() override;

    // The original data in the descriptor for caching.
    enum class Type : uint8_t { Undefined, Spirv, Wgsl };
    Type mType;
    bool mAllowSpirvNonUniformDerivitives = false;
    std::vector<uint32_t> mOriginalSpirv;
    std::string mWgsl;

    // Secure hash computed from shader code and other metadata to be used as a cache key
    // representing the shader module.
    ShaderModuleHash mHash;

    // TODO(dawn:2503): Remove the optional when Dawn can has a consistent default across backends.
    // Right now D3D uses strictness by default, and Vulkan/Metal use fast math by default.
    std::optional<bool> mStrictMath;

    EntryPointMetadataTable mEntryPoints;
    PerStage<std::string> mDefaultEntryPointNames;
    PerStage<size_t> mEntryPointCounts;

    struct TintData {
        // tintProgram is nullable so that it can be lazily (re)generated right before actual using.
        Ref<TintProgram> tintProgram = nullptr;
        int tintProgramRecreateCount = 0;
    };
    MutexProtected<TintData> mTintData;

    std::unique_ptr<const OwnedCompilationMessages> mCompilationMessages;

    const std::vector<tint::wgsl::Extension> mInternalExtensions;
};

}  // namespace dawn::native

#endif  // SRC_DAWN_NATIVE_SHADERMODULE_H_
