[dawn] Make BitSetRangeIterator::Advance simpler to reason about

The calls to ScanForwardAndShiftBits with mBits inversion interspersed
made it difficult to reason about what's happening. The second call
to ScanForwardAndShiftBits before the last mBits inversion introduced
extra 1 bits on the left that had to be cleaned up.

Instead inline ScanForwardAndShiftBits and use countr_one instead of
shifting all of mBits. Finally guard the loop to run only when the
bitset has more than 64 bits.

Simplifies test code mildly.

Bug: None
Change-Id: I81ad70f799e287e980fd1e994b239ec2f1c31e5c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/244815
Reviewed-by: Shaobo Yan <shaoboyan@microsoft.com>
Reviewed-by: Geoff Lang <geofflang@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn/common/BitSetRangeIterator.h b/src/dawn/common/BitSetRangeIterator.h
index 85773c2..aecdb57 100644
--- a/src/dawn/common/BitSetRangeIterator.h
+++ b/src/dawn/common/BitSetRangeIterator.h
@@ -40,8 +40,6 @@
 namespace dawn {
 
 // Similar to BitSetIterator but returns ranges of consecutive bits as (offset, size) pairs
-// TODO(crbug.com/366291600): // Specialization  for bitset size fits in uint64_t to skip
-// loops for bits across words boundary.
 template <size_t N, typename T>
 class BitSetRangeIterator final {
   public:
@@ -51,32 +49,27 @@
 
     class Iterator final {
       public:
-        explicit Iterator(const std::bitset<N>& bits, uint32_t offset = 0, uint32_t size = 0);
+        constexpr explicit Iterator(const std::bitset<N>& bits,
+                                    uint32_t offset = 0,
+                                    uint32_t size = 0);
         Iterator& operator++();
 
         bool operator==(const Iterator& other) const = default;
 
         // Returns a pair of offset and size of the current range
-        std::pair<T, size_t> operator*() const {
-            using U = UnderlyingType<T>;
-            DAWN_ASSERT(static_cast<U>(mOffset) <= std::numeric_limits<U>::max());
-            DAWN_ASSERT(static_cast<size_t>(mSize) <= std::numeric_limits<size_t>::max());
-            return std::make_pair(static_cast<T>(static_cast<U>(mOffset)),
-                                  static_cast<size_t>(mSize));
-        }
+        std::pair<T, size_t> operator*() const;
 
       private:
         void Advance();
-        size_t ScanForwardAndShiftBits();
 
         static constexpr size_t kBitsPerWord = sizeof(uint64_t) * 8;
-        std::bitset<N> mBits;
-        uint32_t mOffset{0};
-        uint32_t mSize{0};
+        std::bitset<N> mBits;  // The original bitset shifted by mOffset + mSize.
+        uint32_t mOffset;
+        uint32_t mSize;
     };
 
     Iterator begin() const { return Iterator(mBits); }
-    Iterator end() const { return Iterator(std::bitset<N>(0), N, 0); }
+    constexpr Iterator end() const { return Iterator(std::bitset<N>(), N, 0); }
 
   private:
     const std::bitset<N> mBits;
@@ -96,10 +89,18 @@
 }
 
 template <size_t N, typename T>
-BitSetRangeIterator<N, T>::Iterator::Iterator(const std::bitset<N>& bits,
-                                              uint32_t offset,
-                                              uint32_t size)
+constexpr BitSetRangeIterator<N, T>::Iterator::Iterator(const std::bitset<N>& bits,
+                                                        uint32_t offset,
+                                                        uint32_t size)
     : mBits(bits), mOffset(offset), mSize(size) {
+    // If the full range is set, directly compute the range. This avoids checking for the full range
+    // in each call to Advance as that can only happen in the first iteration.
+    if (mBits.all()) {
+        mSize = N;
+        mBits.reset();
+        return;
+    }
+
     Advance();
 }
 
@@ -110,52 +111,60 @@
 }
 
 template <size_t N, typename T>
