Range Analysis: Compute range of a `Let` expression
Bug: 348701956
Test: tint_unittests
Change-Id: Icc5c3ae3347e1557495f62e86bd9519ad57189b5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241376
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 6f73a2b..7c7c0af 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -36,6 +36,7 @@
#include "src/tint/lang/core/ir/exit_loop.h"
#include "src/tint/lang/core/ir/function.h"
#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/let.h"
#include "src/tint/lang/core/ir/load.h"
#include "src/tint/lang/core/ir/loop.h"
#include "src/tint/lang/core/ir/multi_in_block.h"
@@ -203,6 +204,8 @@
return GetInfo(function_param, index);
}
+ const IntegerRangeInfo* GetInfo(const Let* let) { return GetInfo(let->Value()); }
+
const IntegerRangeInfo* GetInfo(const Constant* constant) {
if (!IsConstantInteger(constant)) {
return nullptr;
@@ -225,11 +228,12 @@
return GetInfo(param, 0);
},
[&](const InstructionResult* r) {
- // TODO(348701956): Support more instruction types. e.g. Let, Convert, ...
+ // TODO(348701956): Support more instruction types. e.g. Convert, ...
return Switch(
r->Instruction(), [&](const Var* var) { return GetInfo(var); },
[&](const Load* load) { return GetInfo(load); },
[&](const Access* access) { return GetInfo(access); },
+ [&](const Let* let) { return GetInfo(let); },
[&](const Binary* binary) { return GetInfo(binary); },
[](Default) { return nullptr; });
},
@@ -914,4 +918,8 @@
return impl_->GetInfo(binary);
}
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Let* let) {
+ return impl_->GetInfo(let);
+}
+
} // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index cf8441c..ccb1184 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -38,6 +38,7 @@
class Constant;
class Function;
class FunctionParam;
+class Let;
class Load;
class Loop;
class Value;
@@ -100,6 +101,10 @@
/// it has a meaningful range. Returns nullptr otherwise.
const IntegerRangeInfo* GetInfo(const Access* access);
+ /// Returns the integer range info of a given `Let` variable if it is an integer variable and it
+ /// has a meaningful range. Returns nullptr otherwise.
+ const IntegerRangeInfo* GetInfo(const Let* let);
+
/// Returns the integer range info of a given `Constant` if it is an integer.
/// Returns nullptr otherwise.
const IntegerRangeInfo* GetInfo(const Constant* constant);
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 065d384..64a1ea6 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -9187,6 +9187,202 @@
EXPECT_EQ(14, range_add.max_bound);
}
+TEST_F(IR_IntegerRangeAnalysisTest, LetWithAccess) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Let* let = nullptr;
+ b.Append(func->Block(), [&] {
+ // access_x: [0, 3]
+ auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+ let = b.Let(access_x);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %local_id, 0u
+ %4:u32 = let %3
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ auto* info = analysis.GetInfo(let);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+ EXPECT_EQ(0u, range.min_bound);
+ EXPECT_EQ(3u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LetWithLoad) {
+ Var* idx = nullptr;
+ Let* let = nullptr;
+ auto* func = b.Function("func", ty.void_());
+ b.Append(func->Block(), [&] {
+ auto* loop = b.Loop();
+ b.Append(loop->Initializer(), [&] {
+ // idx = 0
+ idx = b.Var("idx", 0_i);
+ b.NextIteration(loop);
+ });
+ b.Append(loop->Body(), [&] {
+ // idx < 10
+ auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+ auto* ifelse = b.If(binary);
+ b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+ b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+ auto* load = b.Load(idx);
+ let = b.Let(load);
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] {
+ // idx++
+ b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+ b.NextIteration(loop);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ loop [i: $B2, b: $B3, c: $B4] { # loop_1
+ $B2: { # initializer
+ %idx:ptr<function, i32, read_write> = var 0i
+ next_iteration # -> $B3
+ }
+ $B3: { # body
+ %3:i32 = load %idx
+ %4:bool = lt %3, 10i
+ if %4 [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_loop # loop_1
+ }
+ }
+ %5:i32 = load %idx
+ %6:i32 = let %5
+ continue # -> $B4
+ }
+ $B4: { # continuing
+ %7:i32 = load %idx
+ %8:i32 = add %7, 1i
+ store %idx, %8
+ next_iteration # -> $B3
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ auto* info = analysis.GetInfo(let);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
+ EXPECT_EQ(0, range.min_bound);
+ EXPECT_EQ(9, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LetWithBinary) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Let* let = nullptr;
+ b.Append(func->Block(), [&] {
+ // access_x: [0, 3]
+ auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+ // access_y: [0, 2]
+ auto* access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+ auto* add = b.Add<u32>(access_x, access_y);
+ let = b.Let(add);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %local_id, 0u
+ %4:u32 = access %local_id, 1u
+ %5:u32 = add %3, %4
+ %6:u32 = let %5
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ auto* info = analysis.GetInfo(let);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+ const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+ EXPECT_EQ(0u, range.min_bound);
+ EXPECT_EQ(5u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LetAsOperand) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Let* let = nullptr;
+ Binary* add = nullptr;
+ b.Append(func->Block(), [&] {
+ // let (access_x): [0, 3]
+ let = b.Let(b.Access(ty.u32(), local_invocation_id, 0_u));
+ // access_y: [0, 2]
+ auto* access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+ add = b.Add<u32>(let, access_y);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %local_id, 0u
+ %4:u32 = let %3
+ %5:u32 = access %local_id, 1u
+ %6:u32 = add %4, %5
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ auto* info_let = analysis.GetInfo(let);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_let->range));
+ const auto& range_let = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_let->range);
+ EXPECT_EQ(0u, range_let.min_bound);
+ EXPECT_EQ(3u, range_let.max_bound);
+
+ auto* info_add = analysis.GetInfo(add);
+ ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range));
+ const auto& range_add = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range);
+ EXPECT_EQ(0u, range_add.min_bound);
+ EXPECT_EQ(5u, range_add.max_bound);
+}
+
TEST_F(IR_IntegerRangeAnalysisTest, MultipleBinaryAdds) {
auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());