// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <utility>
#include <vector>

#include "dawn/native/CommandBuffer.h"
#include "dawn/native/Commands.h"
#include "dawn/native/ComputePassEncoder.h"
#include "dawn/tests/DawnNativeTest.h"
#include "dawn/utils/WGPUHelpers.h"

namespace dawn::native {

    class CommandBufferEncodingTests : public DawnNativeTest {
      protected:
        void ExpectCommands(
            dawn::native::CommandIterator* commands,
            std::vector<std::pair<dawn::native::Command,
                                  std::function<void(dawn::native::CommandIterator*)>>>
                expectedCommands) {
            dawn::native::Command commandId;
            for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
                ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
                ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
                    << "at command " << commandIndex;
                expectedCommands[commandIndex].second(commands);
            }
        }
    };

    // Indirect dispatch validation changes the bind groups in the middle
    // of a pass. Test that bindings are restored after the validation runs.
    TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
        wgpu::BindGroupLayout staticLayout =
            utils::MakeBindGroupLayout(device, {{
                                                   0,
                                                   wgpu::ShaderStage::Compute,
                                                   wgpu::BufferBindingType::Uniform,
                                               }});

        wgpu::BindGroupLayout dynamicLayout =
            utils::MakeBindGroupLayout(device, {{
                                                   0,
                                                   wgpu::ShaderStage::Compute,
                                                   wgpu::BufferBindingType::Uniform,
                                                   true,
                                               }});

        // Create a simple pipeline
        wgpu::ComputePipelineDescriptor csDesc;
        csDesc.compute.module = utils::CreateShaderModule(device, R"(
        @stage(compute) @workgroup_size(1, 1, 1)
        fn main() {
        })");
        csDesc.compute.entryPoint = "main";

        wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
        csDesc.layout = pl0;
        wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);

        wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
        csDesc.layout = pl1;
        wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);

        // Create buffers to use for both the indirect buffer and the bind groups.
        wgpu::Buffer indirectBuffer = utils::CreateBufferFromData<uint32_t>(
            device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});

        wgpu::BufferDescriptor uniformBufferDesc = {};
        uniformBufferDesc.size = 512;
        uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
        wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);

        wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});

        wgpu::BindGroup dynamicBG =
            utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});

        uint32_t dynamicOffset = 256;
        std::vector<uint32_t> emptyDynamicOffsets = {};
        std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};

        // Begin encoding commands.
        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();

        CommandBufferStateTracker* stateTracker =
            FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();

        // Perform a dispatch indirect which will be preceded by a validation dispatch.
        pass.SetPipeline(pipeline0);
        pass.SetBindGroup(0, staticBG);
        pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
        EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());

        pass.DispatchWorkgroupsIndirect(indirectBuffer, 0);

        // Expect restored state.
        EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
        EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
        EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
        EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);

        // Dispatch again to check that the restored state can be used.
        // Also pass an indirect offset which should get replaced with the offset
        // into the scratch indirect buffer (0).
        pass.DispatchWorkgroupsIndirect(indirectBuffer, 4);

        // Expect restored state.
        EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
        EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
        EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
        EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);

        // Change the pipeline
        pass.SetPipeline(pipeline1);
        pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
        pass.SetBindGroup(1, staticBG);
        EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());

        pass.DispatchWorkgroupsIndirect(indirectBuffer, 0);

        // Expect restored state.
        EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
        EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
        EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
        EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
        EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);

        pass.End();

        wgpu::CommandBuffer commandBuffer = encoder.Finish();

        auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
            return [pipeline](CommandIterator* commands) {
                auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
                EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
            };
        };

        auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
                                     std::vector<uint32_t> offsets = {}) {
            return [index, bg, offsets](CommandIterator* commands) {
                auto* cmd = commands->NextCommand<SetBindGroupCmd>();
                uint32_t* dynamicOffsets = nullptr;
                if (cmd->dynamicOffsetCount > 0) {
                    dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
                }

                ASSERT_EQ(cmd->index, BindGroupIndex(index));
                ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
                ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
                for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
                    ASSERT_EQ(dynamicOffsets[i], offsets[i]);
                }
            };
        };

        // Initialize as null. Once we know the pointer, we'll check
        // that it's the same buffer every time.
        WGPUBuffer indirectScratchBuffer = nullptr;
        auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
            auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
            if (indirectScratchBuffer == nullptr) {
                indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
            }
            ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
            ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
        };

        // Initialize as null. Once we know the pointer, we'll check
        // that it's the same pipeline every time.
        WGPUComputePipeline validationPipeline = nullptr;
        auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
            auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
            WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
            if (validationPipeline != nullptr) {
                EXPECT_EQ(pipeline, validationPipeline);
            } else {
                EXPECT_NE(pipeline, nullptr);
                validationPipeline = pipeline;
            }
        };

        auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
            auto* cmd = commands->NextCommand<SetBindGroupCmd>();
            ASSERT_EQ(cmd->index, BindGroupIndex(0));
            ASSERT_NE(cmd->group.Get(), nullptr);
            ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
        };

        auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
            auto* cmd = commands->NextCommand<DispatchCmd>();
            ASSERT_EQ(cmd->x, 1u);
            ASSERT_EQ(cmd->y, 1u);
            ASSERT_EQ(cmd->z, 1u);
        };

        ExpectCommands(
            FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
            {
                {Command::BeginComputePass,
                 [&](CommandIterator* commands) {
                     SkipCommand(commands, Command::BeginComputePass);
                 }},
                // Expect the state to be set.
                {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
                {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
                {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},

                // Expect the validation.
                {Command::SetComputePipeline, ExpectSetValidationPipeline},
                {Command::SetBindGroup, ExpectSetValidationBindGroup},
                {Command::Dispatch, ExpectSetValidationDispatch},

                // Expect the state to be restored.
                {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
                {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
                {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},

                // Expect the dispatchIndirect.
                {Command::DispatchIndirect, ExpectDispatchIndirect},

                // Expect the validation.
                {Command::SetComputePipeline, ExpectSetValidationPipeline},
                {Command::SetBindGroup, ExpectSetValidationBindGroup},
                {Command::Dispatch, ExpectSetValidationDispatch},

                // Expect the state to be restored.
                {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
                {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
                {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},

                // Expect the dispatchIndirect.
                {Command::DispatchIndirect, ExpectDispatchIndirect},

                // Expect the state to be set (new pipeline).
                {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
                {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
                {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},

                // Expect the validation.
                {Command::SetComputePipeline, ExpectSetValidationPipeline},
                {Command::SetBindGroup, ExpectSetValidationBindGroup},
                {Command::Dispatch, ExpectSetValidationDispatch},

                // Expect the state to be restored.
                {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
                {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
                {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},

                // Expect the dispatchIndirect.
                {Command::DispatchIndirect, ExpectDispatchIndirect},

                {Command::EndComputePass,
                 [&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
            });
    }

    // Test that after restoring state, it is fully applied to the state tracker
    // and does not leak state changes that occured between a snapshot and the
    // state restoration.
    TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();

        CommandBufferStateTracker* stateTracker =
            FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();

        // Snapshot the state.
        CommandBufferStateTracker snapshot = *stateTracker;
        // Expect no pipeline in the snapshot
        EXPECT_FALSE(snapshot.HasPipeline());

        // Create a simple pipeline
        wgpu::ComputePipelineDescriptor csDesc;
        csDesc.compute.module = utils::CreateShaderModule(device, R"(
        @stage(compute) @workgroup_size(1, 1, 1)
        fn main() {
        })");
        csDesc.compute.entryPoint = "main";
        wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);

        // Set the pipeline.
        pass.SetPipeline(pipeline);

        // Expect the pipeline to be set.
        EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());

        // Restore the state.
        FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));

        // Expect no pipeline
        EXPECT_FALSE(stateTracker->HasPipeline());
    }

}  // namespace dawn::native
