// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SRC_TINT_UTILS_HASHMAP_BASE_H_
#define SRC_TINT_UTILS_HASHMAP_BASE_H_

#include <algorithm>
#include <functional>
#include <optional>
#include <tuple>
#include <utility>

#include "src/tint/debug.h"
#include "src/tint/utils/hash.h"
#include "src/tint/utils/vector.h"

namespace tint::utils {

/// Action taken by a map mutation
enum class MapAction {
    /// A new entry was added to the map
    kAdded,
    /// A existing entry in the map was replaced
    kReplaced,
    /// No action was taken as the map already contained an entry with the given key
    kKeptExisting,
};

/// KeyValue is a key-value pair.
template <typename KEY, typename VALUE>
struct KeyValue {
    /// The key type
    using Key = KEY;
    /// The value type
    using Value = VALUE;

    /// The key
    Key key;

    /// The value
    Value value;

    /// Equality operator
    /// @param other the RHS of the operator
    /// @returns true if both the key and value of this KeyValue are equal to the key and value
    /// of @p other
    template <typename K, typename V>
    bool operator==(const KeyValue<K, V>& other) const {
        return key == other.key && value == other.value;
    }

    /// Inequality operator
    /// @param other the RHS of the operator
    /// @returns true if either the key and value of this KeyValue are not equal to the key and
    /// value of @p other
    template <typename K, typename V>
    bool operator!=(const KeyValue<K, V>& other) const {
        return *this != other;
    }
};

/// Writes the KeyValue to the std::ostream.
/// @param out the std::ostream to write to
/// @param key_value the KeyValue to write
/// @returns out so calls can be chained
template <typename KEY, typename VALUE>
std::ostream& operator<<(std::ostream& out, const KeyValue<KEY, VALUE>& key_value) {
    return out << "[" << key_value.key << ": " << key_value.value << "]";
}

/// A base class for Hashmap and Hashset that uses a robin-hood hashing algorithm.
/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html
template <typename KEY,
          typename VALUE,
          size_t N,
          typename HASH = Hasher<KEY>,
          typename EQUAL = std::equal_to<KEY>>
class HashmapBase {
    static constexpr bool ValueIsVoid = std::is_same_v<VALUE, void>;

  public:
    /// The key type
    using Key = KEY;
    /// The value type
    using Value = VALUE;
    /// The entry type for the map.
    /// This is:
    /// - Key when Value is void (used by Hashset)
    /// - KeyValue<Key, Value> when Value is void (used by Hashmap)
    using Entry = std::conditional_t<ValueIsVoid, Key, KeyValue<Key, Value>>;

    /// STL-friendly alias to Entry. Used by gmock.
    using value_type = Entry;

  private:
    /// @returns the key from an entry
    static const Key& KeyOf(const Entry& entry) {
        if constexpr (ValueIsVoid) {
            return entry;
        } else {
            return entry.key;
        }
    }

    /// @returns a pointer to the value from an entry.
    static Value* ValueOf(Entry& entry) {
        if constexpr (ValueIsVoid) {
            return nullptr;  // Hashset only has keys
        } else {
            return &entry.value;
        }
    }

    /// A slot is a single entry in the underlying vector.
    /// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance
    /// will be zero.
    struct Slot {
        bool Equals(size_t key_hash, const Key& key) const {
            return key_hash == hash && EQUAL()(key, KeyOf(*entry));
        }

        /// The slot value. If this does not contain a value, then the slot is vacant.
        std::optional<Entry> entry;
        /// The precomputed hash of value.
        size_t hash = 0;
        size_t distance = 0;
    };

    /// The target length of the underlying vector length in relation to the number of entries in
    /// the map, expressed as a percentage. For example a value of `150` would mean there would be
    /// at least 50% more slots than the number of map entries.
    static constexpr size_t kRehashFactor = 150;

    /// @returns the target slot vector size to hold `n` map entries.
    static constexpr size_t NumSlots(size_t count) { return (count * kRehashFactor) / 100; }

    /// The fixed-size slot vector length, based on N and kRehashFactor.
    static constexpr size_t kNumFixedSlots = NumSlots(N);

    /// The minimum number of slots for the map.
    static constexpr size_t kMinSlots = std::max<size_t>(kNumFixedSlots, 4);

