Dawn/Native: Implement SetImmediateData() in Vulkan Backend

This CL implement SetImmediateData() API in vulkan backend. User
could use SetImmediateData() in renderPass/ComputePass to upload
small constants.

Bug:366291600
Change-Id: I64b9b7299aa08aee0c6800f35efa45d363ba4a40
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/227494
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Shaobo Yan <shaoboyan@microsoft.com>
diff --git a/src/dawn/common/Constants.h b/src/dawn/common/Constants.h
index ad5af6a..247a1ab 100644
--- a/src/dawn/common/Constants.h
+++ b/src/dawn/common/Constants.h
@@ -58,6 +58,9 @@
 // Pick 32 here.
 static constexpr uint32_t kMaxImmediateConstantsPerPipeline = 32u;
 
+// Limit user immediate constants to 16 bytes.
+static constexpr uint32_t kMaxImmediateDataBytes = 16u;
+
 // Per stage maximum limits used to optimized Dawn internals.
 static constexpr uint32_t kMaxSampledTexturesPerShaderStage = 16;
 static constexpr uint32_t kMaxSamplersPerShaderStage = 16;
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index d79a65e..0986757 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -1710,6 +1710,10 @@
         mWGSLAllowedFeatures.extensions.insert(
             tint::wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
     }
+    if (mEnabledFeatures.IsEnabled(Feature::ChromiumExperimentalImmediateData)) {
+        mWGSLAllowedFeatures.extensions.insert(
+            tint::wgsl::Extension::kChromiumExperimentalPushConstant);
+    }
 
     // Language features are enabled instance-wide.
     const auto& allowedFeatures = GetInstance()->GetAllowedWGSLLanguageFeatures();
diff --git a/src/dawn/native/ImmediateConstantsTracker.h b/src/dawn/native/ImmediateConstantsTracker.h
index 9fb12ab..adcba55 100644
--- a/src/dawn/native/ImmediateConstantsTracker.h
+++ b/src/dawn/native/ImmediateConstantsTracker.h
@@ -77,13 +77,10 @@
     UserImmediateConstantsTrackerBase() {}
 
     // Setters
-    void SetImmediateData(uint32_t immediateDataRangeOffset, uint32_t* values, uint32_t count) {
-        uint32_t* destData = mContent.template Get<uint32_t>(offsetof(T, userConstants) +
-                                                             immediateDataRangeOffset *
-                                                                 kImmediateConstantElementByteSize);
-        size_t dataSize = count * kImmediateConstantElementByteSize;
-        if (memcmp(destData, values, dataSize) != 0) {
-            memcpy(destData, values, dataSize);
+    void SetImmediateData(uint32_t offset, uint8_t* values, uint32_t size) {
+        uint8_t* destData = mContent.template Get<uint8_t>(offsetof(T, userConstants) + offset);
+        if (memcmp(destData, values, size) != 0) {
+            memcpy(destData, values, size);
             mDirty |= GetImmediateConstantBlockBits(offsetof(T, userConstants),
                                                     sizeof(UserImmediateConstants));
         }
diff --git a/src/dawn/native/Limits.cpp b/src/dawn/native/Limits.cpp
index 39f1105..f5428da 100644
--- a/src/dawn/native/Limits.cpp
+++ b/src/dawn/native/Limits.cpp
@@ -401,4 +401,10 @@
     }
 }
 
+void NormalizeExperimentalLimits(CombinedLimits* limits) {
+    // Enforce immediate data bytes to ensure they don't go over a fixed limit in Dawn's internal
+    // code.
+    limits->experimentalImmediateDataLimits.maxImmediateDataRangeByteSize = kMaxImmediateDataBytes;
+}
+
 }  // namespace dawn::native
diff --git a/src/dawn/native/Limits.h b/src/dawn/native/Limits.h
index a9c9722..383aa50 100644
--- a/src/dawn/native/Limits.h
+++ b/src/dawn/native/Limits.h
@@ -80,6 +80,12 @@
 //      maxUniformBufferBindingSize must not be larger than maxBufferSize.
 void NormalizeLimits(Limits* limits);
 
+// Enforce restriction for experiment limit values, including:
+// 1. Enforce immediate data bytes to ensure they don't go over a fixed limit in Dawn's internal
+//    code.
+// TODO(crbug.com/366291600): Make ApplyLimitTiers and NormalizeLimits accept CombeindLimits.
+void NormalizeExperimentalLimits(CombinedLimits* limits);
+
 }  // namespace dawn::native
 
 #endif  // SRC_DAWN_NATIVE_LIMITS_H_
diff --git a/src/dawn/native/PhysicalDevice.cpp b/src/dawn/native/PhysicalDevice.cpp
index eaccd25..3ceacc9 100644
--- a/src/dawn/native/PhysicalDevice.cpp
+++ b/src/dawn/native/PhysicalDevice.cpp
@@ -69,6 +69,7 @@
         mName, mDriverDescription, mVendorId, mDeviceId, mBackend, mAdapterType);
 
     NormalizeLimits(&mLimits.v1);
+    NormalizeExperimentalLimits(&mLimits);
 
     return {};
 }
diff --git a/src/dawn/native/ProgrammableEncoder.cpp b/src/dawn/native/ProgrammableEncoder.cpp
index 3573edd..a61bca21 100644
--- a/src/dawn/native/ProgrammableEncoder.cpp
+++ b/src/dawn/native/ProgrammableEncoder.cpp
@@ -146,15 +146,17 @@
                     offset, size, maxImmediateDataRangeByteSize);
             }
 
+            // Skip SetImmediateData when uploading constants are empty.
+            if (size == 0) {
+                return {};
+            }
+
             SetImmediateDataCmd* cmd =
                 allocator->Allocate<SetImmediateDataCmd>(Command::SetImmediateData);
             cmd->offset = offset;
             cmd->size = size;
-
-            if (size > 0) {
-                uint8_t* immediateDatas = allocator->AllocateData<uint8_t>(cmd->size);
-                memcpy(immediateDatas, data, size);
-            }
+            uint8_t* immediateDatas = allocator->AllocateData<uint8_t>(cmd->size);
+            memcpy(immediateDatas, data, size);
 
             return {};
         },
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 01e5fd6..0ab88d6 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -1444,6 +1444,14 @@
                         "doesn't use a `pixel local` block.");
     }
 
+    // Validate that immediate data used by programmable state are smaller than pipelineLayout
+    // immediate data range bytes.
+    DAWN_INVALID_IF(entryPoint.immediateDataRangeByteSize > layout->GetImmediateDataRangeByteSize(),
+                    "The entry-point uses more bytes of immediate data (%u) than the reserved "
+                    "amount (%u) in %s.",
+                    entryPoint.immediateDataRangeByteSize, layout->GetImmediateDataRangeByteSize(),
+                    layout);
+
     return {};
 }
 
diff --git a/src/dawn/native/vulkan/CommandBufferVk.cpp b/src/dawn/native/vulkan/CommandBufferVk.cpp
index f66d8a3..eb4a653 100644
--- a/src/dawn/native/vulkan/CommandBufferVk.cpp
+++ b/src/dawn/native/vulkan/CommandBufferVk.cpp
@@ -194,16 +194,18 @@
     uint32_t mInternalImmediateDataSize = 0;
 };
 
