Dawn Native: Setup ClampFragDepth offset in Vulkan Backend

This CL remove hard coded clamp frag depth infos and uses pipeline
allocated offsets for clamp frag depth constants.

It added RenderImmediateConstantTracker in CommandBufferVk for
render pass recording. The tracker helps to manage clmap frag depth
updating.

Bug: 366291600
Change-Id: I76a218fc83884bad336c67470a56f347e671ca93
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/221274
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Shaobo Yan <shaoboyan@microsoft.com>
diff --git a/src/dawn/native/ImmediateConstantsLayout.h b/src/dawn/native/ImmediateConstantsLayout.h
index af72771..d8bec60 100644
--- a/src/dawn/native/ImmediateConstantsLayout.h
+++ b/src/dawn/native/ImmediateConstantsLayout.h
@@ -85,6 +85,35 @@
     return ((1u << constantCount) - 1u) << firstIndex;
 }
 
+// Returns the offset of the member in the packed immediates of the pipeline.
+// The pointer-to-member is a pointer into the structure containing all the potential immediates.
+// However pipelines don't need all of them and use a compacted layout with immediates
+// in the same order, just some of them skipped. For example the pipeline mask 11001111,
+// representing "userConstants: 4 | trivial_constants: 0 (2 at most)|clamp_frag:2",
+// maps to pipeline immediate constant layout: "userConstants:4 | clamp_frag:2
+template <typename Object, typename Member>
+uint32_t GetImmediateByteOffsetInPipeline(Member Object::*ptr,
+                                          const ImmediateConstantMask& pipelineImmediateMask) {
+    Object obj = {};
+    ptrdiff_t offset = reinterpret_cast<char*>(&(obj.*ptr)) - reinterpret_cast<char*>(&obj);
+
+    const ImmediateConstantMask prefixBits =
+        (1u << (offset / kImmediateConstantElementByteSize)) - 1u;
+
+    return (prefixBits & pipelineImmediateMask).count() * kImmediateConstantElementByteSize;
+}
+
+template <typename Object, typename Member>
+bool HasImmediateConstants(Member Object::*ptr,
+                           const ImmediateConstantMask& pipelineImmediateMask) {
+    Object obj = {};
+    ptrdiff_t offset = reinterpret_cast<char*>(&(obj.*ptr)) - reinterpret_cast<char*>(&obj);
+    size_t size = sizeof(Member);
+
+    return pipelineImmediateMask.to_ulong() &
+           GetImmediateConstantBlockBits(offset, size).to_ulong();
+}
+
 }  // namespace dawn::native
 
 #endif  // SRC_DAWN_NATIVE_IMMEDIATECONSTANTSLAYOUT_H_