  public:
    /// Iterator for entries in the map.
    /// Iterators are invalidated if the map is modified.
    class Iterator {
      public:
        /// @returns the value pointed to by this iterator
        const Entry* operator->() const { return &current->entry.value(); }

        /// @returns a reference to the value at the iterator
        const Entry& operator*() const { return current->entry.value(); }

        /// Increments the iterator
        /// @returns this iterator
        Iterator& operator++() {
            if (current == end) {
                return *this;
            }
            current++;
            SkipToNextValue();
            return *this;
        }

        /// Equality operator
        /// @param other the other iterator to compare this iterator to
        /// @returns true if this iterator is equal to other
        bool operator==(const Iterator& other) const { return current == other.current; }

        /// Inequality operator
        /// @param other the other iterator to compare this iterator to
        /// @returns true if this iterator is not equal to other
        bool operator!=(const Iterator& other) const { return current != other.current; }

      private:
        /// Friend class
        friend class HashmapBase;

        Iterator(const Slot* c, const Slot* e) : current(c), end(e) { SkipToNextValue(); }

        /// Moves the iterator forward, stopping at the next slot that is not empty.
        void SkipToNextValue() {
            while (current != end && !current->entry.has_value()) {
                current++;
            }
        }

        const Slot* current;  /// The slot the iterator is pointing to
        const Slot* end;      /// One past the last slot in the map
    };

    /// Constructor
    HashmapBase() { slots_.Resize(kMinSlots); }

    /// Copy constructor
    /// @param other the other HashmapBase to copy
    HashmapBase(const HashmapBase& other) = default;

    /// Move constructor
    /// @param other the other HashmapBase to move
    HashmapBase(HashmapBase&& other) = default;

    /// Destructor
    ~HashmapBase() { Clear(); }

    /// Copy-assignment operator
    /// @param other the other HashmapBase to copy
    /// @returns this so calls can be chained
    HashmapBase& operator=(const HashmapBase& other) = default;

    /// Move-assignment operator
    /// @param other the other HashmapBase to move
    /// @returns this so calls can be chained
    HashmapBase& operator=(HashmapBase&& other) = default;

    /// Removes all entries from the map.
    void Clear() {
        slots_.Clear();  // Destructs all entries
        slots_.Resize(kMinSlots);
        count_ = 0;
        generation_++;
    }

    /// Removes an entry from the map.
    /// @param key the entry key.
    /// @returns true if an entry was removed.
    bool Remove(const Key& key) {
        const auto [found, start] = IndexOf(key);
        if (!found) {
            return false;
        }

        // Shuffle the entries backwards until we either find a free slot, or a slot that has zero
        // distance.
        Slot* prev = nullptr;
        Scan(start, [&](size_t, size_t index) {
            auto& slot = slots_[index];
            if (prev) {
                // note: `distance == 0` also includes empty slots.
                if (slot.distance == 0) {
                    // Clear the previous slot, and stop shuffling.
                    *prev = {};
                    return Action::kStop;
                } else {
                    // Shuffle the slot backwards.
                    prev->entry = std::move(slot.entry);
                    prev->hash = slot.hash;
                    prev->distance = slot.distance - 1;
                }
            }
            prev = &slot;
            return Action::kContinue;
        });

        // Entry was removed.
        count_--;
        generation_++;

        return true;
    }

    /// Checks whether an entry exists in the map
    /// @param key the key to search for.
    /// @returns true if the map contains an entry with the given value.
    bool Contains(const Key& key) const {
        const auto [found, _] = IndexOf(key);
        return found;
    }

    /// Pre-allocates memory so that the map can hold at least `capacity` entries.
    /// @param capacity the new capacity of the map.
    void Reserve(size_t capacity) {
        // Calculate the number of slots required to hold `capacity` entries.
        const size_t num_slots = std::max(NumSlots(capacity), kMinSlots);
        if (slots_.Length() >= num_slots) {
            // Already have enough slots.
            return;
        }

        // Move all the values out of the map and into a vector.
        Vector<Entry, N> entries;
        entries.Reserve(count_);
        for (auto& slot : slots_) {
            if (slot.entry.has_value()) {
                entries.Push(std::move(slot.entry.value()));
            }
        }

        // Clear the map, grow the number of slots.
        Clear();
        slots_.Resize(num_slots);

        // As the number of slots has grown, the slot indices will have changed from before, so
        // re-add all the entries back into the map.
        for (auto& entry : entries) {
            if constexpr (ValueIsVoid) {
                struct NoValue {};
                Put<PutMode::kAdd>(std::move(entry), NoValue{});
            } else {
                Put<PutMode::kAdd>(std::move(entry.key), std::move(entry.value));
            }
        }
    }