-class RenderImmediateConstantTracker : public RenderImmediateConstantsTrackerBase {
+template <typename T>
+class ImmediateConstantTracker : public T {
   public:
-    RenderImmediateConstantTracker() = default;
+    ImmediateConstantTracker() = default;
 
     void Apply(Device* device, VkCommandBuffer commandBuffer) {
-        if (!mLastPipeline) {
+        auto* lastPipeline = this->mLastPipeline;
+        if (!lastPipeline) {
             return;
         }
 
-        const ImmediateConstantMask& pipelineImmediateMask = mLastPipeline->GetImmediateMask();
+        const ImmediateConstantMask& pipelineImmediateMask = lastPipeline->GetImmediateMask();
         const size_t maxImmediateConstantSize =
             pipelineImmediateMask.count() * kImmediateConstantElementByteSize;
 
@@ -211,12 +213,12 @@
         uint32_t immediateContentStartOffset = 0;
         uint32_t immediateDataCount = 0;
 
-        ImmediateConstantMask uploadBits = mDirty & mLastPipeline->GetImmediateMask();
+        ImmediateConstantMask uploadBits = this->mDirty & lastPipeline->GetImmediateMask();
         ImmediateConstantMask prefixBits = ImmediateConstantMask(0u);
 
         // TODO(crbug.com/366291600): Add IterateBitRanges helper function to achieve iteration on
         // ranges.
-        for (ImmediateConstantIndex i : IterateBitSet(mLastPipeline->GetImmediateMask())) {
+        for (ImmediateConstantIndex i : IterateBitSet(lastPipeline->GetImmediateMask())) {
             if (uploadBits.test(i)) {
                 uint32_t index = static_cast<uint32_t>(i);
                 if (immediateDataCount == 0) {
@@ -229,11 +231,11 @@
             } else {
                 if (immediateDataCount > 0) {
                     device->fn.CmdPushConstants(
-                        commandBuffer, ToBackend(mLastPipeline)->GetVkLayout(),
-                        ToBackend(mLastPipeline->GetLayout())->GetImmediateDataRangeStage(),
+                        commandBuffer, ToBackend(lastPipeline)->GetVkLayout(),
+                        ToBackend(lastPipeline->GetLayout())->GetImmediateDataRangeStage(),
                         pushConstantRangeStartOffset,
                         immediateDataCount * kImmediateConstantElementByteSize,
-                        mContent.Get<uint32_t>(immediateContentStartOffset));
+                        this->mContent.template Get<uint32_t>(immediateContentStartOffset));
                     immediateDataCount = 0;
                 }
             }
@@ -243,16 +245,16 @@
         if (immediateDataCount > 0) {
             DAWN_ASSERT(pushConstantRangeStartOffset < maxImmediateConstantSize);
             device->fn.CmdPushConstants(
-                commandBuffer, ToBackend(mLastPipeline)->GetVkLayout(),
-                ToBackend(mLastPipeline->GetLayout())->GetImmediateDataRangeStage(),
+                commandBuffer, ToBackend(lastPipeline)->GetVkLayout(),
+                ToBackend(lastPipeline->GetLayout())->GetImmediateDataRangeStage(),
                 pushConstantRangeStartOffset,
                 immediateDataCount * kImmediateConstantElementByteSize,
-                mContent.Get<uint32_t>(immediateContentStartOffset));
+                this->mContent.template Get<uint32_t>(immediateContentStartOffset));
             immediateDataCount = 0;
         }
 
         // Reset all dirty bits after uploading.
-        mDirty.reset();
+        this->mDirty.reset();
     }
 };
 
@@ -1081,6 +1083,7 @@
 
     uint64_t currentDispatch = 0;
     DescriptorSetTracker descriptorSets = {};
+    ImmediateConstantTracker<ComputeImmediateConstantsTrackerBase> immediates = {};
 
     Command type;
     while (mCommands.NextCommandId(&type)) {
@@ -1105,7 +1108,7 @@
                 DAWN_TRY(TransitionAndClearForSyncScope(
                     device, recordingContext, resourceUsages.dispatchUsages[currentDispatch]));
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_COMPUTE);
-
+                immediates.Apply(device, commands);
                 device->fn.CmdDispatch(commands, dispatch->x, dispatch->y, dispatch->z);
                 currentDispatch++;
                 break;
@@ -1118,7 +1121,7 @@
                 DAWN_TRY(TransitionAndClearForSyncScope(
                     device, recordingContext, resourceUsages.dispatchUsages[currentDispatch]));
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_COMPUTE);
-
+                immediates.Apply(device, commands);
                 device->fn.CmdDispatchIndirect(commands, indirectBuffer,
                                                static_cast<VkDeviceSize>(dispatch->indirectOffset));
                 currentDispatch++;
@@ -1146,6 +1149,7 @@
                 device->fn.CmdBindPipeline(commands, VK_PIPELINE_BIND_POINT_COMPUTE,
                                            pipeline->GetHandle());
                 descriptorSets.OnSetPipeline<ComputePipeline>(pipeline);
+                immediates.OnSetPipeline(pipeline);
                 break;
             }
 
@@ -1207,8 +1211,14 @@
                 break;
             }
 
-            case Command::SetImmediateData:
-                return DAWN_UNIMPLEMENTED_ERROR("SetImmediateData unimplemented");
+            case Command::SetImmediateData: {
+                SetImmediateDataCmd* cmd = mCommands.NextCommand<SetImmediateDataCmd>();
+                DAWN_ASSERT(cmd->size > 0);
+                uint8_t* value = nullptr;
+                value = mCommands.NextData<uint8_t>(cmd->size);
+                immediates.SetImmediateData(cmd->offset, value, cmd->size);
+                break;
+            }
 
             default:
                 DAWN_UNREACHABLE();
@@ -1236,7 +1246,7 @@
 
     DAWN_TRY(RecordBeginRenderPass(recordingContext, device, renderPassCmd));
 
-    RenderImmediateConstantTracker renderImmediateConstantTracker = {};
+    ImmediateConstantTracker<RenderImmediateConstantsTrackerBase> immediates = {};
     // Set the default value for the dynamic state
     {
         device->fn.CmdSetLineWidth(commands, 1.0f);
@@ -1270,7 +1280,7 @@
         device->fn.CmdSetScissor(commands, 0, 1, &scissorRect);
 
         // Apply default frag depth
-        renderImmediateConstantTracker.SetClampFragDepth(0.0, 1.0);
+        immediates.SetClampFragDepth(0.0, 1.0);
     }
 
     DescriptorSetTracker descriptorSets = {};
@@ -1282,7 +1292,7 @@
                 DrawCmd* draw = iter->NextCommand<DrawCmd>();
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                renderImmediateConstantTracker.Apply(device, commands);
+                immediates.Apply(device, commands);
                 device->fn.CmdDraw(commands, draw->vertexCount, draw->instanceCount,
                                    draw->firstVertex, draw->firstInstance);
                 break;
