// Copyright 2017 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dawn/native/BindGroup.h"

#include "dawn/common/Assert.h"
#include "dawn/common/Math.h"
#include "dawn/common/ityp_bitset.h"
#include "dawn/native/BindGroupLayout.h"
#include "dawn/native/Buffer.h"
#include "dawn/native/ChainUtils_autogen.h"
#include "dawn/native/Device.h"
#include "dawn/native/ExternalTexture.h"
#include "dawn/native/ObjectBase.h"
#include "dawn/native/ObjectType_autogen.h"
#include "dawn/native/Sampler.h"
#include "dawn/native/Texture.h"

namespace dawn::native {

namespace {

// Helper functions to perform binding-type specific validation

MaybeError ValidateBufferBinding(const DeviceBase* device,
                                 const BindGroupEntry& entry,
                                 const BindingInfo& bindingInfo) {
    DAWN_INVALID_IF(entry.buffer == nullptr, "Binding entry buffer not set.");

    DAWN_INVALID_IF(entry.sampler != nullptr || entry.textureView != nullptr,
                    "Expected only buffer to be set for binding entry.");

    DAWN_INVALID_IF(entry.nextInChain != nullptr, "nextInChain must be nullptr.");

    DAWN_TRY(device->ValidateObject(entry.buffer));

    ASSERT(bindingInfo.bindingType == BindingInfoType::Buffer);

    wgpu::BufferUsage requiredUsage;
    uint64_t maxBindingSize;
    uint64_t requiredBindingAlignment;
    switch (bindingInfo.buffer.type) {
        case wgpu::BufferBindingType::Uniform:
            requiredUsage = wgpu::BufferUsage::Uniform;
            maxBindingSize = device->GetLimits().v1.maxUniformBufferBindingSize;
            requiredBindingAlignment = device->GetLimits().v1.minUniformBufferOffsetAlignment;
            break;
        case wgpu::BufferBindingType::Storage:
        case wgpu::BufferBindingType::ReadOnlyStorage:
            requiredUsage = wgpu::BufferUsage::Storage;
            maxBindingSize = device->GetLimits().v1.maxStorageBufferBindingSize;
            requiredBindingAlignment = device->GetLimits().v1.minStorageBufferOffsetAlignment;
            break;
        case kInternalStorageBufferBinding:
            requiredUsage = kInternalStorageBuffer;
            maxBindingSize = device->GetLimits().v1.maxStorageBufferBindingSize;
            requiredBindingAlignment = device->GetLimits().v1.minStorageBufferOffsetAlignment;
            break;
        case wgpu::BufferBindingType::Undefined:
            UNREACHABLE();
    }

    uint64_t bufferSize = entry.buffer->GetSize();

    // Handle wgpu::WholeSize, avoiding overflows.
    DAWN_INVALID_IF(entry.offset > bufferSize,
                    "Binding offset (%u) is larger than the size (%u) of %s.", entry.offset,
                    bufferSize, entry.buffer);

    uint64_t bindingSize =
        (entry.size == wgpu::kWholeSize) ? bufferSize - entry.offset : entry.size;

    DAWN_INVALID_IF(bindingSize > bufferSize,
                    "Binding size (%u) is larger than the size (%u) of %s.", bindingSize,
                    bufferSize, entry.buffer);

    DAWN_INVALID_IF(bindingSize == 0, "Binding size is zero");

    // Note that no overflow can happen because we already checked that
    // bufferSize >= bindingSize
    DAWN_INVALID_IF(entry.offset > bufferSize - bindingSize,
                    "Binding range (offset: %u, size: %u) doesn't fit in the size (%u) of %s.",
                    entry.offset, bufferSize, bindingSize, entry.buffer);

    DAWN_INVALID_IF(!IsAligned(entry.offset, requiredBindingAlignment),
                    "Offset (%u) does not satisfy the minimum %s alignment (%u).", entry.offset,
                    bindingInfo.buffer.type, requiredBindingAlignment);

    DAWN_INVALID_IF(!(entry.buffer->GetUsage() & requiredUsage),
                    "Binding usage (%s) of %s doesn't match expected usage (%s).",
                    entry.buffer->GetUsageExternalOnly(), entry.buffer, requiredUsage);

    DAWN_INVALID_IF(bindingSize < bindingInfo.buffer.minBindingSize,
                    "Binding size (%u) is smaller than the minimum binding size (%u).", bindingSize,
                    bindingInfo.buffer.minBindingSize);

    DAWN_INVALID_IF(bindingSize > maxBindingSize,
                    "Binding size (%u) is larger than the maximum binding size (%u).", bindingSize,
                    maxBindingSize);

    return {};
}

MaybeError ValidateTextureBinding(DeviceBase* device,
                                  const BindGroupEntry& entry,
                                  const BindingInfo& bindingInfo) {
    DAWN_INVALID_IF(entry.textureView == nullptr, "Binding entry textureView not set.");

    DAWN_INVALID_IF(entry.sampler != nullptr || entry.buffer != nullptr,
                    "Expected only textureView to be set for binding entry.");

    DAWN_INVALID_IF(entry.nextInChain != nullptr, "nextInChain must be nullptr.");

    DAWN_TRY(device->ValidateObject(entry.textureView));

    TextureViewBase* view = entry.textureView;

    Aspect aspect = view->GetAspects();
    DAWN_INVALID_IF(!HasOneBit(aspect), "Multiple aspects (%s) selected in %s.", aspect, view);

    TextureBase* texture = view->GetTexture();
    switch (bindingInfo.bindingType) {
        case BindingInfoType::Texture: {
            SampleTypeBit supportedTypes =
                texture->GetFormat().GetAspectInfo(aspect).supportedSampleTypes;
            SampleTypeBit requiredType = SampleTypeToSampleTypeBit(bindingInfo.texture.sampleType);

            DAWN_INVALID_IF(!(texture->GetUsage() & wgpu::TextureUsage::TextureBinding),
                            "Usage (%s) of %s doesn't include TextureUsage::TextureBinding.",
                            texture->GetUsage(), texture);

            DAWN_INVALID_IF(texture->IsMultisampledTexture() != bindingInfo.texture.multisampled,
                            "Sample count (%u) of %s doesn't match expectation (multisampled: %d).",
                            texture->GetSampleCount(), texture, bindingInfo.texture.multisampled);

            DAWN_INVALID_IF(
                (supportedTypes & requiredType) == 0,
                "None of the supported sample types (%s) of %s match the expected sample "
                "types (%s).",
                supportedTypes, texture, requiredType);

            DAWN_INVALID_IF(entry.textureView->GetDimension() != bindingInfo.texture.viewDimension,
                            "Dimension (%s) of %s doesn't match the expected dimension (%s).",
                            entry.textureView->GetDimension(), entry.textureView,
                            bindingInfo.texture.viewDimension);
            break;
        }
        case BindingInfoType::StorageTexture: {
            DAWN_INVALID_IF(!(texture->GetUsage() & wgpu::TextureUsage::StorageBinding),
                            "Usage (%s) of %s doesn't include TextureUsage::StorageBinding.",
                            texture->GetUsage(), texture);

            ASSERT(!texture->IsMultisampledTexture());

            DAWN_INVALID_IF(texture->GetFormat().format != bindingInfo.storageTexture.format,
                            "Format (%s) of %s expected to be (%s).", texture->GetFormat().format,
                            texture, bindingInfo.storageTexture.format);

            DAWN_INVALID_IF(
                entry.textureView->GetDimension() != bindingInfo.storageTexture.viewDimension,
                "Dimension (%s) of %s doesn't match the expected dimension (%s).",
                entry.textureView->GetDimension(), entry.textureView,
                bindingInfo.storageTexture.viewDimension);

            DAWN_INVALID_IF(entry.textureView->GetLevelCount() != 1,
                            "mipLevelCount (%u) of %s expected to be 1.",
                            entry.textureView->GetLevelCount(), entry.textureView);
            break;
        }
        default:
            UNREACHABLE();
            break;
    }

    return {};
}

MaybeError ValidateSamplerBinding(const DeviceBase* device,
                                  const BindGroupEntry& entry,
                                  const BindingInfo& bindingInfo) {
    DAWN_INVALID_IF(entry.sampler == nullptr, "Binding entry sampler not set.");

    DAWN_INVALID_IF(entry.textureView != nullptr || entry.buffer != nullptr,
                    "Expected only sampler to be set for binding entry.");

    DAWN_INVALID_IF(entry.nextInChain != nullptr, "nextInChain must be nullptr.");

    DAWN_TRY(device->ValidateObject(entry.sampler));

    ASSERT(bindingInfo.bindingType == BindingInfoType::Sampler);

    switch (bindingInfo.sampler.type) {
        case wgpu::SamplerBindingType::NonFiltering:
            DAWN_INVALID_IF(entry.sampler->IsFiltering(),
                            "Filtering sampler %s is incompatible with non-filtering sampler "
                            "binding.",
                            entry.sampler);
            [[fallthrough]];
        case wgpu::SamplerBindingType::Filtering:
            DAWN_INVALID_IF(entry.sampler->IsComparison(),
                            "Comparison sampler %s is incompatible with non-comparison sampler "
                            "binding.",
                            entry.sampler);
            break;
        case wgpu::SamplerBindingType::Comparison:
            DAWN_INVALID_IF(!entry.sampler->IsComparison(),
                            "Non-comparison sampler %s is imcompatible with comparison sampler "
                            "binding.",
                            entry.sampler);
            break;
        default:
            UNREACHABLE();
            break;
    }

    return {};
}

MaybeError ValidateExternalTextureBinding(
    const DeviceBase* device,
    const BindGroupEntry& entry,
    const ExternalTextureBindingEntry* externalTextureBindingEntry,
    const ExternalTextureBindingExpansionMap& expansions) {
    DAWN_INVALID_IF(externalTextureBindingEntry == nullptr,
                    "Binding entry external texture not set.");

    DAWN_INVALID_IF(
        entry.sampler != nullptr || entry.textureView != nullptr || entry.buffer != nullptr,
        "Expected only external texture to be set for binding entry.");

    DAWN_INVALID_IF(expansions.find(BindingNumber(entry.binding)) == expansions.end(),
                    "External texture binding entry %u is not present in the bind group layout.",
                    entry.binding);

    DAWN_TRY(ValidateSingleSType(externalTextureBindingEntry->nextInChain,
                                 wgpu::SType::ExternalTextureBindingEntry));

    DAWN_TRY(device->ValidateObject(externalTextureBindingEntry->externalTexture));

    return {};
}

}  // anonymous namespace

MaybeError ValidateBindGroupDescriptor(DeviceBase* device, const BindGroupDescriptor* descriptor) {
    DAWN_INVALID_IF(descriptor->nextInChain != nullptr, "nextInChain must be nullptr.");

    DAWN_TRY(device->ValidateObject(descriptor->layout));

    DAWN_INVALID_IF(
        descriptor->entryCount != descriptor->layout->GetUnexpandedBindingCount(),
        "Number of entries (%u) did not match the number of entries (%u) specified in %s."
        "\nExpected layout: %s",
        descriptor->entryCount, static_cast<uint32_t>(descriptor->layout->GetBindingCount()),
        descriptor->layout, descriptor->layout->EntriesToString());

    const BindGroupLayoutBase::BindingMap& bindingMap = descriptor->layout->GetBindingMap();
    ASSERT(bindingMap.size() <= kMaxBindingsPerPipelineLayout);

    ityp::bitset<BindingIndex, kMaxBindingsPerPipelineLayout> bindingsSet;
    for (uint32_t i = 0; i < descriptor->entryCount; ++i) {
        const BindGroupEntry& entry = descriptor->entries[i];

        const auto& it = bindingMap.find(BindingNumber(entry.binding));
        DAWN_INVALID_IF(it == bindingMap.end(),
                        "In entries[%u], binding index %u not present in the bind group layout."
                        "\nExpected layout: %s",
                        i, entry.binding, descriptor->layout->EntriesToString());

        BindingIndex bindingIndex = it->second;
        ASSERT(bindingIndex < descriptor->layout->GetBindingCount());

        DAWN_INVALID_IF(bindingsSet[bindingIndex],
                        "In entries[%u], binding index %u already used by a previous entry", i,
                        entry.binding);

        bindingsSet.set(bindingIndex);

        // Below this block we validate entries based on the bind group layout, in which
        // external textures have been expanded into their underlying contents. For this reason
        // we must identify external texture binding entries by checking the bind group entry
        // itself.
        // TODO(dawn:1293): Store external textures in
        // BindGroupLayoutBase::BindingDataPointers::bindings so checking external textures can
        // be moved in the switch below.
        const ExternalTextureBindingEntry* externalTextureBindingEntry = nullptr;
        FindInChain(entry.nextInChain, &externalTextureBindingEntry);
        if (externalTextureBindingEntry != nullptr) {
            DAWN_TRY(ValidateExternalTextureBinding(
                device, entry, externalTextureBindingEntry,
                descriptor->layout->GetExternalTextureBindingExpansionMap()));
            continue;
        }

        const BindingInfo& bindingInfo = descriptor->layout->GetBindingInfo(bindingIndex);

        // Perform binding-type specific validation.
        switch (bindingInfo.bindingType) {
            case BindingInfoType::Buffer:
                DAWN_TRY_CONTEXT(ValidateBufferBinding(device, entry, bindingInfo),
                                 "validating entries[%u] as a Buffer."
                                 "\nExpected entry layout: %s",
                                 i, bindingInfo);
                break;
            case BindingInfoType::Texture:
            case BindingInfoType::StorageTexture:
                DAWN_TRY_CONTEXT(ValidateTextureBinding(device, entry, bindingInfo),
                                 "validating entries[%u] as a Texture."
                                 "\nExpected entry layout: %s",
                                 i, bindingInfo);
                break;
            case BindingInfoType::Sampler:
                DAWN_TRY_CONTEXT(ValidateSamplerBinding(device, entry, bindingInfo),
                                 "validating entries[%u] as a Sampler."
                                 "\nExpected entry layout: %s",
                                 i, bindingInfo);
                break;
            case BindingInfoType::ExternalTexture:
                UNREACHABLE();
                break;
        }
    }

    // This should always be true because
    //  - numBindings has to match between the bind group and its layout.
    //  - Each binding must be set at most once
    //
    // We don't validate the equality because it wouldn't be possible to cover it with a test.
    ASSERT(bindingsSet.count() == descriptor->layout->GetUnexpandedBindingCount());

    return {};
}  // anonymous namespace

// BindGroup

BindGroupBase::BindGroupBase(DeviceBase* device,
                             const BindGroupDescriptor* descriptor,
                             void* bindingDataStart)
    : ApiObjectBase(device, descriptor->label),
      mLayout(descriptor->layout),
      mBindingData(mLayout->ComputeBindingDataPointers(bindingDataStart)) {
    for (BindingIndex i{0}; i < mLayout->GetBindingCount(); ++i) {
        // TODO(enga): Shouldn't be needed when bindings are tightly packed.
        // This is to fill Ref<ObjectBase> holes with nullptrs.
        new (&mBindingData.bindings[i]) Ref<ObjectBase>();
    }

    for (uint32_t i = 0; i < descriptor->entryCount; ++i) {
        const BindGroupEntry& entry = descriptor->entries[i];

        BindingIndex bindingIndex =
            descriptor->layout->GetBindingIndex(BindingNumber(entry.binding));
        ASSERT(bindingIndex < mLayout->GetBindingCount());

        // Only a single binding type should be set, so once we found it we can skip to the
        // next loop iteration.

        if (entry.buffer != nullptr) {
            ASSERT(mBindingData.bindings[bindingIndex] == nullptr);
            mBindingData.bindings[bindingIndex] = entry.buffer;
            mBindingData.bufferData[bindingIndex].offset = entry.offset;
            uint64_t bufferSize = (entry.size == wgpu::kWholeSize)
                                      ? entry.buffer->GetSize() - entry.offset
                                      : entry.size;
            mBindingData.bufferData[bindingIndex].size = bufferSize;
            continue;
        }

        if (entry.textureView != nullptr) {
            ASSERT(mBindingData.bindings[bindingIndex] == nullptr);
            mBindingData.bindings[bindingIndex] = entry.textureView;
            continue;
        }

        if (entry.sampler != nullptr) {
            ASSERT(mBindingData.bindings[bindingIndex] == nullptr);
            mBindingData.bindings[bindingIndex] = entry.sampler;
            continue;
        }

        // Here we unpack external texture bindings into multiple additional bindings for the
        // external texture's contents. New binding locations previously determined in the bind
        // group layout are created in this bind group and filled with the external texture's
        // underlying resources.
        const ExternalTextureBindingEntry* externalTextureBindingEntry = nullptr;
        FindInChain(entry.nextInChain, &externalTextureBindingEntry);
        if (externalTextureBindingEntry != nullptr) {
            mBoundExternalTextures.push_back(externalTextureBindingEntry->externalTexture);

            ExternalTextureBindingExpansionMap expansions =
                mLayout->GetExternalTextureBindingExpansionMap();
            ExternalTextureBindingExpansionMap::iterator it =
                expansions.find(BindingNumber(entry.binding));

            ASSERT(it != expansions.end());

            BindingIndex plane0BindingIndex =
                descriptor->layout->GetBindingIndex(it->second.plane0);
            BindingIndex plane1BindingIndex =
                descriptor->layout->GetBindingIndex(it->second.plane1);
            BindingIndex paramsBindingIndex =
                descriptor->layout->GetBindingIndex(it->second.params);

            ASSERT(mBindingData.bindings[plane0BindingIndex] == nullptr);

            mBindingData.bindings[plane0BindingIndex] =
                externalTextureBindingEntry->externalTexture->GetTextureViews()[0];

            ASSERT(mBindingData.bindings[plane1BindingIndex] == nullptr);
            mBindingData.bindings[plane1BindingIndex] =
                externalTextureBindingEntry->externalTexture->GetTextureViews()[1];

            ASSERT(mBindingData.bindings[paramsBindingIndex] == nullptr);
            mBindingData.bindings[paramsBindingIndex] =
                externalTextureBindingEntry->externalTexture->GetParamsBuffer();
            mBindingData.bufferData[paramsBindingIndex].offset = 0;
            mBindingData.bufferData[paramsBindingIndex].size =
                sizeof(dawn_native::ExternalTextureParams);

            continue;
        }
    }

    uint32_t packedIdx = 0;
    for (BindingIndex bindingIndex{0}; bindingIndex < descriptor->layout->GetBufferCount();
         ++bindingIndex) {
        if (descriptor->layout->GetBindingInfo(bindingIndex).buffer.minBindingSize == 0) {
            mBindingData.unverifiedBufferSizes[packedIdx] =
                mBindingData.bufferData[bindingIndex].size;
            ++packedIdx;
        }
    }

    TrackInDevice();
}

BindGroupBase::BindGroupBase(DeviceBase* device) : ApiObjectBase(device, kLabelNotImplemented) {
    TrackInDevice();
}

BindGroupBase::~BindGroupBase() = default;

void BindGroupBase::DestroyImpl() {
    if (mLayout != nullptr) {
        ASSERT(!IsError());
        for (BindingIndex i{0}; i < mLayout->GetBindingCount(); ++i) {
            mBindingData.bindings[i].~Ref<ObjectBase>();
        }
    }
}

void BindGroupBase::DeleteThis() {
    // Add another ref to the layout so that if this is the last ref, the layout
    // is destroyed after the bind group. The bind group is slab-allocated inside
    // memory owned by the layout (except for the null backend).
    Ref<BindGroupLayoutBase> layout = mLayout;
    ApiObjectBase::DeleteThis();
}

BindGroupBase::BindGroupBase(DeviceBase* device, ObjectBase::ErrorTag tag)
    : ApiObjectBase(device, tag), mBindingData() {}

// static
BindGroupBase* BindGroupBase::MakeError(DeviceBase* device) {
    return new BindGroupBase(device, ObjectBase::kError);
}

ObjectType BindGroupBase::GetType() const {
    return ObjectType::BindGroup;
}

BindGroupLayoutBase* BindGroupBase::GetLayout() {
    ASSERT(!IsError());
    return mLayout.Get();
}

const BindGroupLayoutBase* BindGroupBase::GetLayout() const {
    ASSERT(!IsError());
    return mLayout.Get();
}

const ityp::span<uint32_t, uint64_t>& BindGroupBase::GetUnverifiedBufferSizes() const {
    ASSERT(!IsError());
    return mBindingData.unverifiedBufferSizes;
}

BufferBinding BindGroupBase::GetBindingAsBufferBinding(BindingIndex bindingIndex) {
    ASSERT(!IsError());
    ASSERT(bindingIndex < mLayout->GetBindingCount());
    ASSERT(mLayout->GetBindingInfo(bindingIndex).bindingType == BindingInfoType::Buffer);
    BufferBase* buffer = static_cast<BufferBase*>(mBindingData.bindings[bindingIndex].Get());
    return {buffer, mBindingData.bufferData[bindingIndex].offset,
            mBindingData.bufferData[bindingIndex].size};
}

SamplerBase* BindGroupBase::GetBindingAsSampler(BindingIndex bindingIndex) const {
    ASSERT(!IsError());
    ASSERT(bindingIndex < mLayout->GetBindingCount());
    ASSERT(mLayout->GetBindingInfo(bindingIndex).bindingType == BindingInfoType::Sampler);
    return static_cast<SamplerBase*>(mBindingData.bindings[bindingIndex].Get());
}

TextureViewBase* BindGroupBase::GetBindingAsTextureView(BindingIndex bindingIndex) {
    ASSERT(!IsError());
    ASSERT(bindingIndex < mLayout->GetBindingCount());
    ASSERT(mLayout->GetBindingInfo(bindingIndex).bindingType == BindingInfoType::Texture ||
           mLayout->GetBindingInfo(bindingIndex).bindingType == BindingInfoType::StorageTexture);
    return static_cast<TextureViewBase*>(mBindingData.bindings[bindingIndex].Get());
}

const std::vector<Ref<ExternalTextureBase>>& BindGroupBase::GetBoundExternalTextures() const {
    return mBoundExternalTextures;
}

}  // namespace dawn::native