diff --git a/src/dawn/native/ImmediateConstantsTracker.cpp b/src/dawn/native/ImmediateConstantsTracker.cpp
index ef0973f..90b6df1 100644
--- a/src/dawn/native/ImmediateConstantsTracker.cpp
+++ b/src/dawn/native/ImmediateConstantsTracker.cpp
@@ -29,21 +29,6 @@
 #include "ImmediateConstantsLayout.h"
 
 namespace dawn::native {
-RenderImmediateConstantsTrackerBase::RenderImmediateConstantsTrackerBase()
-    : UserImmediateConstantsTrackerBase() {}
-
-// Render pipeline changes reset all pipeline related dirty bits and
-// keep frag depth dirty bits which related to viewport.
-// TODO(crbug.com/366291600): Support immediate data compatible.
-void RenderImmediateConstantsTrackerBase::OnPipelineChange(PipelineBase* pipeline) {
-    mPipelineMask = pipeline->GetPipelineMask();
-
-    // frag depth args are related to viewport instead of pipeline
-    static constexpr ImmediateConstantMask fragDepth = GetImmediateConstantBlockBits(
-        offsetof(RenderImmediateConstants, clampFragDepth), sizeof(ClampFragDepthArgs));
-    mDirty &= fragDepth;
-}
-
 void RenderImmediateConstantsTrackerBase::SetClampFragDepth(float minClampFragDepth,
                                                             float maxClampFragDepth) {
     // Put the data in the right layout to match the RenderImmediateConstants struct
@@ -68,16 +53,6 @@
     UpdateImmediateConstants(offsetof(RenderImmediateConstants, firstInstance), firstInstance);
 }
 
-ComputeImmediateConstantsTrackerBase::ComputeImmediateConstantsTrackerBase()
-    : UserImmediateConstantsTrackerBase() {}
-
-// Pipeline changes reset all dirty bits.
-// TODO(crbug.com/366291600): Support immediate data compatible.
-void ComputeImmediateConstantsTrackerBase::OnPipelineChange(PipelineBase* pipeline) {
-    mPipelineMask = pipeline->GetPipelineMask();
-    mDirty.reset();
-}
-
 void ComputeImmediateConstantsTrackerBase::SetNumWorkgroups(uint32_t numWorkgroupX,
                                                             uint32_t numWorkgroupY,
                                                             uint32_t numWorkgroupZ) {
diff --git a/src/dawn/native/ImmediateConstantsTracker.h b/src/dawn/native/ImmediateConstantsTracker.h
index 7a82b4f..9fb12ab 100644
--- a/src/dawn/native/ImmediateConstantsTracker.h
+++ b/src/dawn/native/ImmediateConstantsTracker.h
@@ -37,10 +37,13 @@
 #include "dawn/common/Constants.h"
 #include "dawn/common/ityp_bitset.h"
 #include "dawn/common/ityp_span.h"
+#include "dawn/native/ComputePipeline.h"
 #include "dawn/native/Device.h"
 #include "dawn/native/ImmediateConstantsLayout.h"
 #include "dawn/native/IntegerTypes.h"
 #include "dawn/native/Pipeline.h"
+#include "dawn/native/RenderPipeline.h"
+#include "partition_alloc/pointers/raw_ptr_exclusion.h"
 
 namespace dawn::native {
 
@@ -68,7 +71,7 @@
     alignas(T) unsigned char mData[sizeof(T)] = {0};
 };
 
-template <typename T>
+template <typename T, typename PipelineType>
 class UserImmediateConstantsTrackerBase {
   public:
     UserImmediateConstantsTrackerBase() {}
@@ -86,9 +89,17 @@
         }
     }
 
-    // Getters
-    const ImmediateConstantMask& GetPipelineMask() const { return mPipelineMask; }
+    // TODO(crbug.com/366291600): Support immediate data compatible.
+    void OnSetPipeline(PipelineType* pipeline) {
+        if (mLastPipeline == pipeline) {
+            return;
+        }
 
+        mDirty = pipeline->GetImmediateMask();
+        mLastPipeline = pipeline;
+    }
+
+    // Getters
     const ImmediateConstantMask& GetDirtyBits() const { return mDirty; }
 
     const ImmediateDataContent<T>& GetContent() const { return mContent; }
@@ -108,14 +119,13 @@
 
     ImmediateDataContent<T> mContent;
     ImmediateConstantMask mDirty = ImmediateConstantMask(0);
-    ImmediateConstantMask mPipelineMask = ImmediateConstantMask(0);
+    RAW_PTR_EXCLUSION PipelineType* mLastPipeline = nullptr;
 };
 
 class RenderImmediateConstantsTrackerBase
