tint/resolver: Materialize arguments to @workgroup_size

Bug: tint:1504
Change-Id: I69b448e62a4ebd684f6832f76fd28d8a31892a1a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91847
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 38970a5..5f47efb 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2546,10 +2546,20 @@
     }
 
     /// Creates an ast::WorkgroupAttribute
+    /// @param source the source information
     /// @param x the x dimension expression
     /// @param y the y dimension expression
     /// @returns the workgroup attribute pointer
     template <typename EXPR_X, typename EXPR_Y>
+    const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x, EXPR_Y&& y) {
+        return WorkgroupSize(source, std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
+    }
+
+    /// Creates an ast::WorkgroupAttribute
+    /// @param x the x dimension expression
+    /// @param y the y dimension expression
+    /// @returns the workgroup attribute pointer
+    template <typename EXPR_X, typename EXPR_Y, typename = DisableIfSource<EXPR_X>>
     const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
         return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
     }
@@ -2575,7 +2585,7 @@
     /// @param y the y dimension expression
     /// @param z the z dimension expression
     /// @returns the workgroup attribute pointer
-    template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z>
+    template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z, typename = DisableIfSource<EXPR_X>>
     const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
         return create<ast::WorkgroupAttribute>(source_, Expr(std::forward<EXPR_X>(x)),
                                                Expr(std::forward<EXPR_Y>(y)),
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index ce37ba5..317cadc 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -429,9 +429,8 @@
     // fn main() {}
     auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
     auto* y = GlobalConst("y", ty.u32(), Expr(8_u));
-    auto* func = Func(
-        "main", {}, ty.void_(), {},
-        {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Expr("x"), Expr("y"), Expr(16_u))});
+    auto* func = Func("main", {}, ty.void_(), {},
+                      {Stage(ast::PipelineStage::kCompute), WorkgroupSize("x", "y", 16_u)});
 
     ASSERT_TRUE(r()->Resolve()) << r()->error();
 
@@ -447,43 +446,68 @@
     EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
 }
 
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
+    // @stage(compute) @workgroup_size(1i, 2i, 3i)
+    // fn main() {}
+
+    Func("main", {}, ty.void_(), {},
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_i, 3_i)});
+
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
     // @stage(compute) @workgroup_size(1u, 2u, 3u)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
-         {Stage(ast::PipelineStage::kCompute),
-          WorkgroupSize(Source{{12, 34}}, Expr(1_u), Expr(2_u), Expr(3_u))});
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_u, 3_u)});
 
     ASSERT_TRUE(r()->Resolve()) << r()->error();
 }
 
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) {
-    // @stage(compute) @workgroup_size(1u, 2u, 3_i)
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32_AInt) {
+    // @stage(compute) @workgroup_size(1, 2i, 3)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
-         {Stage(ast::PipelineStage::kCompute),
-          WorkgroupSize(Expr(1_u), Expr(2_u), Expr(Source{{12, 34}}, 3_i))});
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_a, 2_i, 3_a)});
 
-    EXPECT_FALSE(r()->Resolve());
-    EXPECT_EQ(r()->error(),
-              "12:34 error: workgroup_size arguments must be of the same type, "
-              "either i32 or u32");
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
 }
 
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) {
-    // @stage(compute) @workgroup_size(1_i, 2u, 3_i)
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
+    // @stage(compute) @workgroup_size(1u, 2, 3u)
     // fn main() {}
 
     Func("main", {}, ty.void_(), {},
-         {Stage(ast::PipelineStage::kCompute),
-          WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, 2_u), Expr(3_i))});
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_u)});
+
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
+    // @stage(compute) @workgroup_size(1u, 2, 3_i)
+    // fn main() {}
+
+    Func("main", {}, ty.void_(), {},
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_i)});
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: workgroup_size arguments must be of the same type, "
-              "either i32 or u32");
+              "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_I32) {
+    // @stage(compute) @workgroup_size(1_i, 2u, 3)
+    // fn main() {}
+
+    Func("main", {}, ty.void_(), {},
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_u, 3_a)});
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
@@ -492,13 +516,11 @@
     // fn main() {}
     GlobalConst("x", ty.u32(), Expr(64_u));
     Func("main", {}, ty.void_(), {},
-         {Stage(ast::PipelineStage::kCompute),
-          WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, "x"))});
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, "x")});
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: workgroup_size arguments must be of the same type, "
-              "either i32 or u32");
+              "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
@@ -509,13 +531,11 @@
     GlobalConst("x", ty.u32(), Expr(64_u));
     GlobalConst("y", ty.i32(), Expr(32_i));
     Func("main", {}, ty.void_(), {},
-         {Stage(ast::PipelineStage::kCompute),
-          WorkgroupSize(Expr("x"), Expr(Source{{12, 34}}, "y"))});
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y")});
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: workgroup_size arguments must be of the same type, "
-              "either i32 or u32");
+              "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
 }
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
     // let x = 4u;