@@ -1292,7 +1302,7 @@
                 DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                renderImmediateConstantTracker.Apply(device, commands);
+                immediates.Apply(device, commands);
                 device->fn.CmdDrawIndexed(commands, draw->indexCount, draw->instanceCount,
                                           draw->firstIndex, draw->baseVertex, draw->firstInstance);
                 break;
@@ -1303,7 +1313,7 @@
                 Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                renderImmediateConstantTracker.Apply(device, commands);
+                immediates.Apply(device, commands);
                 device->fn.CmdDrawIndirect(commands, buffer->GetHandle(),
                                            static_cast<VkDeviceSize>(draw->indirectOffset), 1, 0);
                 break;
@@ -1315,7 +1325,7 @@
                 DAWN_ASSERT(buffer != nullptr);
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                renderImmediateConstantTracker.Apply(device, commands);
+                immediates.Apply(device, commands);
                 device->fn.CmdDrawIndexedIndirect(commands, buffer->GetHandle(),
                                                   static_cast<VkDeviceSize>(draw->indirectOffset),
                                                   1, 0);
@@ -1332,7 +1342,7 @@
                 Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                renderImmediateConstantTracker.Apply(device, commands);
+                immediates.Apply(device, commands);
 
                 if (countBuffer == nullptr) {
                     device->fn.CmdDrawIndirect(commands, indirectBuffer->GetHandle(),
@@ -1357,7 +1367,7 @@
                 Buffer* countBuffer = ToBackend(cmd->drawCountBuffer.Get());
 
                 descriptorSets.Apply(device, recordingContext, VK_PIPELINE_BIND_POINT_GRAPHICS);
-                renderImmediateConstantTracker.Apply(device, commands);
+                immediates.Apply(device, commands);
 
                 if (countBuffer == nullptr) {
                     device->fn.CmdDrawIndexedIndirect(
@@ -1456,7 +1466,7 @@
                 lastPipeline = pipeline;
 
                 descriptorSets.OnSetPipeline<RenderPipeline>(pipeline);
-                renderImmediateConstantTracker.OnSetPipeline(pipeline);
+                immediates.OnSetPipeline(pipeline);
                 break;
             }
 
@@ -1470,6 +1480,15 @@
                 break;
             }
 
+            case Command::SetImmediateData: {
+                SetImmediateDataCmd* cmd = mCommands.NextCommand<SetImmediateDataCmd>();
+                DAWN_ASSERT(cmd->size > 0);
+                uint8_t* value = nullptr;
+                value = mCommands.NextData<uint8_t>(cmd->size);
+                immediates.SetImmediateData(cmd->offset, value, cmd->size);
+                break;
+            }
+
             default:
                 DAWN_UNREACHABLE();
                 break;
@@ -1536,8 +1555,7 @@
 
                 // Try applying the push constants that contain min/maxDepth immediately. This can
                 // be deferred if no pipeline is currently bound.
-                renderImmediateConstantTracker.SetClampFragDepth(viewport.minDepth,
-                                                                 viewport.maxDepth);
+                immediates.SetClampFragDepth(viewport.minDepth, viewport.maxDepth);
                 break;
             }
 
@@ -1591,9 +1609,6 @@
                 break;
             }
 
-            case Command::SetImmediateData:
-                return DAWN_UNIMPLEMENTED_ERROR("SetImmediateData unimplemented");
-
             default: {
                 EncodeRenderBundleCommand(&mCommands, type);
                 break;
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
index 8f480e9..158f107 100644
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
@@ -36,6 +36,7 @@
 #include "dawn/common/GPUInfo.h"
 #include "dawn/native/ChainUtils.h"
 #include "dawn/native/Error.h"
+#include "dawn/native/ImmediateConstantsLayout.h"
 #include "dawn/native/Instance.h"
 #include "dawn/native/Limits.h"
 #include "dawn/native/vulkan/BackendVk.h"
@@ -499,6 +500,8 @@
     if (mDeviceInfo.HasExt(DeviceExt::ImageDrmFormatModifier)) {
         EnableFeature(Feature::DawnDrmFormatCapabilities);
     }
+
+    EnableFeature(Feature::ChromiumExperimentalImmediateData);
 }
 
 MaybeError PhysicalDevice::InitializeSupportedLimitsImpl(CombinedLimits* limits) {
@@ -695,6 +698,13 @@
     limits->experimentalSubgroupLimits.maxSubgroupSize =
         mDeviceInfo.subgroupSizeControlProperties.maxSubgroupSize;
 
+    // vulkan needs to have enough push constant range size for all
+    // internal and external immediate data usages.
+    constexpr uint32_t kMinVulkanPushConstants = 128;
+    DAWN_ASSERT(vkLimits.maxPushConstantsSize >= kMinVulkanPushConstants);
+    static_assert(kMinVulkanPushConstants >= sizeof(RenderImmediateConstants));
+    static_assert(kMinVulkanPushConstants >= sizeof(ComputeImmediateConstants));
+
     return {};
 }
 
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 4cac367..0f8c178 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -630,6 +630,7 @@
     "end2end/FramebufferFetchTests.cpp",
     "end2end/GpuMemorySynchronizationTests.cpp",
     "end2end/HistogramTests.cpp",
+    "end2end/ImmediateDataTests.cpp",
     "end2end/IndexFormatTests.cpp",
     "end2end/InfiniteLoopTests.cpp",
     "end2end/MaxLimitTests.cpp",
