validation: Allow unsigned workgroup_size component

Bug: tint:923
Change-Id: I7bd7d22279d9c6ce4c3225bdfd8693261b9084f9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58121
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Auto-Submit: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc
index fde0988..da0a2c9 100644
--- a/src/resolver/function_validation_test.cc
+++ b/src/resolver/function_validation_test.cc
@@ -357,6 +357,108 @@
             "declared here:");
 }
 
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
+  // let x = 4u;
+  // let x = 8u;
+  // [[stage(compute), workgroup_size(x, y, 16u]
+  // fn main() {}
+  GlobalConst("x", ty.u32(), Expr(4u));
+  GlobalConst("y", ty.u32(), Expr(8u));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr("x"), Expr("y"), Expr(16u))});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
+  // [[stage(compute), workgroup_size(1u, 2u, 3u)]
+  // fn main() {}
+
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Source{{12, 34}}, Expr(1u), Expr(2u), Expr(3u))});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) {
+  // [[stage(compute), workgroup_size(1u, 2u, 3)]
+  // fn main() {}
+
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(1u), Expr(2u), Expr(Source{{12, 34}}, 3))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameters must be of the same "
+            "type, either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) {
+  // [[stage(compute), workgroup_size(1, 2u, 3)]
+  // fn main() {}
+
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(1), Expr(Source{{12, 34}}, 2u), Expr(3))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameters must be of the same "
+            "type, either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
+  // let x = 64u;
+  // [[stage(compute), workgroup_size(1, x)]
+  // fn main() {}
+  GlobalConst("x", ty.u32(), Expr(64u));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(1), Expr(Source{Source::Location{12, 34}}, "x"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameters must be of the same "
+            "type, either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
+  // let x = 64u;
+  // let y = 32;
+  // [[stage(compute), workgroup_size(x, y)]
+  // fn main() {}
+  GlobalConst("x", ty.u32(), Expr(64u));
+  GlobalConst("y", ty.i32(), Expr(32));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr("x"), Expr(Source{Source::Location{12, 34}}, "y"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameters must be of the same "
+            "type, either i32 or u32");
+}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
+  // let x = 4u;
+  // let x = 8u;
+  // [[stage(compute), workgroup_size(x, y, 16]
+  // fn main() {}
+  GlobalConst("x", ty.u32(), Expr(4u));
+  GlobalConst("y", ty.u32(), Expr(8u));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr("x"), Expr("y"),
+                      Expr(Source{Source::Location{12, 34}}, 16))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameters must be of the same "
+            "type, either i32 or u32");
+}
+
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
   // [[stage(compute), workgroup_size(64.0)]
   // fn main() {}
