utils: Add more functionality to EnumSet

Add equality operators for EnumSet <-> Enum.
Add unidirectional iterator.
Add ostream support.

Change-Id: I8ea9e905bf17e618c6b12004200d37f65ccfb68c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68402
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/utils/enum_set.h b/src/utils/enum_set.h
index 46d93e1..3a71f78 100644
--- a/src/utils/enum_set.h
+++ b/src/utils/enum_set.h
@@ -17,6 +17,7 @@
 
 #include <cstdint>
 #include <functional>
+#include <ostream>
 #include <type_traits>
 
 namespace tint {
@@ -62,16 +63,79 @@
   /// Equality operator
   /// @param rhs the other EnumSet to compare this to
   /// @return true if this EnumSet is equal to rhs
-  inline bool operator==(const EnumSet& rhs) const { return set == rhs.set; }
+  inline bool operator==(EnumSet rhs) const { return set == rhs.set; }
 
   /// Inequality operator
   /// @param rhs the other EnumSet to compare this to
   /// @return true if this EnumSet is not equal to rhs
-  inline bool operator!=(const EnumSet& rhs) const { return set != rhs.set; }
+  inline bool operator!=(EnumSet rhs) const { return set != rhs.set; }
+
+  /// Equality operator
+  /// @param rhs the enum to compare this to
+  /// @return true if this EnumSet only contains `rhs`
+  inline bool operator==(Enum rhs) const { return set == Bit(rhs); }
+
+  /// Inequality operator
+  /// @param rhs the enum to compare this to
+  /// @return false if this EnumSet only contains `rhs`
+  inline bool operator!=(Enum rhs) const { return set != Bit(rhs); }
 
   /// @return the underlying value for the EnumSet
   inline uint64_t Value() const { return set; }
 
+  /// Iterator provides read-only, unidirectional iterator over the enums of an
+  /// EnumSet.
+  class Iterator {
+    static constexpr int8_t kEnd = 63;
+
+    Iterator(uint64_t s, int8_t b) : set(s), pos(b) {}
+
+    /// Make the constructor accessible to the EnumSet.
+    friend struct EnumSet;
+
+   public:
+    /// @return the Enum value at this point in the iterator
+    Enum operator*() const { return static_cast<Enum>(pos); }
+
+    /// Increments the iterator
+    /// @returns this iterator
+    Iterator& operator++() {
+      while (pos < kEnd) {
+        pos++;
+        if (set & (static_cast<uint64_t>(1) << static_cast<uint64_t>(pos))) {
+          break;
+        }
+      }
+      return *this;
+    }
+
+    /// Equality operator
+    /// @param rhs the Iterator to compare this to
+    /// @return true if the two iterators are equal
+    bool operator==(const Iterator& rhs) const {
+      return set == rhs.set && pos == rhs.pos;
+    }
+
+    /// Inequality operator
+    /// @param rhs the Iterator to compare this to
+    /// @return true if the two iterators are different
+    bool operator!=(const Iterator& rhs) const { return !(*this == rhs); }
+
+   private:
+    const uint64_t set;
+    int8_t pos;
+  };
+
+  /// @returns an read-only iterator to the beginning of the set
+  Iterator begin() {
+    auto it = Iterator{set, -1};
+    ++it;  // Move to first set bit
+    return it;
+  }
+
+  /// @returns an iterator to the beginning of the set
+  Iterator end() { return Iterator{set, Iterator::kEnd}; }
+
  private:
   static constexpr uint64_t Bit(Enum value) {
     return static_cast<uint64_t>(1) << static_cast<uint64_t>(value);
@@ -87,6 +151,24 @@
   uint64_t set = 0;
 };
 
+/// Writes the EnumSet to the std::ostream.
+/// @param out the std::ostream to write to
+/// @param set the EnumSet to write
+/// @returns out so calls can be chained
+template <typename ENUM>
+inline std::ostream& operator<<(std::ostream& out, EnumSet<ENUM> set) {
+  out << "{";
+  bool first = true;
+  for (auto e : set) {
+    if (!first) {
+      out << ", ";
+    }
+    first = false;
+    out << e;
+  }
+  return out << "}";
+}
+
 }  // namespace utils
 }  // namespace tint
 