    /// @returns the number of entries in the map.
    size_t Count() const { return count_; }

    /// @returns true if the map contains no entries.
    bool IsEmpty() const { return count_ == 0; }

    /// @returns a monotonic counter which is incremented whenever the map is mutated.
    size_t Generation() const { return generation_; }

    /// @returns an iterator to the start of the map.
    Iterator begin() const { return Iterator{slots_.begin(), slots_.end()}; }

    /// @returns an iterator to the end of the map.
    Iterator end() const { return Iterator{slots_.end(), slots_.end()}; }

    /// A debug function for checking that the map is in good health.
    /// Asserts if the map is corrupted.
    void ValidateIntegrity() const {
        size_t num_alive = 0;
        for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) {
            const auto& slot = slots_[slot_idx];
            if (slot.entry.has_value()) {
                num_alive++;
                auto const [index, hash] = Hash(KeyOf(*slot.entry));
                TINT_ASSERT(Utils, hash == slot.hash);
                TINT_ASSERT(Utils, slot_idx == Wrap(index + slot.distance));
            }
        }
        TINT_ASSERT(Utils, num_alive == count_);
    }

  protected:
    /// The behaviour of Put() when an entry already exists with the given key.
    enum class PutMode {
        /// Do not replace existing entries with the new value.
        kAdd,
        /// Replace existing entries with the new value.
        kReplace,
    };

    /// Result of Put()
    struct PutResult {
        /// Whether the insert replaced or added a new entry to the map.
        MapAction action = MapAction::kAdded;
        /// A pointer to the inserted entry value.
        Value* value = nullptr;

        /// @returns true if the entry was added to the map, or an existing entry was replaced.
        operator bool() const { return action != MapAction::kKeptExisting; }
    };

    /// The common implementation for Add() and Replace()
    /// @param key the key of the entry to add to the map.
    /// @param value the value of the entry to add to the map.
    /// @returns A PutResult describing the result of the insertion
    template <PutMode MODE, typename K, typename V>
    PutResult Put(K&& key, V&& value) {
        // Ensure the map can fit a new entry
        if (ShouldRehash(count_ + 1)) {
            Reserve((count_ + 1) * 2);
        }

        const auto hash = Hash(key);

        auto make_entry = [&]() {
            if constexpr (ValueIsVoid) {
                return std::forward<K>(key);
            } else {
                return Entry{std::forward<K>(key), std::forward<V>(value)};
            }
        };

        PutResult result{};
        Scan(hash.scan_start, [&](size_t distance, size_t index) {
            auto& slot = slots_[index];
            if (!slot.entry.has_value()) {
                // Found an empty slot.
                // Place value directly into the slot, and we're done.
                slot.entry.emplace(make_entry());
                slot.hash = hash.code;
                slot.distance = distance;
                count_++;
                generation_++;
                result = PutResult{MapAction::kAdded, ValueOf(*slot.entry)};
                return Action::kStop;
            }

            // Slot has an entry

            if (slot.Equals(hash.code, key)) {
                // Slot is equal to value. Replace or preserve?
                if constexpr (MODE == PutMode::kReplace) {
                    slot.entry = make_entry();
                    generation_++;
                    result = PutResult{MapAction::kReplaced, ValueOf(*slot.entry)};
                } else {
                    result = PutResult{MapAction::kKeptExisting, ValueOf(*slot.entry)};
                }
                return Action::kStop;
            }

            if (slot.distance < distance) {
                // Existing slot has a closer distance than the value we're attempting to insert.
                // Steal from the rich!
                // Move the current slot to a temporary (evicted), and put the value into the slot.
                Slot evicted{make_entry(), hash.code, distance};
                std::swap(evicted, slot);

                // Find a new home for the evicted slot.
                evicted.distance++;  // We've already swapped at index.
                InsertShuffle(Wrap(index + 1), std::move(evicted));

                count_++;
                generation_++;
                result = PutResult{MapAction::kAdded, ValueOf(*slot.entry)};

                return Action::kStop;
            }
            return Action::kContinue;
        });

        return result;
    }

