[tint][ir][val] Check numeric type of params to @workgroup_size

Refactors checks for @workgroup_size into a separate utility function,
and implements rules about numeric types for params.

A follow up CL will implement rules related to params being
const/override expressions.

Issue: 376624999

Change-Id: Ib7cf63f2dd191bba95c2e94d45313584311c9342
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/213296
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 5c3b6f5..2221117 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -79,6 +79,7 @@
 #include "src/tint/lang/core/ir/var.h"
 #include "src/tint/lang/core/type/bool.h"
 #include "src/tint/lang/core/type/f32.h"
+#include "src/tint/lang/core/type/i32.h"
 #include "src/tint/lang/core/type/i8.h"
 #include "src/tint/lang/core/type/memory_view.h"
 #include "src/tint/lang/core/type/pointer.h"
@@ -883,6 +884,10 @@
     /// @param func the function to validate
     void CheckFunction(const Function* func);
 
+    /// Validates the workgroup_size attribute for a given function
+    /// @param func the function to validate
+    void CheckWorkgroupSize(const Function* func);
+
     /// Validates the specific function as a vertex entry point
     /// @param ep the function to validate
     void CheckVertexEntryPoint(const Function* ep);
@@ -1903,21 +1908,14 @@
         }
     }
 
-    if (func->Stage() == Function::PipelineStage::kCompute) {
-        if (DAWN_UNLIKELY(!func->WorkgroupSize().has_value())) {
-            AddError(func) << "compute entry point requires workgroup size attribute";
-        }
+    CheckWorkgroupSize(func);
 
+    if (func->Stage() == Function::PipelineStage::kCompute) {
         if (DAWN_UNLIKELY(func->ReturnType() && !func->ReturnType()->Is<core::type::Void>())) {
             AddError(func) << "compute entry point must not have a return type";
         }
     }
 
-    if (DAWN_UNLIKELY(func->Stage() != Function::PipelineStage::kCompute &&
-                      func->WorkgroupSize().has_value())) {
-        AddError(func) << "workgroup size attribute only valid on compute entry point";
-    }
-
     if (func->Stage() == Function::PipelineStage::kFragment) {
         if (!func->ReturnType()->Is<core::type::Void>()) {
             CheckFunctionReturnAttributes(
@@ -1977,6 +1975,50 @@
     ProcessTasks();
 }
 
+void Validator::CheckWorkgroupSize(const Function* func) {
+    if (func->Stage() != Function::PipelineStage::kCompute) {
+        if (func->WorkgroupSize().has_value()) {
+            AddError(func) << "@workgroup_size only valid on compute entry point";
+        }
+        return;
+    }
+
+    if (!func->WorkgroupSize().has_value()) {
+        AddError(func) << "compute entry point requires @workgroup_size";
+        return;
+    }
+
+    auto workgroup_sizes = func->WorkgroupSize().value();
+    // The number parameters cannot be checked here, since it is stored internally as a 3 element
+    // array, so will always have 3 elements at this point.
+    TINT_ASSERT(workgroup_sizes.size() == 3);
+
+    std::optional<const core::type::Type*> sizes_ty;
+    for (auto* size : workgroup_sizes) {
+        if (!size || !size->Type()) {
+            AddError(func) << "a @workgroup_size param is undefined or missing a type";
+            return;
+        }
+
+        auto* ty = size->Type();
+        if (!ty->IsAnyOf<core::type::I32, core::type::U32>()) {
+            AddError(func) << "@workgroup_size params must be an i32 or u32";
+            return;
+        }
+
+        if (!sizes_ty.has_value()) {
+            sizes_ty = ty;
+        }
+
+        if (sizes_ty != ty) {
+            AddError(func) << "@workgroup_size params must be all i32s or all u32s";
+            return;
+        }
+
+        // TODO(376624999): Implement enforcing rules around override and constant expressions
+    }
+}
+
 void Validator::CheckVertexEntryPoint(const Function* ep) {
     bool contains_position = IsPositionPresent(ep->ReturnAttributes(), ep->ReturnType());
 
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 3060f7c..b2fccba 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -835,26 +835,6 @@
 )");
 }
 
-TEST_F(IR_ValidatorTest, Function_MissingWorkgroupSize) {
-    auto* f = b.Function("f", ty.void_(), Function::PipelineStage::kCompute);
-    b.Append(f->Block(), [&] { b.Return(f); });
-
-    auto res = ir::Validate(mod);
-    ASSERT_NE(res, Success);
-    EXPECT_EQ(res.Failure().reason.Str(),
-              R"(:1:1 error: compute entry point requires workgroup size attribute
-%f = @compute func():void {
-^^
-
-note: # Disassembly
-%f = @compute func():void {
-  $B1: {
-    ret
-  }
-}
-)");
-}
-
 TEST_F(IR_ValidatorTest, Function_UnnamedEntryPoint) {
     auto* f = b.Function(ty.void_(), ir::Function::PipelineStage::kCompute);
     f->SetWorkgroupSize({b.Constant(1_u), b.Constant(1_u), b.Constant(1_u)});
@@ -946,7 +926,7 @@
 
 TEST_F(IR_ValidatorTest, Function_Compute_NonVoidReturn) {
     auto* f = b.Function("my_func", ty.f32(), core::ir::Function::PipelineStage::kCompute);
-    f->SetWorkgroupSize(b.Constant(0_u), b.Constant(0_u), b.Constant(0_u));
+    f->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
 
     b.Append(f->Block(), [&] { b.Unreachable(); });
 
@@ -954,11 +934,11 @@
     ASSERT_NE(res, Success);
     EXPECT_EQ(res.Failure().reason.Str(),
               R"(:1:1 error: compute entry point must not have a return type
-%my_func = @compute @workgroup_size(0u, 0u, 0u) func():f32 {
+%my_func = @compute @workgroup_size(1u, 1u, 1u) func():f32 {
 ^^^^^^^^
 
 note: # Disassembly
-%my_func = @compute @workgroup_size(0u, 0u, 0u) func():f32 {
+%my_func = @compute @workgroup_size(1u, 1u, 1u) func():f32 {
   $B1: {
     unreachable
   }
@@ -966,7 +946,27 @@
 )");
 }
 
-TEST_F(IR_ValidatorTest, Function_WorkspaceSizeOnlyOnCompute) {
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_MissingOnCompute) {
+    auto* f = b.Function("f", ty.void_(), Function::PipelineStage::kCompute);
+    b.Append(f->Block(), [&] { b.Return(f); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:1:1 error: compute entry point requires @workgroup_size
+%f = @compute func():void {
+^^
+
+note: # Disassembly
+%f = @compute func():void {
+  $B1: {
+    ret
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_NonCompute) {
     auto* f = FragmentEntryPoint();
     f->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
 
@@ -975,7 +975,7 @@
     auto res = ir::Validate(mod);
     ASSERT_NE(res, Success);
     EXPECT_EQ(res.Failure().reason.Str(),
-              R"(:1:1 error: workgroup size attribute only valid on compute entry point
+              R"(:1:1 error: @workgroup_size only valid on compute entry point
 %f = @fragment @workgroup_size(1u, 1u, 1u) func():void {
 ^^
 
@@ -988,6 +988,72 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamUndefined) {
+    auto* f = ComputeEntryPoint();
+    f->SetWorkgroupSize({nullptr, b.Constant(2_u), b.Constant(3_u)});
+
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:1:1 error: a @workgroup_size param is undefined or missing a type
+%f = @compute @workgroup_size(undef, 2u, 3u) func():void {
+^^
+
+note: # Disassembly
+%f = @compute @workgroup_size(undef, 2u, 3u) func():void {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamWrongType) {
+    auto* f = ComputeEntryPoint();
+    f->SetWorkgroupSize({b.Constant(1_f), b.Constant(2_u), b.Constant(3_u)});
+
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:1:1 error: @workgroup_size params must be an i32 or u32
+%f = @compute @workgroup_size(1.0f, 2u, 3u) func():void {
+^^
+
+note: # Disassembly
+%f = @compute @workgroup_size(1.0f, 2u, 3u) func():void {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamsSameType) {
+    auto* f = ComputeEntryPoint();
+    f->SetWorkgroupSize({b.Constant(1_u), b.Constant(2_i), b.Constant(3_u)});
+
+    b.Append(f->Block(), [&] { b.Unreachable(); });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(),
+              R"(:1:1 error: @workgroup_size params must be all i32s or all u32s
+%f = @compute @workgroup_size(1u, 2i, 3u) func():void {
+^^
+
+note: # Disassembly
+%f = @compute @workgroup_size(1u, 2i, 3u) func():void {
+  $B1: {
+    unreachable
+  }
+}
+)");
+}
+
 TEST_F(IR_ValidatorTest, Function_Vertex_BasicPosition) {
     auto* f = b.Function("my_func", ty.vec4<f32>(), Function::PipelineStage::kVertex);
     f->SetReturnBuiltin(BuiltinValue::kPosition);