Validate workgroup size and storage requirements

We define hard limits on these attributes for compute stages. This
enforces them.

BUG: dawn:322
Change-Id: I9b279774e877b5d40d912cb9f812f23d61c20a42
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/56806
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index c02e809..36d1c27 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -14,6 +14,7 @@
 
 #include "dawn_native/ShaderModule.h"
 
+#include "common/Constants.h"
 #include "common/HashUtils.h"
 #include "common/VertexFormatUtils.h"
 #include "dawn_native/BindGroupLayout.h"
@@ -903,6 +904,50 @@
                 DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
 
                 if (metadata->stage == SingleShaderStage::Compute) {
+                    if (entryPoint.workgroup_size_x > kMaxComputeWorkgroupSizeX) {
+                        errorStream << "Workgroup X dimension exceeds maximum allowed:"
+                                    << entryPoint.workgroup_size_x << " > "
+                                    << kMaxComputeWorkgroupSizeX;
+                        return DAWN_VALIDATION_ERROR(errorStream.str());
+                    }
+                    if (entryPoint.workgroup_size_y > kMaxComputeWorkgroupSizeY) {
+                        errorStream << "Workgroup Y dimension exceeds maximum allowed: "
+                                    << entryPoint.workgroup_size_y << " > "
+                                    << kMaxComputeWorkgroupSizeY;
+                        return DAWN_VALIDATION_ERROR(errorStream.str());
+                    }
+                    if (entryPoint.workgroup_size_z > kMaxComputeWorkgroupSizeZ) {
+                        errorStream << "Workgroup Z dimension exceeds maximum allowed: "
+                                    << entryPoint.workgroup_size_z << " > "
+                                    << kMaxComputeWorkgroupSizeZ;
+                        return DAWN_VALIDATION_ERROR(errorStream.str());
+                    }
+
+                    // Dimensions have already been validated against their individual limits above.
+                    // This assertion ensures that the product of such limited dimensions cannot
+                    // possibly overflow a uint32_t.
+                    static_assert(static_cast<uint64_t>(kMaxComputeWorkgroupSizeX) *
+                                          kMaxComputeWorkgroupSizeY * kMaxComputeWorkgroupSizeZ <=
+                                      std::numeric_limits<uint32_t>::max(),
+                                  "Per-dimension workgroup size limits are too high");
+                    uint32_t num_invocations = entryPoint.workgroup_size_x *
+                                               entryPoint.workgroup_size_y *
+                                               entryPoint.workgroup_size_z;
+                    if (num_invocations > kMaxComputeWorkgroupInvocations) {
+                        errorStream << "Number of workgroup invocations exceeds maximum allowed: "
+                                    << num_invocations << " > " << kMaxComputeWorkgroupInvocations;
+                        return DAWN_VALIDATION_ERROR(errorStream.str());
+                    }
+
+                    const size_t workgroup_storage_size =
+                        inspector.GetWorkgroupStorageSize(entryPoint.name);
+                    if (workgroup_storage_size > kMaxComputeWorkgroupStorageSize) {
+                        errorStream << "Workgroup shared storage size for " << entryPoint.name
+                                    << " exceeds the maximum allowed: " << workgroup_storage_size
+                                    << " > " << kMaxComputeWorkgroupStorageSize;
+                        return DAWN_VALIDATION_ERROR(errorStream.str());
+                    }
+
                     metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
                     metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
                     metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
diff --git a/src/tests/perf_tests/ShaderRobustnessPerf.cpp b/src/tests/perf_tests/ShaderRobustnessPerf.cpp
index 36d1e25..a1fcf11 100644
--- a/src/tests/perf_tests/ShaderRobustnessPerf.cpp
+++ b/src/tests/perf_tests/ShaderRobustnessPerf.cpp
@@ -17,7 +17,7 @@
 #include "utils/WGPUHelpers.h"
 
 namespace {
-    constexpr uint32_t kTileSize = 64u;
+    constexpr uint32_t kTileSize = 32u;
 
     const std::string& kMatMulFloatHeader = R"(
         [[block]] struct Uniforms {
@@ -62,18 +62,18 @@
 
         let RowPerThread : u32 = 4u;
         let ColPerThread : u32 = 4u;
-        let TileAOuter : u32 = 64u;
-        let TileBOuter : u32 = 64u;
-        let TileInner : u32 = 64u;)";
+        let TileAOuter : u32 = 32u;
+        let TileBOuter : u32 = 32u;
+        let TileInner : u32 = 32u;)";
 
     const std::string& kMatMulFloatSharedArray1D = R"(
-        var<workgroup> mm_Asub : array<f32, 4096>;
-        var<workgroup> mm_Bsub : array<f32, 4096>;)";
+        var<workgroup> mm_Asub : array<f32, 1024>;
+        var<workgroup> mm_Bsub : array<f32, 1024>;)";
     const std::string& kMatMulFloatSharedArray2D = R"(