-    : public UserImmediateConstantsTrackerBase<RenderImmediateConstants> {
+    : public UserImmediateConstantsTrackerBase<RenderImmediateConstants, RenderPipelineBase> {
   public:
-    RenderImmediateConstantsTrackerBase();
-    void OnPipelineChange(PipelineBase* pipeline);
+    RenderImmediateConstantsTrackerBase() = default;
     void SetClampFragDepth(float minClampFragDepth, float maxClampFragDepth);
     void SetFirstIndexOffset(uint32_t firstVertex, uint32_t firstInstance);
     void SetFirstVertex(uint32_t firstVertex);
@@ -123,10 +133,9 @@
 };
 
 class ComputeImmediateConstantsTrackerBase
-    : public UserImmediateConstantsTrackerBase<ComputeImmediateConstants> {
+    : public UserImmediateConstantsTrackerBase<ComputeImmediateConstants, ComputePipelineBase> {
   public:
-    ComputeImmediateConstantsTrackerBase();
-    void OnPipelineChange(PipelineBase* pipeline);
+    ComputeImmediateConstantsTrackerBase() = default;
     void SetNumWorkgroups(uint32_t numWorkgroupX, uint32_t numWorkgroupY, uint32_t numWorkgroupZ);
 };
 
diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp
index 25a3913..46c7e36 100644
--- a/src/dawn/native/Pipeline.cpp
+++ b/src/dawn/native/Pipeline.cpp
@@ -35,6 +35,7 @@
 #include "dawn/common/Enumerator.h"
 #include "dawn/native/BindGroupLayout.h"
 #include "dawn/native/Device.h"
+#include "dawn/native/ImmediateConstantsLayout.h"
 #include "dawn/native/ObjectBase.h"
 #include "dawn/native/ObjectContentHasher.h"
 #include "dawn/native/PipelineLayout.h"
@@ -291,8 +292,8 @@
     return mStageMask;
 }
 
-const ImmediateConstantMask& PipelineBase::GetPipelineMask() const {
-    return mPipelineMask;
+const ImmediateConstantMask& PipelineBase::GetImmediateMask() const {
+    return mImmediateMask;
 }
 
 MaybeError PipelineBase::ValidateGetBindGroupLayout(BindGroupIndex groupIndex) {
@@ -379,12 +380,19 @@
     if (!scopedUsePrograms) {
         scopedUsePrograms = UseShaderPrograms();
     }
+
+    // Set immediate constant status. userConstants is the first element in both
+    // RenderImmediateConstants and ComputeImmediateConstants.
+    ImmediateConstantMask userConstantsBits =
+        GetImmediateConstantBlockBits(0, GetLayout()->GetImmediateDataRangeByteSize());
+    mImmediateMask |= userConstantsBits;
+
     DAWN_TRY_CONTEXT(InitializeImpl(), "initializing %s", this);
     return {};
 }
 
-void PipelineBase::SetPipelineMaskForTesting(ImmediateConstantMask immediateConstantMask) {
-    mPipelineMask = immediateConstantMask;
+void PipelineBase::SetImmediateMaskForTesting(ImmediateConstantMask immediateConstantMask) {
+    mImmediateMask = immediateConstantMask;
 }
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/Pipeline.h b/src/dawn/native/Pipeline.h
index e0951cc..712b705 100644
--- a/src/dawn/native/Pipeline.h
+++ b/src/dawn/native/Pipeline.h
@@ -77,7 +77,7 @@
     const PerStage<ProgrammableStage>& GetAllStages() const;
     bool HasStage(SingleShaderStage stage) const;
     wgpu::ShaderStage GetStageMask() const;
-    const ImmediateConstantMask& GetPipelineMask() const;
+    const ImmediateConstantMask& GetImmediateMask() const;
 
     ResultOrError<Ref<BindGroupLayoutBase>> GetBindGroupLayout(uint32_t groupIndex);
 
@@ -94,7 +94,7 @@
     // Initialize() should only be called once by the frontend.
     MaybeError Initialize(std::optional<ScopedUseShaderPrograms> scopedUsePrograms = std::nullopt);
 
-    void SetPipelineMaskForTesting(ImmediateConstantMask immediateConstantMask);
+    void SetImmediateMaskForTesting(ImmediateConstantMask immediateConstantMask);
 
   protected:
     PipelineBase(DeviceBase* device,
@@ -103,7 +103,7 @@
                  std::vector<StageAndDescriptor> stages);
     PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag, StringView label);
 
-    ImmediateConstantMask mPipelineMask = ImmediateConstantMask(0);
+    ImmediateConstantMask mImmediateMask = ImmediateConstantMask(0);
 
   private:
     MaybeError ValidateGetBindGroupLayout(BindGroupIndex group);
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index b9b2d65..bbc9716 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -949,8 +949,7 @@
 // RenderPipelineBase
 
 RenderPipelineBase::RenderPipelineBase(DeviceBase* device,
-                                       const UnpackedPtr<RenderPipelineDescriptor>& descriptor,
-                                       ImmediateConstantMask requiredInternalImmediateConstants)
+                                       const UnpackedPtr<RenderPipelineDescriptor>& descriptor)
     : PipelineBase(device,
                    descriptor->layout,
                    descriptor->label,
diff --git a/src/dawn/native/RenderPipeline.h b/src/dawn/native/RenderPipeline.h
index 7450c3a..670b39e 100644
--- a/src/dawn/native/RenderPipeline.h
+++ b/src/dawn/native/RenderPipeline.h
@@ -89,10 +89,7 @@
 class RenderPipelineBase : public PipelineBase,
                            public ContentLessObjectCacheable<RenderPipelineBase> {
   public:
-    RenderPipelineBase(
-        DeviceBase* device,
-        const UnpackedPtr<RenderPipelineDescriptor>& descriptor,
-        ImmediateConstantMask requiredInternalImmediateConstants = ImmediateConstantMask(0u));
+    RenderPipelineBase(DeviceBase* device, const UnpackedPtr<RenderPipelineDescriptor>& descriptor);
     ~RenderPipelineBase() override;
 
     static Ref<RenderPipelineBase> MakeError(DeviceBase* device, StringView label);
diff --git a/src/dawn/native/vulkan/CommandBufferVk.cpp b/src/dawn/native/vulkan/CommandBufferVk.cpp
index f78cd17..a2ae115 100644
--- a/src/dawn/native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn/native/vulkan/CommandBufferVk.cpp
@@ -28,6 +28,7 @@
 #include "dawn/native/vulkan/CommandBufferVk.h"
 
 #include <algorithm>
+#include <limits>
 #include <vector>
 
 #include "dawn/native/BindGroupTracker.h"
@@ -36,6 +37,7 @@
 #include "dawn/native/Commands.h"
 #include "dawn/native/DynamicUploader.h"
 #include "dawn/native/EnumMaskIterator.h"
+#include "dawn/native/ImmediateConstantsTracker.h"
 #include "dawn/native/RenderBundle.h"
 #include "dawn/native/vulkan/BindGroupVk.h"
 #include "dawn/native/vulkan/BufferVk.h"
@@ -192,6 +194,68 @@
     uint32_t mInternalImmediateDataSize = 0;
 };
 
+class RenderImmediateConstantTracker : public RenderImmediateConstantsTrackerBase {
+  public:
+    RenderImmediateConstantTracker() = default;
+
+    void Apply(Device* device, VkCommandBuffer commandBuffer) {
+        if (!mLastPipeline) {
+            return;
+        }
+
+        const ImmediateConstantMask& pipelineImmediateMask = mLastPipeline->GetImmediateMask();
+        const size_t maxImmediateConstantSize =
+            pipelineImmediateMask.count() * kImmediateConstantElementByteSize;
+
+        uint32_t pushConstantRangeStartOffset = 0;
+        uint32_t immediateContentStartOffset = 0;
+        uint32_t immediateDataCount = 0;
+
+        ImmediateConstantMask uploadBits = mDirty & mLastPipeline->GetImmediateMask();
+        ImmediateConstantMask prefixBits = ImmediateConstantMask(0u);
+
+        // TODO(crbug.com/366291600): Add IterateBitRanges helper function to achieve iteration on
+        // ranges.
+        for (ImmediateConstantIndex i : IterateBitSet(mLastPipeline->GetImmediateMask())) {
+            if (uploadBits.test(i)) {
+                uint32_t index = static_cast<uint32_t>(i);
+                if (immediateDataCount == 0) {
+                    prefixBits = (1u << index) - 1u;
+                    pushConstantRangeStartOffset = (prefixBits & pipelineImmediateMask).count() *
+                                                   kImmediateConstantElementByteSize;
+                    immediateContentStartOffset = index * kImmediateConstantElementByteSize;
+                }
+                ++immediateDataCount;
+            } else {
+                if (immediateDataCount > 0) {
+                    device->fn.CmdPushConstants(
+                        commandBuffer, ToBackend(mLastPipeline)->GetVkLayout(),
+                        ToBackend(mLastPipeline->GetLayout())->GetImmediateDataRangeStage(),
+                        pushConstantRangeStartOffset,
+                        immediateDataCount * kImmediateConstantElementByteSize,
+                        mContent.Get<uint32_t>(immediateContentStartOffset));
+                    immediateDataCount = 0;
+                }
+            }
+        }
+
+        // Final Uploading
+        if (immediateDataCount > 0) {
+            DAWN_ASSERT(pushConstantRangeStartOffset < maxImmediateConstantSize);
+            device->fn.CmdPushConstants(
+                commandBuffer, ToBackend(mLastPipeline)->GetVkLayout(),
+                ToBackend(mLastPipeline->GetLayout())->GetImmediateDataRangeStage(),
+                pushConstantRangeStartOffset,
+                immediateDataCount * kImmediateConstantElementByteSize,
+                mContent.Get<uint32_t>(immediateContentStartOffset));
+            immediateDataCount = 0;
+        }
+
+        // Reset all dirty bits after uploading.
+        mDirty.reset();
+    }
+};
+
 // Records the necessary barriers for a synchronization scope using the resource usage
 // data pre-computed in the frontend. Also performs lazy initialization if required.
 MaybeError TransitionAndClearForSyncScope(Device* device,
@@ -1170,6 +1234,7 @@
 
     DAWN_TRY(RecordBeginRenderPass(recordingContext, device, renderPassCmd));
 
+    RenderImmediateConstantTracker renderImmediateConstantTracker = {};
     // Set the default value for the dynamic state
     {
         device->fn.CmdSetLineWidth(commands, 1.0f);
@@ -1201,35 +1266,21 @@
         scissorRect.extent.width = renderPassCmd->width;
         scissorRect.extent.height = renderPassCmd->height;
         device->fn.CmdSetScissor(commands, 0, 1, &scissorRect);
+
+        // Apply default frag depth
+        renderImmediateConstantTracker.SetClampFragDepth(0.0, 1.0);
     }
 
     DescriptorSetTracker descriptorSets = {};
     RenderPipeline* lastPipeline = nullptr;
 
-    // Tracking for the push constants needed by the ClampFragDepth transform.
-    // TODO(dawn:1125): Avoid the need for this when the depthClamp feature is available, but doing
-    // so would require fixing issue dawn:1576 first to have more dynamic push constant usage. (and
-    // also additional tests that the dirtying logic here is correct so with a Toggle we can test it
-    // on our infra).
-    ClampFragDepthArgs clampFragDepthArgs = {0.0f, 1.0f};
-    bool clampFragDepthArgsDirty = true;
-    auto ApplyClampFragDepthArgs = [&] {
-        if (!clampFragDepthArgsDirty || lastPipeline == nullptr) {
-            return;
-        }
-        device->fn.CmdPushConstants(
-            commands, lastPipeline->GetVkLayout(),
-            ToBackend(lastPipeline->GetLayout())->GetImmediateDataRangeStage(),
-            kClampFragDepthArgsOffset, kClampFragDepthArgsSize, &clampFragDepthArgs);
-        clampFragDepthArgsDirty = false;
-    };
-
     auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
         switch (type) {
             case Command::Draw: {
                 DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                renderImmediateConstantTracker.Apply(device, commands);
                 device->fn.CmdDraw(commands, draw->vertexCount, draw->instanceCount,
                                    draw->firstVertex, draw->firstInstance);
                 break;
@@ -1239,6 +1290,7 @@
                 DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                renderImmediateConstantTracker.Apply(device, commands);
                 device->fn.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
                                           draw->firstIndex, draw->baseVertex, draw->firstInstance);
                 break;
@@ -1249,6 +1301,7 @@
                 Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                renderImmediateConstantTracker.Apply(device, commands);
                 device->fn.CmdDrawIndirect(commands, buffer->GetHandle(),
                                            static_cast<VkDeviceSize>(draw->indirectOffset), 1, 0);
                 break;
@@ -1260,6 +1313,7 @@
                 DAWN_ASSERT(buffer != nullptr);
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                renderImmediateConstantTracker.Apply(device, commands);
                 device->fn.CmdDrawIndexedIndirect(commands, buffer->GetHandle(),
                                                   static_cast<VkDeviceSize>(draw->indirectOffset),
                                                   1, 0);
@@ -1276,6 +1330,7 @@
                 Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                renderImmediateConstantTracker.Apply(device, commands);
 
                 if (countBuffer == nullptr) {
                     device->fn.CmdDrawIndirect(commands, indirectBuffer->GetHandle(),
@@ -1300,6 +1355,7 @@
                 Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
+                renderImmediateConstantTracker.Apply(device, commands);
 
                 if (countBuffer == nullptr) {
                     device->fn.CmdDrawIndexedIndirect(
@@ -1398,9 +1454,7 @@
                 lastPipeline = pipeline;
 
                 descriptorSets.OnSetPipeline<RenderPipeline>(pipeline);
-
-                // Apply the deferred min/maxDepth push constants update if needed.
-                ApplyClampFragDepthArgs();
+                renderImmediateConstantTracker.OnSetPipeline(pipeline);
                 break;
             }
 
@@ -1480,9 +1534,8 @@
 
                 // Try applying the push constants that contain min/maxDepth immediately. This can
                 // be deferred if no pipeline is currently bound.
-                clampFragDepthArgs = {viewport.minDepth, viewport.maxDepth};
-                clampFragDepthArgsDirty = true;
-                ApplyClampFragDepthArgs();
+                renderImmediateConstantTracker.SetClampFragDepth(viewport.minDepth,
+                                                                 viewport.maxDepth);
                 break;
             }
 
diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp
index 6235ff5..0e173a8 100644
--- a/src/dawn/native/vulkan/ComputePipelineVk.cpp
+++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp
@@ -58,11 +58,6 @@
     // Vulkan devices need cache UUID field to be serialized into pipeline cache keys.
     StreamIn(&mCacheKey, device->GetDeviceInfo().properties.pipelineCacheUUID);
 
-    // Set Immediate Constants states
-    mPipelineMask |=
-        GetImmediateConstantBlockBits(offsetof(ComputeImmediateConstants, userConstants),
-                                      GetLayout()->GetImmediateDataRangeByteSize());
-
     // Compute pipeline doesn't have clamp depth feature.
     // TODO(crbug.com/366291600): Setting immediate data size if needed.
     DAWN_TRY(InitializeBase(layout, 0));
@@ -86,8 +81,7 @@
     ShaderModule::ModuleAndSpirv moduleAndSpirv;
     DAWN_TRY_ASSIGN(moduleAndSpirv,
                     module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout,
-                                              /*clampFragDepth*/ false,
-                                              /*emitPointSize*/ false));
+                                              /*emitPointSize*/ false, GetImmediateMask()));
 
     createInfo.stage.module = moduleAndSpirv.module;
     createInfo.stage.pName = kRemappedEntryPointName;
diff --git a/src/dawn/native/vulkan/PipelineVk.h b/src/dawn/native/vulkan/PipelineVk.h
index 833a538..299127b 100644
--- a/src/dawn/native/vulkan/PipelineVk.h
+++ b/src/dawn/native/vulkan/PipelineVk.h
@@ -49,6 +49,7 @@
     uint32_t GetInternalImmediateDataSize() const;
 
   protected:
+    // TODO(crbug.com/366291600): Accept immediate data mask instead of size.
     MaybeError InitializeBase(PipelineLayout* layout, uint32_t internalImmediateDataSize);
     void DestroyImpl();
 
diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp
index 1ae4a5b..15a6db6 100644
--- a/src/dawn/native/vulkan/RenderPipelineVk.cpp
+++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp
@@ -346,12 +346,7 @@
 Ref<RenderPipeline> RenderPipeline::CreateUninitialized(
     Device* device,
     const UnpackedPtr<RenderPipelineDescriptor>& descriptor) {
-    // Possible required internal immediate constants for RenderPipelineVk:
-    // - ClampFragDepth
-    const ImmediateConstantMask requiredInternalConstants = GetImmediateConstantBlockBits(
-        offsetof(RenderImmediateConstants, clampFragDepth), sizeof(ClampFragDepthArgs));
-
-    return AcquireRef(new RenderPipeline(device, descriptor, requiredInternalConstants));
+    return AcquireRef(new RenderPipeline(device, descriptor));
 }
 
 MaybeError RenderPipeline::InitializeImpl() {
@@ -361,14 +356,9 @@
     // Vulkan devices need cache UUID field to be serialized into pipeline cache keys.
     StreamIn(&mCacheKey, device->GetDeviceInfo().properties.pipelineCacheUUID);
 
-    // Set immediate constant status
-    mPipelineMask |=
-        GetImmediateConstantBlockBits(offsetof(RenderImmediateConstants, userConstants),
-                                      GetLayout()->GetImmediateDataRangeByteSize());
-
     // Gather list of internal immediate constants used by this pipeline
     if (UsesFragDepth() && !HasUnclippedDepth()) {
-        mPipelineMask |= GetImmediateConstantBlockBits(
+        mImmediateMask |= GetImmediateConstantBlockBits(
             offsetof(RenderImmediateConstants, clampFragDepth), sizeof(ClampFragDepthArgs));
     }
 
@@ -378,12 +368,12 @@
     uint32_t stageCount = 0;
 
     auto AddShaderStage = [&](SingleShaderStage stage, VkShaderStageFlagBits vkStage,
-                              bool clampFragDepth, bool emitPointSize) -> MaybeError {
+                              bool emitPointSize) -> MaybeError {
         const ProgrammableStage& programmableStage = GetStage(stage);
         ShaderModule::ModuleAndSpirv moduleAndSpirv;
         DAWN_TRY_ASSIGN(moduleAndSpirv, ToBackend(programmableStage.module)
                                             ->GetHandleAndSpirv(stage, programmableStage, layout,
-                                                                clampFragDepth, emitPointSize));
+                                                                emitPointSize, GetImmediateMask()));
         mHasInputAttachment = mHasInputAttachment || moduleAndSpirv.hasInputAttachment;
         // Record cache key for each shader since it will become inaccessible later on.
         StreamIn(&mCacheKey, stream::Iterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount));
@@ -404,14 +394,12 @@
 
     // Add the vertex stage that's always present.
     DAWN_TRY(AddShaderStage(SingleShaderStage::Vertex, VK_SHADER_STAGE_VERTEX_BIT,
-                            /*clampFragDepth*/ false,
                             GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList));
 
     // Add the fragment stage if present.
     if (GetStageMask() & wgpu::ShaderStage::Fragment) {
-        bool clampFragDepth = UsesFragDepth() && !HasUnclippedDepth();
         DAWN_TRY(AddShaderStage(SingleShaderStage::Fragment, VK_SHADER_STAGE_FRAGMENT_BIT,
-                                clampFragDepth, /*emitPointSize*/ false));
+                                /*emitPointSize*/ false));
     }
 
     PipelineVertexInputStateCreateInfoTemporaryAllocations tempAllocations;
@@ -575,7 +563,12 @@
     }
 
     // TODO(crbug.com/366291600): Add internal immediate data size when needed.