diff --git a/src/dawn/tests/end2end/ImmediateDataTests.cpp b/src/dawn/tests/end2end/ImmediateDataTests.cpp
new file mode 100644
index 0000000..93bc4d1
--- /dev/null
+++ b/src/dawn/tests/end2end/ImmediateDataTests.cpp
@@ -0,0 +1,490 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <array>
+#include <limits>
+#include <vector>
+
+#include "dawn/tests/DawnTest.h"
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
+#include "dawn/utils/WGPUHelpers.h"
+
+namespace dawn {
+namespace {
+
+constexpr uint32_t kRTSize = 1;
+
+class ImmediateDataTests : public DawnTest {
+  protected:
+    std::vector<wgpu::FeatureName> GetRequiredFeatures() override {
+        return {wgpu::FeatureName::ChromiumExperimentalImmediateData};
+    }
+
+    void SetUp() override {
+        DawnTest::SetUp();
+
+        mShaderModule = utils::CreateShaderModule(device, R"(
+            enable chromium_experimental_push_constant;
+            struct PushConstant {
+                color: vec3<f32>,
+                colorDiff: f32,
+            };
+            var<push_constant> constants: PushConstant;
+            struct VertexOut {
+                @location(0) color : vec3f,
+                @builtin(position) position : vec4f,
+            }
+
+            @vertex fn vsMain(@builtin(vertex_index) VertexIndex : u32) -> VertexOut {
+                const pos = array(
+                    vec2( 1.0, -1.0),
+                    vec2(-1.0, -1.0),
+                    vec2( 0.0,  1.0),
+                );
+                var output: VertexOut;
+                output.position = vec4f(pos[VertexIndex], 0.0, 1.0);
+                output.color = constants.color;
+                return output;
+            }
+
+            // to reuse the same pipeline layout
+            @fragment fn fsMain(@location(0) color:vec3f) -> @location(0) vec4f {
+                return vec4f(color + vec3f(constants.colorDiff), 1.0);
+            }
+
+            var<push_constant> computeConstants: vec4u;
+            @group(0) @binding(0) var<storage, read_write> output : vec4u;
+
+            @compute @workgroup_size(1, 1, 1)
+            fn csMain() {
+                output = computeConstants;
+            })");
+
+        wgpu::BufferDescriptor bufferDesc;
+        bufferDesc.size = sizeof(uint32_t) * 4;
+        bufferDesc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage;
+        mStorageBuffer = device.CreateBuffer(&bufferDesc);
+    }
+
+    wgpu::BindGroupLayout CreateBindGroupLayout() {
+        wgpu::BindGroupLayoutEntry entries[1];
+        entries[0].binding = 0;
+        entries[0].visibility = wgpu::ShaderStage::Compute;
+        entries[0].buffer.type = wgpu::BufferBindingType::Storage;
+
+        wgpu::BindGroupLayoutDescriptor bindGroupLayoutDesc;
+        bindGroupLayoutDesc.entryCount = 1;
+        bindGroupLayoutDesc.entries = entries;
+
+        return device.CreateBindGroupLayout(&bindGroupLayoutDesc);
+    }
+
+    wgpu::PipelineLayout CreatePipelineLayout() {
+        wgpu::BindGroupLayout bindGroupLayout = CreateBindGroupLayout();
+
+        wgpu::PipelineLayoutDescriptor pipelineLayoutDesc;
+        pipelineLayoutDesc.bindGroupLayoutCount = 1;
+        pipelineLayoutDesc.bindGroupLayouts = &bindGroupLayout;
+        pipelineLayoutDesc.immediateDataRangeByteSize = kMaxImmediateDataBytes;
+        return device.CreatePipelineLayout(&pipelineLayoutDesc);
+    }
+
+    wgpu::RenderPipeline CreateRenderPipeline() {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = mShaderModule;
+        pipelineDescriptor.cFragment.module = mShaderModule;
+        pipelineDescriptor.cFragment.targetCount = 1;
+        pipelineDescriptor.layout = CreatePipelineLayout();
+
+        return device.CreateRenderPipeline(&pipelineDescriptor);
+    }
+
+    wgpu::ComputePipeline CreateComputePipeline() {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = mShaderModule;
+        csDesc.layout = CreatePipelineLayout();
+
+        return device.CreateComputePipeline(&csDesc);
+    }
+
+    wgpu::BindGroup CreateBindGroup() {
+        return utils::MakeBindGroup(device, CreateBindGroupLayout(), {{0, mStorageBuffer}});
+    }
+
+    wgpu::ShaderModule mShaderModule;
+    wgpu::Buffer mStorageBuffer;
+};
+
+// ImmediateData has been uploaded successfully.
+TEST_P(ImmediateDataTests, BasicRenderPipeline) {
+    wgpu::RenderPipeline pipeline = CreateRenderPipeline();
+    utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+    // rgba8unorm: {0.1, 0.3, 0.5} + {0.1 diff} => {0.2, 0.4, 0.6} => {51, 102, 153, 255}
+    std::array<float, 4> immediateData = {0.1, 0.3, 0.5, 0.1};
+    wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+    wgpu::RenderPassEncoder renderPassEncoder =
+        commandEncoder.BeginRenderPass(&renderPass.renderPassInfo);
+    renderPassEncoder.SetImmediateData(0, immediateData.data(),
+                                       immediateData.size() * sizeof(uint32_t));
+    renderPassEncoder.SetPipeline(CreateRenderPipeline());
+    renderPassEncoder.SetBindGroup(0, CreateBindGroup());
+    renderPassEncoder.Draw(3);
+    renderPassEncoder.End();
+    wgpu::CommandBuffer commands = commandEncoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8(51, 102, 153, 255), renderPass.color, 0, 0);
+}
+
+// ImmediateData has been uploaded successfully.
+TEST_P(ImmediateDataTests, BasicComputePipeline) {
+    std::array<uint32_t, 4> immediateData = {25, 128, 240, 255};
+    wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+    wgpu::ComputePassEncoder computePassEncoder = commandEncoder.BeginComputePass();
+    computePassEncoder.SetPipeline(CreateComputePipeline());
+    computePassEncoder.SetImmediateData(0, immediateData.data(),
+                                        immediateData.size() * sizeof(uint32_t));
+    computePassEncoder.SetBindGroup(0, CreateBindGroup());
+    computePassEncoder.DispatchWorkgroups(1);
+    computePassEncoder.End();
+    wgpu::CommandBuffer commands = commandEncoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_BUFFER_U32_RANGE_EQ(immediateData.data(), mStorageBuffer, 0, immediateData.size());
+}
+
+// ImmediateData range should be initialized to 0.
+TEST_P(ImmediateDataTests, ImmediateDataInitialization) {
+    // Render pipeline
+    {
+        wgpu::RenderPipeline pipeline = CreateRenderPipeline();
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        // rgba8unorm: {0.0, 0.4, 0.6} + {0.0 diff} => {0.0, 0.4, 0.6} => {0, 102, 153, 255}
+        std::array<float, 2> immediateData = {0.4, 0.6};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::RenderPassEncoder renderPassEncoder =
+            commandEncoder.BeginRenderPass(&renderPass.renderPassInfo);
+        renderPassEncoder.SetImmediateData(4, immediateData.data(), 8);
+        renderPassEncoder.SetPipeline(CreateRenderPipeline());
+        renderPassEncoder.SetBindGroup(0, CreateBindGroup());
+        renderPassEncoder.Draw(3);
+        renderPassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8(0, 102, 153, 255), renderPass.color, 0, 0);
+    }
+
+    // Compute Pipeline
+    {
+        std::array<uint32_t, 2> immediateData = {128, 240};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder computePassEncoder = commandEncoder.BeginComputePass();
+        computePassEncoder.SetPipeline(CreateComputePipeline());
+        computePassEncoder.SetImmediateData(4, immediateData.data(), 8);
+        computePassEncoder.SetBindGroup(0, CreateBindGroup());
+        computePassEncoder.DispatchWorkgroups(1);
+        computePassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        std::array<uint32_t, 4> expected = {0, 128, 240, 0};
+        EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), mStorageBuffer, 0, immediateData.size());
+    }
+}
+
+// SetImmediateData with offset on immediate data range.
+TEST_P(ImmediateDataTests, SetImmediateDataWithRangeOffset) {
+    constexpr uint32_t kHalfImmediateDataSize = 8;
+    // Render Pipeline
+    {
+        wgpu::RenderPipeline pipeline = CreateRenderPipeline();
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        // rgba8unorm: {0.1, 0.3, 0.5} + {0.1 diff} => {0.2, 0.4, 0.6} => {51, 102, 153, 255}
+        std::array<float, 4> immediateData = {0.1, 0.3, 0.5, 0.1};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::RenderPassEncoder renderPassEncoder =
+            commandEncoder.BeginRenderPass(&renderPass.renderPassInfo);
+        renderPassEncoder.SetImmediateData(0, immediateData.data(), 16);
+        // Update {0.1, 0.3, 0.5} to {0.1,0.5,0.7} and + {0.1 diff} => {0.2, 0.6, 0.8} => {51,
+        // 153, 204, 255}
+        std::array<float, 2> immediateDataUpdated = {0.5, 0.7};
+        renderPassEncoder.SetImmediateData(4, immediateDataUpdated.data(), 8);
+        renderPassEncoder.SetPipeline(CreateRenderPipeline());
+        renderPassEncoder.SetBindGroup(0, CreateBindGroup());
+        renderPassEncoder.Draw(3);
+        renderPassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8(51, 153, 204, 255), renderPass.color, 0, 0);
+    }
+
+    // Compute Pipeline
+    {
+        std::array<uint32_t, 4> immediateData = {25, 128, 240, 255};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder computePassEncoder = commandEncoder.BeginComputePass();
+        computePassEncoder.SetPipeline(CreateComputePipeline());
+        // Using two SetImmediateData + Offset to swap first half and second half value in immediate
+        // data range.
+        computePassEncoder.SetImmediateData(kHalfImmediateDataSize, immediateData.data(),
+                                            kHalfImmediateDataSize);
+        computePassEncoder.SetImmediateData(0, immediateData.data() + 2, kHalfImmediateDataSize);
+        computePassEncoder.SetBindGroup(0, CreateBindGroup());
+        computePassEncoder.DispatchWorkgroups(1);
+        computePassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        std::array<uint32_t, 4> expected = {240, 255, 25, 128};
+        EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), mStorageBuffer, 0, expected.size());
+    }
+}
+
+// SetImmediateData should upload dirtied, latest contents between pipeline switches before draw or
+// dispatch.
+TEST_P(ImmediateDataTests, SetImmediateDataMultipleTimes) {
+    // Render Pipeline
+    {
+        wgpu::RenderPipeline pipeline = CreateRenderPipeline();
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        // rgba8unorm: {0.1, 0.3, 0.5} + {0.1 diff} => {0.2, 0.4, 0.6} => {51, 102, 153, 255}
+        std::array<float, 4> immediateData = {0.1, 0.3, 0.5, 0.1};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::RenderPassEncoder renderPassEncoder =
+            commandEncoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+        // Using 4 SetImmediateData to update all immediate data to 0.1.
+        renderPassEncoder.SetImmediateData(0, immediateData.data(), immediateData.size() * 4);
+        renderPassEncoder.SetImmediateData(4, immediateData.data(), (immediateData.size() - 1) * 4);
+        renderPassEncoder.SetPipeline(CreateRenderPipeline());
+        renderPassEncoder.SetImmediateData(8, immediateData.data(), 8);
+        renderPassEncoder.SetPipeline(CreateRenderPipeline());
+        renderPassEncoder.SetImmediateData(12, immediateData.data(), 4);
+        renderPassEncoder.SetBindGroup(0, CreateBindGroup());
+        renderPassEncoder.Draw(3);
+        renderPassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8(51, 51, 51, 255), renderPass.color, 0, 0);
+    }
+
+    // Compute Pipeline
+    {
+        std::array<uint32_t, 4> immediateData = {25, 128, 240, 255};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder computePassEncoder = commandEncoder.BeginComputePass();
+
+        // Using 4 SetImmediateData to update all immediate data to 25.
+        computePassEncoder.SetImmediateData(0, immediateData.data(), immediateData.size() * 4);
+        computePassEncoder.SetImmediateData(4, immediateData.data(),
+                                            (immediateData.size() - 1) * 4);
+        computePassEncoder.SetPipeline(CreateComputePipeline());
+        computePassEncoder.SetImmediateData(8, immediateData.data(), 8);
+        computePassEncoder.SetPipeline(CreateComputePipeline());
+        computePassEncoder.SetImmediateData(12, immediateData.data(), 4);
+
+        computePassEncoder.SetBindGroup(0, CreateBindGroup());
+        computePassEncoder.DispatchWorkgroups(1);
+        computePassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        std::array<uint32_t, 4> expected = {25, 25, 25, 25};
+        EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), mStorageBuffer, 0, expected.size());
+    }
+}
+
+// Test that clamp frag depth(supported by internal immediate constants)
+// works fine when shaders have user immediate data
+TEST_P(ImmediateDataTests, UsingImmediateDataDontAffectClampFragDepth) {
+    wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+        enable chromium_experimental_push_constant;
+        var<push_constant> constants: vec4f;
+        @vertex fn vs() -> @builtin(position) vec4f {
+            return vec4f(0.0, 0.0, 0.5, 1.0);
+        }
+
+        @fragment fn fs() -> @builtin(frag_depth) f32 {
+            return constants.r;
+        }
+    )");
+
+    // Create the pipeline that uses frag_depth to output the depth.
+    utils::ComboRenderPipelineDescriptor pDesc;
+    pDesc.vertex.module = module;
+    pDesc.primitive.topology = wgpu::PrimitiveTopology::PointList;
+    pDesc.cFragment.module = module;
+    pDesc.cFragment.targetCount = 0;
+
+    wgpu::DepthStencilState* pDescDS = pDesc.EnableDepthStencil(wgpu::TextureFormat::Depth32Float);
+    pDescDS->depthWriteEnabled = wgpu::OptionalBool::True;
+    pDescDS->depthCompare = wgpu::CompareFunction::Always;
+    wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&pDesc);
+
+    // Create a depth-only render pass.
+    wgpu::TextureDescriptor depthDesc;
+    depthDesc.size = {1, 1};
+    depthDesc.usage = wgpu::TextureUsage::RenderAttachment | wgpu::TextureUsage::CopySrc;
+    depthDesc.format = wgpu::TextureFormat::Depth32Float;
+    wgpu::Texture depthTexture = device.CreateTexture(&depthDesc);
+
+    std::array<float, 4> immediateData = {1.0, 1.0, 1.0, 1.0};
+
+    utils::ComboRenderPassDescriptor renderPassDesc({}, depthTexture.CreateView());
+    renderPassDesc.cDepthStencilAttachmentInfo.stencilLoadOp = wgpu::LoadOp::Undefined;
+    renderPassDesc.cDepthStencilAttachmentInfo.stencilStoreOp = wgpu::StoreOp::Undefined;
+
+    // Draw a point with a skewed viewport, so 1.0 depth gets clamped to 0.5.
+    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+    wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPassDesc);
+    pass.SetViewport(0, 0, 1, 1, 0.0, 0.5);
+    pass.SetImmediateData(0, immediateData.data(), immediateData.size() * 4);
+    pass.SetPipeline(pipeline);
+    pass.Draw(1);
+    pass.End();
+
+    wgpu::CommandBuffer commands = encoder.Finish();
+    queue.Submit(1, &commands);
+
+    EXPECT_PIXEL_FLOAT_EQ(0.5f, depthTexture, 0, 0);
+}
+
+// SetImmediateData Multiple times should upload dirtied, latest contents.
+TEST_P(ImmediateDataTests, SetImmediateDataWithPipelineSwitch) {
+    wgpu::ShaderModule shaderModuleWithLessImmediateData = utils::CreateShaderModule(device, R"(
+        enable chromium_experimental_push_constant;
+        struct PushConstant {
+            color: vec3<f32>,
+        };
+        var<push_constant> constants: PushConstant;
+        struct VertexOut {
+            @location(0) color : vec3f,
+            @builtin(position) position : vec4f,
+        }
+
+        @vertex fn vsMain(@builtin(vertex_index) VertexIndex : u32) -> VertexOut {
+            const pos = array(
+                vec2( 1.0, -1.0),
+                vec2(-1.0, -1.0),
+                vec2( 0.0,  1.0),
+            );
+            var output: VertexOut;
+            output.position = vec4f(pos[VertexIndex], 0.0, 1.0);
+            output.color = constants.color;
+            return output;
+        }
+
+        // to reuse the same pipeline layout
+        @fragment fn fsMain(@location(0) color:vec3f) -> @location(0) vec4f {
+            return vec4f(color, 1.0);
+        }
+
+        var<push_constant> computeConstants: vec3u;
+        @group(0) @binding(0) var<storage, read_write> output : vec3u;
+
+        @compute @workgroup_size(1, 1, 1)
+        fn csMain() {
+            output = computeConstants;
+        })");
+
+    // Render Pipeline
+    {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = shaderModuleWithLessImmediateData;
+        pipelineDescriptor.cFragment.module = shaderModuleWithLessImmediateData;
+        pipelineDescriptor.cFragment.targetCount = 1;
+
+        wgpu::RenderPipeline pipelineWithLessImmediateData =
+            device.CreateRenderPipeline(&pipelineDescriptor);
+
+        wgpu::RenderPipeline pipeline = CreateRenderPipeline();
+        utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::RenderPassEncoder renderPassEncoder =
+            commandEncoder.BeginRenderPass(&renderPass.renderPassInfo);
+
+        // rgba8unorm: {0.2, 0.4, 0.6} + {0.1 diff} => {0.3, 0.5, 0.7}
+        std::array<float, 4> immediateData = {0.2, 0.4, 0.6, 0.1};
+        renderPassEncoder.SetImmediateData(0, immediateData.data(), immediateData.size() * 4);
+        renderPassEncoder.SetPipeline(CreateRenderPipeline());
+
+        // replace the pipeline and rgba8unorm: {0.4, 0.4, 0.6} => {102, 102, 153}
+        float data = 0.4;
+        renderPassEncoder.SetImmediateData(0, &data, 4);
+        renderPassEncoder.SetPipeline(pipelineWithLessImmediateData);
+        renderPassEncoder.Draw(3);
+        renderPassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(utils::RGBA8(102, 102, 153, 255), renderPass.color, 0, 0);
+    }
+
+    // Compute Pipeline
+    {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = shaderModuleWithLessImmediateData;
+
+        wgpu::ComputePipeline pipelineWithLessImmediateData = device.CreateComputePipeline(&csDesc);
+
+        wgpu::BindGroup bindGroup = utils::MakeBindGroup(
+            device, pipelineWithLessImmediateData.GetBindGroupLayout(0), {{0, mStorageBuffer}});
+
+        std::array<uint32_t, 4> immediateData = {25, 128, 240, 255};
+        wgpu::CommandEncoder commandEncoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder computePassEncoder = commandEncoder.BeginComputePass();
+
+        computePassEncoder.SetImmediateData(0, immediateData.data(), immediateData.size() * 4);
+        computePassEncoder.SetPipeline(CreateComputePipeline());
+
+        uint32_t data = 128;
+        computePassEncoder.SetImmediateData(0, &data, 4);
+        computePassEncoder.SetPipeline(pipelineWithLessImmediateData);
+
+        computePassEncoder.SetBindGroup(0, bindGroup);
+        computePassEncoder.DispatchWorkgroups(1);
+        computePassEncoder.End();
+        wgpu::CommandBuffer commands = commandEncoder.Finish();
+        queue.Submit(1, &commands);
+
+        std::array<uint32_t, 3> expected = {128, 128, 240};
+        EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), mStorageBuffer, 0, expected.size());
+    }
+}
+
+DAWN_INSTANTIATE_TEST(ImmediateDataTests, VulkanBackend());
+
+}  // anonymous namespace
+}  // namespace dawn
diff --git a/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp b/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp
index aa6982c..6588afb 100644
--- a/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp
+++ b/src/dawn/tests/unittests/native/ImmediateConstantsTrackerTests.cpp
@@ -88,9 +88,9 @@
 
 // Test immediate setting update dirty bits and contents correctly.
 TEST_F(ImmediateConstantsTrackerTest, SetImmediateData) {
-    static constexpr uint32_t rangeOffset = 1u;
+    static constexpr uint32_t rangeOffset = 1u * kImmediateConstantElementByteSize;
     static constexpr uint32_t dataOffset = 2u;
-    static constexpr uint32_t userImmediateDataCount = 2u;
+    static constexpr uint32_t userImmediateDataSize = 2u * kImmediateConstantElementByteSize;
     ImmediateConstantMask expected =
         GetImmediateConstantBlockBits(0u, sizeof(UserImmediateConstants));
 
@@ -100,15 +100,13 @@
         RenderImmediateConstantsTrackerBase tracker;
         int32_t userImmediateData[] = {2, 4, -6, 8};
         tracker.SetImmediateData(rangeOffset,
-                                 reinterpret_cast<uint32_t*>(&userImmediateData[dataOffset]),
-                                 userImmediateDataCount);
+                                 reinterpret_cast<uint8_t*>(&userImmediateData[dataOffset]),
+                                 userImmediateDataSize);
         EXPECT_TRUE(tracker.GetDirtyBits() == expected);
 
