Typeify VertexBufferSlot and VertexAttributeLocation

Bug: dawn:442
Change-Id: Ic4c29eed51984d367dc7fd6055e33d26bfc7faed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28041
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/common/Constants.h b/src/common/Constants.h
index 5a7c1df..d47afc0 100644
--- a/src/common/Constants.h
+++ b/src/common/Constants.h
@@ -18,7 +18,7 @@
 #include <cstdint>
 
 static constexpr uint32_t kMaxBindGroups = 4u;
-static constexpr uint32_t kMaxVertexAttributes = 16u;
+static constexpr uint8_t kMaxVertexAttributes = 16u;
 // Vulkan has a standalone limit named maxVertexInputAttributeOffset (2047u at least) for vertex
 // attribute offset. The limit might be meaningless because Vulkan has another limit named
 // maxVertexInputBindingStride (2048u at least). We use maxVertexAttributeEnd (2048u) here to
@@ -26,7 +26,7 @@
 // (char). We may use maxVertexInputBindingStride (maxVertexBufferStride below) instead to replace
 // maxVertexAttributeEnd in future.
 static constexpr uint32_t kMaxVertexAttributeEnd = 2048u;
-static constexpr uint32_t kMaxVertexBuffers = 16u;
+static constexpr uint8_t kMaxVertexBuffers = 16u;
 static constexpr uint32_t kMaxVertexBufferStride = 2048u;
 static constexpr uint32_t kNumStages = 3;
 static constexpr uint8_t kMaxColorAttachments = 4u;
diff --git a/src/common/TypedInteger.h b/src/common/TypedInteger.h
index 5474d9a..fd45895 100644
--- a/src/common/TypedInteger.h
+++ b/src/common/TypedInteger.h
@@ -16,6 +16,7 @@
 #define COMMON_TYPEDINTEGER_H_
 
 #include "common/Assert.h"
+#include "common/UnderlyingType.h"
 
 #include <limits>
 #include <type_traits>
@@ -131,6 +132,63 @@
         }
 
         template <typename T2 = T>
+        static constexpr std::enable_if_t<std::is_unsigned<T2>::value, decltype(T(0) + T2(0))>
+        AddImpl(TypedIntegerImpl<Tag, T> lhs, TypedIntegerImpl<Tag, T2> rhs) {
+            static_assert(std::is_same<T, T2>::value, "");
+
+            // Overflow would wrap around
+            ASSERT(lhs.mValue + rhs.mValue >= lhs.mValue);
+            return lhs.mValue + rhs.mValue;
+        }
+
+        template <typename T2 = T>
+        static constexpr std::enable_if_t<std::is_signed<T2>::value, decltype(T(0) + T2(0))>
+        AddImpl(TypedIntegerImpl<Tag, T> lhs, TypedIntegerImpl<Tag, T2> rhs) {
+            static_assert(std::is_same<T, T2>::value, "");
+
+            if (lhs.mValue > 0) {
+                // rhs is positive: |rhs| is at most the distance between max and |lhs|.
+                // rhs is negative: (positive + negative) won't overflow
+                ASSERT(rhs.mValue <= std::numeric_limits<T>::max() - lhs.mValue);
+            } else {
+                // rhs is postive: (negative + positive) won't underflow
+                // rhs is negative: |rhs| isn't less than the (negative) distance between min
+                // and |lhs|
+                ASSERT(rhs.mValue >= std::numeric_limits<T>::min() - lhs.mValue);
+            }
+            return lhs.mValue + rhs.mValue;
+        }
+
+        template <typename T2 = T>
+        static constexpr std::enable_if_t<std::is_unsigned<T>::value, decltype(T(0) - T2(0))>
+        SubImpl(TypedIntegerImpl<Tag, T> lhs, TypedIntegerImpl<Tag, T2> rhs) {
+            static_assert(std::is_same<T, T2>::value, "");
+
+            // Overflow would wrap around
+            ASSERT(lhs.mValue - rhs.mValue <= lhs.mValue);
+            return lhs.mValue - rhs.mValue;
+        }
+
+        template <typename T2 = T>
+        static constexpr std::enable_if_t<std::is_signed<T>::value, decltype(T(0) - T2(0))> SubImpl(
+            TypedIntegerImpl<Tag, T> lhs,
+            TypedIntegerImpl<Tag, T2> rhs) {
+            static_assert(std::is_same<T, T2>::value, "");
+
+            if (lhs.mValue > 0) {
+                // rhs is positive: positive minus positive won't overflow
+                // rhs is negative: |rhs| isn't less than the (negative) distance between |lhs|
+                // and max.
+                ASSERT(rhs.mValue >= lhs.mValue - std::numeric_limits<T>::max());
+            } else {
+                // rhs is positive: |rhs| is at most the distance between min and |lhs|
+                // rhs is negative: negative minus negative won't overflow
+                ASSERT(rhs.mValue <= lhs.mValue - std::numeric_limits<T>::min());
+            }
+            return lhs.mValue - rhs.mValue;
+        }
+
+        template <typename T2 = T>
         constexpr std::enable_if_t<std::is_signed<T2>::value, TypedIntegerImpl> operator-() const {
             static_assert(std::is_same<T, T2>::value, "");
             // The negation of the most negative value cannot be represented.
@@ -138,57 +196,16 @@
             return TypedIntegerImpl(-this->mValue);
         }
 