-    DAWN_TRY(InitializeBase(layout, kClampFragDepthArgsSize));
+    ImmediateConstantMask userConstantBits =
+        GetImmediateConstantBlockBits(offsetof(RenderImmediateConstants, userConstants),
+                                      GetLayout()->GetImmediateDataRangeByteSize());
+    uint32_t internalImmediateConstantsSize =
+        (mImmediateMask & userConstantBits.flip()).count() * kImmediateConstantElementByteSize;
+    DAWN_TRY(InitializeBase(layout, internalImmediateConstantsSize));
 
     // The create info chains in a bunch of things created on the stack here or inside state
     // objects.
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 35a45d6..ae1970d 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -38,6 +38,7 @@
 #include "dawn/common/MatchVariant.h"
 #include "dawn/native/Adapter.h"
 #include "dawn/native/CacheRequest.h"
+#include "dawn/native/ImmediateConstantsLayout.h"
 #include "dawn/native/PhysicalDevice.h"
 #include "dawn/native/Serializable.h"
 #include "dawn/native/TintUtils.h"
@@ -212,8 +213,8 @@
     SingleShaderStage stage,
     const ProgrammableStage& programmableStage,
     const PipelineLayout* layout,
-    bool clampFragDepth,
-    bool emitPointSize) {
+    bool emitPointSize,
+    const ImmediateConstantMask& pipelineImmediateMask) {
     TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
 
     ScopedTintICEHandler scopedICEHandler(GetDevice());
@@ -346,7 +347,6 @@
 
     req.tintOptions.statically_paired_texture_binding_points =
         std::move(statically_paired_texture_binding_points);
-    req.tintOptions.clamp_frag_depth = clampFragDepth;
     req.tintOptions.disable_robustness = !GetDevice()->IsRobustnessEnabled();
     req.tintOptions.emit_vertex_point_size = emitPointSize;
 
@@ -383,6 +383,14 @@
         req.tintOptions.pass_matrix_by_pointer = true;
     }
 
