Range Analysis: Compute range of a scalar integer `Constant`
This patch computes the range of an integer scalar `Constant` and
returns an `IntegerRangeInfo` object so that we can always compute
the valid integer ranges among `IntegerRangeInfo` objects.
Bug: 348701956
Test: tint_unittests
Change-Id: If4dd9902bdba7bec59e2674ecde36f563f2fb7d9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/240675
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index b6401f1..51d7f90 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -83,6 +83,16 @@
const Binary* binary = nullptr;
};
+IntegerRangeInfo ToIntegerRangeInfo(const Constant* constant,
+ int64_t min_value,
+ int64_t max_value) {
+ if (constant->Type()->IsSignedIntegerScalar()) {
+ return IntegerRangeInfo(min_value, max_value);
+ } else {
+ return IntegerRangeInfo(static_cast<uint64_t>(min_value), static_cast<uint64_t>(max_value));
+ }
+}
+
} // namespace
IntegerRangeInfo::IntegerRangeInfo(int64_t min_bound, int64_t max_bound) {
@@ -193,6 +203,18 @@
return GetInfo(function_param, index);
}
+ const IntegerRangeInfo* GetInfo(const Constant* constant) {
+ if (!IsConstantInteger(constant)) {
+ return nullptr;
+ }
+ const IntegerRangeInfo& range_info =
+ integer_constant_range_info_map_.GetOrAdd(constant, [&]() -> IntegerRangeInfo {
+ int64_t const_value = GetValueFromConstant(constant);
+ return ToIntegerRangeInfo(constant, const_value, const_value);
+ });
+ return &range_info;
+ }
+
/// Analyze a loop to compute the range of the loop control variable if possible.
void AnalyzeLoop(const Loop* loop) {
const Var* index = GetLoopControlVariableFromConstantInitializer(loop);
@@ -213,19 +235,12 @@
TINT_ASSERT(index->Initializer()->As<Constant>());
// for (var i = const_init; ...)
- int64_t const_init = GetValueFromConstant(index->Initializer()->As<Constant>());
+ const Constant* constant_initializer = index->Initializer()->As<Constant>();
+ int64_t const_init = GetValueFromConstant(constant_initializer);
// for (...; i++) or for(...; i--)
bool index_is_increasing = update->Op() == BinaryOp::kAdd;
- auto to_integer_range_info = [&](int64_t min_value, int64_t max_value) {
- if (index->Initializer()->As<Constant>()->Type()->IsSignedIntegerScalar()) {
- return IntegerRangeInfo(min_value, max_value);
- } else {
- return IntegerRangeInfo(static_cast<uint64_t>(min_value),
- static_cast<uint64_t>(max_value));
- }
- };
switch (compare_info.op) {
case BinaryOp::kLessThanEqual: {
// for (var index = const_init; index <= const_rhs; index++)
@@ -238,8 +253,8 @@
// - `index <= const_rhs` can correctly exit when `const_init + 1` is the maximum
// value of `i32` or `u32`.
if (index_is_increasing && const_init <= compare_info.const_rhs) {
- IntegerRangeInfo range_info =
- to_integer_range_info(const_init, compare_info.const_rhs);
+ IntegerRangeInfo range_info = ToIntegerRangeInfo(
+ constant_initializer, const_init, compare_info.const_rhs);
integer_var_range_info_map_.Add(index, range_info);
}
break;
@@ -254,8 +269,8 @@
// - `index >= const_rhs` can correctly exit when `const_init - 1` is the minimum
// value of `i32` or `u32`.
if (!index_is_increasing && const_init >= compare_info.const_rhs) {
- IntegerRangeInfo range_info =
- to_integer_range_info(compare_info.const_rhs, const_init);
+ IntegerRangeInfo range_info = ToIntegerRangeInfo(
+ constant_initializer, compare_info.const_rhs, const_init);
integer_var_range_info_map_.Add(index, range_info);
}
break;
@@ -676,6 +691,7 @@
Hashmap<const FunctionParam*, Vector<IntegerRangeInfo, 3>, 4>
integer_function_param_range_info_map_;
Hashmap<const Var*, IntegerRangeInfo, 8> integer_var_range_info_map_;
+ Hashmap<const Constant*, IntegerRangeInfo, 8> integer_constant_range_info_map_;
};
IntegerRangeAnalysis::IntegerRangeAnalysis(Function* func)
@@ -716,4 +732,8 @@
return impl_->GetInfo(access);
}
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Constant* constant) {
+ return impl_->GetInfo(constant);
+}
+
} // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index 9c14509..d92f77c 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -35,6 +35,7 @@
namespace tint::core::ir {
class Access;
class Binary;
+class Constant;
class Function;
class FunctionParam;
class Load;
@@ -98,6 +99,10 @@
/// it has a meaningful range. Returns nullptr otherwise.
const IntegerRangeInfo* GetInfo(const Access* access);
+ /// Returns the integer range info of a given `Constant` if it is an integer.
+ /// Returns nullptr otherwise.
+ const IntegerRangeInfo* GetInfo(const Constant* constant);
+
/// Note: This function is only for tests.
/// Returns the pointer of the loop control variable in the given loop when its initializer
/// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 814193f..0979ecb 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -8795,5 +8795,88 @@
ASSERT_EQ(nullptr, access_x_info);
}
+TEST_F(IR_IntegerRangeAnalysisTest, SignedIntegerScalarConstant) {
+ auto* func = b.Function("func", ty.void_());
+ Constant* constant = nullptr;
+ b.Append(func->Block(), [&] {
+ constant = b.Constant(10_i);
+ b.Var("a", constant);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ %a:ptr<function, i32, read_write> = var 10i
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ auto* info = analysis.GetInfo(constant);
+ ASSERT_NE(nullptr, info);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
+ EXPECT_EQ(10, range.min_bound);
+ EXPECT_EQ(10, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, UnsignedIntegerScalarConstant) {
+ auto* func = b.Function("func", ty.void_());
+ Constant* constant = nullptr;
+ b.Append(func->Block(), [&] {
+ constant = b.Constant(20_u);
+ b.Var("a", constant);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ %a:ptr<function, u32, read_write> = var 20u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ auto* info = analysis.GetInfo(constant);
+ ASSERT_NE(nullptr, info);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+ EXPECT_EQ(20u, range.min_bound);
+ EXPECT_EQ(20u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, NonIntegerConstant) {
+ auto* func = b.Function("func", ty.void_());
+ Constant* constant = nullptr;
+ b.Append(func->Block(), [&] {
+ constant = b.Constant(1.0_f);
+ b.Var("a", constant);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ %a:ptr<function, f32, read_write> = var 1.0f
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ auto* info = analysis.GetInfo(constant);
+ ASSERT_EQ(nullptr, info);
+}
+
} // namespace
} // namespace tint::core::ir::analysis