Range Analysis: Compute range of `Binary` with `Subtract` operator
This patch computes the range of a `Binary` with `Subtract` operator.
When any overflow or underflow happen, no valid range will be returned.
Bug: 348701956
Test: tint_unittests
Change-Id: I947028dabbc2448641921c0671ff14f63ba292ac
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241415
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index d5312a8..6f73a2b 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -253,9 +253,12 @@
// TODO(348701956): Support more binary operators
switch (binary->Op()) {
- case BinaryOp::kAdd: {
+ case BinaryOp::kAdd:
return ComputeAndCacheIntegerRangeForBinaryAdd(binary, range_lhs, range_rhs);
- }
+
+ case BinaryOp::kSubtract:
+ return ComputeAndCacheIntegerRangeForBinarySubtract(binary, range_lhs, range_rhs);
+
default:
return nullptr;
}
@@ -791,6 +794,68 @@
}
}
+ const IntegerRangeInfo* ComputeAndCacheIntegerRangeForBinarySubtract(
+ const Binary* binary,
+ const IntegerRangeInfo* lhs,
+ const IntegerRangeInfo* rhs) {
+ // Subtract two 32-bit signed integer values saved in int64_t. Return {} when either
+ // overflow or underflow happens.
+ auto SafeSubtractI32 = [](int64_t a, int64_t b) -> std::optional<int64_t> {
+ TINT_ASSERT(a >= i32::kLowestValue && a <= i32::kHighestValue);
+ TINT_ASSERT(b >= i32::kLowestValue && b <= i32::kHighestValue);
+
+ int64_t diff = a - b;
+ if (diff > i32::kHighestValue || diff < i32::kLowestValue) {
+ return {};
+ }
+ return diff;
+ };
+
+ // No-underflow Subtract between two 32-bit unsigned integer values saved in uint64_t.
+ // Return {} when underflow happens.
+ auto SafeSubtractU32 = [](uint64_t a, uint64_t b) -> std::optional<uint64_t> {
+ TINT_ASSERT(a <= u32::kHighestValue);
+ TINT_ASSERT(b <= u32::kHighestValue);
+
+ if (a < b) {
+ return {};
+ }
+ return a - b;
+ };
+
+ if (std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(lhs->range)) {
+ auto lhs_i32 = std::get<IntegerRangeInfo::SignedIntegerRange>(lhs->range);
+ auto rhs_i32 = std::get<IntegerRangeInfo::SignedIntegerRange>(rhs->range);
+
+ // [min1, max1] - [min2, max2] => [min1 - max2, max1 - min2]
+ std::optional<int64_t> min_bound =
+ SafeSubtractI32(lhs_i32.min_bound, rhs_i32.max_bound);
+ std::optional<int64_t> max_bound =
+ SafeSubtractI32(lhs_i32.max_bound, rhs_i32.min_bound);
+ if (!min_bound.has_value() || !max_bound.has_value()) {
+ return nullptr;
+ }
+ auto result = integer_binary_range_info_map_.Add(
+ binary, IntegerRangeInfo(*min_bound, *max_bound));
+ return &result.value;
+ } else {
+ auto lhs_u32 = std::get<IntegerRangeInfo::UnsignedIntegerRange>(lhs->range);
+ auto rhs_u32 = std::get<IntegerRangeInfo::UnsignedIntegerRange>(rhs->range);
+
+ // [min1, max1] - [min2, max2] => [min1 - max2, max1 - min2]
+ std::optional<uint64_t> min_bound =
+ SafeSubtractU32(lhs_u32.min_bound, rhs_u32.max_bound);
+ std::optional<uint64_t> max_bound =
+ SafeSubtractU32(lhs_u32.max_bound, rhs_u32.min_bound);
+ if (!min_bound || !max_bound) {
+ return nullptr;
+ }
+ auto result = integer_binary_range_info_map_.Add(
+ binary, IntegerRangeInfo(*min_bound, *max_bound));
+ return &result.value;
+ }
+ }
+
Function* function_;
Hashmap<const FunctionParam*, Vector<IntegerRangeInfo, 3>, 4>
integer_function_param_range_info_map_;
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 5133cf8..065d384 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -9469,5 +9469,496 @@
ASSERT_EQ(nullptr, info);
}
+TEST_F(IR_IntegerRangeAnalysisTest, BinarySubtract_Success_U32) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 1_u, 1_u);
+ auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Var* idx = nullptr;
+ Binary* subtract = nullptr;
+ b.Append(func->Block(), [&] {
+ auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 4
+ idx = b.Var("idx", 4_u);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_u);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ // add = idx - local_id.x
+ subtract = b.Subtract<u32>(b.Load(idx), access_x);
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<u32>(b.Load(idx), 1_u));
+ b.NextIteration(loop);
+ });
+
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 1u, 1u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %local_id, 0u
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, u32, read_write> = var 4u
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %5:u32 = load %idx
+ %6:bool = lt %5, 10u
+ if %6 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %7:u32 = load %idx
+ %8:u32 = sub %7, %3
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %9:u32 = load %idx
+ %10:u32 = add %9, 1u
+ store %idx, %10
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `subtract` (`idx - local_id.x`)
+ // idx: [4, 9] local_id.x: [0, 3]
+ auto* info = analysis.GetInfo(subtract);
+ ASSERT_NE(nullptr, info);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+ EXPECT_EQ(1u, range.min_bound);
+ EXPECT_EQ(9u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinarySubtract_Success_I32) {
+ Var* idx = nullptr;
+ Binary* subtract = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 5
+ idx = b.Var("idx", 5_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+
+ Var* idy = nullptr;
+ auto* loop2 = b.Loop();
+ b.Append(loop2->Initializer(), [&] {
+ // idy = 1
+ idy = b.Var("idy", 1_i);
+ b.NextIteration(loop2);
+ });
+ b.Append(loop2->Body(), [&] {
+ // idy < 4
+ auto* binary_inner = b.LessThan<bool>(b.Load(idy), 4_i);
+ auto* ifelse_inner = b.If(binary_inner);
+ b.Append(ifelse_inner->True(), [&] { b.ExitIf(ifelse_inner); });
+ b.Append(ifelse_inner->False(), [&] { b.ExitLoop(loop2); });
+ auto* loadx = b.Load(idx);
+ auto* loady = b.Load(idy);
+ subtract = b.Subtract<i32>(loadx, loady);
+ b.Continue(loop2);
+ });
+ b.Append(loop2->Continuing(), [&] {
+ // idy++
+ b.Store(idy, b.Add<i32>(b.Load(idy), 1_i));
+ b.NextIteration(loop2);
+ });
+
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var 5i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 10i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ loop [i: $B7, b: $B8, c: $B9] { # loop_2
+ $B7: { # initializer
+ %idy:ptr<function, i32, read_write> = var 1i
+ next_iteration # -> $B8
+ }
+ $B8: { # body
+ %6:i32 = load %idy
+ %7:bool = lt %6, 4i
+ if %7 [t: $B10, f: $B11] { # if_2
+ $B10: { # true
+ exit_if # if_2
+ }
+ $B11: { # false
+ exit_loop # loop_2
+ }
+ }
+ %8:i32 = load %idx
+ %9:i32 = load %idy
+ %10:i32 = sub %8, %9
+ continue # -> $B9
+ }
+ $B9: { # continuing
+ %11:i32 = load %idy
+ %12:i32 = add %11, 1i
+ store %idy, %12
+ next_iteration # -> $B8
+ }
+ }
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %13:i32 = load %idx
+ %14:i32 = add %13, 1i
+ store %idx, %14
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `subtract` (`idx - idy`)
+ // idx: [5, 9], idy: [1, 3]
+ auto* info = analysis.GetInfo(subtract);
+ ASSERT_NE(nullptr, info);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
+ EXPECT_EQ(2, range.min_bound);
+ EXPECT_EQ(8, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinarySubtract_Failure_Underflow_U32) {
+ Var* idx = nullptr;
+ Binary* subtract = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 5
+ idx = b.Var("idx", 5_u);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_u);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+
+ Var* idy = nullptr;
+ auto* loop2 = b.Loop();
+ b.Append(loop2->Initializer(), [&] {
+ // idy = 1
+ idy = b.Var("idy", 1_u);
+ b.NextIteration(loop2);
+ });
+ b.Append(loop2->Body(), [&] {
+ // idy < 7
+ auto* binary_inner = b.LessThan<bool>(b.Load(idy), 7_u);
+ auto* ifelse_inner = b.If(binary_inner);
+ b.Append(ifelse_inner->True(), [&] { b.ExitIf(ifelse_inner); });
+ b.Append(ifelse_inner->False(), [&] { b.ExitLoop(loop2); });
+ auto* loadx = b.Load(idx);
+ auto* loady = b.Load(idy);
+ subtract = b.Subtract<u32>(loadx, loady);
+ b.Continue(loop2);
+ });
+ b.Append(loop2->Continuing(), [&] {
+ // idy++
+ b.Store(idy, b.Add<u32>(b.Load(idy), 1_u));
+ b.NextIteration(loop2);
+ });
+
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<u32>(b.Load(idx), 1_u));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, u32, read_write> = var 5u
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:u32 = load %idx
+ %4:bool = lt %3, 10u
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ loop [i: $B7, b: $B8, c: $B9] { # loop_2
+ $B7: { # initializer
+ %idy:ptr<function, u32, read_write> = var 1u
+ next_iteration # -> $B8
+ }
+ $B8: { # body
+ %6:u32 = load %idy
+ %7:bool = lt %6, 7u
+ if %7 [t: $B10, f: $B11] { # if_2
+ $B10: { # true
+ exit_if # if_2
+ }
+ $B11: { # false
+ exit_loop # loop_2
+ }
+ }
+ %8:u32 = load %idx
+ %9:u32 = load %idy
+ %10:u32 = sub %8, %9
+ continue # -> $B9
+ }
+ $B9: { # continuing
+ %11:u32 = load %idy
+ %12:u32 = add %11, 1u
+ store %idy, %12
+ next_iteration # -> $B8
+ }
+ }
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %13:u32 = load %idx
+ %14:u32 = add %13, 1u
+ store %idx, %14
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `subtract` (`idx - idy`)
+ // idx: [5, 9], idy: [1, 7], idx.min_bound < idy.max_bound
+ auto* info = analysis.GetInfo(subtract);
+ ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinarySubtract_Failure_Overflow_I32) {
+ // kSmallValue = -2147483640
+ constexpr int32_t kSmallValue = i32::kLowestValue + 8;
+
+ Var* idx = nullptr;
+ Binary* subtract = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 0
+ idx = b.Var("idx", 0_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 9
+ auto* binary = b.LessThan<bool>(b.Load(idx), 9_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ // subtract = idx - kSmallValue
+ subtract = b.Subtract<i32>(b.Load(idx), b.Constant(i32(kSmallValue)));
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var 0i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 9i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:i32 = load %idx
+ %6:i32 = sub %5, -2147483640i
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:i32 = load %idx
+ %8:i32 = add %7, 1i
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `subtract` (`idx - (-2147483640)`)
+ // idx: [0, 8], 2147483640 > i32::kHighestValue - 8
+ auto* info = analysis.GetInfo(subtract);
+ ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinarySubtract_Failure_Underflow_I32) {
+ // kLargeValue = 2147483640
+ constexpr int32_t kLargeValue = i32::kHighestValue - 7;
+
+ Var* idx = nullptr;
+ Binary* subtract = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = -9
+ idx = b.Var("idx", -9_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 0
+ auto* binary = b.LessThan<bool>(b.Load(idx), 0_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ // subtract = idx - kLargeValue
+ subtract = b.Subtract<i32>(b.Load(idx), b.Constant(i32(kLargeValue)));
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var -9i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 0i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:i32 = load %idx
+ %6:i32 = sub %5, 2147483640i
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:i32 = load %idx
+ %8:i32 = add %7, 1i
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `subtract` (`idx - 2147483640`)
+ // idx: [-9, 0], idx < i32::kLowestValue + 2147483640
+ auto* info = analysis.GetInfo(subtract);
+ ASSERT_EQ(nullptr, info);
+}
+
} // namespace
} // namespace tint::core::ir::analysis