+    // Set internal immediate constant offsets
+    if (HasImmediateConstants(&RenderImmediateConstants::clampFragDepth, pipelineImmediateMask)) {
+        uint32_t offsetStartBytes = GetImmediateByteOffsetInPipeline(
+            &RenderImmediateConstants::clampFragDepth, pipelineImmediateMask);
+        req.tintOptions.depth_range_offsets = {
+            offsetStartBytes, offsetStartBytes + kImmediateConstantElementByteSize};
+    }
+
     const CombinedLimits& limits = GetDevice()->GetLimits();
     req.limits = LimitsForCompilationRequest::Create(limits.v1);
     req.adapter = UnsafeUnkeyedValue(static_cast<const AdapterBase*>(GetDevice()->GetAdapter()));
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h
index 5881dc6..50c6084 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.h
+++ b/src/dawn/native/vulkan/ShaderModuleVk.h
@@ -85,8 +85,8 @@
     ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(SingleShaderStage stage,
                                                     const ProgrammableStage& programmableStage,
                                                     const PipelineLayout* layout,
-                                                    bool clampFragDepth,
-                                                    bool emitPointSize);
+                                                    bool emitPointSize,
+                                                    const ImmediateConstantMask& pipelineMask);
 
   private:
     ShaderModule(Device* device,
diff --git a/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp b/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp
index 4f2c029..aa6982c 100644
--- a/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp
+++ b/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp
@@ -71,41 +71,17 @@
 
 // Test pipeline change reset dirty bits and update tracked pipeline constants mask.
 TEST_F(ImmediateConstantsTrackerTest, OnPipelineChange) {
-    // RenderImmediateConstantsTrackerBase
-    {
-        RenderImmediateConstantsTrackerBase tracker;
+    RenderImmediateConstantsTrackerBase tracker;
 
-        // Control Case
-        tracker.SetDirtyBitsForTesting({0b00100101});
-        EXPECT_TRUE(tracker.GetDirtyBits() == ImmediateConstantMask(0b00100101));
-        EXPECT_TRUE(tracker.GetPipelineMask() == ImmediateConstantMask(0));
+    // Control Case
+    EXPECT_TRUE(tracker.GetDirtyBits() == ImmediateConstantMask(0));
 
-        // Pipeline change should reset dirty bits
-        wgpu::RenderPipeline wgpuPipeline = MakeTestRenderPipeline();
-        RenderPipelineBase* pipeline = FromAPI(wgpuPipeline.Get());
-        pipeline->SetPipelineMaskForTesting({0b01010101});
-        tracker.OnPipelineChange(pipeline);
-        EXPECT_TRUE(tracker.GetDirtyBits() == ImmediateConstantMask(0b00100000));
-        EXPECT_TRUE(tracker.GetPipelineMask() == ImmediateConstantMask(0b01010101));
-    }
-
-    // ComputeImmediateConstantsTrackerBase
-    {
-        ComputeImmediateConstantsTrackerBase tracker;
-
-        // Control Case
-        tracker.SetDirtyBitsForTesting({0b00100101});
-        EXPECT_TRUE(tracker.GetDirtyBits() == ImmediateConstantMask(0b00100101));
-        EXPECT_TRUE(tracker.GetPipelineMask() == ImmediateConstantMask(0));
-
-        // Pipeline change should reset dirty bits
-        wgpu::ComputePipeline wgpuPipeline = MakeTestComputePipeline();
-        ComputePipelineBase* pipeline = FromAPI(wgpuPipeline.Get());
-        pipeline->SetPipelineMaskForTesting({0b01000101});
-        tracker.OnPipelineChange(pipeline);
-        EXPECT_TRUE(tracker.GetDirtyBits() == ImmediateConstantMask(0));
-        EXPECT_TRUE(tracker.GetPipelineMask() == ImmediateConstantMask(0b01000101));
-    }
+    // Pipeline change should reset dirty bits
+    wgpu::RenderPipeline wgpuPipeline = MakeTestRenderPipeline();
+    RenderPipelineBase* pipeline = FromAPI(wgpuPipeline.Get());
+    pipeline->SetImmediateMaskForTesting({0b01010101});
+    tracker.OnSetPipeline(pipeline);
+    EXPECT_TRUE(tracker.GetDirtyBits() == ImmediateConstantMask(0b01010101));
 
     device.Destroy();
 }
diff --git a/src/dawn/tests/white_box/ImmediateConstantOffsetTests.cpp b/src/dawn/tests/white_box/ImmediateConstantOffsetTests.cpp
index 898833e..6079122 100644
--- a/src/dawn/tests/white_box/ImmediateConstantOffsetTests.cpp
+++ b/src/dawn/tests/white_box/ImmediateConstantOffsetTests.cpp
@@ -76,7 +76,7 @@
     expectedImmediateConstantMask |= (1u << 5u);
 
     // Check dirty bits are set correctly.
-    EXPECT_TRUE(FromAPI(MakeTestRenderPipelineWithClampingFragDepth().Get())->GetPipelineMask() ==
+    EXPECT_TRUE(FromAPI(MakeTestRenderPipelineWithClampingFragDepth().Get())->GetImmediateMask() ==
                 expectedImmediateConstantMask);
 }