-        template <typename T2 = T>
-        constexpr std::enable_if_t<std::is_unsigned<T2>::value, TypedIntegerImpl> operator+(
-            TypedIntegerImpl rhs) const {
-            static_assert(std::is_same<T, T2>::value, "");
-            // Overflow would wrap around
-            ASSERT(this->mValue + rhs.mValue >= this->mValue);
-
-            return TypedIntegerImpl(this->mValue + rhs.mValue);
+        constexpr TypedIntegerImpl operator+(TypedIntegerImpl rhs) const {
+            auto result = AddImpl(*this, rhs);
+            static_assert(std::is_same<T, decltype(result)>::value, "Use ityp::Add instead.");
+            return TypedIntegerImpl(result);
         }
 
-        template <typename T2 = T>
-        constexpr std::enable_if_t<std::is_unsigned<T2>::value, TypedIntegerImpl> operator-(
-            TypedIntegerImpl rhs) const {
-            static_assert(std::is_same<T, T2>::value, "");
-            // Overflow would wrap around
-            ASSERT(this->mValue - rhs.mValue <= this->mValue);
-            return TypedIntegerImpl(this->mValue - rhs.mValue);
-        }
-
-        template <typename T2 = T>
-        constexpr std::enable_if_t<std::is_signed<T2>::value, TypedIntegerImpl> operator+(
-            TypedIntegerImpl rhs) const {
-            static_assert(std::is_same<T, T2>::value, "");
-            if (this->mValue > 0) {
-                // rhs is positive: |rhs| is at most the distance between max and |this|.
-                // rhs is negative: (positive + negative) won't overflow
-                ASSERT(rhs.mValue <= std::numeric_limits<T>::max() - this->mValue);
-            } else {
-                // rhs is postive: (negative + positive) won't underflow
-                // rhs is negative: |rhs| isn't less than the (negative) distance between min
-                // and |this|
-                ASSERT(rhs.mValue >= std::numeric_limits<T>::min() - this->mValue);
-            }
-            return TypedIntegerImpl(this->mValue + rhs.mValue);
-        }
-
-        template <typename T2 = T>
-        constexpr std::enable_if_t<std::is_signed<T2>::value, TypedIntegerImpl> operator-(
-            TypedIntegerImpl rhs) const {
-            static_assert(std::is_same<T, T2>::value, "");
-            if (this->mValue > 0) {
-                // rhs is positive: positive minus positive won't overflow
-                // rhs is negative: |rhs| isn't less than the (negative) distance between |this|
-                // and max.
-                ASSERT(rhs.mValue >= this->mValue - std::numeric_limits<T>::max());
-            } else {
-                // rhs is positive: |rhs| is at most the distance between min and |this|
-                // rhs is negative: negative minus negative won't overflow
-                ASSERT(rhs.mValue <= this->mValue - std::numeric_limits<T>::min());
-            }
-            return TypedIntegerImpl(this->mValue - rhs.mValue);
+        constexpr TypedIntegerImpl operator-(TypedIntegerImpl rhs) const {
+            auto result = SubImpl(*this, rhs);
+            static_assert(std::is_same<T, decltype(result)>::value, "Use ityp::Sub instead.");
+            return TypedIntegerImpl(result);
         }
     };
 
@@ -209,4 +226,37 @@
 
 }  // namespace std
 
+namespace ityp {
+
+    // These helpers below are provided since the default arithmetic operators for small integer
+    // types like uint8_t and uint16_t return integers, not their same type. To avoid lots of
+    // casting or conditional code between Release/Debug. Callsites should use ityp::Add(a, b) and
+    // ityp::Sub(a, b) instead.
+
+    template <typename Tag, typename T>
+    constexpr ::detail::TypedIntegerImpl<Tag, T> Add(::detail::TypedIntegerImpl<Tag, T> lhs,
+                                                     ::detail::TypedIntegerImpl<Tag, T> rhs) {
+        return ::detail::TypedIntegerImpl<Tag, T>(
+            static_cast<T>(::detail::TypedIntegerImpl<Tag, T>::AddImpl(lhs, rhs)));
+    }
+
+    template <typename Tag, typename T>
+    constexpr ::detail::TypedIntegerImpl<Tag, T> Sub(::detail::TypedIntegerImpl<Tag, T> lhs,
+                                                     ::detail::TypedIntegerImpl<Tag, T> rhs) {
+        return ::detail::TypedIntegerImpl<Tag, T>(
+            static_cast<T>(::detail::TypedIntegerImpl<Tag, T>::SubImpl(lhs, rhs)));
+    }
+
+    template <typename T>
+    constexpr std::enable_if_t<std::is_integral<T>::value, T> Add(T lhs, T rhs) {
+        return static_cast<T>(lhs + rhs);
+    }
+
+    template <typename T>
+    constexpr std::enable_if_t<std::is_integral<T>::value, T> Sub(T lhs, T rhs) {
+        return static_cast<T>(lhs - rhs);
+    }
+
+}  // namespace ityp
+
 #endif  // COMMON_TYPEDINTEGER_H_
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 6372a6d..4e5b54f 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -119,7 +119,7 @@
         if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
             ASSERT(mLastRenderPipeline != nullptr);
 
-            const std::bitset<kMaxVertexBuffers>& requiredVertexBuffers =
+            const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& requiredVertexBuffers =
                 mLastRenderPipeline->GetVertexBufferSlotsUsed();
             if ((mVertexBufferSlotsUsed & requiredVertexBuffers) == requiredVertexBuffers) {
                 mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
@@ -230,7 +230,7 @@
         mIndexFormat = format;
     }
 
-    void CommandBufferStateTracker::SetVertexBuffer(uint32_t slot) {
+    void CommandBufferStateTracker::SetVertexBuffer(VertexBufferSlot slot) {
         mVertexBufferSlotsUsed.set(slot);
     }
 
diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h
index 146214d..ac02330 100644
--- a/src/dawn_native/CommandBufferStateTracker.h
+++ b/src/dawn_native/CommandBufferStateTracker.h
@@ -17,11 +17,11 @@
 
 #include "common/Constants.h"
 #include "common/ityp_array.h"
+#include "common/ityp_bitset.h"
 #include "dawn_native/BindingInfo.h"
 #include "dawn_native/Error.h"
 #include "dawn_native/Forward.h"
 
-#include <bitset>
 #include <map>
 #include <set>
 
@@ -39,7 +39,7 @@
         void SetRenderPipeline(RenderPipelineBase* pipeline);
         void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup);
         void SetIndexBuffer(wgpu::IndexFormat format);
-        void SetVertexBuffer(uint32_t slot);
+        void SetVertexBuffer(VertexBufferSlot slot);
 
         static constexpr size_t kNumAspects = 4;
         using ValidationAspects = std::bitset<kNumAspects>;
@@ -54,7 +54,7 @@
         ValidationAspects mAspects;
 
         ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {};