-        uint32_t userImmediateDataRangeOffset =
-            userImmediateDataStartByteOffset + rangeOffset * kImmediateConstantElementByteSize;
+        uint32_t userImmediateDataRangeOffset = userImmediateDataStartByteOffset + rangeOffset;
         EXPECT_TRUE(memcmp(tracker.GetContent().Get<int32_t>(userImmediateDataRangeOffset),
-                           &userImmediateData[dataOffset],
-                           sizeof(int32_t) * userImmediateDataCount) == 0);
+                           &userImmediateData[dataOffset], userImmediateDataSize) == 0);
     }
 
     // ComputeImmediateConstantsTracker
@@ -116,15 +114,13 @@
         ComputeImmediateConstantsTrackerBase tracker;
         int32_t userImmediateData[] = {2, 4, -6, 8};
         tracker.SetImmediateData(rangeOffset,
-                                 reinterpret_cast<uint32_t*>(&userImmediateData[dataOffset]),
-                                 userImmediateDataCount);
+                                 reinterpret_cast<uint8_t*>(&userImmediateData[dataOffset]),
+                                 userImmediateDataSize);
         EXPECT_TRUE(tracker.GetDirtyBits() == expected);
 
-        uint32_t userImmediateDataRangeOffset =
-            userImmediateDataStartByteOffset + rangeOffset * kImmediateConstantElementByteSize;
+        uint32_t userImmediateDataRangeOffset = userImmediateDataStartByteOffset + rangeOffset;
         EXPECT_TRUE(memcmp(tracker.GetContent().Get<int32_t>(userImmediateDataRangeOffset),
-                           &userImmediateData[dataOffset],
-                           sizeof(int32_t) * userImmediateDataCount) == 0);
+                           &userImmediateData[dataOffset], userImmediateDataSize) == 0);
     }
 
     device.Destroy();