-        var<workgroup> mm_Asub : array<array<f32, 64>, 64>;
-        var<workgroup> mm_Bsub : array<array<f32, 64>, 64>;)";
+        var<workgroup> mm_Asub : array<array<f32, 32>, 32>;
+        var<workgroup> mm_Bsub : array<array<f32, 32>, 32>;)";
     const std::string& kMatMulFloatBodyPart1 = R"(
-        [[stage(compute), workgroup_size(16, 16, 1)]]
+        [[stage(compute), workgroup_size(8, 8, 1)]]
         fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
                 [[builtin(global_invocation_id)]] global_id  : vec3<u32>) {
             let tileRow : u32 = local_id.y * RowPerThread;
@@ -95,9 +95,9 @@
                 acc[index] = 0.;
             }
 
-            let ColPerThreadA : u32 = TileInner / 16u;
+            let ColPerThreadA : u32 = TileInner / 8u;
             let tileColA : u32 = local_id.x * ColPerThreadA;
-            let RowPerThreadB : u32 = TileInner / 16u;
+            let RowPerThreadB : u32 = TileInner / 8u;
             let tileRowB : u32 = local_id.y * RowPerThreadB;
 
             // Loop over shared dimension.
@@ -229,17 +229,16 @@
 
         let RowPerThread : u32 = 4u;
         let ColPerThread : u32 = 4u;
-        let TileAOuter : u32 = 64u;
-        let TileBOuter : u32 = 64u;
-        let TileInner : u32 = 64u;)";
+        let TileOuter : u32 = 32u;
+        let TileInner : u32 = 32u;)";
     const std::string& kMatMulVec4SharedArray1D = R"(
-        var<workgroup> mm_Asub : array<vec4<f32>, 1024>;
-        var<workgroup> mm_Bsub : array<vec4<f32>, 1024>;)";
+        var<workgroup> mm_Asub : array<vec4<f32>, 256>;
+        var<workgroup> mm_Bsub : array<vec4<f32>, 256>;)";
     const std::string& kMatMulVec4SharedArray2D = R"(
-        var<workgroup> mm_Asub : array<array<vec4<f32>, 16>, 64>;
-        var<workgroup> mm_Bsub : array<array<vec4<f32>, 16>, 64>;)";
+        var<workgroup> mm_Asub : array<array<vec4<f32>, 8>, 32>;
+        var<workgroup> mm_Bsub : array<array<vec4<f32>, 8>, 32>;)";
     const std::string& kMatMulVec4BodyPart1 = R"(
-        [[stage(compute), workgroup_size(16, 16, 1)]]
+        [[stage(compute), workgroup_size(8, 8, 1)]]
         fn main([[builtin(local_invocation_id)]] local_id : vec3<u32>,
                 [[builtin(global_invocation_id)]] global_id  : vec3<u32>) {
             let tileRow : u32 = local_id.y * RowPerThread;
@@ -262,7 +261,7 @@
             }
 
             var globalColA : u32 = tileCol;
-            let RowPerThreadB : u32 = TileInner / 16u;
+            let RowPerThreadB : u32 = TileInner / 8u;
             let tileRowB : u32 = local_id.y * RowPerThreadB;
 
             // Loop over shared dimension.
@@ -281,7 +280,7 @@
                 for (var innerRow : u32 = 0u; innerRow < RowPerThreadB; innerRow = innerRow + 1u) {
                     let inputRow : u32 = tileRowB + innerRow;
                     let inputCol : u32 = tileCol;
-                    let index : u32 = inputRow * TileBOuter / ColPerThread + inputCol;
+                    let index : u32 = inputRow * TileOuter / ColPerThread + inputCol;
                     mm_Bsub[index] = mm_readB(t * TileInner + inputRow, globalCol);;
                 }
 