-        std::bitset<kMaxVertexBuffers> mVertexBufferSlotsUsed;
+        ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed;
         bool mIndexBufferSet = false;
         wgpu::IndexFormat mIndexFormat;
 
diff --git a/src/dawn_native/Commands.h b/src/dawn_native/Commands.h
index 3d34693..d78ff13 100644
--- a/src/dawn_native/Commands.h
+++ b/src/dawn_native/Commands.h
@@ -235,7 +235,7 @@
     };
 
     struct SetVertexBufferCmd {
-        uint32_t slot;
+        VertexBufferSlot slot;
         Ref<BufferBase> buffer;
         uint64_t offset;
         uint64_t size;
diff --git a/src/dawn_native/IntegerTypes.h b/src/dawn_native/IntegerTypes.h
index 689cf8e..bd31c79 100644
--- a/src/dawn_native/IntegerTypes.h
+++ b/src/dawn_native/IntegerTypes.h
@@ -36,6 +36,13 @@
     constexpr ColorAttachmentIndex kMaxColorAttachmentsTyped =
         ColorAttachmentIndex(kMaxColorAttachments);
 
+    using VertexBufferSlot = TypedInteger<struct VertexBufferSlotT, uint8_t>;
+    using VertexAttributeLocation = TypedInteger<struct VertexAttributeLocationT, uint8_t>;
+
+    constexpr VertexBufferSlot kMaxVertexBuffersTyped = VertexBufferSlot(kMaxVertexBuffers);
+    constexpr VertexAttributeLocation kMaxVertexAttributesTyped =
+        VertexAttributeLocation(kMaxVertexAttributes);
+
 }  // namespace dawn_native
 
 #endif  // DAWNNATIVE_INTEGERTYPES_H_
diff --git a/src/dawn_native/RenderEncoderBase.cpp b/src/dawn_native/RenderEncoderBase.cpp
index b25e6cc..1c2cb62 100644
--- a/src/dawn_native/RenderEncoderBase.cpp
+++ b/src/dawn_native/RenderEncoderBase.cpp
@@ -215,7 +215,7 @@
 
             SetVertexBufferCmd* cmd =
                 allocator->Allocate<SetVertexBufferCmd>(Command::SetVertexBuffer);
-            cmd->slot = slot;
+            cmd->slot = VertexBufferSlot(static_cast<uint8_t>(slot));
             cmd->buffer = buffer;
             cmd->offset = offset;
             cmd->size = size;
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index cbe99d0..7144f33 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -412,21 +412,24 @@
             mVertexState = VertexStateDescriptor();
         }
 
