[tint][ir][val] Check numeric type of params to @workgroup_size
Refactors checks for @workgroup_size into a separate utility function,
and implements rules about numeric types for params.
A follow up CL will implement rules related to params being
const/override expressions.
Issue: 376624999
Change-Id: Ib7cf63f2dd191bba95c2e94d45313584311c9342
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/213296
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 5c3b6f5..2221117 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -79,6 +79,7 @@
#include "src/tint/lang/core/ir/var.h"
#include "src/tint/lang/core/type/bool.h"
#include "src/tint/lang/core/type/f32.h"
+#include "src/tint/lang/core/type/i32.h"
#include "src/tint/lang/core/type/i8.h"
#include "src/tint/lang/core/type/memory_view.h"
#include "src/tint/lang/core/type/pointer.h"
@@ -883,6 +884,10 @@
/// @param func the function to validate
void CheckFunction(const Function* func);
+ /// Validates the workgroup_size attribute for a given function
+ /// @param func the function to validate
+ void CheckWorkgroupSize(const Function* func);
+
/// Validates the specific function as a vertex entry point
/// @param ep the function to validate
void CheckVertexEntryPoint(const Function* ep);
@@ -1903,21 +1908,14 @@
}
}
- if (func->Stage() == Function::PipelineStage::kCompute) {
- if (DAWN_UNLIKELY(!func->WorkgroupSize().has_value())) {
- AddError(func) << "compute entry point requires workgroup size attribute";
- }
+ CheckWorkgroupSize(func);
+ if (func->Stage() == Function::PipelineStage::kCompute) {
if (DAWN_UNLIKELY(func->ReturnType() && !func->ReturnType()->Is<core::type::Void>())) {
AddError(func) << "compute entry point must not have a return type";
}
}
- if (DAWN_UNLIKELY(func->Stage() != Function::PipelineStage::kCompute &&
- func->WorkgroupSize().has_value())) {
- AddError(func) << "workgroup size attribute only valid on compute entry point";
- }
-
if (func->Stage() == Function::PipelineStage::kFragment) {
if (!func->ReturnType()->Is<core::type::Void>()) {
CheckFunctionReturnAttributes(
@@ -1977,6 +1975,50 @@
ProcessTasks();
}
+void Validator::CheckWorkgroupSize(const Function* func) {
+ if (func->Stage() != Function::PipelineStage::kCompute) {
+ if (func->WorkgroupSize().has_value()) {
+ AddError(func) << "@workgroup_size only valid on compute entry point";
+ }
+ return;
+ }
+
+ if (!func->WorkgroupSize().has_value()) {
+ AddError(func) << "compute entry point requires @workgroup_size";
+ return;
+ }
+
+ auto workgroup_sizes = func->WorkgroupSize().value();
+ // The number parameters cannot be checked here, since it is stored internally as a 3 element
+ // array, so will always have 3 elements at this point.
+ TINT_ASSERT(workgroup_sizes.size() == 3);
+
+ std::optional<const core::type::Type*> sizes_ty;
+ for (auto* size : workgroup_sizes) {
+ if (!size || !size->Type()) {
+ AddError(func) << "a @workgroup_size param is undefined or missing a type";
+ return;
+ }
+
+ auto* ty = size->Type();
+ if (!ty->IsAnyOf<core::type::I32, core::type::U32>()) {
+ AddError(func) << "@workgroup_size params must be an i32 or u32";
+ return;
+ }
+
+ if (!sizes_ty.has_value()) {
+ sizes_ty = ty;
+ }
+
+ if (sizes_ty != ty) {
+ AddError(func) << "@workgroup_size params must be all i32s or all u32s";
+ return;
+ }
+
+ // TODO(376624999): Implement enforcing rules around override and constant expressions
+ }
+}
+
void Validator::CheckVertexEntryPoint(const Function* ep) {
bool contains_position = IsPositionPresent(ep->ReturnAttributes(), ep->ReturnType());
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 3060f7c..b2fccba 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -835,26 +835,6 @@
)");
}
-TEST_F(IR_ValidatorTest, Function_MissingWorkgroupSize) {
- auto* f = b.Function("f", ty.void_(), Function::PipelineStage::kCompute);
- b.Append(f->Block(), [&] { b.Return(f); });
-
- auto res = ir::Validate(mod);
- ASSERT_NE(res, Success);
- EXPECT_EQ(res.Failure().reason.Str(),
- R"(:1:1 error: compute entry point requires workgroup size attribute
-%f = @compute func():void {
-^^
-
-note: # Disassembly
-%f = @compute func():void {
- $B1: {
- ret
- }
-}
-)");
-}
-
TEST_F(IR_ValidatorTest, Function_UnnamedEntryPoint) {
auto* f = b.Function(ty.void_(), ir::Function::PipelineStage::kCompute);
f->SetWorkgroupSize({b.Constant(1_u), b.Constant(1_u), b.Constant(1_u)});
@@ -946,7 +926,7 @@
TEST_F(IR_ValidatorTest, Function_Compute_NonVoidReturn) {
auto* f = b.Function("my_func", ty.f32(), core::ir::Function::PipelineStage::kCompute);
- f->SetWorkgroupSize(b.Constant(0_u), b.Constant(0_u), b.Constant(0_u));
+ f->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
b.Append(f->Block(), [&] { b.Unreachable(); });
@@ -954,11 +934,11 @@
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
R"(:1:1 error: compute entry point must not have a return type
-%my_func = @compute @workgroup_size(0u, 0u, 0u) func():f32 {
+%my_func = @compute @workgroup_size(1u, 1u, 1u) func():f32 {
^^^^^^^^
note: # Disassembly
-%my_func = @compute @workgroup_size(0u, 0u, 0u) func():f32 {
+%my_func = @compute @workgroup_size(1u, 1u, 1u) func():f32 {
$B1: {
unreachable
}
@@ -966,7 +946,27 @@
)");
}
-TEST_F(IR_ValidatorTest, Function_WorkspaceSizeOnlyOnCompute) {
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_MissingOnCompute) {
+ auto* f = b.Function("f", ty.void_(), Function::PipelineStage::kCompute);
+ b.Append(f->Block(), [&] { b.Return(f); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:1:1 error: compute entry point requires @workgroup_size
+%f = @compute func():void {
+^^
+
+note: # Disassembly
+%f = @compute func():void {
+ $B1: {
+ ret
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_NonCompute) {
auto* f = FragmentEntryPoint();
f->SetWorkgroupSize(b.Constant(1_u), b.Constant(1_u), b.Constant(1_u));
@@ -975,7 +975,7 @@
auto res = ir::Validate(mod);
ASSERT_NE(res, Success);
EXPECT_EQ(res.Failure().reason.Str(),
- R"(:1:1 error: workgroup size attribute only valid on compute entry point
+ R"(:1:1 error: @workgroup_size only valid on compute entry point
%f = @fragment @workgroup_size(1u, 1u, 1u) func():void {
^^
@@ -988,6 +988,72 @@
)");
}
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamUndefined) {
+ auto* f = ComputeEntryPoint();
+ f->SetWorkgroupSize({nullptr, b.Constant(2_u), b.Constant(3_u)});
+
+ b.Append(f->Block(), [&] { b.Unreachable(); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:1:1 error: a @workgroup_size param is undefined or missing a type
+%f = @compute @workgroup_size(undef, 2u, 3u) func():void {
+^^
+
+note: # Disassembly
+%f = @compute @workgroup_size(undef, 2u, 3u) func():void {
+ $B1: {
+ unreachable
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamWrongType) {
+ auto* f = ComputeEntryPoint();
+ f->SetWorkgroupSize({b.Constant(1_f), b.Constant(2_u), b.Constant(3_u)});
+
+ b.Append(f->Block(), [&] { b.Unreachable(); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:1:1 error: @workgroup_size params must be an i32 or u32
+%f = @compute @workgroup_size(1.0f, 2u, 3u) func():void {
+^^
+
+note: # Disassembly
+%f = @compute @workgroup_size(1.0f, 2u, 3u) func():void {
+ $B1: {
+ unreachable
+ }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamsSameType) {
+ auto* f = ComputeEntryPoint();
+ f->SetWorkgroupSize({b.Constant(1_u), b.Constant(2_i), b.Constant(3_u)});
+
+ b.Append(f->Block(), [&] { b.Unreachable(); });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:1:1 error: @workgroup_size params must be all i32s or all u32s
+%f = @compute @workgroup_size(1u, 2i, 3u) func():void {
+^^
+
+note: # Disassembly
+%f = @compute @workgroup_size(1u, 2i, 3u) func():void {
+ $B1: {
+ unreachable
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, Function_Vertex_BasicPosition) {
auto* f = b.Function("my_func", ty.vec4<f32>(), Function::PipelineStage::kVertex);
f->SetReturnBuiltin(BuiltinValue::kPosition);