diff --git a/src/utils/enum_set_test.cc b/src/utils/enum_set_test.cc
index 80c3dcf..9a5186d 100644
--- a/src/utils/enum_set_test.cc
+++ b/src/utils/enum_set_test.cc
@@ -14,13 +14,30 @@
 
 #include "src/utils/enum_set.h"
 
-#include "gtest/gtest.h"
+#include <sstream>
+#include <vector>
+
+#include "gmock/gmock.h"
 
 namespace tint {
 namespace utils {
 namespace {
 
-enum class E { A, B, C };
+using ::testing::ElementsAre;
+
+enum class E { A = 0, B = 3, C = 7 };
+
+std::ostream& operator<<(std::ostream& out, E e) {
+  switch (e) {
+    case E::A:
+      return out << "A";
+    case E::B:
+      return out << "B";
+    case E::C:
+      return out << "C";
+  }
+  return out << "E(" << static_cast<uint32_t>(e) << ")";
+}
 
 TEST(EnumSetTest, ConstructEmpty) {
   EnumSet<E> set;
@@ -59,16 +76,34 @@
   EXPECT_FALSE(set.Contains(E::C));
 }
 
-TEST(EnumSetTest, Equality) {
+TEST(EnumSetTest, EqualitySet) {
   EXPECT_TRUE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::B));
   EXPECT_FALSE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::C));
 }
 
-TEST(EnumSetTest, Inequality) {
+TEST(EnumSetTest, InequalitySet) {
   EXPECT_FALSE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::B));
   EXPECT_TRUE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::C));
 }
 
+TEST(EnumSetTest, EqualityEnum) {
+  EXPECT_TRUE(EnumSet<E>(E::A) == E::A);
+  EXPECT_FALSE(EnumSet<E>(E::B) == E::A);
+  EXPECT_FALSE(EnumSet<E>(E::B) == E::C);
+  EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::A);
+  EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::B);
+  EXPECT_FALSE(EnumSet<E>(E::A, E::B) == E::C);
+}
+
+TEST(EnumSetTest, InequalityEnum) {
+  EXPECT_FALSE(EnumSet<E>(E::A) != E::A);
+  EXPECT_TRUE(EnumSet<E>(E::B) != E::A);
+  EXPECT_TRUE(EnumSet<E>(E::B) != E::C);
+  EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::A);
+  EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::B);
+  EXPECT_TRUE(EnumSet<E>(E::A, E::B) != E::C);
+}
+
 TEST(EnumSetTest, Hash) {
   auto hash = [&](EnumSet<E> s) { return std::hash<EnumSet<E>>()(s); };
   EXPECT_EQ(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::B)));
@@ -78,9 +113,44 @@
 TEST(EnumSetTest, Value) {
   EXPECT_EQ(EnumSet<E>().Value(), 0u);
   EXPECT_EQ(EnumSet<E>(E::A).Value(), 1u);
-  EXPECT_EQ(EnumSet<E>(E::B).Value(), 2u);
-  EXPECT_EQ(EnumSet<E>(E::C).Value(), 4u);
-  EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 5u);
+  EXPECT_EQ(EnumSet<E>(E::B).Value(), 8u);
+  EXPECT_EQ(EnumSet<E>(E::C).Value(), 128u);
+  EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 129u);
+}
+
+TEST(EnumSetTest, Iterator) {
+  auto set = EnumSet<E>(E::C, E::A);
+
+  auto it = set.begin();
+  EXPECT_EQ(*it, E::A);
+  EXPECT_NE(it, set.end());
+  ++it;
+  EXPECT_EQ(*it, E::C);
+  EXPECT_NE(it, set.end());
+  ++it;
+  EXPECT_EQ(it, set.end());
+}
+
+TEST(EnumSetTest, IteratorEmpty) {
+  auto set = EnumSet<E>();
+  EXPECT_EQ(set.begin(), set.end());
+}
+
+TEST(EnumSetTest, Loop) {
+  auto set = EnumSet<E>(E::C, E::A);
+
+  std::vector<E> seen;
+  for (auto e : set) {
+    seen.emplace_back(e);
+  }
+
+  EXPECT_THAT(seen, ElementsAre(E::A, E::C));
+}
+
+TEST(EnumSetTest, Ostream) {
+  std::stringstream ss;
+  ss << EnumSet<E>(E::A, E::C);
+  EXPECT_EQ(ss.str(), "{A, C}");
 }
 
 }  // namespace