Enforce per-dimension dispatch size limits

Note that this is for direct dispatch calls only. Indirect dispatch
calls are still not validated.

Bug: dawn:1006
Change-Id: I061c15208a01dfb803923823ba4afd38667cad22
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59122
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Ken Rockot <rockot@google.com>
diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp
index 88c7e69..dcc5df8 100644
--- a/src/dawn_native/ComputePassEncoder.cpp
+++ b/src/dawn_native/ComputePassEncoder.cpp
@@ -25,6 +25,18 @@
 
 namespace dawn_native {
 
+    namespace {
+
+        MaybeError ValidatePerDimensionDispatchSizeLimit(uint32_t size) {
+            if (size > kMaxComputePerDimensionDispatchSize) {
+                return DAWN_VALIDATION_ERROR("Dispatch size exceeds defined limits");
+            }
+
+            return {};
+        }
+
+    }  // namespace
+
     ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
                                            CommandEncoder* commandEncoder,
                                            EncodingContext* encodingContext)
@@ -63,6 +75,9 @@
         mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
             if (IsValidationEnabled()) {
                 DAWN_TRY(mCommandBufferState.ValidateCanDispatch());
+                DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(x));
+                DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(y));
+                DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(z));
             }
 
             // Record the synchronization scope for Dispatch, which is just the current bindgroups.
diff --git a/src/tests/unittests/validation/ComputeValidationTests.cpp b/src/tests/unittests/validation/ComputeValidationTests.cpp
index 8135997..6d66b87 100644
--- a/src/tests/unittests/validation/ComputeValidationTests.cpp
+++ b/src/tests/unittests/validation/ComputeValidationTests.cpp
@@ -12,9 +12,72 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include "common/Constants.h"
 #include "tests/unittests/validation/ValidationTest.h"
-
-class ComputeValidationTest : public ValidationTest {};
+#include "utils/WGPUHelpers.h"
 
 // TODO(cwallez@chromium.org): Add a regression test for Disptach validation trying to acces the
 // input state.
+
+class ComputeValidationTest : public ValidationTest {
+  protected:
+    void SetUp() override {
+        ValidationTest::SetUp();
+
+        wgpu::ShaderModule computeModule = utils::CreateShaderModule(device, R"(
+            [[stage(compute), workgroup_size(1)]] fn main() {
+            })");
+
+        // Set up compute pipeline
+        wgpu::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, nullptr);
+
+        wgpu::ComputePipelineDescriptor csDesc;
+        csDesc.layout = pl;
+        csDesc.compute.module = computeModule;
+        csDesc.compute.entryPoint = "main";
+        pipeline = device.CreateComputePipeline(&csDesc);
+    }
+
+    void TestDispatch(uint32_t x, uint32_t y, uint32_t z) {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+        pass.SetPipeline(pipeline);
+        pass.Dispatch(x, y, z);
+        pass.EndPass();
+        encoder.Finish();
+    }
+
+    wgpu::ComputePipeline pipeline;
+};
+
+// Check that 1x1x1 dispatch is OK.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_SmallestValid) {
+    TestDispatch(1, 1, 1);
+}
+
+// Check that the largest allowed dispatch is OK.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_LargestValid) {
+    constexpr uint32_t kMax = kMaxComputePerDimensionDispatchSize;
+    TestDispatch(kMax, kMax, kMax);
+}
+
+// Check that exceeding the maximum on the X dimension results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidX) {
+    ASSERT_DEVICE_ERROR(TestDispatch(kMaxComputePerDimensionDispatchSize + 1, 1, 1));
+}
+
+// Check that exceeding the maximum on the Y dimension results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidY) {
+    ASSERT_DEVICE_ERROR(TestDispatch(1, kMaxComputePerDimensionDispatchSize + 1, 1));
+}
+
+// Check that exceeding the maximum on the Z dimension results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidZ) {
+    ASSERT_DEVICE_ERROR(TestDispatch(1, 1, kMaxComputePerDimensionDispatchSize + 1));
+}
+
+// Check that exceeding the maximum on all dimensions results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidAll) {
+    constexpr uint32_t kMax = kMaxComputePerDimensionDispatchSize;
+    ASSERT_DEVICE_ERROR(TestDispatch(kMax + 1, kMax + 1, kMax + 1));
+}