[ir] Validate intermediate workgroup size products When checking that the total workgroup size is less that UINT32_MAX, we were only checking the final x*y*z product. This may overflow a uint64_t and wrap around to be a valid uint32_t value, so we need to check the intermediate products instead. Fixed: 463283605 Change-Id: Ie4cb2354bc6693230b5b591d152ec3d95b0469c4 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/277935 Reviewed-by: Peter McNeeley <petermcneeley@google.com> Commit-Queue: James Price <jrprice@google.com> Commit-Queue: Peter McNeeley <petermcneeley@google.com> Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc index 1d34719..79bb2c9 100644 --- a/src/tint/lang/core/ir/validator.cc +++ b/src/tint/lang/core/ir/validator.cc
@@ -2970,6 +2970,12 @@ return; } total_size *= c->Value()->ValueAs<uint64_t>(); + + constexpr uint64_t kMaxGridSize = 0xffffffff; + if (total_size > kMaxGridSize) { + AddError(func) << "workgroup grid size cannot exceed 0x" << std::hex + << kMaxGridSize; + } continue; } @@ -3007,11 +3013,6 @@ AddError(func) << "@workgroup_size must be an InstructionResult or a Constant"; } - - constexpr uint64_t kMaxGridSize = 0xffffffff; - if (total_size > kMaxGridSize) { - AddError(func) << "workgroup grid size cannot exceed 0x" << std::hex << kMaxGridSize; - } } void Validator::CheckPositionPresentForVertexOutput(const Function* ep) {
diff --git a/src/tint/lang/core/ir/validator_function_test.cc b/src/tint/lang/core/ir/validator_function_test.cc index 66f1f30..fd3addf 100644 --- a/src/tint/lang/core/ir/validator_function_test.cc +++ b/src/tint/lang/core/ir/validator_function_test.cc
@@ -2498,6 +2498,24 @@ )")) << res.Failure(); } +// Test the case where the intermediate workgroup product overflows a uint64_t and wraps back around +// to be a valid uint32_t value. +TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamsTooLarge_U64Overflow) { + auto* f = ComputeEntryPoint(); + f->SetWorkgroupSize( + {b.Constant(1526726656_i), b.Constant(1526726656_i), b.Constant(1526726656_i)}); + + b.Append(f->Block(), [&] { b.Unreachable(); }); + + auto res = ir::Validate(mod); + ASSERT_NE(res, Success); + EXPECT_THAT(res.Failure().reason, + testing::HasSubstr(R"(:1:1 error: workgroup grid size cannot exceed 0xffffffff +%f = @compute @workgroup_size(1526726656i, 1526726656i, 1526726656i) func():void { +^^ +)")) << res.Failure(); +} + TEST_F(IR_ValidatorTest, Function_WorkgroupSize_OverrideWithoutAllowOverrides) { auto* o = b.Override(ty.u32()); auto* f = ComputeEntryPoint();