Range Analysis: Compute range of `Binary` with `Add` operator
This patch computes the range of a `Binary` with `Add` operator. To
compute the range we will first get the range of two operands, and
then compute the range of their sum. When any overflow or underflow
happen, no valid range will be returned.
Bug: 348701956
Test: tint_unittests
Change-Id: Ibe6eda96c8c654c3ffe22798472125506d595819
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241141
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 51b98c2..d5312a8 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -225,16 +225,42 @@
return GetInfo(param, 0);
},
[&](const InstructionResult* r) {
- // TODO(348701956): Support more instruction types. e.g. Binary, Let, ...
+ // TODO(348701956): Support more instruction types. e.g. Let, Convert, ...
return Switch(
r->Instruction(), [&](const Var* var) { return GetInfo(var); },
[&](const Load* load) { return GetInfo(load); },
[&](const Access* access) { return GetInfo(access); },
+ [&](const Binary* binary) { return GetInfo(binary); },
[](Default) { return nullptr; });
},
[](Default) { return nullptr; });
}
+ const IntegerRangeInfo* GetInfo(const Binary* binary) {
+ const IntegerRangeInfo* existing_range = integer_binary_range_info_map_.Get(binary).value;
+ if (existing_range) {
+ return existing_range;
+ }
+
+ const IntegerRangeInfo* range_lhs = GetInfo(binary->LHS());
+ if (!range_lhs) {
+ return nullptr;
+ }
+ const IntegerRangeInfo* range_rhs = GetInfo(binary->RHS());
+ if (!range_rhs) {
+ return nullptr;
+ }
+
+ // TODO(348701956): Support more binary operators
+ switch (binary->Op()) {
+ case BinaryOp::kAdd: {
+ return ComputeAndCacheIntegerRangeForBinaryAdd(binary, range_lhs, range_rhs);
+ }
+ default:
+ return nullptr;
+ }
+ }
+
/// Analyze a loop to compute the range of the loop control variable if possible.
void AnalyzeLoop(const Loop* loop) {
const Var* index = GetLoopControlVariableFromConstantInitializer(loop);
@@ -707,11 +733,70 @@
}
private:
+ const IntegerRangeInfo* ComputeAndCacheIntegerRangeForBinaryAdd(const Binary* binary,
+ const IntegerRangeInfo* lhs,
+ const IntegerRangeInfo* rhs) {
+ // Add two 32-bit signed integer values saved in int64_t. Return {} when either overflow or
+ // underflow happens.
+ auto SafeAddI32 = [](int64_t a, int64_t b) -> std::optional<int64_t> {
+ TINT_ASSERT(a >= i32::kLowestValue && a <= i32::kHighestValue);
+ TINT_ASSERT(b >= i32::kLowestValue && b <= i32::kHighestValue);
+
+ int64_t sum = a + b;
+ if (sum > i32::kHighestValue || sum < i32::kLowestValue) {
+ return {};
+ }
+ return sum;
+ };
+
+ // No-overflow add between two 32-bit unsigned integer values saved in uint64_t. Return {}
+ // when overflow happens.
+ auto SafeAddU32 = [](uint64_t a, uint64_t b) -> std::optional<uint64_t> {
+ TINT_ASSERT(a <= u32::kHighestValue);
+ TINT_ASSERT(b <= u32::kHighestValue);
+
+ uint64_t sum = a + b;
+ if (sum > u32::kHighestValue) {
+ return {};
+ }
+ return sum;
+ };
+
+ if (std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(lhs->range)) {
+ auto lhs_i32 = std::get<IntegerRangeInfo::SignedIntegerRange>(lhs->range);
+ auto rhs_i32 = std::get<IntegerRangeInfo::SignedIntegerRange>(rhs->range);
+
+ // [min1, max1] + [min2, max2] => [min1 + min2, max1 + max2]
+ std::optional<int64_t> min_bound = SafeAddI32(lhs_i32.min_bound, rhs_i32.min_bound);
+ std::optional<int64_t> max_bound = SafeAddI32(lhs_i32.max_bound, rhs_i32.max_bound);
+ if (!min_bound.has_value() || !max_bound.has_value()) {
+ return nullptr;
+ }
+ auto result = integer_binary_range_info_map_.Add(
+ binary, IntegerRangeInfo(*min_bound, *max_bound));
+ return &result.value;
+ } else {
+ auto lhs_u32 = std::get<IntegerRangeInfo::UnsignedIntegerRange>(lhs->range);
+ auto rhs_u32 = std::get<IntegerRangeInfo::UnsignedIntegerRange>(rhs->range);
+
+ // [min1, max1] + [min2, max2] => [min1 + min2, max1 + max2]
+ std::optional<uint64_t> min_bound = SafeAddU32(lhs_u32.min_bound, rhs_u32.min_bound);
+ std::optional<uint64_t> max_bound = SafeAddU32(lhs_u32.max_bound, rhs_u32.max_bound);
+ if (!min_bound || !max_bound) {
+ return nullptr;
+ }
+ auto result = integer_binary_range_info_map_.Add(
+ binary, IntegerRangeInfo(*min_bound, *max_bound));
+ return &result.value;
+ }
+ }
+
Function* function_;
Hashmap<const FunctionParam*, Vector<IntegerRangeInfo, 3>, 4>
integer_function_param_range_info_map_;
Hashmap<const Var*, IntegerRangeInfo, 8> integer_var_range_info_map_;
Hashmap<const Constant*, IntegerRangeInfo, 8> integer_constant_range_info_map_;
+ Hashmap<const Binary*, IntegerRangeInfo, 8> integer_binary_range_info_map_;
};
IntegerRangeAnalysis::IntegerRangeAnalysis(Function* func)
@@ -760,4 +845,8 @@
return impl_->GetInfo(value);
}
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Binary* binary) {
+ return impl_->GetInfo(binary);
+}
+
} // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index 33defaf..cf8441c 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -108,6 +108,10 @@
/// it has a meaningful range. Returns nullptr otherwise.
const IntegerRangeInfo* GetInfo(const Value* value);
+ /// Returns the integer range info of a given `Binary` variable if it is an integer variable and
+ /// it has a meaningful range. Returns nullptr otherwise.
+ const IntegerRangeInfo* GetInfo(const Binary* binary);
+
/// Note: This function is only for tests.
/// Returns the pointer of the loop control variable in the given loop when its initializer
/// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index c76976d..5133cf8 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -8886,6 +8886,7 @@
func->SetParams({localInvocationIndex});
b.Append(func->Block(), [&] {
+ // add = localInvocationIndex + 5
add = b.Add<u32>(localInvocationIndex, 5_u);
b.Return(func);
});
@@ -8902,13 +8903,22 @@
EXPECT_EQ(Validate(mod), Success);
IntegerRangeAnalysis analysis(func);
+
+ // Range of `add->LHS()` (`localInvocationIndex`)
auto* info = analysis.GetInfo(add->LHS());
ASSERT_NE(nullptr, info);
ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
-
const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
EXPECT_EQ(0u, range.min_bound);
EXPECT_EQ(23u, range.max_bound);
+
+ // Range of `add` (`localInvocationIndex + 5`)
+ auto* info_add = analysis.GetInfo(add);
+ ASSERT_NE(nullptr, info_add);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range));
+ const auto& range_add = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range);
+ EXPECT_EQ(5u, range_add.min_bound);
+ EXPECT_EQ(28u, range_add.max_bound);
}
TEST_F(IR_IntegerRangeAnalysisTest, ValueAsVectorFunctionParameter) {
@@ -8971,6 +8981,7 @@
IntegerRangeAnalysis analysis(func);
+ // Range of `add->LHS()` (`local_id.x`)
auto* info_lhs = analysis.GetInfo(add->LHS());
ASSERT_NE(nullptr, info_lhs);
ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_lhs->range));
@@ -8978,12 +8989,21 @@
EXPECT_EQ(0u, range_lhs.min_bound);
EXPECT_EQ(3u, range_lhs.max_bound);
+ // Range of `add->RHS()` (`local_id.y`)
auto* info_rhs = analysis.GetInfo(add->RHS());
ASSERT_NE(nullptr, info_rhs);
ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_rhs->range));
const auto& range_rhs = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_rhs->range);
EXPECT_EQ(0u, range_rhs.min_bound);
EXPECT_EQ(2u, range_rhs.max_bound);
+
+ // Range of `add` (`local_id.x + local_id.y`)
+ auto* info_add = analysis.GetInfo(add);
+ ASSERT_NE(nullptr, info_add);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range));
+ const auto& range_add = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range);
+ EXPECT_EQ(0u, range_add.min_bound);
+ EXPECT_EQ(5u, range_add.max_bound);
}
TEST_F(IR_IntegerRangeAnalysisTest, ValueAsLoadAndConstant) {
@@ -9054,6 +9074,7 @@
IntegerRangeAnalysis analysis(func);
+ // Range of `add->LHS()` (`idx`)
auto* info_lhs = analysis.GetInfo(add->LHS());
ASSERT_NE(nullptr, info_lhs);
ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info_lhs->range));
@@ -9061,17 +9082,27 @@
EXPECT_EQ(0, range_lhs.min_bound);
EXPECT_EQ(9, range_lhs.max_bound);
+ // Range of `add->RHS()` (5)
auto* info_rhs = analysis.GetInfo(add->RHS());
ASSERT_NE(nullptr, info_rhs);
ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info_rhs->range));
const auto& range_rhs = std::get<IntegerRangeInfo::SignedIntegerRange>(info_rhs->range);
EXPECT_EQ(5, range_rhs.min_bound);
EXPECT_EQ(5, range_rhs.max_bound);
+
+ // Range of `add` (`idx + 5`)
+ auto* info_add = analysis.GetInfo(add);
+ ASSERT_NE(nullptr, info_add);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info_add->range));
+ const auto& range_add = std::get<IntegerRangeInfo::SignedIntegerRange>(info_add->range);
+ EXPECT_EQ(5, range_add.min_bound);
+ EXPECT_EQ(14, range_add.max_bound);
}
TEST_F(IR_IntegerRangeAnalysisTest, ValueAsVar) {
Var* idx = nullptr;
Value* value = nullptr;
+ Binary* add = nullptr;
auto* func = b.Function("func", ty.void_());
b.Append(func->Block(), [&] {
auto* loop = b.Loop();
@@ -9087,7 +9118,8 @@
b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
value = b.Value(idx);
- b.Add<i32>(b.Load(value), 5_i);
+ // add = value + 5
+ add = b.Add<i32>(b.Load(value), 5_i);
b.Continue(loop);
});
b.Append(loop->Continuing(), [&] {
@@ -9138,12 +9170,303 @@
IntegerRangeAnalysis analysis(func);
+ // Range of `value`
auto* info = analysis.GetInfo(value);
ASSERT_NE(nullptr, info);
ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
EXPECT_EQ(0, range.min_bound);
EXPECT_EQ(9, range.max_bound);
+
+ // Range of `add` (`value + 5`)
+ auto* info_add = analysis.GetInfo(add);
+ ASSERT_NE(nullptr, info_add);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info_add->range));
+ const auto& range_add = std::get<IntegerRangeInfo::SignedIntegerRange>(info_add->range);
+ EXPECT_EQ(5, range_add.min_bound);
+ EXPECT_EQ(14, range_add.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, MultipleBinaryAdds) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Var* idx = nullptr;
+ Binary* add = nullptr;
+ b.Append(func->Block(), [&] {
+ auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+ auto* access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 0
+ idx = b.Var("idx", 0_u);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_u);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ // add1 = local_id.x + local_id.y
+ auto* add1 = b.Add<u32>(access_x, access_y);
+ // add = idx + add1
+ add = b.Add<u32>(b.Load(idx), add1);
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<u32>(b.Load(idx), 1_u));
+ b.NextIteration(loop);
+ });
+
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %local_id, 0u
+ %4:u32 = access %local_id, 1u
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, u32, read_write> = var 0u
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %6:u32 = load %idx
+ %7:bool = lt %6, 10u
+ if %7 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %8:u32 = add %3, %4
+ %9:u32 = load %idx
+ %10:u32 = add %9, %8
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %11:u32 = load %idx
+ %12:u32 = add %11, 1u
+ store %idx, %12
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `add` (`idx + (local_id.x + local_id.y)`)
+ // access_x: [0, 3], access_y: [0, 2], idx: [0, 9]
+ auto* info_add = analysis.GetInfo(add);
+ ASSERT_NE(nullptr, info_add);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range));
+ const auto& range_add = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range);
+ EXPECT_EQ(0u, range_add.min_bound);
+ EXPECT_EQ(14u, range_add.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinaryAdd_U32_Overflow) {
+ auto* func = b.ComputeFunction("my_func", 8_u, 1_u, 1_u);
+ auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ // kLargeValue = 4294967289u
+ constexpr uint32_t kLargeValue = u32::kHighestValue - 6u;
+ Binary* add = nullptr;
+ b.Append(func->Block(), [&] {
+ auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+ // add = local_id.x + kLargeValue
+ add = b.Add<u32>(access_x, b.Constant(u32(kLargeValue)));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(8u, 1u, 1u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %local_id, 0u
+ %4:u32 = add %3, 4294967289u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `add` (`local_id.x + 4294967289`)
+ // local_id.x: [0, 7], 4294967289 > u32::kHighestValue - 7
+ auto* info = analysis.GetInfo(add);
+ ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinaryAdd_I32_Overflow) {
+ // kLargeValue = 2147483639
+ constexpr int32_t kLargeValue = i32::kHighestValue - 8;
+
+ Var* idx = nullptr;
+ Binary* add = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 0
+ idx = b.Var("idx", 0_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ // add = idx + kLargeValue
+ add = b.Add<i32>(b.Load(idx), b.Constant(i32(kLargeValue)));
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var 0i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 10i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:i32 = load %idx
+ %6:i32 = add %5, 2147483639i
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:i32 = load %idx
+ %8:i32 = add %7, 1i
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `add` (`idx + 2147483639`)
+ // idx: [0, 9], 2147483639 > i32::kHighestValue - 9
+ auto* info = analysis.GetInfo(add);
+ ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, BinaryAdd_I32_Underflow) {
+ // kSmallValue = -2147483640
+ constexpr int32_t kSmallValue = i32::kLowestValue + 8;
+
+ Var* idx = nullptr;
+ Binary* add = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = -9
+ idx = b.Var("idx", -9_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 1
+ auto* binary = b.LessThan<bool>(b.Load(idx), 1_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ // add = idx + kSmallValue
+ add = b.Add<i32>(b.Load(idx), b.Constant(i32(kSmallValue)));
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var -9i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 1i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:i32 = load %idx
+ %6:i32 = add %5, -2147483640i
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:i32 = load %idx
+ %8:i32 = add %7, 1i
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ // Range of `add` (`idx + (-2147483640)`)
+ // idx: [-9, 0], -2147483640 < i32::kLowestValue + 9
+ auto* info = analysis.GetInfo(add);
+ ASSERT_EQ(nullptr, info);
}
} // namespace