@@ -289,10 +288,10 @@
 
                 // Compute acc values for a single thread.
                 for (var k : u32 = 0u; k < TileInner / ColPerThread; k = k + 1u) {
-                    BCached[0] = mm_Bsub[(k * ColPerThread) * (TileBOuter / ColPerThread) + tileCol];
-                    BCached[1] = mm_Bsub[(k * ColPerThread + 1u) * (TileBOuter / ColPerThread) + tileCol];
-                    BCached[2] = mm_Bsub[(k * ColPerThread + 2u) * (TileBOuter / ColPerThread) + tileCol];
-                    BCached[3] = mm_Bsub[(k * ColPerThread + 3u) * (TileBOuter / ColPerThread) + tileCol];
+                    BCached[0] = mm_Bsub[(k * ColPerThread) * (TileOuter / ColPerThread) + tileCol];
+                    BCached[1] = mm_Bsub[(k * ColPerThread + 1u) * (TileOuter / ColPerThread) + tileCol];
+                    BCached[2] = mm_Bsub[(k * ColPerThread + 2u) * (TileOuter / ColPerThread) + tileCol];
+                    BCached[3] = mm_Bsub[(k * ColPerThread + 3u) * (TileOuter / ColPerThread) + tileCol];
 
                     for (var i : u32 = 0u; i < RowPerThread; i = i + 1u) {
                         ACached = mm_Asub[(tileRow + i) * (TileInner / ColPerThread) + k];)";
diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
index a83ab88..58fff4c 100644
--- a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
+++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp
@@ -282,3 +282,64 @@
         ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, fragmentShader.c_str()));
     }
 }
+
+// Tests that we validate workgroup size limits.
+TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) {
+    DAWN_SKIP_TEST_IF(!HasToggleEnabled("use_tint_generator"));
+
+    auto MakeShaderWithWorkgroupSize = [this](uint32_t x, uint32_t y, uint32_t z) {
+        std::ostringstream ss;
+        ss << "[[stage(compute), workgroup_size(" << x << "," << y << "," << z
+           << ")]] fn main() {}";
+        utils::CreateShaderModule(device, ss.str().c_str());
+    };
+
+    MakeShaderWithWorkgroupSize(1, 1, 1);
+    MakeShaderWithWorkgroupSize(kMaxComputeWorkgroupSizeX, 1, 1);
+    MakeShaderWithWorkgroupSize(1, kMaxComputeWorkgroupSizeY, 1);
+    MakeShaderWithWorkgroupSize(1, 1, kMaxComputeWorkgroupSizeZ);
+
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(kMaxComputeWorkgroupSizeX + 1, 1, 1));
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(1, kMaxComputeWorkgroupSizeY + 1, 1));
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(1, 1, kMaxComputeWorkgroupSizeZ + 1));
+
+    // No individual dimension exceeds its limit, but the combined size should definitely exceed the
+    // total invocation limit.
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupSize(
+        kMaxComputeWorkgroupSizeX, kMaxComputeWorkgroupSizeY, kMaxComputeWorkgroupSizeZ));
+}
+
+// Tests that we validate workgroup storage size limits.
+TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) {
+    DAWN_SKIP_TEST_IF(!HasToggleEnabled("use_tint_generator"));
+
+    constexpr uint32_t kVec4Size = 16;
+    constexpr uint32_t kMaxVec4Count = kMaxComputeWorkgroupStorageSize / kVec4Size;
+    constexpr uint32_t kMat4Size = 64;
+    constexpr uint32_t kMaxMat4Count = kMaxComputeWorkgroupStorageSize / kMat4Size;
+
+    auto MakeShaderWithWorkgroupStorage = [this](uint32_t vec4_count, uint32_t mat4_count) {
+        std::ostringstream ss;
+        std::ostringstream body;
+        if (vec4_count > 0) {
+            ss << "var<workgroup> vec4_data: array<vec4<f32>, " << vec4_count << ">;";
+            body << "ignore(vec4_data);";
+        }
+        if (mat4_count > 0) {
+            ss << "var<workgroup> mat4_data: array<mat4x4<f32>, " << mat4_count << ">;";
+            body << "ignore(mat4_data);";
+        }
+        ss << "[[stage(compute), workgroup_size(1)]] fn main() { " << body.str() << " }";
+        utils::CreateShaderModule(device, ss.str().c_str());
+    };
+
+    MakeShaderWithWorkgroupStorage(1, 1);
+    MakeShaderWithWorkgroupStorage(kMaxVec4Count, 0);
+    MakeShaderWithWorkgroupStorage(0, kMaxMat4Count);
+    MakeShaderWithWorkgroupStorage(kMaxVec4Count - 4, 1);
+    MakeShaderWithWorkgroupStorage(4, kMaxMat4Count - 1);
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(kMaxVec4Count + 1, 0));
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(kMaxVec4Count - 3, 1));
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(0, kMaxMat4Count + 1));
+    ASSERT_DEVICE_ERROR(MakeShaderWithWorkgroupStorage(4, kMaxMat4Count));
+}