Range Analysis: Compute range of `Load` from loop control variable
This patch computes the range of all the `Load` instructions that
loads value from a loop control variable with valid range.
Bug: 348701956
Test: tint_unittests
Change-Id: I7d532826350690d90b07b5074b480329cfbb9a12
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/239276
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 4813e2e..56306a3 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -151,8 +151,20 @@
return integer_var_range_info_map_.Get(var).value;
}
+ const IntegerRangeInfo* GetInfo(const Load* load) {
+ const InstructionResult* instruction = load->From()->As<InstructionResult>();
+ if (!instruction) {
+ return nullptr;
+ }
+ const Var* load_from_var = instruction->Instruction()->As<Var>();
+ if (!load_from_var) {
+ return nullptr;
+ }
+ return GetInfo(load_from_var);
+ }
+
/// Analyze a loop to compute the range of the loop control variable if possible.
- void AnalyzeLoop(Loop* loop) {
+ void AnalyzeLoop(const Loop* loop) {
const Var* index = GetLoopControlVariableFromConstantInitializer(loop);
if (!index) {
return;
@@ -666,4 +678,8 @@
return impl_->GetInfo(var);
}
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Load* load) {
+ return impl_->GetInfo(load);
+}
+
} // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index 4b2f5b2..b6805a4 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -36,6 +36,7 @@
class Binary;
class Function;
class FunctionParam;
+class Load;
class Loop;
class Var;
} // namespace tint::core::ir
@@ -82,12 +83,16 @@
/// @returns the integer range info
const IntegerRangeInfo* GetInfo(const FunctionParam* param, uint32_t index = 0);
- /// Returns the integer range info of a given variable with given index, if it is an integer
- /// variable.
+ /// Returns the integer range info of a given variable if it is an integer variable and it has a
+ /// meaningful range. Returns nullptr otherwise.
/// @param var the variable to get information about
/// @returns the integer range info
const IntegerRangeInfo* GetInfo(const Var* var);
+ /// Returns the integer range info of a given `Load` variable if it is an integer variable and
+ /// it has a meaningful range. Returns nullptr otherwise.
+ const IntegerRangeInfo* GetInfo(const Load* load_var);
+
/// Note: This function is only for tests.
/// Returns the pointer of the loop control variable in the given loop when its initializer
/// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 8807932..7b8e872 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -8421,5 +8421,200 @@
EXPECT_EQ(nullptr, info);
}
+TEST_F(IR_IntegerRangeAnalysisTest, LoadFromLoopControlVariableWithRange) {
+ Var* idx = nullptr;
+ Loop* loop = nullptr;
+ Load* load_idx = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 0
+ idx = b.Var("idx", 0_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ load_idx = b.Load(idx);
+ b.Add<i32>(load_idx, 4_i);
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var 0i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 10i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:i32 = load %idx
+ %6:i32 = add %5, 4i
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:i32 = load %idx
+ %8:i32 = add %7, 1i
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ const IntegerRangeInfo* idx_info = analysis.GetInfo(idx);
+ EXPECT_NE(nullptr, idx_info);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(idx_info->range));
+
+ const auto& range_idx = std::get<IntegerRangeInfo::SignedIntegerRange>(idx_info->range);
+ EXPECT_EQ(0, range_idx.min_bound);
+ EXPECT_EQ(9, range_idx.max_bound);
+
+ const IntegerRangeInfo* load_idx_info = analysis.GetInfo(load_idx);
+ EXPECT_NE(nullptr, load_idx_info);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(load_idx_info->range));
+
+ const auto& range_load_idx =
+ std::get<IntegerRangeInfo::SignedIntegerRange>(load_idx_info->range);
+ EXPECT_EQ(0, range_load_idx.min_bound);
+ EXPECT_EQ(9, range_load_idx.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LoadFromLoopControlVariableWithoutRange) {
+ Var* idx = nullptr;
+ Loop* loop = nullptr;
+ Binary* binary = nullptr;
+ Load* load_idx = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 20u
+ idx = b.Var("idx", 20_u);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // 30u > idx
+ binary = b.GreaterThan<bool>(30_u, b.Load(idx));
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ load_idx = b.Load(idx);
+ b.Add<u32>(load_idx, 4_u);
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx--
+ b.Store(idx, b.Subtract<u32>(b.Load(idx), 1_u));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, u32, read_write> = var 20u
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:u32 = load %idx
+ %4:bool = gt 30u, %3
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:u32 = load %idx
+ %6:u32 = add %5, 4u
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:u32 = load %idx
+ %8:u32 = sub %7, 1u
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ EXPECT_EQ(idx, analysis.GetLoopControlVariableFromConstantInitializerForTest(loop));
+ EXPECT_EQ(binary, analysis.GetBinaryToCompareLoopControlVariableInLoopBodyForTest(loop, idx));
+
+ const IntegerRangeInfo* idx_info = analysis.GetInfo(idx);
+ EXPECT_EQ(nullptr, idx_info);
+ const IntegerRangeInfo* load_idx_info = analysis.GetInfo(idx);
+ EXPECT_EQ(nullptr, load_idx_info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LoadFromNonLoopControlVariable) {
+ Load* load_a = nullptr;
+
+ auto* var_a = mod.root_block->Append(b.Var<workgroup, u32>("a"));
+ auto* func = b.ComputeFunction("foo", 3_u, 5_u, 7_u);
+ b.Append(func->Block(), [&] { //
+ load_a = b.Load(var_a);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %a:ptr<workgroup, u32, read_write> = var undef
+}
+
+%foo = @compute @workgroup_size(3u, 5u, 7u) func():void {
+ $B2: {
+ %3:u32 = load %a
+ ret
+ }
+}
+)";
+
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ EXPECT_EQ(nullptr, analysis.GetInfo(load_a));
+}
+
} // namespace
} // namespace tint::core::ir::analysis