diff --git a/src/dawn/tests/unittests/validation/ImmediateDataTests.cpp b/src/dawn/tests/unittests/validation/ImmediateDataTests.cpp
index f710c67..407d346 100644
--- a/src/dawn/tests/unittests/validation/ImmediateDataTests.cpp
+++ b/src/dawn/tests/unittests/validation/ImmediateDataTests.cpp
@@ -32,6 +32,7 @@
 
 #include "dawn/common/NonMovable.h"
 #include "dawn/tests/unittests/validation/ValidationTest.h"
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
 #include "dawn/utils/WGPUHelpers.h"
 
 namespace dawn {
@@ -57,30 +58,30 @@
         return {wgpu::FeatureName::ChromiumExperimentalImmediateData};
     }
 
-    uint32_t GetMaxImmediateDataRangeByteSize() {
-        if (maxImmediateDataByteSize != std::numeric_limits<uint32_t>::max()) {
-            return maxImmediateDataByteSize;
-        }
-        wgpu::Limits supportedLimits = {};
-        wgpu::DawnExperimentalImmediateDataLimits immediateDataLimits = {};
-        supportedLimits.nextInChain = &immediateDataLimits;
-        device.GetLimits(&supportedLimits);
-        for (auto* chain = supportedLimits.nextInChain; chain; chain = chain->nextInChain) {
-            switch (chain->sType) {
-                case (wgpu::SType::DawnExperimentalImmediateDataLimits): {
-                    auto* t = static_cast<wgpu::DawnExperimentalImmediateDataLimits*>(
-                        supportedLimits.nextInChain);
-                    maxImmediateDataByteSize = t->maxImmediateDataRangeByteSize;
-                    break;
-                }
-                default:
-                    DAWN_UNREACHABLE();
-            }
-        }
-        return maxImmediateDataByteSize;
+    wgpu::BindGroupLayout CreateBindGroupLayout() {
+        wgpu::BindGroupLayoutEntry entries[1];
+        entries[0].binding = 0;
+        entries[0].visibility = wgpu::ShaderStage::Compute;
+        entries[0].buffer.type = wgpu::BufferBindingType::Storage;
+
+        wgpu::BindGroupLayoutDescriptor bindGroupLayoutDesc;
+        bindGroupLayoutDesc.entryCount = 1;
+        bindGroupLayoutDesc.entries = entries;
+
+        return device.CreateBindGroupLayout(&bindGroupLayoutDesc);
     }
 
-    uint32_t maxImmediateDataByteSize = std::numeric_limits<uint32_t>::max();
+    wgpu::PipelineLayout CreatePipelineLayout(uint32_t requiredImmediateDataRangeByteSize) {
+        wgpu::BindGroupLayout bindGroupLayout = CreateBindGroupLayout();
+
+        wgpu::PipelineLayoutDescriptor pipelineLayoutDesc;
+        pipelineLayoutDesc.bindGroupLayoutCount = 1;
+        pipelineLayoutDesc.bindGroupLayouts = &bindGroupLayout;
+        pipelineLayoutDesc.immediateDataRangeByteSize = requiredImmediateDataRangeByteSize;
+        return device.CreatePipelineLayout(&pipelineLayoutDesc);
+    }
+
+    wgpu::ShaderModule mShaderModule;
 };
 
 // Check that non-zero immediateDataRangeByteSize is possible with feature enabled and size must