@@ -525,13 +545,11 @@
     GlobalConst("x", ty.u32(), Expr(4_u));
     GlobalConst("y", ty.u32(), Expr(8_u));
     Func("main", {}, ty.void_(), {},
-         {Stage(ast::PipelineStage::kCompute),
-          WorkgroupSize(Expr("x"), Expr("y"), Expr(Source{{12, 34}}, 16_i))});
+         {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y", 16_i)});
 
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
-              "12:34 error: workgroup_size arguments must be of the same type, "
-              "either i32 or u32");
+              "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index bcb1671..8f4451e 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -134,6 +134,11 @@
     //   default: {}
     // }
     kSwitchCaseWithAbstractCase,
+
+    // @workgroup_size(target_expr, abstract_expr, 123)
+    // @stage(compute)
+    // fn f() {}
+    kWorkgroupSize
 };
 
 static std::ostream& operator<<(std::ostream& o, Method m) {
@@ -162,6 +167,8 @@
             return o << "switch-cond-with-abstract";
         case Method::kSwitchCaseWithAbstractCase:
             return o << "switch-case-with-abstract";
+        case Method::kWorkgroupSize:
+            return o << "workgroup-size";
     }
     return o << "<unknown>";
 }
@@ -286,6 +293,11 @@
                                   Case(abstract_expr->As<ast::IntLiteralExpression>()),  //
                                   DefaultCase()));
             break;
+        case Method::kWorkgroupSize:
+            Func("f", {}, ty.void_(), {},
+                 {WorkgroupSize(target_expr(), abstract_expr, Expr(123_a)),
+                  Stage(ast::PipelineStage::kCompute)});
+            break;
     }
 
     auto check_types_and_values = [&](const sem::Expression* expr) {
@@ -461,6 +473,19 @@
                                               Types<u32, AInt>(AInt(kLowestU32), kLowestU32),    //
                                           })));
 
+INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
+                         MaterializeAbstractNumericToConcreteType,
+                         testing::Combine(testing::Values(Expectation::kMaterialize),
+                                          testing::Values(Method::kWorkgroupSize),
+                                          testing::ValuesIn(std::vector<Data>{
+                                              Types<i32, AInt>(1_a, 1.0),          //
+                                              Types<i32, AInt>(10_a, 10.0),        //
+                                              Types<i32, AInt>(65535_a, 65535.0),  //
+                                              Types<u32, AInt>(1_a, 1.0),          //
+                                              Types<u32, AInt>(10_a, 10.0),        //
+                                              Types<u32, AInt>(65535_a, 65535.0),  //
+                                          })));
+
 // TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
 INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
                          MaterializeAbstractNumericToConcreteType,
@@ -558,6 +583,11 @@
     //   default: {}
     // }
     kSwitch,
+
+    // @workgroup_size(abstract_expr)
+    // @stage(compute)
+    // fn f() {}
+    kWorkgroupSize
 };
 
 static std::ostream& operator<<(std::ostream& o, Method m) {
@@ -576,6 +606,8 @@
             return o << "array-length";
         case Method::kSwitch:
             return o << "switch";
+        case Method::kWorkgroupSize:
+            return o << "workgroup-size";
     }
     return o << "<unknown>";
 }
@@ -656,6 +688,10 @@
                                   Case(abstract_expr()->As<ast::IntLiteralExpression>()),
                                   DefaultCase()));
             break;
+        case Method::kWorkgroupSize:
+            Func("f", {}, ty.void_(), {},
+                 {WorkgroupSize(abstract_expr()), Stage(ast::PipelineStage::kCompute)});
+            break;
     }
 
     auto check_types_and_values = [&](const sem::Expression* expr) {
@@ -734,11 +770,6 @@
     Method::kVar,
 };
 
-/// Methods that support materialization for switch cases
-constexpr Method kSwitchMethods[] = {
-    Method::kSwitch,
-};
-
 INSTANTIATE_TEST_SUITE_P(
     MaterializeScalar,
     MaterializeAbstractNumericToDefaultType,
@@ -798,13 +829,23 @@
 INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
                          MaterializeAbstractNumericToDefaultType,
                          testing::Combine(testing::Values(Expectation::kMaterialize),
-                                          testing::ValuesIn(kSwitchMethods),
+                                          testing::Values(Method::kSwitch),
                                           testing::ValuesIn(std::vector<Data>{
                                               Types<i32, AInt>(0_a, 0.0),                        //
                                               Types<i32, AInt>(AInt(kHighestI32), kHighestI32),  //
                                               Types<i32, AInt>(AInt(kLowestI32), kLowestI32),    //
                                           })));
 
+INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
+                         MaterializeAbstractNumericToDefaultType,
+                         testing::Combine(testing::Values(Expectation::kMaterialize),
+                                          testing::Values(Method::kWorkgroupSize),
+                                          testing::ValuesIn(std::vector<Data>{
+                                              Types<i32, AInt>(1_a, 1.0),          //
+                                              Types<i32, AInt>(10_a, 10.0),        //
+                                              Types<i32, AInt>(65535_a, 65535.0),  //
+                                          })));
+
 INSTANTIATE_TEST_SUITE_P(ScalarValueCannotBeRepresented,
                          MaterializeAbstractNumericToDefaultType,
                          testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
@@ -840,7 +881,16 @@
 INSTANTIATE_TEST_SUITE_P(SwitchValueCannotBeRepresented,
                          MaterializeAbstractNumericToDefaultType,
                          testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
-                                          testing::ValuesIn(kSwitchMethods),
+                                          testing::Values(Method::kSwitch),
+                                          testing::ValuesIn(std::vector<Data>{
+                                              Types<i32, AInt>(0_a, kHighestI32 + 1),  //
+                                              Types<i32, AInt>(0_a, kLowestI32 - 1),   //
+                                          })));
+
+INSTANTIATE_TEST_SUITE_P(WorkgroupSizeValueCannotBeRepresented,
+                         MaterializeAbstractNumericToDefaultType,
+                         testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
+                                          testing::Values(Method::kWorkgroupSize),
                                           testing::ValuesIn(std::vector<Data>{
                                               Types<i32, AInt>(0_a, kHighestI32 + 1),  //
                                               Types<i32, AInt>(0_a, kLowestI32 - 1),   //
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 807dc62..c964d8a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -721,52 +721,61 @@
     }
 
     auto values = attr->Values();
-    auto any_i32 = false;
-    auto any_u32 = false;
+    std::array<const sem::Expression*, 3> args = {};
+    std::array<const sem::Type*, 3> arg_tys = {};
+    size_t arg_count = 0;
+
+    constexpr const char* kErrBadType =
+        "workgroup_size argument must be either literal or module-scope constant of type i32 "
+        "or u32";
+
     for (int i = 0; i < 3; i++) {
-        // Each argument to this attribute can either be a literal, an
-        // identifier for a module-scope constants, or nullptr if not specified.
-
-        auto* expr = values[i];
+        // Each argument to this attribute can either be a literal, an identifier for a module-scope
+        // constants, or nullptr if not specified.
+        auto* value = values[i];
+        if (!value) {
+            break;
+        }
+        const auto* expr = Expression(value);
         if (!expr) {
-            // Not specified, just use the default.
-            continue;
+            return false;
         }
-
-        auto* expr_sem = Expression(expr);
-        if (!expr_sem) {
+        auto* ty = expr->Type();
+        if (!ty->IsAnyOf<sem::I32, sem::U32, sem::AbstractInt>()) {
+            AddError(kErrBadType, value->source);
             return false;
         }
 
-        constexpr const char* kErrBadType =
-            "workgroup_size argument must be either literal or module-scope "
-            "constant of type i32 or u32";
-        constexpr const char* kErrInconsistentType =
-            "workgroup_size arguments must be of the same type, either i32 "
-            "or u32";
+        args[i] = expr;
+        arg_tys[i] = ty;
+        arg_count++;
+    }
 
-        auto* ty = sem_.TypeOf(expr);
-        bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
-        bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
-        if (!is_i32 && !is_u32) {
-            AddError(kErrBadType, expr->source);
-            return false;
-        }
+    auto* common_ty = sem::Type::Common(arg_tys.data(), arg_count);
+    if (!common_ty) {
+        AddError("workgroup_size arguments must be of the same type, either i32 or u32",
+                 attr->source);
+        return false;
+    }
 
-        any_i32 = any_i32 || is_i32;
-        any_u32 = any_u32 || is_u32;
-        if (any_i32 && any_u32) {
-            AddError(kErrInconsistentType, expr->source);
+    // If all arguments are abstract-integers, then materialize to i32.
+    if (common_ty->Is<sem::AbstractInt>()) {
+        common_ty = builder_->create<sem::I32>();
+    }
+
+    for (size_t i = 0; i < arg_count; i++) {
+        auto* materialized = Materialize(args[i], common_ty);
+        if (!materialized) {
             return false;
         }
 
         sem::Constant value;
 
-        if (auto* user = sem_.Get(expr)->As<sem::VariableUser>()) {
+        if (auto* user = args[i]->As<sem::VariableUser>()) {
             // We have an variable of a module-scope constant.
             auto* decl = user->Variable()->Declaration();
             if (!decl->is_const) {
-                AddError(kErrBadType, expr->source);
+                AddError(kErrBadType, values[i]->source);
                 return false;
             }
             // Capture the constant if it is pipeline-overridable.
@@ -781,8 +790,8 @@
                 ws[i].value = 0;
                 continue;
             }
-        } else if (expr->Is<ast::LiteralExpression>()) {
-            value = sem_.Get(expr)->ConstantValue();
+        } else if (values[i]->Is<ast::LiteralExpression>()) {
+            value = materialized->ConstantValue();
         } else {
             AddError(
                 "workgroup_size argument must be either a literal or a "