Enforce per-dimension dispatch size limits
Note that this is for direct dispatch calls only. Indirect dispatch
calls are still not validated.
Bug: dawn:1006
Change-Id: I061c15208a01dfb803923823ba4afd38667cad22
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/59122
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Ken Rockot <rockot@google.com>
diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp
index 88c7e69..dcc5df8 100644
--- a/src/dawn_native/ComputePassEncoder.cpp
+++ b/src/dawn_native/ComputePassEncoder.cpp
@@ -25,6 +25,18 @@
namespace dawn_native {
+ namespace {
+
+ MaybeError ValidatePerDimensionDispatchSizeLimit(uint32_t size) {
+ if (size > kMaxComputePerDimensionDispatchSize) {
+ return DAWN_VALIDATION_ERROR("Dispatch size exceeds defined limits");
+ }
+
+ return {};
+ }
+
+ } // namespace
+
ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext)
@@ -63,6 +75,9 @@
mEncodingContext->TryEncode(this, [&](CommandAllocator* allocator) -> MaybeError {
if (IsValidationEnabled()) {
DAWN_TRY(mCommandBufferState.ValidateCanDispatch());
+ DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(x));
+ DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(y));
+ DAWN_TRY(ValidatePerDimensionDispatchSizeLimit(z));
}
// Record the synchronization scope for Dispatch, which is just the current bindgroups.
diff --git a/src/tests/unittests/validation/ComputeValidationTests.cpp b/src/tests/unittests/validation/ComputeValidationTests.cpp
index 8135997..6d66b87 100644
--- a/src/tests/unittests/validation/ComputeValidationTests.cpp
+++ b/src/tests/unittests/validation/ComputeValidationTests.cpp
@@ -12,9 +12,72 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "common/Constants.h"
#include "tests/unittests/validation/ValidationTest.h"
-
-class ComputeValidationTest : public ValidationTest {};
+#include "utils/WGPUHelpers.h"
// TODO(cwallez@chromium.org): Add a regression test for Disptach validation trying to acces the
// input state.
+
+class ComputeValidationTest : public ValidationTest {
+ protected:
+ void SetUp() override {
+ ValidationTest::SetUp();
+
+ wgpu::ShaderModule computeModule = utils::CreateShaderModule(device, R"(
+ [[stage(compute), workgroup_size(1)]] fn main() {
+ })");
+
+ // Set up compute pipeline
+ wgpu::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, nullptr);
+
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.layout = pl;
+ csDesc.compute.module = computeModule;
+ csDesc.compute.entryPoint = "main";
+ pipeline = device.CreateComputePipeline(&csDesc);
+ }
+
+ void TestDispatch(uint32_t x, uint32_t y, uint32_t z) {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+ pass.SetPipeline(pipeline);
+ pass.Dispatch(x, y, z);
+ pass.EndPass();
+ encoder.Finish();
+ }
+
+ wgpu::ComputePipeline pipeline;
+};
+
+// Check that 1x1x1 dispatch is OK.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_SmallestValid) {
+ TestDispatch(1, 1, 1);
+}
+
+// Check that the largest allowed dispatch is OK.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_LargestValid) {
+ constexpr uint32_t kMax = kMaxComputePerDimensionDispatchSize;
+ TestDispatch(kMax, kMax, kMax);
+}
+
+// Check that exceeding the maximum on the X dimension results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidX) {
+ ASSERT_DEVICE_ERROR(TestDispatch(kMaxComputePerDimensionDispatchSize + 1, 1, 1));
+}
+
+// Check that exceeding the maximum on the Y dimension results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidY) {
+ ASSERT_DEVICE_ERROR(TestDispatch(1, kMaxComputePerDimensionDispatchSize + 1, 1));
+}
+
+// Check that exceeding the maximum on the Z dimension results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidZ) {
+ ASSERT_DEVICE_ERROR(TestDispatch(1, 1, kMaxComputePerDimensionDispatchSize + 1));
+}
+
+// Check that exceeding the maximum on all dimensions results in validation failure.
+TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidAll) {
+ constexpr uint32_t kMax = kMaxComputePerDimensionDispatchSize;
+ ASSERT_DEVICE_ERROR(TestDispatch(kMax + 1, kMax + 1, kMax + 1));
+}