utils: Add more functionality to EnumSet Add equality operators for EnumSet <-> Enum. Add unidirectional iterator. Add ostream support. Change-Id: I8ea9e905bf17e618c6b12004200d37f65ccfb68c Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68402 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/utils/enum_set.h b/src/utils/enum_set.h index 46d93e1..3a71f78 100644 --- a/src/utils/enum_set.h +++ b/src/utils/enum_set.h
@@ -17,6 +17,7 @@ #include <cstdint> #include <functional> +#include <ostream> #include <type_traits> namespace tint { @@ -62,16 +63,79 @@ /// Equality operator /// @param rhs the other EnumSet to compare this to /// @return true if this EnumSet is equal to rhs - inline bool operator==(const EnumSet& rhs) const { return set == rhs.set; } + inline bool operator==(EnumSet rhs) const { return set == rhs.set; } /// Inequality operator /// @param rhs the other EnumSet to compare this to /// @return true if this EnumSet is not equal to rhs - inline bool operator!=(const EnumSet& rhs) const { return set != rhs.set; } + inline bool operator!=(EnumSet rhs) const { return set != rhs.set; } + + /// Equality operator + /// @param rhs the enum to compare this to + /// @return true if this EnumSet only contains `rhs` + inline bool operator==(Enum rhs) const { return set == Bit(rhs); } + + /// Inequality operator + /// @param rhs the enum to compare this to + /// @return false if this EnumSet only contains `rhs` + inline bool operator!=(Enum rhs) const { return set != Bit(rhs); } /// @return the underlying value for the EnumSet inline uint64_t Value() const { return set; } + /// Iterator provides read-only, unidirectional iterator over the enums of an + /// EnumSet. + class Iterator { + static constexpr int8_t kEnd = 63; + + Iterator(uint64_t s, int8_t b) : set(s), pos(b) {} + + /// Make the constructor accessible to the EnumSet. + friend struct EnumSet; + + public: + /// @return the Enum value at this point in the iterator + Enum operator*() const { return static_cast<Enum>(pos); } + + /// Increments the iterator + /// @returns this iterator + Iterator& operator++() { + while (pos < kEnd) { + pos++; + if (set & (static_cast<uint64_t>(1) << static_cast<uint64_t>(pos))) { + break; + } + } + return *this; + } + + /// Equality operator + /// @param rhs the Iterator to compare this to + /// @return true if the two iterators are equal + bool operator==(const Iterator& rhs) const { + return set == rhs.set && pos == rhs.pos; + } + + /// Inequality operator + /// @param rhs the Iterator to compare this to + /// @return true if the two iterators are different + bool operator!=(const Iterator& rhs) const { return !(*this == rhs); } + + private: + const uint64_t set; + int8_t pos; + }; + + /// @returns an read-only iterator to the beginning of the set + Iterator begin() { + auto it = Iterator{set, -1}; + ++it; // Move to first set bit + return it; + } + + /// @returns an iterator to the beginning of the set + Iterator end() { return Iterator{set, Iterator::kEnd}; } + private: static constexpr uint64_t Bit(Enum value) { return static_cast<uint64_t>(1) << static_cast<uint64_t>(value); @@ -87,6 +151,24 @@ uint64_t set = 0; }; +/// Writes the EnumSet to the std::ostream. +/// @param out the std::ostream to write to +/// @param set the EnumSet to write +/// @returns out so calls can be chained +template <typename ENUM> +inline std::ostream& operator<<(std::ostream& out, EnumSet<ENUM> set) { + out << "{"; + bool first = true; + for (auto e : set) { + if (!first) { + out << ", "; + } + first = false; + out << e; + } + return out << "}"; +} + } // namespace utils } // namespace tint
diff --git a/src/utils/enum_set_test.cc b/src/utils/enum_set_test.cc index 80c3dcf..9a5186d 100644 --- a/src/utils/enum_set_test.cc +++ b/src/utils/enum_set_test.cc
@@ -14,13 +14,30 @@ #include "src/utils/enum_set.h" -#include "gtest/gtest.h" +#include <sstream> +#include <vector> + +#include "gmock/gmock.h" namespace tint { namespace utils { namespace { -enum class E { A, B, C }; +using ::testing::ElementsAre; + +enum class E { A = 0, B = 3, C = 7 }; + +std::ostream& operator<<(std::ostream& out, E e) { + switch (e) { + case E::A: + return out << "A"; + case E::B: + return out << "B"; + case E::C: + return out << "C"; + } + return out << "E(" << static_cast<uint32_t>(e) << ")"; +} TEST(EnumSetTest, ConstructEmpty) { EnumSet<E> set; @@ -59,16 +76,34 @@ EXPECT_FALSE(set.Contains(E::C)); } -TEST(EnumSetTest, Equality) { +TEST(EnumSetTest, EqualitySet) { EXPECT_TRUE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::B)); EXPECT_FALSE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::C)); } -TEST(EnumSetTest, Inequality) { +TEST(EnumSetTest, InequalitySet) { EXPECT_FALSE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::B)); EXPECT_TRUE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::C)); } +TEST(EnumSetTest, EqualityEnum) { + EXPECT_TRUE(EnumSet<E>(E::A) == E::A); + EXPECT_FALSE(EnumSet<E>(E::B) == E::A); + EXPECT_FALSE(EnumSet<E>(E::B) == E::C); + EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::A); + EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::B); + EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::C); +} + +TEST(EnumSetTest, InequalityEnum) { + EXPECT_FALSE(EnumSet<E>(E::A) != E::A); + EXPECT_TRUE(EnumSet<E>(E::B) != E::A); + EXPECT_TRUE(EnumSet<E>(E::B) != E::C); + EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::A); + EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::B); + EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::C); +} + TEST(EnumSetTest, Hash) { auto hash = [&](EnumSet<E> s) { return std::hash<EnumSet<E>>()(s); }; EXPECT_EQ(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::B))); @@ -78,9 +113,44 @@ TEST(EnumSetTest, Value) { EXPECT_EQ(EnumSet<E>().Value(), 0u); EXPECT_EQ(EnumSet<E>(E::A).Value(), 1u); - EXPECT_EQ(EnumSet<E>(E::B).Value(), 2u); - EXPECT_EQ(EnumSet<E>(E::C).Value(), 4u); - EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 5u); + EXPECT_EQ(EnumSet<E>(E::B).Value(), 8u); + EXPECT_EQ(EnumSet<E>(E::C).Value(), 128u); + EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 129u); +} + +TEST(EnumSetTest, Iterator) { + auto set = EnumSet<E>(E::C, E::A); + + auto it = set.begin(); + EXPECT_EQ(*it, E::A); + EXPECT_NE(it, set.end()); + ++it; + EXPECT_EQ(*it, E::C); + EXPECT_NE(it, set.end()); + ++it; + EXPECT_EQ(it, set.end()); +} + +TEST(EnumSetTest, IteratorEmpty) { + auto set = EnumSet<E>(); + EXPECT_EQ(set.begin(), set.end()); +} + +TEST(EnumSetTest, Loop) { + auto set = EnumSet<E>(E::C, E::A); + + std::vector<E> seen; + for (auto e : set) { + seen.emplace_back(e); + } + + EXPECT_THAT(seen, ElementsAre(E::A, E::C)); +} + +TEST(EnumSetTest, Ostream) { + std::stringstream ss; + ss << EnumSet<E>(E::A, E::C); + EXPECT_EQ(ss.str(), "{A, C}"); } } // namespace