-        for (uint32_t slot = 0; slot < mVertexState.vertexBufferCount; ++slot) {
+        for (uint8_t slot = 0; slot < mVertexState.vertexBufferCount; ++slot) {
             if (mVertexState.vertexBuffers[slot].attributeCount == 0) {
                 continue;
             }
 
-            mVertexBufferSlotsUsed.set(slot);
-            mVertexBufferInfos[slot].arrayStride = mVertexState.vertexBuffers[slot].arrayStride;
-            mVertexBufferInfos[slot].stepMode = mVertexState.vertexBuffers[slot].stepMode;
+            VertexBufferSlot typedSlot(slot);
 
-            uint32_t location = 0;
+            mVertexBufferSlotsUsed.set(typedSlot);
+            mVertexBufferInfos[typedSlot].arrayStride =
+                mVertexState.vertexBuffers[slot].arrayStride;
+            mVertexBufferInfos[typedSlot].stepMode = mVertexState.vertexBuffers[slot].stepMode;
+
             for (uint32_t i = 0; i < mVertexState.vertexBuffers[slot].attributeCount; ++i) {
-                location = mVertexState.vertexBuffers[slot].attributes[i].shaderLocation;
+                VertexAttributeLocation location = VertexAttributeLocation(static_cast<uint8_t>(
+                    mVertexState.vertexBuffers[slot].attributes[i].shaderLocation));
                 mAttributeLocationsUsed.set(location);
                 mAttributeInfos[location].shaderLocation = location;
-                mAttributeInfos[location].vertexBufferSlot = slot;
+                mAttributeInfos[location].vertexBufferSlot = typedSlot;
                 mAttributeInfos[location].offset =
                     mVertexState.vertexBuffers[slot].attributes[i].offset;
                 mAttributeInfos[location].format =
@@ -489,23 +492,26 @@
         return &mVertexState;
     }
 
-    const std::bitset<kMaxVertexAttributes>& RenderPipelineBase::GetAttributeLocationsUsed() const {
+    const ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>&
+    RenderPipelineBase::GetAttributeLocationsUsed() const {
         ASSERT(!IsError());
         return mAttributeLocationsUsed;
     }
 
-    const VertexAttributeInfo& RenderPipelineBase::GetAttribute(uint32_t location) const {
+    const VertexAttributeInfo& RenderPipelineBase::GetAttribute(
+        VertexAttributeLocation location) const {
         ASSERT(!IsError());
         ASSERT(mAttributeLocationsUsed[location]);
         return mAttributeInfos[location];
     }
 
-    const std::bitset<kMaxVertexBuffers>& RenderPipelineBase::GetVertexBufferSlotsUsed() const {
+    const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
+    RenderPipelineBase::GetVertexBufferSlotsUsed() const {
         ASSERT(!IsError());
         return mVertexBufferSlotsUsed;
     }
 
-    const VertexBufferInfo& RenderPipelineBase::GetVertexBuffer(uint32_t slot) const {
+    const VertexBufferInfo& RenderPipelineBase::GetVertexBuffer(VertexBufferSlot slot) const {
         ASSERT(!IsError());
         ASSERT(mVertexBufferSlotsUsed[slot]);
         return mVertexBufferInfos[slot];
@@ -582,12 +588,6 @@
         return mAttachmentState.Get();
     }
 
-    std::bitset<kMaxVertexAttributes> RenderPipelineBase::GetAttributesUsingVertexBuffer(
-        uint32_t slot) const {
-        ASSERT(!IsError());
-        return attributesUsingVertexBuffer[slot];
-    }
-
     size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const {
         // Hash modules and layout
         size_t hash = PipelineBase::HashForCache(pipeline);
@@ -619,15 +619,15 @@
 
         // Hash vertex state
         HashCombine(&hash, pipeline->mAttributeLocationsUsed);
-        for (uint32_t i : IterateBitSet(pipeline->mAttributeLocationsUsed)) {
-            const VertexAttributeInfo& desc = pipeline->GetAttribute(i);
+        for (VertexAttributeLocation location : IterateBitSet(pipeline->mAttributeLocationsUsed)) {
+            const VertexAttributeInfo& desc = pipeline->GetAttribute(location);
             HashCombine(&hash, desc.shaderLocation, desc.vertexBufferSlot, desc.offset,
                         desc.format);
         }
 
         HashCombine(&hash, pipeline->mVertexBufferSlotsUsed);
-        for (uint32_t i : IterateBitSet(pipeline->mVertexBufferSlotsUsed)) {
-            const VertexBufferInfo& desc = pipeline->GetVertexBuffer(i);
+        for (VertexBufferSlot slot : IterateBitSet(pipeline->mVertexBufferSlotsUsed)) {
+            const VertexBufferInfo& desc = pipeline->GetVertexBuffer(slot);
             HashCombine(&hash, desc.arrayStride, desc.stepMode);
         }
 
@@ -709,9 +709,9 @@
             return false;
         }
 
-        for (uint32_t i : IterateBitSet(a->mAttributeLocationsUsed)) {
-            const VertexAttributeInfo& descA = a->GetAttribute(i);
-            const VertexAttributeInfo& descB = b->GetAttribute(i);
+        for (VertexAttributeLocation loc : IterateBitSet(a->mAttributeLocationsUsed)) {
+            const VertexAttributeInfo& descA = a->GetAttribute(loc);
+            const VertexAttributeInfo& descB = b->GetAttribute(loc);
             if (descA.shaderLocation != descB.shaderLocation ||
                 descA.vertexBufferSlot != descB.vertexBufferSlot || descA.offset != descB.offset ||
                 descA.format != descB.format) {
@@ -723,9 +723,9 @@
             return false;
         }
 
-        for (uint32_t i : IterateBitSet(a->mVertexBufferSlotsUsed)) {
-            const VertexBufferInfo& descA = a->GetVertexBuffer(i);
-            const VertexBufferInfo& descB = b->GetVertexBuffer(i);
+        for (VertexBufferSlot slot : IterateBitSet(a->mVertexBufferSlotsUsed)) {
+            const VertexBufferInfo& descA = a->GetVertexBuffer(slot);
+            const VertexBufferInfo& descB = b->GetVertexBuffer(slot);
             if (descA.arrayStride != descB.arrayStride || descA.stepMode != descB.stepMode) {
                 return false;
             }
diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h
index a5d9cca..817a3e5 100644
--- a/src/dawn_native/RenderPipeline.h
+++ b/src/dawn_native/RenderPipeline.h
@@ -15,7 +15,9 @@
 #ifndef DAWNNATIVE_RENDERPIPELINE_H_
 #define DAWNNATIVE_RENDERPIPELINE_H_
 
+#include "common/TypedInteger.h"
 #include "dawn_native/AttachmentState.h"
+#include "dawn_native/IntegerTypes.h"
 #include "dawn_native/Pipeline.h"
 
 #include "dawn_native/dawn_platform.h"
@@ -45,8 +47,8 @@
     struct VertexAttributeInfo {
         wgpu::VertexFormat format;
         uint64_t offset;
-        uint32_t shaderLocation;
-        uint32_t vertexBufferSlot;
+        VertexAttributeLocation shaderLocation;
+        VertexBufferSlot vertexBufferSlot;
     };
 
     struct VertexBufferInfo {
@@ -62,10 +64,11 @@
         static RenderPipelineBase* MakeError(DeviceBase* device);
 
         const VertexStateDescriptor* GetVertexStateDescriptor() const;
-        const std::bitset<kMaxVertexAttributes>& GetAttributeLocationsUsed() const;
-        const VertexAttributeInfo& GetAttribute(uint32_t location) const;
-        const std::bitset<kMaxVertexBuffers>& GetVertexBufferSlotsUsed() const;
-        const VertexBufferInfo& GetVertexBuffer(uint32_t slot) const;
+        const ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>&
+        GetAttributeLocationsUsed() const;
+        const VertexAttributeInfo& GetAttribute(VertexAttributeLocation location) const;
+        const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& GetVertexBufferSlotsUsed() const;
+        const VertexBufferInfo& GetVertexBuffer(VertexBufferSlot slot) const;
 
         const ColorStateDescriptor* GetColorStateDescriptor(
             ColorAttachmentIndex attachmentSlot) const;
@@ -84,10 +87,6 @@
 
         const AttachmentState* GetAttachmentState() const;
 
-        std::bitset<kMaxVertexAttributes> GetAttributesUsingVertexBuffer(uint32_t slot) const;
-        std::array<std::bitset<kMaxVertexAttributes>, kMaxVertexBuffers>
-            attributesUsingVertexBuffer;
-
         // Functors necessary for the unordered_set<RenderPipelineBase*>-based cache.
         struct HashFunc {
             size_t operator()(const RenderPipelineBase* pipeline) const;
@@ -101,10 +100,11 @@
 
         // Vertex state
         VertexStateDescriptor mVertexState;
-        std::bitset<kMaxVertexAttributes> mAttributeLocationsUsed;
-        std::array<VertexAttributeInfo, kMaxVertexAttributes> mAttributeInfos;
-        std::bitset<kMaxVertexBuffers> mVertexBufferSlotsUsed;
-        std::array<VertexBufferInfo, kMaxVertexBuffers> mVertexBufferInfos;
+        ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> mAttributeLocationsUsed;
+        ityp::array<VertexAttributeLocation, VertexAttributeInfo, kMaxVertexAttributes>
+            mAttributeInfos;
+        ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed;
+        ityp::array<VertexBufferSlot, VertexBufferInfo, kMaxVertexBuffers> mVertexBufferInfos;
 
         // Attachments
         Ref<AttachmentState> mAttachmentState;
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index fb7e08e..90999c1 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -452,9 +452,12 @@
     namespace {
         class VertexBufferTracker {
           public:
-            void OnSetVertexBuffer(uint32_t slot, Buffer* buffer, uint64_t offset, uint64_t size) {
+            void OnSetVertexBuffer(VertexBufferSlot slot,
+                                   Buffer* buffer,
+                                   uint64_t offset,
+                                   uint64_t size) {
                 mStartSlot = std::min(mStartSlot, slot);
-                mEndSlot = std::max(mEndSlot, slot + 1);
+                mEndSlot = std::max(mEndSlot, ityp::Add(slot, VertexBufferSlot(uint8_t(1))));
 
                 auto* d3d12BufferView = &mD3D12BufferViews[slot];
                 d3d12BufferView->BufferLocation = buffer->GetVA() + offset;
@@ -466,11 +469,8 @@
                        const RenderPipeline* renderPipeline) {
                 ASSERT(renderPipeline != nullptr);
 
-                std::bitset<kMaxVertexBuffers> vertexBufferSlotsUsed =
-                    renderPipeline->GetVertexBufferSlotsUsed();
-
-                uint32_t startSlot = mStartSlot;
-                uint32_t endSlot = mEndSlot;
+                VertexBufferSlot startSlot = mStartSlot;
+                VertexBufferSlot endSlot = mEndSlot;
 
                 // If the vertex state has changed, we need to update the StrideInBytes
                 // for the D3D12 buffer views. We also need to extend the dirty range to
@@ -478,9 +478,11 @@
                 if (mLastAppliedRenderPipeline != renderPipeline) {
                     mLastAppliedRenderPipeline = renderPipeline;
 
-                    for (uint32_t slot : IterateBitSet(vertexBufferSlotsUsed)) {
+                    for (VertexBufferSlot slot :
+                         IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
                         startSlot = std::min(startSlot, slot);
-                        endSlot = std::max(endSlot, slot + 1);
+                        endSlot =
+                            std::max(endSlot, ityp::Add(slot, VertexBufferSlot(uint8_t(1))));
                         mD3D12BufferViews[slot].StrideInBytes =
                             renderPipeline->GetVertexBuffer(slot).arrayStride;
                     }
@@ -494,11 +496,12 @@
                 // to SetVertexBuffer. This makes it correct to only track the start
                 // and end of the dirty range. When Apply is called,
                 // we will at worst set non-dirty vertex buffers in duplicate.
-                uint32_t count = endSlot - startSlot;
-                commandList->IASetVertexBuffers(startSlot, count, &mD3D12BufferViews[startSlot]);
+                commandList->IASetVertexBuffers(static_cast<uint8_t>(startSlot),
+                                                static_cast<uint8_t>(ityp::Sub(endSlot, startSlot)),
+                                                &mD3D12BufferViews[startSlot]);
 
-                mStartSlot = kMaxVertexBuffers;
-                mEndSlot = 0;
+                mStartSlot = VertexBufferSlot(kMaxVertexBuffers);
+                mEndSlot = VertexBufferSlot(uint8_t(0));
             }
 
           private:
@@ -507,9 +510,10 @@
             // represent the union of the dirty ranges (the union may have non-dirty
             // data in the middle of the range).
             const RenderPipeline* mLastAppliedRenderPipeline = nullptr;
-            uint32_t mStartSlot = kMaxVertexBuffers;
-            uint32_t mEndSlot = 0;
-            std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> mD3D12BufferViews = {};
+            VertexBufferSlot mStartSlot{kMaxVertexBuffers};
+            VertexBufferSlot mEndSlot{uint8_t(0)};
+            ityp::array<VertexBufferSlot, D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers>
+                mD3D12BufferViews = {};
         };
 
         class IndexBufferTracker {
diff --git a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
index 52ec371..1187f75 100644
--- a/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/RenderPipelineD3D12.cpp
@@ -419,17 +419,17 @@
     D3D12_INPUT_LAYOUT_DESC RenderPipeline::ComputeInputLayout(
         std::array<D3D12_INPUT_ELEMENT_DESC, kMaxVertexAttributes>* inputElementDescriptors) {
         unsigned int count = 0;
-        for (auto i : IterateBitSet(GetAttributeLocationsUsed())) {
+        for (VertexAttributeLocation loc : IterateBitSet(GetAttributeLocationsUsed())) {
             D3D12_INPUT_ELEMENT_DESC& inputElementDescriptor = (*inputElementDescriptors)[count++];
 
-            const VertexAttributeInfo& attribute = GetAttribute(i);
+            const VertexAttributeInfo& attribute = GetAttribute(loc);
 
             // If the HLSL semantic is TEXCOORDN the SemanticName should be "TEXCOORD" and the
             // SemanticIndex N
             inputElementDescriptor.SemanticName = "TEXCOORD";
-            inputElementDescriptor.SemanticIndex = static_cast<uint32_t>(i);
+            inputElementDescriptor.SemanticIndex = static_cast<uint8_t>(loc);
             inputElementDescriptor.Format = VertexFormatType(attribute.format);
-            inputElementDescriptor.InputSlot = attribute.vertexBufferSlot;
+            inputElementDescriptor.InputSlot = static_cast<uint8_t>(attribute.vertexBufferSlot);
 
             const VertexBufferInfo& input = GetVertexBuffer(attribute.vertexBufferSlot);
 
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index e0ca8f8..3b0c7bd 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -506,16 +506,13 @@
                 : mLengthTracker(lengthTracker) {
             }
 
-            void OnSetVertexBuffer(uint32_t slot, Buffer* buffer, uint64_t offset) {
+            void OnSetVertexBuffer(VertexBufferSlot slot, Buffer* buffer, uint64_t offset) {
                 mVertexBuffers[slot] = buffer->GetMTLBuffer();
                 mVertexBufferOffsets[slot] = offset;
 
                 ASSERT(buffer->GetSize() < std::numeric_limits<uint32_t>::max());
                 mVertexBufferBindingSizes[slot] = static_cast<uint32_t>(buffer->GetSize() - offset);
-
-                // Use 64 bit masks and make sure there are no shift UB
-                static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, "");
-                mDirtyVertexBuffers |= 1ull << slot;
+                mDirtyVertexBuffers.set(slot);
             }
 
             void OnSetPipeline(RenderPipeline* lastPipeline, RenderPipeline* pipeline) {
@@ -528,21 +525,21 @@
             void Apply(id<MTLRenderCommandEncoder> encoder,
                        RenderPipeline* pipeline,
                        bool enableVertexPulling) {
-                std::bitset<kMaxVertexBuffers> vertexBuffersToApply =
+                const auto& vertexBuffersToApply =
                     mDirtyVertexBuffers & pipeline->GetVertexBufferSlotsUsed();
 
-                for (uint32_t dawnIndex : IterateBitSet(vertexBuffersToApply)) {
-                    uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(dawnIndex);
+                for (VertexBufferSlot slot : IterateBitSet(vertexBuffersToApply)) {
+                    uint32_t metalIndex = pipeline->GetMtlVertexBufferIndex(slot);
 
                     if (enableVertexPulling) {
                         // Insert lengths for vertex buffers bound as storage buffers
                         mLengthTracker->data[SingleShaderStage::Vertex][metalIndex] =
-                            mVertexBufferBindingSizes[dawnIndex];
+                            mVertexBufferBindingSizes[slot];
                         mLengthTracker->dirtyStages |= wgpu::ShaderStage::Vertex;
                     }
 
-                    [encoder setVertexBuffers:&mVertexBuffers[dawnIndex]
-                                      offsets:&mVertexBufferOffsets[dawnIndex]
+                    [encoder setVertexBuffers:&mVertexBuffers[slot]
+                                      offsets:&mVertexBufferOffsets[slot]
                                     withRange:NSMakeRange(metalIndex, 1)];
                 }
 
@@ -551,10 +548,10 @@
 
           private:
             // All the indices in these arrays are Dawn vertex buffer indices
-            std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers;
-            std::array<id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers;
-            std::array<NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets;
-            std::array<uint32_t, kMaxVertexBuffers> mVertexBufferBindingSizes;
+            ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mDirtyVertexBuffers;
+            ityp::array<VertexBufferSlot, id<MTLBuffer>, kMaxVertexBuffers> mVertexBuffers;
+            ityp::array<VertexBufferSlot, NSUInteger, kMaxVertexBuffers> mVertexBufferOffsets;
+            ityp::array<VertexBufferSlot, uint32_t, kMaxVertexBuffers> mVertexBufferBindingSizes;
 
             StorageBufferLengthTracker* mLengthTracker;
         };
diff --git a/src/dawn_native/metal/RenderPipelineMTL.h b/src/dawn_native/metal/RenderPipelineMTL.h
index 4d8656b..80f3bba 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.h
+++ b/src/dawn_native/metal/RenderPipelineMTL.h
@@ -38,7 +38,7 @@
 
         // For each Dawn vertex buffer, give the index in which it will be positioned in the Metal
         // vertex buffer table.
-        uint32_t GetMtlVertexBufferIndex(uint32_t dawnIndex) const;
+        uint32_t GetMtlVertexBufferIndex(VertexBufferSlot slot) const;
 
         wgpu::ShaderStage GetStagesRequiringStorageBufferLength() const;
 
@@ -54,7 +54,7 @@
         MTLCullMode mMtlCullMode;
         id<MTLRenderPipelineState> mMtlRenderPipelineState = nil;
         id<MTLDepthStencilState> mMtlDepthStencilState = nil;
-        std::array<uint32_t, kMaxVertexBuffers> mMtlVertexBufferIndices;
+        ityp::array<VertexBufferSlot, uint32_t, kMaxVertexBuffers> mMtlVertexBufferIndices;
 
         wgpu::ShaderStage mStagesRequiringStorageBufferLength = wgpu::ShaderStage::None;
     };
diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm
index ee5fd04..48e6f13 100644
--- a/src/dawn_native/metal/RenderPipelineMTL.mm
+++ b/src/dawn_native/metal/RenderPipelineMTL.mm
@@ -431,9 +431,9 @@
         return mMtlDepthStencilState;
     }
 
-    uint32_t RenderPipeline::GetMtlVertexBufferIndex(uint32_t dawnIndex) const {
-        ASSERT(dawnIndex < kMaxVertexBuffers);
-        return mMtlVertexBufferIndices[dawnIndex];
+    uint32_t RenderPipeline::GetMtlVertexBufferIndex(VertexBufferSlot slot) const {
+        ASSERT(slot < kMaxVertexBuffersTyped);
+        return mMtlVertexBufferIndices[slot];
     }
 
     wgpu::ShaderStage RenderPipeline::GetStagesRequiringStorageBufferLength() const {
@@ -447,8 +447,8 @@
         uint32_t mtlVertexBufferIndex =
             ToBackend(GetLayout())->GetBufferBindingCount(SingleShaderStage::Vertex);
 
-        for (uint32_t dawnVertexBufferSlot : IterateBitSet(GetVertexBufferSlotsUsed())) {
-            const VertexBufferInfo& info = GetVertexBuffer(dawnVertexBufferSlot);
+        for (VertexBufferSlot slot : IterateBitSet(GetVertexBufferSlotsUsed())) {
+            const VertexBufferInfo& info = GetVertexBuffer(slot);
 
             MTLVertexBufferLayoutDescriptor* layoutDesc = [MTLVertexBufferLayoutDescriptor new];
             if (info.arrayStride == 0) {
@@ -456,10 +456,10 @@
                 // but the arrayStride must NOT be 0, so we made up it with
                 // max(attrib.offset + sizeof(attrib) for each attrib)
                 size_t maxArrayStride = 0;
-                for (uint32_t attribIndex : IterateBitSet(GetAttributeLocationsUsed())) {
-                    const VertexAttributeInfo& attrib = GetAttribute(attribIndex);
+                for (VertexAttributeLocation loc : IterateBitSet(GetAttributeLocationsUsed())) {
+                    const VertexAttributeInfo& attrib = GetAttribute(loc);
                     // Only use the attributes that use the current input
-                    if (attrib.vertexBufferSlot != dawnVertexBufferSlot) {
+                    if (attrib.vertexBufferSlot != slot) {
                         continue;
                     }
                     maxArrayStride = std::max(
@@ -479,18 +479,18 @@
             mtlVertexDescriptor.layouts[mtlVertexBufferIndex] = layoutDesc;
             [layoutDesc release];
 
-            mMtlVertexBufferIndices[dawnVertexBufferSlot] = mtlVertexBufferIndex;
+            mMtlVertexBufferIndices[slot] = mtlVertexBufferIndex;
             mtlVertexBufferIndex++;
         }
 
-        for (uint32_t i : IterateBitSet(GetAttributeLocationsUsed())) {
-            const VertexAttributeInfo& info = GetAttribute(i);
+        for (VertexAttributeLocation loc : IterateBitSet(GetAttributeLocationsUsed())) {
+            const VertexAttributeInfo& info = GetAttribute(loc);
 
             auto attribDesc = [MTLVertexAttributeDescriptor new];
             attribDesc.format = VertexFormatType(info.format);
             attribDesc.offset = info.offset;
             attribDesc.bufferIndex = mMtlVertexBufferIndices[info.vertexBufferSlot];
-            mtlVertexDescriptor.attributes[i] = attribDesc;
+            mtlVertexDescriptor.attributes[static_cast<uint8_t>(loc)] = attribDesc;
             [attribDesc release];
         }
 
diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm
index 2588aa0..ca19ae1 100644
--- a/src/dawn_native/metal/ShaderModuleMTL.mm
+++ b/src/dawn_native/metal/ShaderModuleMTL.mm
@@ -124,14 +124,15 @@
         // Add vertex buffers bound as storage buffers
         if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
             stage == SingleShaderStage::Vertex) {
-            for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
-                uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
+            for (VertexBufferSlot slot :
+                 IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
+                uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
 
                 spirv_cross::MSLResourceBinding mslBinding;
 
                 mslBinding.stage = spv::ExecutionModelVertex;
                 mslBinding.desc_set = kPullingBufferBindingSet;
-                mslBinding.binding = dawnIndex;
+                mslBinding.binding = static_cast<uint8_t>(slot);
                 mslBinding.msl_buffer = metalIndex;
                 compiler.add_msl_resource_binding(mslBinding);
             }
diff --git a/src/dawn_native/opengl/CommandBufferGL.cpp b/src/dawn_native/opengl/CommandBufferGL.cpp
index aa17ff5..4c64ec2 100644
--- a/src/dawn_native/opengl/CommandBufferGL.cpp
+++ b/src/dawn_native/opengl/CommandBufferGL.cpp
@@ -142,13 +142,10 @@
                 mIndexBuffer = ToBackend(buffer);
             }
 
-            void OnSetVertexBuffer(uint32_t slot, BufferBase* buffer, uint64_t offset) {
+            void OnSetVertexBuffer(VertexBufferSlot slot, BufferBase* buffer, uint64_t offset) {
                 mVertexBuffers[slot] = ToBackend(buffer);
                 mVertexBufferOffsets[slot] = offset;
-
-                // Use 64 bit masks and make sure there are no shift UB
-                static_assert(kMaxVertexBuffers <= 8 * sizeof(unsigned long long) - 1, "");
-                mDirtyVertexBuffers |= 1ull << slot;
+                mDirtyVertexBuffers.set(slot);
             }
 
             void OnSetPipeline(RenderPipelineBase* pipeline) {
@@ -168,13 +165,14 @@
                     mIndexBufferDirty = false;
                 }
 
-                for (uint32_t slot : IterateBitSet(mDirtyVertexBuffers &
-                                                   mLastPipeline->GetVertexBufferSlotsUsed())) {
-                    for (uint32_t location :
-                         IterateBitSet(mLastPipeline->GetAttributesUsingVertexBuffer(slot))) {
+                for (VertexBufferSlot slot : IterateBitSet(
+                         mDirtyVertexBuffers & mLastPipeline->GetVertexBufferSlotsUsed())) {
+                    for (VertexAttributeLocation location : IterateBitSet(
+                             ToBackend(mLastPipeline)->GetAttributesUsingVertexBuffer(slot))) {
                         const VertexAttributeInfo& attribute =
                             mLastPipeline->GetAttribute(location);
 
+                        GLuint attribIndex = static_cast<GLuint>(static_cast<uint8_t>(location));
                         GLuint buffer = mVertexBuffers[slot]->GetHandle();
                         uint64_t offset = mVertexBufferOffsets[slot];
 
@@ -186,11 +184,11 @@
                         gl.BindBuffer(GL_ARRAY_BUFFER, buffer);
                         if (VertexFormatIsInt(attribute.format)) {
                             gl.VertexAttribIPointer(
-                                location, components, formatType, vertexBuffer.arrayStride,
+                                attribIndex, components, formatType, vertexBuffer.arrayStride,
                                 reinterpret_cast<void*>(
                                     static_cast<intptr_t>(offset + attribute.offset)));
                         } else {
-                            gl.VertexAttribPointer(location, components, formatType, normalized,
+                            gl.VertexAttribPointer(attribIndex, components, formatType, normalized,
                                                    vertexBuffer.arrayStride,
                                                    reinterpret_cast<void*>(static_cast<intptr_t>(
                                                        offset + attribute.offset)));
@@ -205,9 +203,9 @@
             bool mIndexBufferDirty = false;
             Buffer* mIndexBuffer = nullptr;
 
-            std::bitset<kMaxVertexBuffers> mDirtyVertexBuffers;
-            std::array<Buffer*, kMaxVertexBuffers> mVertexBuffers;
-            std::array<uint64_t, kMaxVertexBuffers> mVertexBufferOffsets;
+            ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mDirtyVertexBuffers;
+            ityp::array<VertexBufferSlot, Buffer*, kMaxVertexBuffers> mVertexBuffers;
+            ityp::array<VertexBufferSlot, uint64_t, kMaxVertexBuffers> mVertexBufferOffsets;
 
             RenderPipelineBase* mLastPipeline = nullptr;
         };
diff --git a/src/dawn_native/opengl/RenderPipelineGL.cpp b/src/dawn_native/opengl/RenderPipelineGL.cpp
index 5130e59..f176cd6 100644
--- a/src/dawn_native/opengl/RenderPipelineGL.cpp
+++ b/src/dawn_native/opengl/RenderPipelineGL.cpp
@@ -218,29 +218,36 @@
         return mGlPrimitiveTopology;
     }
 
+    ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>
+    RenderPipeline::GetAttributesUsingVertexBuffer(VertexBufferSlot slot) const {
+        ASSERT(!IsError());
+        return mAttributesUsingVertexBuffer[slot];
+    }
+
     void RenderPipeline::CreateVAOForVertexState(const VertexStateDescriptor* vertexState) {
         const OpenGLFunctions& gl = ToBackend(GetDevice())->gl;
 
         gl.GenVertexArrays(1, &mVertexArrayObject);
         gl.BindVertexArray(mVertexArrayObject);
 
-        for (uint32_t location : IterateBitSet(GetAttributeLocationsUsed())) {
+        for (VertexAttributeLocation location : IterateBitSet(GetAttributeLocationsUsed())) {
             const auto& attribute = GetAttribute(location);
-            gl.EnableVertexAttribArray(location);
+            GLuint glAttrib = static_cast<GLuint>(static_cast<uint8_t>(location));
+            gl.EnableVertexAttribArray(glAttrib);
 
-            attributesUsingVertexBuffer[attribute.vertexBufferSlot][location] = true;
+            mAttributesUsingVertexBuffer[attribute.vertexBufferSlot][location] = true;
             const VertexBufferInfo& vertexBuffer = GetVertexBuffer(attribute.vertexBufferSlot);
 
             if (vertexBuffer.arrayStride == 0) {
                 // Emulate a stride of zero (constant vertex attribute) by
                 // setting the attribute instance divisor to a huge number.
-                gl.VertexAttribDivisor(location, 0xffffffff);
+                gl.VertexAttribDivisor(glAttrib, 0xffffffff);
             } else {
                 switch (vertexBuffer.stepMode) {
                     case wgpu::InputStepMode::Vertex:
                         break;
                     case wgpu::InputStepMode::Instance:
-                        gl.VertexAttribDivisor(location, 1);
+                        gl.VertexAttribDivisor(glAttrib, 1);
                         break;
                     default:
                         UNREACHABLE();
diff --git a/src/dawn_native/opengl/RenderPipelineGL.h b/src/dawn_native/opengl/RenderPipelineGL.h
index 2f3363a..e15c864 100644
--- a/src/dawn_native/opengl/RenderPipelineGL.h
+++ b/src/dawn_native/opengl/RenderPipelineGL.h
@@ -32,6 +32,8 @@
         RenderPipeline(Device* device, const RenderPipelineDescriptor* descriptor);
 
         GLenum GetGLPrimitiveTopology() const;
+        ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> GetAttributesUsingVertexBuffer(
+            VertexBufferSlot slot) const;
 
         void ApplyNow(PersistentPipelineState& persistentPipelineState);
 
@@ -42,6 +44,11 @@
         // TODO(yunchao.he@intel.com): vao need to be deduplicated between pipelines.
         GLuint mVertexArrayObject;
         GLenum mGlPrimitiveTopology;
+
+        ityp::array<VertexBufferSlot,
+                    ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes>,
+                    kMaxVertexBuffers>
+            mAttributesUsingVertexBuffer;
     };
 
 }}  // namespace dawn_native::opengl
diff --git a/src/dawn_native/vulkan/CommandBufferVk.cpp b/src/dawn_native/vulkan/CommandBufferVk.cpp
index 240614c..b238fc0 100644
--- a/src/dawn_native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn_native/vulkan/CommandBufferVk.cpp
@@ -1050,7 +1050,8 @@
                     VkBuffer buffer = ToBackend(cmd->buffer)->GetHandle();
                     VkDeviceSize offset = static_cast<VkDeviceSize>(cmd->offset);
 
-                    device->fn.CmdBindVertexBuffers(commands, cmd->slot, 1, &*buffer, &offset);
+                    device->fn.CmdBindVertexBuffers(commands, static_cast<uint8_t>(cmd->slot), 1,
+                                                    &*buffer, &offset);
                     break;
                 }
 
diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp
index d23d515..7908451 100644
--- a/src/dawn_native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp
@@ -517,11 +517,11 @@
         PipelineVertexInputStateCreateInfoTemporaryAllocations* tempAllocations) {
         // Fill in the "binding info" that will be chained in the create info
         uint32_t bindingCount = 0;
-        for (uint32_t i : IterateBitSet(GetVertexBufferSlotsUsed())) {
-            const VertexBufferInfo& bindingInfo = GetVertexBuffer(i);
+        for (VertexBufferSlot slot : IterateBitSet(GetVertexBufferSlotsUsed())) {
+            const VertexBufferInfo& bindingInfo = GetVertexBuffer(slot);
 
             VkVertexInputBindingDescription* bindingDesc = &tempAllocations->bindings[bindingCount];
-            bindingDesc->binding = i;
+            bindingDesc->binding = static_cast<uint8_t>(slot);
             bindingDesc->stride = bindingInfo.arrayStride;
             bindingDesc->inputRate = VulkanInputRate(bindingInfo.stepMode);
 
@@ -530,13 +530,13 @@
 
         // Fill in the "attribute info" that will be chained in the create info
         uint32_t attributeCount = 0;
-        for (uint32_t i : IterateBitSet(GetAttributeLocationsUsed())) {
-            const VertexAttributeInfo& attributeInfo = GetAttribute(i);
+        for (VertexAttributeLocation loc : IterateBitSet(GetAttributeLocationsUsed())) {
+            const VertexAttributeInfo& attributeInfo = GetAttribute(loc);
 
             VkVertexInputAttributeDescription* attributeDesc =
                 &tempAllocations->attributes[attributeCount];
-            attributeDesc->location = i;
-            attributeDesc->binding = attributeInfo.vertexBufferSlot;
+            attributeDesc->location = static_cast<uint8_t>(loc);
+            attributeDesc->binding = static_cast<uint8_t>(attributeInfo.vertexBufferSlot);
             attributeDesc->format = VulkanVertexFormat(attributeInfo.format);
             attributeDesc->offset = attributeInfo.offset;