tint/resolver: Materialize arguments to @workgroup_size
Bug: tint:1504
Change-Id: I69b448e62a4ebd684f6832f76fd28d8a31892a1a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91847
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 38970a5..5f47efb 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2546,10 +2546,20 @@
}
/// Creates an ast::WorkgroupAttribute
+ /// @param source the source information
/// @param x the x dimension expression
/// @param y the y dimension expression
/// @returns the workgroup attribute pointer
template <typename EXPR_X, typename EXPR_Y>
+ const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x, EXPR_Y&& y) {
+ return WorkgroupSize(source, std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
+ }
+
+ /// Creates an ast::WorkgroupAttribute
+ /// @param x the x dimension expression
+ /// @param y the y dimension expression
+ /// @returns the workgroup attribute pointer
+ template <typename EXPR_X, typename EXPR_Y, typename = DisableIfSource<EXPR_X>>
const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y), nullptr);
}
@@ -2575,7 +2585,7 @@
/// @param y the y dimension expression
/// @param z the z dimension expression
/// @returns the workgroup attribute pointer
- template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z>
+ template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z, typename = DisableIfSource<EXPR_X>>
const ast::WorkgroupAttribute* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
return create<ast::WorkgroupAttribute>(source_, Expr(std::forward<EXPR_X>(x)),
Expr(std::forward<EXPR_Y>(y)),
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index ce37ba5..317cadc 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -429,9 +429,8 @@
// fn main() {}
auto* x = GlobalConst("x", ty.u32(), Expr(4_u));
auto* y = GlobalConst("y", ty.u32(), Expr(8_u));
- auto* func = Func(
- "main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Expr("x"), Expr("y"), Expr(16_u))});
+ auto* func = Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize("x", "y", 16_u)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -447,43 +446,68 @@
EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32) {
+ // @stage(compute) @workgroup_size(1i, 2i, 3i)
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_i, 3_i)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
// @stage(compute) @workgroup_size(1u, 2u, 3u)
// fn main() {}
Func("main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Source{{12, 34}}, Expr(1_u), Expr(2_u), Expr(3_u))});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_u, 3_u)});
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeU32) {
- // @stage(compute) @workgroup_size(1u, 2u, 3_i)
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_I32_AInt) {
+ // @stage(compute) @workgroup_size(1, 2i, 3)
// fn main() {}
Func("main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Expr(1_u), Expr(2_u), Expr(Source{{12, 34}}, 3_i))});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_a, 2_i, 3_a)});
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size arguments must be of the same type, "
- "either i32 or u32");
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchTypeI32) {
- // @stage(compute) @workgroup_size(1_i, 2u, 3_i)
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32_AInt) {
+ // @stage(compute) @workgroup_size(1u, 2, 3u)
// fn main() {}
Func("main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, 2_u), Expr(3_i))});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_u)});
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
+ // @stage(compute) @workgroup_size(1u, 2, 3_i)
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_u, 2_a, 3_i)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size arguments must be of the same type, "
- "either i32 or u32");
+ "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_I32) {
+ // @stage(compute) @workgroup_size(1_i, 2u, 3)
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, 2_u, 3_a)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch) {
@@ -492,13 +516,11 @@
// fn main() {}
GlobalConst("x", ty.u32(), Expr(64_u));
Func("main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Expr(1_i), Expr(Source{{12, 34}}, "x"))});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, 1_i, "x")});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size arguments must be of the same type, "
- "either i32 or u32");
+ "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_TypeMismatch2) {
@@ -509,13 +531,11 @@
GlobalConst("x", ty.u32(), Expr(64_u));
GlobalConst("y", ty.i32(), Expr(32_i));
Func("main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Expr("x"), Expr(Source{{12, 34}}, "y"))});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y")});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size arguments must be of the same type, "
- "either i32 or u32");
+ "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Mismatch_ConstU32) {
// let x = 4u;
@@ -525,13 +545,11 @@
GlobalConst("x", ty.u32(), Expr(4_u));
GlobalConst("y", ty.u32(), Expr(8_u));
Func("main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Expr("x"), Expr("y"), Expr(Source{{12, 34}}, 16_i))});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(Source{{12, 34}}, "x", "y", 16_i)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: workgroup_size arguments must be of the same type, "
- "either i32 or u32");
+ "12:34 error: workgroup_size arguments must be of the same type, either i32 or u32");
}
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index bcb1671..8f4451e 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -134,6 +134,11 @@
// default: {}
// }
kSwitchCaseWithAbstractCase,
+
+ // @workgroup_size(target_expr, abstract_expr, 123)
+ // @stage(compute)
+ // fn f() {}
+ kWorkgroupSize
};
static std::ostream& operator<<(std::ostream& o, Method m) {
@@ -162,6 +167,8 @@
return o << "switch-cond-with-abstract";
case Method::kSwitchCaseWithAbstractCase:
return o << "switch-case-with-abstract";
+ case Method::kWorkgroupSize:
+ return o << "workgroup-size";
}
return o << "<unknown>";
}
@@ -286,6 +293,11 @@
Case(abstract_expr->As<ast::IntLiteralExpression>()), //
DefaultCase()));
break;
+ case Method::kWorkgroupSize:
+ Func("f", {}, ty.void_(), {},
+ {WorkgroupSize(target_expr(), abstract_expr, Expr(123_a)),
+ Stage(ast::PipelineStage::kCompute)});
+ break;
}
auto check_types_and_values = [&](const sem::Expression* expr) {
@@ -461,6 +473,19 @@
Types<u32, AInt>(AInt(kLowestU32), kLowestU32), //
})));
+INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
+ MaterializeAbstractNumericToConcreteType,
+ testing::Combine(testing::Values(Expectation::kMaterialize),
+ testing::Values(Method::kWorkgroupSize),
+ testing::ValuesIn(std::vector<Data>{
+ Types<i32, AInt>(1_a, 1.0), //
+ Types<i32, AInt>(10_a, 10.0), //
+ Types<i32, AInt>(65535_a, 65535.0), //
+ Types<u32, AInt>(1_a, 1.0), //
+ Types<u32, AInt>(10_a, 10.0), //
+ Types<u32, AInt>(65535_a, 65535.0), //
+ })));
+
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
MaterializeAbstractNumericToConcreteType,
@@ -558,6 +583,11 @@
// default: {}
// }
kSwitch,
+
+ // @workgroup_size(abstract_expr)
+ // @stage(compute)
+ // fn f() {}
+ kWorkgroupSize
};
static std::ostream& operator<<(std::ostream& o, Method m) {
@@ -576,6 +606,8 @@
return o << "array-length";
case Method::kSwitch:
return o << "switch";
+ case Method::kWorkgroupSize:
+ return o << "workgroup-size";
}
return o << "<unknown>";
}
@@ -656,6 +688,10 @@
Case(abstract_expr()->As<ast::IntLiteralExpression>()),
DefaultCase()));
break;
+ case Method::kWorkgroupSize:
+ Func("f", {}, ty.void_(), {},
+ {WorkgroupSize(abstract_expr()), Stage(ast::PipelineStage::kCompute)});
+ break;
}
auto check_types_and_values = [&](const sem::Expression* expr) {
@@ -734,11 +770,6 @@
Method::kVar,
};
-/// Methods that support materialization for switch cases
-constexpr Method kSwitchMethods[] = {
- Method::kSwitch,
-};
-
INSTANTIATE_TEST_SUITE_P(
MaterializeScalar,
MaterializeAbstractNumericToDefaultType,
@@ -798,13 +829,23 @@
INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kMaterialize),
- testing::ValuesIn(kSwitchMethods),
+ testing::Values(Method::kSwitch),
testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(0_a, 0.0), //
Types<i32, AInt>(AInt(kHighestI32), kHighestI32), //
Types<i32, AInt>(AInt(kLowestI32), kLowestI32), //
})));
+INSTANTIATE_TEST_SUITE_P(MaterializeWorkgroupSize,
+ MaterializeAbstractNumericToDefaultType,
+ testing::Combine(testing::Values(Expectation::kMaterialize),
+ testing::Values(Method::kWorkgroupSize),
+ testing::ValuesIn(std::vector<Data>{
+ Types<i32, AInt>(1_a, 1.0), //
+ Types<i32, AInt>(10_a, 10.0), //
+ Types<i32, AInt>(65535_a, 65535.0), //
+ })));
+
INSTANTIATE_TEST_SUITE_P(ScalarValueCannotBeRepresented,
MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
@@ -840,7 +881,16 @@
INSTANTIATE_TEST_SUITE_P(SwitchValueCannotBeRepresented,
MaterializeAbstractNumericToDefaultType,
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
- testing::ValuesIn(kSwitchMethods),
+ testing::Values(Method::kSwitch),
+ testing::ValuesIn(std::vector<Data>{
+ Types<i32, AInt>(0_a, kHighestI32 + 1), //
+ Types<i32, AInt>(0_a, kLowestI32 - 1), //
+ })));
+
+INSTANTIATE_TEST_SUITE_P(WorkgroupSizeValueCannotBeRepresented,
+ MaterializeAbstractNumericToDefaultType,
+ testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented),
+ testing::Values(Method::kWorkgroupSize),
testing::ValuesIn(std::vector<Data>{
Types<i32, AInt>(0_a, kHighestI32 + 1), //
Types<i32, AInt>(0_a, kLowestI32 - 1), //
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 807dc62..c964d8a 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -721,52 +721,61 @@
}
auto values = attr->Values();
- auto any_i32 = false;
- auto any_u32 = false;
+ std::array<const sem::Expression*, 3> args = {};
+ std::array<const sem::Type*, 3> arg_tys = {};
+ size_t arg_count = 0;
+
+ constexpr const char* kErrBadType =
+ "workgroup_size argument must be either literal or module-scope constant of type i32 "
+ "or u32";
+
for (int i = 0; i < 3; i++) {
- // Each argument to this attribute can either be a literal, an
- // identifier for a module-scope constants, or nullptr if not specified.
-
- auto* expr = values[i];
+ // Each argument to this attribute can either be a literal, an identifier for a module-scope
+ // constants, or nullptr if not specified.
+ auto* value = values[i];
+ if (!value) {
+ break;
+ }
+ const auto* expr = Expression(value);
if (!expr) {
- // Not specified, just use the default.
- continue;
+ return false;
}
-
- auto* expr_sem = Expression(expr);
- if (!expr_sem) {
+ auto* ty = expr->Type();
+ if (!ty->IsAnyOf<sem::I32, sem::U32, sem::AbstractInt>()) {
+ AddError(kErrBadType, value->source);
return false;
}
- constexpr const char* kErrBadType =
- "workgroup_size argument must be either literal or module-scope "
- "constant of type i32 or u32";
- constexpr const char* kErrInconsistentType =
- "workgroup_size arguments must be of the same type, either i32 "
- "or u32";
+ args[i] = expr;
+ arg_tys[i] = ty;
+ arg_count++;
+ }
- auto* ty = sem_.TypeOf(expr);
- bool is_i32 = ty->UnwrapRef()->Is<sem::I32>();
- bool is_u32 = ty->UnwrapRef()->Is<sem::U32>();
- if (!is_i32 && !is_u32) {
- AddError(kErrBadType, expr->source);
- return false;
- }
+ auto* common_ty = sem::Type::Common(arg_tys.data(), arg_count);
+ if (!common_ty) {
+ AddError("workgroup_size arguments must be of the same type, either i32 or u32",
+ attr->source);
+ return false;
+ }
- any_i32 = any_i32 || is_i32;
- any_u32 = any_u32 || is_u32;
- if (any_i32 && any_u32) {
- AddError(kErrInconsistentType, expr->source);
+ // If all arguments are abstract-integers, then materialize to i32.
+ if (common_ty->Is<sem::AbstractInt>()) {
+ common_ty = builder_->create<sem::I32>();
+ }
+
+ for (size_t i = 0; i < arg_count; i++) {
+ auto* materialized = Materialize(args[i], common_ty);
+ if (!materialized) {
return false;
}
sem::Constant value;
- if (auto* user = sem_.Get(expr)->As<sem::VariableUser>()) {
+ if (auto* user = args[i]->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration();
if (!decl->is_const) {
- AddError(kErrBadType, expr->source);
+ AddError(kErrBadType, values[i]->source);
return false;
}
// Capture the constant if it is pipeline-overridable.
@@ -781,8 +790,8 @@
ws[i].value = 0;
continue;
}
- } else if (expr->Is<ast::LiteralExpression>()) {
- value = sem_.Get(expr)->ConstantValue();
+ } else if (values[i]->Is<ast::LiteralExpression>()) {
+ value = materialized->ConstantValue();
} else {
AddError(
"workgroup_size argument must be either a literal or a "