utils: Add more functionality to EnumSet
Add equality operators for EnumSet <-> Enum.
Add unidirectional iterator.
Add ostream support.
Change-Id: I8ea9e905bf17e618c6b12004200d37f65ccfb68c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68402
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/utils/enum_set.h b/src/utils/enum_set.h
index 46d93e1..3a71f78 100644
--- a/src/utils/enum_set.h
+++ b/src/utils/enum_set.h
@@ -17,6 +17,7 @@
#include <cstdint>
#include <functional>
+#include <ostream>
#include <type_traits>
namespace tint {
@@ -62,16 +63,79 @@
/// Equality operator
/// @param rhs the other EnumSet to compare this to
/// @return true if this EnumSet is equal to rhs
- inline bool operator==(const EnumSet& rhs) const { return set == rhs.set; }
+ inline bool operator==(EnumSet rhs) const { return set == rhs.set; }
/// Inequality operator
/// @param rhs the other EnumSet to compare this to
/// @return true if this EnumSet is not equal to rhs
- inline bool operator!=(const EnumSet& rhs) const { return set != rhs.set; }
+ inline bool operator!=(EnumSet rhs) const { return set != rhs.set; }
+
+ /// Equality operator
+ /// @param rhs the enum to compare this to
+ /// @return true if this EnumSet only contains `rhs`
+ inline bool operator==(Enum rhs) const { return set == Bit(rhs); }
+
+ /// Inequality operator
+ /// @param rhs the enum to compare this to
+ /// @return false if this EnumSet only contains `rhs`
+ inline bool operator!=(Enum rhs) const { return set != Bit(rhs); }
/// @return the underlying value for the EnumSet
inline uint64_t Value() const { return set; }
+ /// Iterator provides read-only, unidirectional iterator over the enums of an
+ /// EnumSet.
+ class Iterator {
+ static constexpr int8_t kEnd = 63;
+
+ Iterator(uint64_t s, int8_t b) : set(s), pos(b) {}
+
+ /// Make the constructor accessible to the EnumSet.
+ friend struct EnumSet;
+
+ public:
+ /// @return the Enum value at this point in the iterator
+ Enum operator*() const { return static_cast<Enum>(pos); }
+
+ /// Increments the iterator
+ /// @returns this iterator
+ Iterator& operator++() {
+ while (pos < kEnd) {
+ pos++;
+ if (set & (static_cast<uint64_t>(1) << static_cast<uint64_t>(pos))) {
+ break;
+ }
+ }
+ return *this;
+ }
+
+ /// Equality operator
+ /// @param rhs the Iterator to compare this to
+ /// @return true if the two iterators are equal
+ bool operator==(const Iterator& rhs) const {
+ return set == rhs.set && pos == rhs.pos;
+ }
+
+ /// Inequality operator
+ /// @param rhs the Iterator to compare this to
+ /// @return true if the two iterators are different
+ bool operator!=(const Iterator& rhs) const { return !(*this == rhs); }
+
+ private:
+ const uint64_t set;
+ int8_t pos;
+ };
+
+ /// @returns an read-only iterator to the beginning of the set
+ Iterator begin() {
+ auto it = Iterator{set, -1};
+ ++it; // Move to first set bit
+ return it;
+ }
+
+ /// @returns an iterator to the beginning of the set
+ Iterator end() { return Iterator{set, Iterator::kEnd}; }
+
private:
static constexpr uint64_t Bit(Enum value) {
return static_cast<uint64_t>(1) << static_cast<uint64_t>(value);
@@ -87,6 +151,24 @@
uint64_t set = 0;
};
+/// Writes the EnumSet to the std::ostream.
+/// @param out the std::ostream to write to
+/// @param set the EnumSet to write
+/// @returns out so calls can be chained
+template <typename ENUM>
+inline std::ostream& operator<<(std::ostream& out, EnumSet<ENUM> set) {
+ out << "{";
+ bool first = true;
+ for (auto e : set) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+ out << e;
+ }
+ return out << "}";
+}
+
} // namespace utils
} // namespace tint
diff --git a/src/utils/enum_set_test.cc b/src/utils/enum_set_test.cc
index 80c3dcf..9a5186d 100644
--- a/src/utils/enum_set_test.cc
+++ b/src/utils/enum_set_test.cc
@@ -14,13 +14,30 @@
#include "src/utils/enum_set.h"
-#include "gtest/gtest.h"
+#include <sstream>
+#include <vector>
+
+#include "gmock/gmock.h"
namespace tint {
namespace utils {
namespace {
-enum class E { A, B, C };
+using ::testing::ElementsAre;
+
+enum class E { A = 0, B = 3, C = 7 };
+
+std::ostream& operator<<(std::ostream& out, E e) {
+ switch (e) {
+ case E::A:
+ return out << "A";
+ case E::B:
+ return out << "B";
+ case E::C:
+ return out << "C";
+ }
+ return out << "E(" << static_cast<uint32_t>(e) << ")";
+}
TEST(EnumSetTest, ConstructEmpty) {
EnumSet<E> set;
@@ -59,16 +76,34 @@
EXPECT_FALSE(set.Contains(E::C));
}
-TEST(EnumSetTest, Equality) {
+TEST(EnumSetTest, EqualitySet) {
EXPECT_TRUE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::B));
EXPECT_FALSE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::C));
}
-TEST(EnumSetTest, Inequality) {
+TEST(EnumSetTest, InequalitySet) {
EXPECT_FALSE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::B));
EXPECT_TRUE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::C));
}
+TEST(EnumSetTest, EqualityEnum) {
+ EXPECT_TRUE(EnumSet<E>(E::A) == E::A);
+ EXPECT_FALSE(EnumSet<E>(E::B) == E::A);
+ EXPECT_FALSE(EnumSet<E>(E::B) == E::C);
+ EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::A);
+ EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::B);
+ EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::C);
+}
+
+TEST(EnumSetTest, InequalityEnum) {
+ EXPECT_FALSE(EnumSet<E>(E::A) != E::A);
+ EXPECT_TRUE(EnumSet<E>(E::B) != E::A);
+ EXPECT_TRUE(EnumSet<E>(E::B) != E::C);
+ EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::A);
+ EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::B);
+ EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::C);
+}
+
TEST(EnumSetTest, Hash) {
auto hash = [&](EnumSet<E> s) { return std::hash<EnumSet<E>>()(s); };
EXPECT_EQ(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::B)));
@@ -78,9 +113,44 @@
TEST(EnumSetTest, Value) {
EXPECT_EQ(EnumSet<E>().Value(), 0u);
EXPECT_EQ(EnumSet<E>(E::A).Value(), 1u);
- EXPECT_EQ(EnumSet<E>(E::B).Value(), 2u);
- EXPECT_EQ(EnumSet<E>(E::C).Value(), 4u);
- EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 5u);
+ EXPECT_EQ(EnumSet<E>(E::B).Value(), 8u);
+ EXPECT_EQ(EnumSet<E>(E::C).Value(), 128u);
+ EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 129u);
+}
+
+TEST(EnumSetTest, Iterator) {
+ auto set = EnumSet<E>(E::C, E::A);
+
+ auto it = set.begin();
+ EXPECT_EQ(*it, E::A);
+ EXPECT_NE(it, set.end());
+ ++it;
+ EXPECT_EQ(*it, E::C);
+ EXPECT_NE(it, set.end());
+ ++it;
+ EXPECT_EQ(it, set.end());
+}
+
+TEST(EnumSetTest, IteratorEmpty) {
+ auto set = EnumSet<E>();
+ EXPECT_EQ(set.begin(), set.end());
+}
+
+TEST(EnumSetTest, Loop) {
+ auto set = EnumSet<E>(E::C, E::A);
+
+ std::vector<E> seen;
+ for (auto e : set) {
+ seen.emplace_back(e);
+ }
+
+ EXPECT_THAT(seen, ElementsAre(E::A, E::C));
+}
+
+TEST(EnumSetTest, Ostream) {
+ std::stringstream ss;
+ ss << EnumSet<E>(E::A, E::C);
+ EXPECT_EQ(ss.str(), "{A, C}");
}
} // namespace