@@ -91,16 +92,15 @@
     wgpu::PipelineLayoutDescriptor desc;
     desc.bindGroupLayoutCount = 0;
 
-    uint32_t maxImmediateDataRangeByteSize = GetMaxImmediateDataRangeByteSize();
     // Success case with valid immediateDataRangeByteSize.
     {
-        desc.immediateDataRangeByteSize = maxImmediateDataRangeByteSize;
+        desc.immediateDataRangeByteSize = kMaxImmediateDataBytes;
         device.CreatePipelineLayout(&desc);
     }
 
     // Failed case with invalid immediateDataRangeByteSize that exceed limits.
     {
-        desc.immediateDataRangeByteSize = maxImmediateDataRangeByteSize + 1;
+        desc.immediateDataRangeByteSize = kMaxImmediateDataBytes + 1;
         ASSERT_DEVICE_ERROR(device.CreatePipelineLayout(&desc));
     }
 }
@@ -151,14 +151,12 @@
 TEST_F(ImmediateDataTest, ValidateSetImmediateDataOOB) {
     DAWN_SKIP_TEST_IF(!device.HasFeature(wgpu::FeatureName::ChromiumExperimentalImmediateData));
 
-    uint32_t maxImmediateDataRangeByteSize = GetMaxImmediateDataRangeByteSize();
-
     // Success cases
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-        std::vector<uint32_t> data(maxImmediateDataRangeByteSize / 4, 0);
+        std::vector<uint32_t> data(kMaxImmediateDataBytes / 4, 0);
         wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
-        computePass.SetImmediateData(0, data.data(), maxImmediateDataRangeByteSize);
+        computePass.SetImmediateData(0, data.data(), kMaxImmediateDataBytes);
         computePass.End();
         encoder.Finish();
     }
@@ -166,7 +164,7 @@
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
         wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
-        computePass.SetImmediateData(maxImmediateDataRangeByteSize, nullptr, 0);
+        computePass.SetImmediateData(kMaxImmediateDataBytes, nullptr, 0);
         computePass.End();
         encoder.Finish();
     }
@@ -175,7 +173,7 @@
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
         uint32_t data = 0;
         wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
-        computePass.SetImmediateData(maxImmediateDataRangeByteSize - 4, &data, 4);
+        computePass.SetImmediateData(kMaxImmediateDataBytes - 4, &data, 4);
         computePass.End();
         encoder.Finish();
     }
@@ -183,7 +181,7 @@
     // Failed case with offset oob
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-        uint32_t offset = maxImmediateDataRangeByteSize + 4;
+        uint32_t offset = kMaxImmediateDataBytes + 4;
         wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
         computePass.SetImmediateData(offset, nullptr, 0);
         computePass.End();
@@ -193,7 +191,7 @@
     // Failed cases with size oob
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-        uint32_t size = maxImmediateDataRangeByteSize + 4;
+        uint32_t size = kMaxImmediateDataBytes + 4;
         std::vector<uint32_t> data(size / 4, 0);
         wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
         computePass.SetImmediateData(0, data.data(), size);
