Range Analysis: Compute range of a `Let` expression

Bug: 348701956
Test: tint_unittests
Change-Id: Icc5c3ae3347e1557495f62e86bd9519ad57189b5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/241376
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 6f73a2b..7c7c0af 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -36,6 +36,7 @@
 #include "src/tint/lang/core/ir/exit_loop.h"
 #include "src/tint/lang/core/ir/function.h"
 #include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/let.h"
 #include "src/tint/lang/core/ir/load.h"
 #include "src/tint/lang/core/ir/loop.h"
 #include "src/tint/lang/core/ir/multi_in_block.h"
@@ -203,6 +204,8 @@
         return GetInfo(function_param, index);
     }
 
+    const IntegerRangeInfo* GetInfo(const Let* let) { return GetInfo(let->Value()); }
+
     const IntegerRangeInfo* GetInfo(const Constant* constant) {
         if (!IsConstantInteger(constant)) {
             return nullptr;
@@ -225,11 +228,12 @@
                 return GetInfo(param, 0);
             },
             [&](const InstructionResult* r) {
-                // TODO(348701956): Support more instruction types. e.g. Let, Convert, ...
+                // TODO(348701956): Support more instruction types. e.g. Convert, ...
                 return Switch(
                     r->Instruction(), [&](const Var* var) { return GetInfo(var); },
                     [&](const Load* load) { return GetInfo(load); },
                     [&](const Access* access) { return GetInfo(access); },
+                    [&](const Let* let) { return GetInfo(let); },
                     [&](const Binary* binary) { return GetInfo(binary); },
                     [](Default) { return nullptr; });
             },
@@ -914,4 +918,8 @@
     return impl_->GetInfo(binary);
 }
 
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Let* let) {
+    return impl_->GetInfo(let);
+}
+
 }  // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index cf8441c..ccb1184 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -38,6 +38,7 @@
 class Constant;
 class Function;
 class FunctionParam;
+class Let;
 class Load;
 class Loop;
 class Value;
@@ -100,6 +101,10 @@
     /// it has a meaningful range. Returns nullptr otherwise.
     const IntegerRangeInfo* GetInfo(const Access* access);
 
+    /// Returns the integer range info of a given `Let` variable if it is an integer variable and it
+    /// has a meaningful range. Returns nullptr otherwise.
+    const IntegerRangeInfo* GetInfo(const Let* let);
+
     /// Returns the integer range info of a given `Constant` if it is an integer.
     /// Returns nullptr otherwise.
     const IntegerRangeInfo* GetInfo(const Constant* constant);
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 065d384..64a1ea6 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -9187,6 +9187,202 @@
     EXPECT_EQ(14, range_add.max_bound);
 }
 
+TEST_F(IR_IntegerRangeAnalysisTest, LetWithAccess) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+    local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({local_invocation_id});
+
+    Let* let = nullptr;
+    b.Append(func->Block(), [&] {
+        // access_x: [0, 3]
+        auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+        let = b.Let(access_x);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:u32 = access %local_id, 0u
+    %4:u32 = let %3
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    auto* info = analysis.GetInfo(let);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+    const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+    EXPECT_EQ(0u, range.min_bound);
+    EXPECT_EQ(3u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LetWithLoad) {
+    Var* idx = nullptr;
+    Let* let = nullptr;
+    auto* func = b.Function("func", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] {
+            // idx = 0
+            idx = b.Var("idx", 0_i);
+            b.NextIteration(loop);
+        });
+        b.Append(loop->Body(), [&] {
+            // idx < 10
+            auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+            auto* ifelse = b.If(binary);
+            b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+            b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+            auto* load = b.Load(idx);
+            let = b.Let(load);
+            b.Continue(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            // idx++
+            b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+            b.NextIteration(loop);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    loop [i: $B2, b: $B3, c: $B4] {  # loop_1
+      $B2: {  # initializer
+        %idx:ptr<function, i32, read_write> = var 0i
+        next_iteration  # -> $B3
+      }
+      $B3: {  # body
+        %3:i32 = load %idx
+        %4:bool = lt %3, 10i
+        if %4 [t: $B5, f: $B6] {  # if_1
+          $B5: {  # true
+            exit_if  # if_1
+          }
+          $B6: {  # false
+            exit_loop  # loop_1
+          }
+        }
+        %5:i32 = load %idx
+        %6:i32 = let %5
+        continue  # -> $B4
+      }
+      $B4: {  # continuing
+        %7:i32 = load %idx
+        %8:i32 = add %7, 1i
+        store %idx, %8
+        next_iteration  # -> $B3
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    auto* info = analysis.GetInfo(let);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
+    const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
+    EXPECT_EQ(0, range.min_bound);
+    EXPECT_EQ(9, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LetWithBinary) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+    local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({local_invocation_id});
+
+    Let* let = nullptr;
+    b.Append(func->Block(), [&] {
+        // access_x: [0, 3]
+        auto* access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+        // access_y: [0, 2]
+        auto* access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+        auto* add = b.Add<u32>(access_x, access_y);
+        let = b.Let(add);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:u32 = access %local_id, 0u
+    %4:u32 = access %local_id, 1u
+    %5:u32 = add %3, %4
+    %6:u32 = let %5
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    auto* info = analysis.GetInfo(let);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+    const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+    EXPECT_EQ(0u, range.min_bound);
+    EXPECT_EQ(5u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LetAsOperand) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());
+    local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({local_invocation_id});
+
+    Let* let = nullptr;
+    Binary* add = nullptr;
+    b.Append(func->Block(), [&] {
+        // let (access_x): [0, 3]
+        let = b.Let(b.Access(ty.u32(), local_invocation_id, 0_u));
+        // access_y: [0, 2]
+        auto* access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+        add = b.Add<u32>(let, access_y);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%local_id:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:u32 = access %local_id, 0u
+    %4:u32 = let %3
+    %5:u32 = access %local_id, 1u
+    %6:u32 = add %4, %5
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    auto* info_let = analysis.GetInfo(let);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_let->range));
+    const auto& range_let = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_let->range);
+    EXPECT_EQ(0u, range_let.min_bound);
+    EXPECT_EQ(3u, range_let.max_bound);
+
+    auto* info_add = analysis.GetInfo(add);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range));
+    const auto& range_add = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info_add->range);
+    EXPECT_EQ(0u, range_add.min_bound);
+    EXPECT_EQ(5u, range_add.max_bound);
+}
+
 TEST_F(IR_IntegerRangeAnalysisTest, MultipleBinaryAdds) {
     auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
     auto* local_invocation_id = b.FunctionParam("local_id", mod.Types().vec3<u32>());