-size_t BitSetRangeIterator<N, T>::Iterator::ScanForwardAndShiftBits() {
-    if (mBits.none()) {
-        return N;  // Or some other indicator that there are no bits.
-    }
-
-    constexpr std::bitset<N> wordMask(std::numeric_limits<uint64_t>::max());
-    size_t offset = 0;
-    while ((mBits & wordMask).to_ullong() == 0) {
-        offset += kBitsPerWord;
-        mBits >>= kBitsPerWord;
-    }
-
-    size_t nextBit = static_cast<size_t>(
-        std::countr_zero(static_cast<uint64_t>((mBits & wordMask).to_ullong())));
-    mBits >>= nextBit;
-    return offset + nextBit;
+std::pair<T, size_t> BitSetRangeIterator<N, T>::Iterator::operator*() const {
+    using U = UnderlyingType<T>;
+    DAWN_ASSERT(static_cast<U>(mOffset) <= std::numeric_limits<U>::max());
+    DAWN_ASSERT(static_cast<size_t>(mSize) <= std::numeric_limits<size_t>::max());
+    return std::make_pair(static_cast<T>(static_cast<U>(mOffset)), static_cast<size_t>(mSize));
 }
 
 template <size_t N, typename T>
 void BitSetRangeIterator<N, T>::Iterator::Advance() {
+    constexpr std::bitset<N> kBlockMask(std::numeric_limits<uint64_t>::max());
+
     // Bits are currently shifted to mOffset + mSize.
     mOffset += mSize;
+    mSize = 0;
 
-    size_t rangeStart = ScanForwardAndShiftBits();
-    if (rangeStart == N) {
-        // Reached the end, no more ranges.
+    // There are no more 1s, so there are no more ranges.
+    if (mBits.none()) {
         mOffset = N;
-        mSize = 0;
         return;
     }
 
-    mOffset += rangeStart;
-    mBits = ~mBits;
+    // Look for the next 1, shifting mBits as we go and accounting for it in mOffset.
+    // The loop jumps in blocks of 64bit while the rest of the code dose the last sub-64bit count.
+    {
+        if constexpr (N > kBitsPerWord) {
+            while ((mBits & kBlockMask).to_ullong() == 0) {
+                mOffset += kBitsPerWord;
+                mBits >>= kBitsPerWord;
+            }
+        }
 
-    size_t rangeCount = ScanForwardAndShiftBits();
-    if (rangeCount == N) {
-        // All bits until the end of the set are set.
-        rangeCount = N - mOffset;
+        size_t nextBit = static_cast<size_t>(
+            std::countr_zero(static_cast<uint64_t>((mBits & kBlockMask).to_ullong())));
+        mOffset += nextBit;
+        mBits >>= nextBit;
     }
 
-    mSize = rangeCount;
-    mBits = ~mBits;
+    // Look for the next 0, shifting mBits as we go and accounting for it in mSize. There is a next
+    // zero bit because the case with all bits set to 1 is handled in the iterator constructor.
+    // The loop jumps in blocks of 64bit while the rest of the code dose the last sub-64bit count.
+    DAWN_ASSERT(!mBits.all());
+    {
+        if constexpr (N > kBitsPerWord) {
+            while ((mBits & kBlockMask).to_ullong() == kBlockMask) {
+                mSize += kBitsPerWord;
+                mBits >>= kBitsPerWord;
+            }
+        }
 
-    // Clear the bits for the current range.
-    mBits <<= rangeCount;
-    mBits >>= rangeCount;
+        size_t nextBit = static_cast<size_t>(
+            std::countr_one(static_cast<uint64_t>((mBits & kBlockMask).to_ullong())));
+        mSize += nextBit;
+        mBits >>= nextBit;
+    }
 }
 
 // Helper to avoid needing to specify the template parameter size
diff --git a/src/dawn/tests/unittests/BitSetRangeIteratorTests.cpp b/src/dawn/tests/unittests/BitSetRangeIteratorTests.cpp
index 1798b47..5a0068d 100644
--- a/src/dawn/tests/unittests/BitSetRangeIteratorTests.cpp
+++ b/src/dawn/tests/unittests/BitSetRangeIteratorTests.cpp
@@ -37,34 +37,36 @@
 
 class BitSetRangeIteratorTest : public testing::Test {
   protected:
+    struct Range {
+        uint32_t offset;
+        size_t size;
+    };
+
     template <size_t N>
-    void RunSingleBitSetRangeTests(uint32_t offset, uint32_t size) {
+    void RunBitSetRangeTests(std::vector<Range> ranges) {
         std::bitset<N> stateBits;
 
-        std::vector<std::pair<uint32_t, size_t>> expectedRanges;
-        expectedRanges.push_back({offset, size});
-
-        for (const auto& range : expectedRanges) {
-            for (uint32_t i = 0; i < range.second; ++i) {
-                stateBits.set(range.first + i);
+        for (const auto& range : ranges) {
+            for (uint32_t i = 0; i < range.size; ++i) {
+                stateBits.set(range.offset + i);
             }
         }
 
-        std::vector<std::pair<uint32_t, size_t>> foundRanges;
-        for (auto range : IterateBitSetRanges(stateBits)) {
-            foundRanges.push_back(range);
+        std::vector<Range> foundRanges;
+        for (auto [offset, size] : IterateBitSetRanges(stateBits)) {
+            foundRanges.push_back({offset, size});
         }
 
-        EXPECT_EQ(expectedRanges.size(), foundRanges.size());
-        for (size_t i = 0; i < expectedRanges.size(); ++i) {
-            EXPECT_EQ(expectedRanges[i].first, foundRanges[i].first);
-            EXPECT_EQ(expectedRanges[i].second, foundRanges[i].second);
+        EXPECT_EQ(ranges.size(), foundRanges.size());
+        for (size_t i = 0; i < ranges.size(); ++i) {
+            EXPECT_EQ(ranges[i].offset, foundRanges[i].offset);
+            EXPECT_EQ(ranges[i].size, foundRanges[i].size);
         }
     }
 
     template <size_t N>
     void RunSingleBitTests() {
-        RunSingleBitSetRangeTests<N>(N / 4u * 3u, 1);
+        RunBitSetRangeTests<N>({{N / 4u * 3u, 1}});
     }
 
     template <size_t N>
@@ -147,37 +149,25 @@
 // Test basic range iteration with single bits (each range has size 1)
 TEST_F(BitSetRangeIteratorTest, SingleBit) {
     // Smaller than 1 word
-    {
-        RunSingleBitTests<kBitsPerWord - 1>();
-    }
+    RunSingleBitTests<kBitsPerWord - 1>();
 
     // Equal to 1 word
-    {
-        RunSingleBitTests<kBitsPerWord>();
-    }
+    RunSingleBitTests<kBitsPerWord>();
 
     // Larger than 1 word
-    {
-        RunSingleBitTests<kBitsPerWord * 2 - 1>();
-    }
+    RunSingleBitTests<kBitsPerWord * 2 - 1>();
 }
 
 // Test ranges with consecutive bits
 TEST_F(BitSetRangeIteratorTest, ConsecutiveBitRanges) {
     // Smaller than 1 word
-    {
-        RunConsecutiveBitRangesTests<kBitsPerWord - 1>();
-    }
+    RunConsecutiveBitRangesTests<kBitsPerWord - 1>();
 
     // Equal to 1 word
-    {
-        RunConsecutiveBitRangesTests<kBitsPerWord>();
-    }
+    RunConsecutiveBitRangesTests<kBitsPerWord>();
 
     // Larger than 1 word
-    {
-        RunConsecutiveBitRangesTests<kBitsPerWord * 2 - 1>();
-    }
+    RunConsecutiveBitRangesTests<kBitsPerWord * 2 - 1>();
 }
 
 // Test an empty iterator
@@ -190,55 +180,35 @@
 // Test iterating a result of combining two bitsets
 TEST_F(BitSetRangeIteratorTest, NonLValueBitset) {
     // Smaller than 1 word
-    {
-        RunNonLValueBitset<kBitsPerWord - 1>();
-    }
+    RunNonLValueBitset<kBitsPerWord - 1>();
 
     // Equal to 1 word
-    {
-        RunNonLValueBitset<kBitsPerWord>();
-    }
+    RunNonLValueBitset<kBitsPerWord>();
 
     // Larger than 1 word
-    {
-        RunNonLValueBitset<kBitsPerWord * 2 - 1>();
-    }
+    RunNonLValueBitset<kBitsPerWord * 2 - 1>();
 }
 
 // Test ranges that cross word boundaries
 TEST_F(BitSetRangeIteratorTest, CrossWordBoundaryRanges) {
-    std::bitset<kBitsPerWord * 2> stateBits;
-    // Set a range that crosses bit word boundary
-    for (uint32_t i = kBitsPerWord - 2; i <= kBitsPerWord + 1; ++i) {
-        stateBits.set(i);
-    }
+    // One range that crosses the boundary.
+    // RunBitSetRangeTests<kBitsPerWord * 2>({{kBitsPerWord - 2, 4}});
 
-    std::vector<std::pair<uint32_t, size_t>> foundRanges;
-    for (auto range : IterateBitSetRanges(stateBits)) {
-        foundRanges.push_back(range);
-    }
-
-    EXPECT_EQ(1u, foundRanges.size());
-    EXPECT_EQ(kBitsPerWord - 2, foundRanges[0].first);
-    EXPECT_EQ(4u, foundRanges[0].second);
+    // One range that crosses the boundary then another one.
+    RunBitSetRangeTests<kBitsPerWord * 3>(
+        {{kBitsPerWord - 1, 2 + kBitsPerWord}, {kBitsPerWord * 2 + 2, 1}});
 }
 
 // Test ranges that start from first bit.
-TEST_F(BitSetRangeIteratorTest, SingleBitSetRange) {
+TEST_F(BitSetRangeIteratorTest, RangeWithZeroethBit) {
     // Smaller than 1 word
-    {
-        RunSingleBitSetRangeTests<kBitsPerWord - 1>(0, kBitsPerWord - 1);
-    }
+    RunBitSetRangeTests<kBitsPerWord - 1>({{0, kBitsPerWord - 1}});
 
     // Equal to 1 word
-    {
-        RunSingleBitSetRangeTests<kBitsPerWord>(0, kBitsPerWord - 1);
-    }
+    RunBitSetRangeTests<kBitsPerWord>({{0, kBitsPerWord - 1}});
 
     // Larger than 1 word
-    {
-        RunSingleBitSetRangeTests<kBitsPerWord * 2>(kBitsPerWord / 2, kBitsPerWord);
-    }
+    RunBitSetRangeTests<kBitsPerWord * 2>({{kBitsPerWord / 2, kBitsPerWord}});
 }
 
 }  // anonymous namespace