@@ -204,13 +202,205 @@
     // Failed cases with offset + size oob
     {
         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-        uint32_t offset = maxImmediateDataRangeByteSize;
+        uint32_t offset = kMaxImmediateDataBytes;
         uint32_t data[] = {0};
         wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
         computePass.SetImmediateData(offset, data, 4);
         computePass.End();
         ASSERT_DEVICE_ERROR(encoder.Finish());
     }
+
+    // Failed case with super large offset + size oob but looping back to zero
+    {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        uint32_t offset = std::numeric_limits<uint32_t>::max() - 3;
+        uint32_t data[] = {0};
+        wgpu::ComputePassEncoder computePass = encoder.BeginComputePass();
+        computePass.SetImmediateData(offset, data, 4);
+        computePass.End();
+        ASSERT_DEVICE_ERROR(encoder.Finish());
+    }
+}
+
+// Check that pipelineLayout immediate data bytes compatible with shaders.
+TEST_F(ImmediateDataTest, ValidatePipelineLayoutImmediateDataBytesAndShaders) {
+    DAWN_SKIP_TEST_IF(!device.HasFeature(wgpu::FeatureName::ChromiumExperimentalImmediateData));
+    constexpr uint32_t kShaderImmediateDataBytes = 12u;
+    wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, R"(
+        enable chromium_experimental_push_constant;
+        var<push_constant> fragmentConstants: vec3f;
+        var<push_constant> computeConstants: vec3u;
+        @vertex fn vsMain(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4f {
+            const pos = array(
+                vec2( 1.0, -1.0),
+                vec2(-1.0, -1.0),
+                vec2( 0.0,  1.0),
+            );
+            return vec4(pos[VertexIndex], 0.0, 1.0);
+        }
+
+        // to reuse the same pipeline layout
+        @fragment fn fsMain() -> @location(0) vec4f {
+            return vec4f(fragmentConstants, 1.0);
+        }
+
+
+        @group(0) @binding(0) var<storage, read_write> output : vec3u;
+
+        @compute @workgroup_size(1, 1, 1)
+        fn csMain() {
+            output = computeConstants;
+        })");
+
+    // Success cases
+    {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = shaderModule;
+        pipelineDescriptor.cFragment.module = shaderModule;
+        pipelineDescriptor.cFragment.targetCount = 1;
+        pipelineDescriptor.layout = CreatePipelineLayout(kShaderImmediateDataBytes);
+        device.CreateRenderPipeline(&pipelineDescriptor);
+    }
+
+    {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = shaderModule;
+        csDesc.layout = CreatePipelineLayout(kShaderImmediateDataBytes);
+
+        device.CreateComputePipeline(&csDesc);
+    }
+
+    // Default layout
+    {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = shaderModule;
+        pipelineDescriptor.cFragment.module = shaderModule;
+        pipelineDescriptor.cFragment.targetCount = 1;
+        device.CreateRenderPipeline(&pipelineDescriptor);
+    }
+
+    {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = shaderModule;
+
+        device.CreateComputePipeline(&csDesc);
+    }
+
+    // Failed case with fragment shader requires more immediate data.
+    {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = shaderModule;
+        pipelineDescriptor.cFragment.module = shaderModule;
+        pipelineDescriptor.cFragment.targetCount = 1;
+        pipelineDescriptor.layout = CreatePipelineLayout(kShaderImmediateDataBytes - 4);
+        ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&pipelineDescriptor));
+    }
+
+    // Failed cases with compute shader requires more immediate data.
+    {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = shaderModule;
+        csDesc.layout = CreatePipelineLayout(kShaderImmediateDataBytes - 4);
+
+        ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&csDesc));
+    }
+}
+
+// Check that default pipelineLayout has too many immediate data bytes .
+TEST_F(ImmediateDataTest, ValidateDefaultPipelineLayout) {
+    DAWN_SKIP_TEST_IF(!device.HasFeature(wgpu::FeatureName::ChromiumExperimentalImmediateData));
+    wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, R"(
+        enable chromium_experimental_push_constant;
+        var<push_constant> fragmentConstants: vec4f;
+        var<push_constant> computeConstants: vec4u;
+        @vertex fn vsMain(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4f {
+            const pos = array(
+                vec2( 1.0, -1.0),
+                vec2(-1.0, -1.0),
+                vec2( 0.0,  1.0),
+            );
+            return vec4(pos[VertexIndex], 0.0, 1.0);
+        }
+
+        // to reuse the same pipeline layout
+        @fragment fn fsMain() -> @location(0) vec4f {
+            return fragmentConstants;
+        }
+
+        @group(0) @binding(0) var<storage, read_write> output : vec4u;
+
+        @compute @workgroup_size(1, 1, 1)
+        fn csMain() {
+            output = computeConstants;
+        })");
+
+    wgpu::ShaderModule oobShaderModule = utils::CreateShaderModule(device, R"(
+            enable chromium_experimental_push_constant;
+            struct FragmentConstants {
+                constants: vec4f,
+                constantsOOB: f32,
+            };
+
+            struct ComputeConstants {
+                constants: vec4u,
+                constantsOOB: u32,
+            };
+            var<push_constant> fragmentConstants: FragmentConstants;
+            var<push_constant> computeConstants: ComputeConstants;
+            @vertex fn vsMain(@builtin(vertex_index) VertexIndex : u32) -> @builtin(position) vec4f {
+                const pos = array(
+                    vec2( 1.0, -1.0),
+                    vec2(-1.0, -1.0),
+                    vec2( 0.0,  1.0),
+                );
+                return vec4(pos[VertexIndex], 0.0, 1.0);
+            }
+
+            // to reuse the same pipeline layout
+            @fragment fn fsMain() -> @location(0) vec4f {
+                return vec4f(fragmentConstants.constants.x + fragmentConstants.constantsOOB,
+                             fragmentConstants.constants.yzw);
+            }
+
+            @group(0) @binding(0) var<storage, read_write> output : vec4u;
+
+            @compute @workgroup_size(1, 1, 1)
+            fn csMain() {
+                output = vec4u(computeConstants.constants.x + computeConstants.constantsOOB,
+                               computeConstants.constants.yzw);
+            })");
+
+    // Success cases
+    {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = shaderModule;
+        pipelineDescriptor.cFragment.module = shaderModule;
+        pipelineDescriptor.cFragment.targetCount = 1;
+        device.CreateRenderPipeline(&pipelineDescriptor);
+    }
+
+    {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = shaderModule;
+
+        device.CreateComputePipeline(&csDesc);
+    }
+
+    // Using too many immediate data cases
+    {
+        utils::ComboRenderPipelineDescriptor pipelineDescriptor;
+        pipelineDescriptor.vertex.module = oobShaderModule;
+        pipelineDescriptor.cFragment.module = oobShaderModule;
+        pipelineDescriptor.cFragment.targetCount = 1;
+        ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&pipelineDescriptor));
+    }
+
+    {
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.compute.module = oobShaderModule;
+
+        ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&csDesc));
+    }
 }
 
 }  // anonymous namespace