Vulkan: Use ityp::bitset for Instance/DeviceExtSet

ityp::bitset allows the creation of a bitset indexed by enums. Use this
instead of our custom wrapper around bitset that only supports .Set and
.Has.

Bug: dawn:635
Change-Id: I6680feb9b1741648d974cf1cef48cb1863aa20af
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/38103
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Stephen White <senorblanco@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/vulkan/BackendVk.cpp b/src/dawn_native/vulkan/BackendVk.cpp
index d725d98..2543828 100644
--- a/src/dawn_native/vulkan/BackendVk.cpp
+++ b/src/dawn_native/vulkan/BackendVk.cpp
@@ -240,8 +240,8 @@
         usedKnobs.extensions = extensionsToRequest;
 
         std::vector<const char*> extensionNames;
-        for (uint32_t ext : IterateBitSet(extensionsToRequest.extensionBitSet)) {
-            const InstanceExtInfo& info = GetInstanceExtInfo(static_cast<InstanceExt>(ext));
+        for (InstanceExt ext : IterateBitSet(extensionsToRequest)) {
+            const InstanceExtInfo& info = GetInstanceExtInfo(ext);
 
             if (info.versionPromoted > mGlobalInfo.apiVersion) {
                 extensionNames.push_back(info.name);
diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp
index 4ad8210..8b79d3e 100644
--- a/src/dawn_native/vulkan/DeviceVk.cpp
+++ b/src/dawn_native/vulkan/DeviceVk.cpp
@@ -279,8 +279,8 @@
 
         // However only request the extensions that haven't been promoted in the device's apiVersion
         std::vector<const char*> extensionNames;
-        for (uint32_t ext : IterateBitSet(usedKnobs.extensions.extensionBitSet)) {
-            const DeviceExtInfo& info = GetDeviceExtInfo(static_cast<DeviceExt>(ext));
+        for (DeviceExt ext : IterateBitSet(usedKnobs.extensions)) {
+            const DeviceExtInfo& info = GetDeviceExtInfo(ext);
 
             if (info.versionPromoted > mDeviceInfo.properties.apiVersion) {
                 extensionNames.push_back(info.name);
diff --git a/src/dawn_native/vulkan/VulkanExtensions.cpp b/src/dawn_native/vulkan/VulkanExtensions.cpp
index e12de93..05a8d60 100644
--- a/src/dawn_native/vulkan/VulkanExtensions.cpp
+++ b/src/dawn_native/vulkan/VulkanExtensions.cpp
@@ -51,14 +51,6 @@
         //
     }};
 
-    void InstanceExtSet::Set(InstanceExt extension, bool enabled) {
-        extensionBitSet.set(static_cast<uint32_t>(extension), enabled);
-    }
-
-    bool InstanceExtSet::Has(InstanceExt extension) const {
-        return extensionBitSet[static_cast<uint32_t>(extension)];
-    }
-
     const InstanceExtInfo& GetInstanceExtInfo(InstanceExt ext) {
         uint32_t index = static_cast<uint32_t>(ext);
         ASSERT(index < sInstanceExtInfos.size());
@@ -84,8 +76,8 @@
         InstanceExtSet trimmedSet;
 
         auto HasDep = [&](InstanceExt ext) -> bool {
-            ASSERT(visitedSet.Has(ext));
-            return trimmedSet.Has(ext);
+            ASSERT(visitedSet[ext]);
+            return trimmedSet[ext];
         };
 
         for (uint32_t i = 0; i < sInstanceExtInfos.size(); i++) {
@@ -117,8 +109,8 @@
                     UNREACHABLE();
             }
 
-            trimmedSet.Set(ext, hasDependencies && advertisedExts.Has(ext));
-            visitedSet.Set(ext, true);
+            trimmedSet.set(ext, hasDependencies && advertisedExts[ext]);
+            visitedSet.set(ext, true);
         }
 
         return trimmedSet;
@@ -127,7 +119,7 @@
     void MarkPromotedExtensions(InstanceExtSet* extensions, uint32_t version) {
         for (const InstanceExtInfo& info : sInstanceExtInfos) {
             if (info.versionPromoted <= version) {
-                extensions->Set(info.index, true);
+                extensions->set(info.index, true);
             }
         }
     }
@@ -167,14 +159,6 @@
         //
     }};
 
-    void DeviceExtSet::Set(DeviceExt extension, bool enabled) {
-        extensionBitSet.set(static_cast<uint32_t>(extension), enabled);
-    }
-
-    bool DeviceExtSet::Has(DeviceExt extension) const {
-        return extensionBitSet[static_cast<uint32_t>(extension)];
-    }
-
     const DeviceExtInfo& GetDeviceExtInfo(DeviceExt ext) {
         uint32_t index = static_cast<uint32_t>(ext);
         ASSERT(index < sDeviceExtInfos.size());
@@ -199,8 +183,8 @@
         DeviceExtSet trimmedSet;
 
         auto HasDep = [&](DeviceExt ext) -> bool {
-            ASSERT(visitedSet.Has(ext));
-            return trimmedSet.Has(ext);
+            ASSERT(visitedSet[ext]);
+            return trimmedSet[ext];
         };
 
         for (uint32_t i = 0; i < sDeviceExtInfos.size(); i++) {
@@ -222,16 +206,15 @@
                 // advertises the extension. So if we didn't have this check, we'd risk a calling
                 // a nullptr.
                 case DeviceExt::GetPhysicalDeviceProperties2:
-                    hasDependencies = instanceExts.Has(InstanceExt::GetPhysicalDeviceProperties2);
+                    hasDependencies = instanceExts[InstanceExt::GetPhysicalDeviceProperties2];
                     break;
                 case DeviceExt::ExternalMemoryCapabilities:
-                    hasDependencies = instanceExts.Has(InstanceExt::ExternalMemoryCapabilities) &&
+                    hasDependencies = instanceExts[InstanceExt::ExternalMemoryCapabilities] &&
                                       HasDep(DeviceExt::GetPhysicalDeviceProperties2);
                     break;
                 case DeviceExt::ExternalSemaphoreCapabilities:
-                    hasDependencies =
-                        instanceExts.Has(InstanceExt::ExternalSemaphoreCapabilities) &&
-                        HasDep(DeviceExt::GetPhysicalDeviceProperties2);
+                    hasDependencies = instanceExts[InstanceExt::ExternalSemaphoreCapabilities] &&
+                                      HasDep(DeviceExt::GetPhysicalDeviceProperties2);
                     break;
 
                 case DeviceExt::ImageDrmFormatModifier:
@@ -242,7 +225,7 @@
                     break;
 
                 case DeviceExt::Swapchain:
-                    hasDependencies = instanceExts.Has(InstanceExt::Surface);
+                    hasDependencies = instanceExts[InstanceExt::Surface];
                     break;
 
                 case DeviceExt::SamplerYCbCrConversion:
@@ -295,8 +278,8 @@
                     UNREACHABLE();
             }
 
-            trimmedSet.Set(ext, hasDependencies && advertisedExts.Has(ext));
-            visitedSet.Set(ext, true);
+            trimmedSet.set(ext, hasDependencies && advertisedExts[ext]);
+            visitedSet.set(ext, true);
         }
 
         return trimmedSet;
@@ -305,7 +288,7 @@
     void MarkPromotedExtensions(DeviceExtSet* extensions, uint32_t version) {
         for (const DeviceExtInfo& info : sDeviceExtInfos) {
             if (info.versionPromoted <= version) {
-                extensions->Set(info.index, true);
+                extensions->set(info.index, true);
             }
         }
     }
diff --git a/src/dawn_native/vulkan/VulkanExtensions.h b/src/dawn_native/vulkan/VulkanExtensions.h
index 123c579..d3950f1 100644
--- a/src/dawn_native/vulkan/VulkanExtensions.h
+++ b/src/dawn_native/vulkan/VulkanExtensions.h
@@ -15,7 +15,8 @@
 #ifndef DAWNNATIVE_VULKAN_VULKANEXTENSIONS_H_
 #define DAWNNATIVE_VULKAN_VULKANEXTENSIONS_H_
 
-#include <bitset>
+#include "common/ityp_bitset.h"
+
 #include <unordered_map>
 
 namespace dawn_native { namespace vulkan {
@@ -43,12 +44,8 @@
         EnumCount,
     };
 
-    // A bitset wrapper that is indexed with InstanceExt.
-    struct InstanceExtSet {
-        std::bitset<static_cast<size_t>(InstanceExt::EnumCount)> extensionBitSet;
-        void Set(InstanceExt extension, bool enabled);
-        bool Has(InstanceExt extension) const;
-    };
+    // A bitset that is indexed with InstanceExt.
+    using InstanceExtSet = ityp::bitset<InstanceExt, static_cast<uint32_t>(InstanceExt::EnumCount)>;
 
     // Information about a known instance extension.
     struct InstanceExtInfo {
@@ -106,14 +103,10 @@
         EnumCount,
     };
 
-    // A bitset wrapper that is indexed with DeviceExt.
-    struct DeviceExtSet {
-        std::bitset<static_cast<size_t>(DeviceExt::EnumCount)> extensionBitSet;
-        void Set(DeviceExt extension, bool enabled);
-        bool Has(DeviceExt extension) const;
-    };
+    // A bitset that is indexed with DeviceExt.
+    using DeviceExtSet = ityp::bitset<DeviceExt, static_cast<uint32_t>(DeviceExt::EnumCount)>;
 
-    // A bitset wrapper that is indexed with DeviceExt.
+    // Information about a known device extension.
     struct DeviceExtInfo {
         DeviceExt index;
         const char* name;
diff --git a/src/dawn_native/vulkan/VulkanInfo.cpp b/src/dawn_native/vulkan/VulkanInfo.cpp
index 84f3edf..365ffa7 100644
--- a/src/dawn_native/vulkan/VulkanInfo.cpp
+++ b/src/dawn_native/vulkan/VulkanInfo.cpp
@@ -52,11 +52,11 @@
     const char kLayerNameFuchsiaImagePipeSwapchain[] = "VK_LAYER_FUCHSIA_imagepipe_swapchain";
 
     bool VulkanGlobalKnobs::HasExt(InstanceExt ext) const {
-        return extensions.Has(ext);
+        return extensions[ext];
     }
 
     bool VulkanDeviceKnobs::HasExt(DeviceExt ext) const {
-        return extensions.Has(ext);
+        return extensions[ext];
     }
 
     ResultOrError<VulkanGlobalInfo> GatherGlobalInfo(const Backend& backend) {
@@ -124,7 +124,7 @@
             for (const VkExtensionProperties& extension : extensionsProperties) {
                 auto it = knownExts.find(extension.extensionName);
                 if (it != knownExts.end()) {
-                    info.extensions.Set(it->second, true);
+                    info.extensions.set(it->second, true);
                 }
             }
 
@@ -141,7 +141,7 @@
                     auto it = knownExts.find(extension.extensionName);
                     if (it != knownExts.end() &&
                         it->second == InstanceExt::FuchsiaImagePipeSurface) {
-                        info.extensions.Set(InstanceExt::FuchsiaImagePipeSurface, true);
+                        info.extensions.set(InstanceExt::FuchsiaImagePipeSurface, true);
                     }
                 }
             }
@@ -240,7 +240,7 @@
             for (const VkExtensionProperties& extension : extensionsProperties) {
                 auto it = knownExts.find(extension.extensionName);
                 if (it != knownExts.end()) {
-                    info.extensions.Set(it->second, true);
+                    info.extensions.set(it->second, true);
                 }
             }
 
@@ -262,17 +262,17 @@
         properties2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
         PNextChainBuilder propertiesChain(&properties2);
 
-        if (info.extensions.Has(DeviceExt::ShaderFloat16Int8)) {
+        if (info.extensions[DeviceExt::ShaderFloat16Int8]) {
             featuresChain.Add(&info.shaderFloat16Int8Features,
                               VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR);
         }
 
-        if (info.extensions.Has(DeviceExt::_16BitStorage)) {
+        if (info.extensions[DeviceExt::_16BitStorage]) {
             featuresChain.Add(&info._16BitStorageFeatures,
                               VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES);
         }
 
-        if (info.extensions.Has(DeviceExt::SubgroupSizeControl)) {
+        if (info.extensions[DeviceExt::SubgroupSizeControl]) {
             featuresChain.Add(&info.subgroupSizeControlFeatures,
                               VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT);
             propertiesChain.Add(
@@ -280,7 +280,7 @@
                 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_PROPERTIES_EXT);
         }
 
-        if (info.extensions.Has(DeviceExt::DriverProperties)) {
+        if (info.extensions[DeviceExt::DriverProperties]) {
             propertiesChain.Add(&info.driverProperties,
                                 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES);
         }
@@ -291,7 +291,7 @@
         // Note that info.properties has already been filled at the start of this function to get
         // `apiVersion`.
         ASSERT(info.properties.apiVersion != 0);
-        if (info.extensions.Has(DeviceExt::GetPhysicalDeviceProperties2)) {
+        if (info.extensions[DeviceExt::GetPhysicalDeviceProperties2]) {
             vkFunctions.GetPhysicalDeviceProperties2(physicalDevice, &properties2);
             vkFunctions.GetPhysicalDeviceFeatures2(physicalDevice, &features2);
             info.features = features2.features;