    /// Return type of the Scan() callback.
    enum class Action {
        /// Continue scanning for a slot
        kContinue,
        /// Immediately stop scanning for a slot
        kStop,
    };

    /// Sequentially visits each of the slots starting with the slot with the index @p start,
    /// calling the callback function @p f for each slot until @p f returns Action::kStop.
    /// @param start the index of the first slot to start scanning from.
    /// @param f the callback function which:
    /// * must be a function with the signature `Action(size_t distance, size_t index)`.
    /// * must return Action::kStop within one whole cycle of the slots.
    template <typename F>
    void Scan(size_t start, F&& f) const {
        size_t distance = 0;
        for (size_t index = start; index < slots_.Length(); index++) {
            if (f(distance, index) == Action::kStop) {
                return;
            }
            distance++;
        }
        for (size_t index = 0; index < start; index++) {
            if (f(distance, index) == Action::kStop) {
                return;
            }
            distance++;
        }
        tint::diag::List diags;
        TINT_ICE(Utils, diags) << "HashmapBase::Scan() looped entire map without finding a slot";
    }

    /// HashResult is the return value of Hash()
    struct HashResult {
        /// The target (zero-distance) slot index for the key.
        size_t scan_start;
        /// The calculated hash code of the key.
        size_t code;
    };

    /// @param key the key to hash
    /// @returns a tuple holding the target slot index for the given value, and the hash of the
    /// value, respectively.
    HashResult Hash(const Key& key) const {
        size_t hash = HASH()(key);
        size_t index = Wrap(hash);
        return {index, hash};
    }

    /// Looks for the key in the map.
    /// @param key the key to search for.
    /// @returns a tuple holding a boolean representing whether the key was found in the map, and
    /// if found, the index of the slot that holds the key.
    std::tuple<bool, size_t> IndexOf(const Key& key) const {
        const auto hash = Hash(key);

        bool found = false;
        size_t idx = 0;

        Scan(hash.scan_start, [&](size_t distance, size_t index) {
            auto& slot = slots_[index];
            if (!slot.entry.has_value()) {
                return Action::kStop;
            }
            if (slot.Equals(hash.code, key)) {
                found = true;
                idx = index;
                return Action::kStop;
            }
            if (slot.distance < distance) {
                // If the slot distance is less than the current probe distance, then the slot must
                // be for entry that has an index that comes after key. In this situation, we know
                // that the map does not contain the key, as it would have been found before this
                // slot. The "Lookup" section of https://programming.guide/robin-hood-hashing.html
                // suggests that the condition should inverted, but this is wrong.
                return Action::kStop;
            }
            return Action::kContinue;
        });

        return {found, idx};
    }

    /// Shuffles slots for an insertion that has been placed one slot before `start`.
    /// @param start the index of the first slot to start shuffling.
    /// @param evicted the slot content that was evicted for the insertion.
    void InsertShuffle(size_t start, Slot&& evicted) {
        Scan(start, [&](size_t, size_t index) {
            auto& slot = slots_[index];

            if (!slot.entry.has_value()) {
                // Empty slot found for evicted.
                slot = std::move(evicted);
                return Action::kStop;  //  We're done.
            }

            if (slot.distance < evicted.distance) {
                // Occupied slot has shorter distance to evicted.
                // Swap slot and evicted.
                std::swap(slot, evicted);
            }

            // evicted moves further from the target slot...
            evicted.distance++;

            return Action::kContinue;
        });
    }

    /// @param count the number of new entries in the map
    /// @returns true if the map should grow the slot vector, and rehash the items.
    bool ShouldRehash(size_t count) const { return NumSlots(count) > slots_.Length(); }

    /// @param index an input value
    /// @returns the input value modulo the number of slots.
    size_t Wrap(size_t index) const { return index % slots_.Length(); }

    /// The vector of slots. The vector length is equal to its capacity.
    Vector<Slot, kNumFixedSlots> slots_;

    /// The number of entries in the map.
    size_t count_ = 0;

    /// Counter that's incremented with each modification to the map.
    size_t generation_ = 0;
};

}  // namespace tint::utils

#endif  // SRC_TINT_UTILS_HASHMAP_BASE_H_
