// Copyright 2019 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include <vector>

#include "gtest/gtest.h"

#include "dawn/dawn_proc.h"
#include "dawn/native/Instance.h"
#include "dawn/native/null/DeviceNull.h"
#include "dawn/utils/TerribleCommandBuffer.h"
#include "dawn/wire/WireClient.h"

namespace {

// dawn_wire and dawn_native contain duplicated code for the handling of GetProcAddress
// so we run the tests against both implementations. This enum is used as a test parameters to
// know which implementation to test.
enum class DawnFlavor {
    Native,
    Wire,
};

std::ostream& operator<<(std::ostream& stream, DawnFlavor flavor) {
    switch (flavor) {
        case DawnFlavor::Native:
            stream << "dawn_native";
            break;

        case DawnFlavor::Wire:
            stream << "dawn_wire";
            break;

        default:
            UNREACHABLE();
            break;
    }
    return stream;
}

class GetProcAddressTests : public testing::TestWithParam<DawnFlavor> {
  public:
    GetProcAddressTests()
        : testing::TestWithParam<DawnFlavor>(),
          mNativeInstance(dawn::native::InstanceBase::Create()),
          mNativeAdapter(mNativeInstance.Get()) {}

    void SetUp() override {
        switch (GetParam()) {
            case DawnFlavor::Native: {
                mDevice = wgpu::Device::Acquire(
                    reinterpret_cast<WGPUDevice>(mNativeAdapter.APICreateDevice()));
                mProcs = dawn::native::GetProcs();
                break;
            }

            case DawnFlavor::Wire: {
                mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();

                dawn::wire::WireClientDescriptor clientDesc = {};
                clientDesc.serializer = mC2sBuf.get();
                mWireClient = std::make_unique<dawn::wire::WireClient>(clientDesc);

                mDevice = wgpu::Device::Acquire(mWireClient->ReserveDevice().device);
                mProcs = dawn::wire::client::GetProcs();
                break;
            }

            default:
                UNREACHABLE();
                break;
        }

        dawnProcSetProcs(&mProcs);
    }

    void TearDown() override {
        // Destroy the device before freeing the instance or the wire client in the destructor
        mDevice = wgpu::Device();
    }

  protected:
    Ref<dawn::native::InstanceBase> mNativeInstance;
    dawn::native::null::Adapter mNativeAdapter;

    std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf;
    std::unique_ptr<dawn::wire::WireClient> mWireClient;

    wgpu::Device mDevice;
    DawnProcTable mProcs;
};

// Test GetProcAddress with and without devices on some valid examples
TEST_P(GetProcAddressTests, ValidExamples) {
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "wgpuDeviceCreateBuffer"),
              reinterpret_cast<WGPUProc>(mProcs.deviceCreateBuffer));
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "wgpuDeviceCreateBuffer"),
              reinterpret_cast<WGPUProc>(mProcs.deviceCreateBuffer));
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "wgpuQueueSubmit"),
              reinterpret_cast<WGPUProc>(mProcs.queueSubmit));
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "wgpuQueueSubmit"),
              reinterpret_cast<WGPUProc>(mProcs.queueSubmit));
}

// Test GetProcAddress with and without devices on nullptr procName
TEST_P(GetProcAddressTests, Nullptr) {
    ASSERT_EQ(mProcs.getProcAddress(nullptr, nullptr), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), nullptr), nullptr);
}

// Test GetProcAddress with and without devices on some invalid
TEST_P(GetProcAddressTests, InvalidExamples) {
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "wgpuDeviceDoSomething"), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "wgpuDeviceDoSomething"), nullptr);

    // Trigger the condition where lower_bound will return the end of the procMap.
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "zzzzzzz"), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "zzzzzzz"), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "ZZ"), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "ZZ"), nullptr);

    // Some more potential corner cases.
    ASSERT_EQ(mProcs.getProcAddress(nullptr, ""), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), ""), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "0"), nullptr);
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "0"), nullptr);
}

// Test that GetProcAddress supports freestanding function that are handled specially
TEST_P(GetProcAddressTests, FreeStandingFunctions) {
    ASSERT_EQ(mProcs.getProcAddress(nullptr, "wgpuGetProcAddress"),
              reinterpret_cast<WGPUProc>(mProcs.getProcAddress));
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "wgpuGetProcAddress"),
              reinterpret_cast<WGPUProc>(mProcs.getProcAddress));

    ASSERT_EQ(mProcs.getProcAddress(nullptr, "wgpuCreateInstance"),
              reinterpret_cast<WGPUProc>(mProcs.createInstance));
    ASSERT_EQ(mProcs.getProcAddress(mDevice.Get(), "wgpuCreateInstance"),
              reinterpret_cast<WGPUProc>(mProcs.createInstance));
}

INSTANTIATE_TEST_SUITE_P(,
                         GetProcAddressTests,
                         testing::Values(DawnFlavor::Native, DawnFlavor::Wire),
                         testing::PrintToStringParamName());

TEST(GetProcAddressInternalTests, CheckDawnNativeProcMapOrder) {
    std::vector<const char*> names = dawn::native::GetProcMapNamesForTesting();
    for (size_t i = 1; i < names.size(); i++) {
        ASSERT_LT(std::string(names[i - 1]), std::string(names[i]));
    }
}

TEST(GetProcAddressInternalTests, CheckDawnWireClientProcMapOrder) {
    std::vector<const char*> names = dawn::wire::client::GetProcMapNamesForTesting();
    for (size_t i = 1; i < names.size(); i++) {
        ASSERT_LT(std::string(names[i - 1]), std::string(names[i]));
    }
}
}  // anonymous namespace