@@ -368,8 +470,8 @@
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
-            "12:34 error: workgroup_size parameter must be a literal i32 or an "
-            "i32 module-scope constant");
+            "12:34 error: workgroup_size parameter must be either literal or "
+            "module-scope constant of type i32 or u32");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
@@ -382,9 +484,8 @@
             Source{Source::Location{12, 34}}, Literal(-2)))});
 
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: workgroup_size parameter must be a positive i32 value");
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be at least 1");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Zero) {
@@ -397,9 +498,8 @@
             Source{Source::Location{12, 34}}, Literal(0)))});
 
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: workgroup_size parameter must be a positive i32 value");
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be at least 1");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
@@ -413,8 +513,8 @@
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
-            "12:34 error: workgroup_size parameter must be a literal i32 or an "
-            "i32 module-scope constant");
+            "12:34 error: workgroup_size parameter must be either literal or "
+            "module-scope constant of type i32 or u32");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
@@ -427,9 +527,8 @@
         WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
 
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: workgroup_size parameter must be a positive i32 value");
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be at least 1");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Zero) {
@@ -442,9 +541,8 @@
         WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
 
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: workgroup_size parameter must be a positive i32 value");
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be at least 1");
 }
 
 TEST_F(ResolverFunctionValidationTest,
@@ -459,9 +557,8 @@
         WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
 
   EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(
-      r()->error(),
-      "12:34 error: workgroup_size parameter must be a positive i32 value");
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be at least 1");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
@@ -475,8 +572,8 @@
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
-            "12:34 error: workgroup_size parameter must be a literal i32 or an "
-            "i32 module-scope constant");
+            "12:34 error: workgroup_size parameter must be either literal or "
+            "module-scope constant of type i32 or u32");
 }
 
 TEST_F(ResolverFunctionValidationTest, ReturnIsConstructible_NonPlain) {
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index eac0b63..014b257 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1608,6 +1608,8 @@
   if (auto* workgroup =
           ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
     auto values = workgroup->values();
+    auto is_i32 = false;
+    auto is_less_than_one = true;
     for (int i = 0; i < 3; i++) {
       // Each argument to this decoration can either be a literal, an
       // identifier for a module-scope constants, or nullptr if not specified.
@@ -1619,7 +1621,7 @@
 
       Mark(values[i]);
 
-      int32_t value = 0;
+      uint32_t value = 0;
       if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
         // We have an identifier of a module-scope constant.
         if (!Identifier(ident)) {
@@ -1628,10 +1630,10 @@
 
         VariableInfo* var;
         if (!variable_stack_.get(ident->symbol(), &var) ||
-            !(var->declaration->is_const() && var->type->Is<sem::I32>())) {
+            !(var->declaration->is_const() && var->type->is_integer_scalar())) {
           AddError(
-              "workgroup_size parameter must be a literal i32 or an i32 "
-              "module-scope constant",
+              "workgroup_size parameter must be either literal or module-scope "
+              "constant of type i32 or u32",
               values[i]->source());
           return false;
         }
@@ -1646,12 +1648,28 @@
         if (constructor) {
           // Resolve the constructor expression to use as the default value.
           auto val = ConstantValueOf(constructor);
-          if (!val.IsValid() || !val.Type()->Is<sem::I32>()) {
+          if (!val.IsValid() || !val.Type()->is_integer_scalar()) {
             TINT_ICE(Resolver, diagnostics_)
                 << "failed to resolve workgroup_size constant value";
             return false;
           }
-          value = val.Elements()[0].i32;
+
+          if (i == 0) {
+            is_i32 = val.Type()->Is<sem::I32>();
+          } else {
+            if (is_i32 != val.Type()->Is<sem::I32>()) {
+              AddError(
+                  "workgroup_size parameters must be of the same type, "
+                  "either i32 or u32",
+                  values[i]->source());
+              return false;
+            }
+          }
+          is_less_than_one =
+              is_i32 ? val.Elements()[0].i32 < 1 : val.Elements()[0].u32 < 1;
+
+          value = is_i32 ? static_cast<uint32_t>(val.Elements()[0].i32)
+                         : val.Elements()[0].u32;
         } else {
           // No constructor means this value must be overriden by the user.
           info->workgroup_size[i].value = 0;
@@ -1661,22 +1679,36 @@
                      values[i]->As<ast::ScalarConstructorExpression>()) {
         // We have a literal.
         Mark(scalar->literal());
-
-        auto* i32_literal = scalar->literal()->As<ast::IntLiteral>();
-        if (!i32_literal) {
+        auto* literal = scalar->literal()->As<ast::IntLiteral>();
+        if (!literal) {
           AddError(
-              "workgroup_size parameter must be a literal i32 or an i32 "
-              "module-scope constant",
+              "workgroup_size parameter must be either literal or module-scope "
+              "constant of type i32 or u32",
               values[i]->source());
           return false;
         }
 
-        value = i32_literal->value_as_i32();
+        if (i == 0) {
+          is_i32 = literal->Is<ast::SintLiteral>();
+        } else {
+          if (literal->Is<ast::SintLiteral>() != is_i32) {
+            AddError(
+                "workgroup_size parameters must be of the same type, "
+                "either i32 or u32",
+                values[i]->source());
+            return false;
+          }
+        }
+
+        is_less_than_one =
+            is_i32 ? literal->value_as_i32() < 1 : literal->value_as_u32() < 1;
+        value = is_i32 ? static_cast<uint32_t>(literal->value_as_i32())
+                       : literal->value_as_u32();
       }
 
       // Validate and set the default value for this dimension.
-      if (value < 1) {
-        AddError("workgroup_size parameter must be a positive i32 value",
+      if (is_less_than_one) {
+        AddError("workgroup_size parameter must be at least 1",
                  values[i]->source());
         return false;
       }