Range Analysis: Compute range of a `Convert` expression

This patch implements the computation of the range of a `Convert`
expression when it is converting between `i32` and `u32`. A `nullptr`
will be returned when any overflow or underflow happens.

Bug: 348701956
Test: tint_unittests
Change-Id: Iab00f2065508f7fcf32f7c7bd8846694a465b34c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/243875
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 6e50258..b70c95e 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -32,6 +32,7 @@
 #include "src/tint/lang/core/constant/scalar.h"
 #include "src/tint/lang/core/ir/access.h"
 #include "src/tint/lang/core/ir/binary.h"
+#include "src/tint/lang/core/ir/convert.h"
 #include "src/tint/lang/core/ir/exit_if.h"
 #include "src/tint/lang/core/ir/exit_loop.h"
 #include "src/tint/lang/core/ir/function.h"
@@ -177,6 +178,57 @@
         return &range_info;
     }
 
+    const IntegerRangeInfo* GetInfo(const Convert* convert) {
+        const IntegerRangeInfo* existing_range = integer_convert_range_info_map_.Get(convert).value;
+        if (existing_range) {
+            return existing_range;
+        }
+
+        auto* result_type = convert->Result()->Type();
+        if (!result_type->IsIntegerScalar()) {
+            return nullptr;
+        }
+
+        const auto* operand = convert->Operand(Convert::kValueOperandOffset);
+        const IntegerRangeInfo* operand_range_info = GetInfo(operand);
+        if (!operand_range_info) {
+            return nullptr;
+        }
+        auto* operand_type = operand->Type();
+        TINT_ASSERT(operand_type->IsIntegerScalar());
+
+        if (operand_type == result_type) {
+            return operand_range_info;
+        }
+
+        if (std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(
+                operand_range_info->range)) {
+            // result = convert<u32>(operand), operand cannot be negative.
+            TINT_ASSERT(result_type->As<type::U32>());
+            const auto& range =
+                std::get<IntegerRangeInfo::SignedIntegerRange>(operand_range_info->range);
+            if (range.min_bound < 0) {
+                return nullptr;
+            }
+            auto result = integer_convert_range_info_map_.Add(
+                convert, IntegerRangeInfo(static_cast<uint64_t>(range.min_bound),
+                                          static_cast<uint64_t>(range.max_bound)));
+            return &result.value;
+        } else {
+            // result = convert<i32>(operand), operand cannot be greater than `i32::kHighestValue`.
+            TINT_ASSERT(result_type->As<type::I32>());
+            const auto& range =
+                std::get<IntegerRangeInfo::UnsignedIntegerRange>(operand_range_info->range);
+            if (range.max_bound > i32::kHighestValue) {
+                return nullptr;
+            }
+            auto result = integer_convert_range_info_map_.Add(
+                convert, IntegerRangeInfo(static_cast<int64_t>(range.min_bound),
+                                          static_cast<int64_t>(range.max_bound)));
+            return &result.value;
+        }
+    }
+
     const IntegerRangeInfo* GetInfo(const Value* value) {
         return Switch(
             value, [&](const Constant* constant) { return GetInfo(constant); },
@@ -187,13 +239,14 @@
                 return GetInfo(param, 0);
             },
             [&](const InstructionResult* r) {
-                // TODO(348701956): Support more instruction types. e.g. Convert, ...
+                // TODO(348701956): Support more instruction types
                 return Switch(
                     r->Instruction(), [&](const Var* var) { return GetInfo(var); },
                     [&](const Load* load) { return GetInfo(load); },
                     [&](const Access* access) { return GetInfo(access); },
                     [&](const Let* let) { return GetInfo(let); },
                     [&](const Binary* binary) { return GetInfo(binary); },
+                    [&](const Convert* convert) { return GetInfo(convert); },
                     [](Default) { return nullptr; });
             },
             [](Default) { return nullptr; });
@@ -944,6 +997,7 @@
     Hashmap<const Var*, IntegerRangeInfo, 8> integer_var_range_info_map_;
     Hashmap<const Constant*, IntegerRangeInfo, 8> integer_constant_range_info_map_;
     Hashmap<const Binary*, IntegerRangeInfo, 8> integer_binary_range_info_map_;
+    Hashmap<const Convert*, IntegerRangeInfo, 8> integer_convert_range_info_map_;
 };
 
 IntegerRangeAnalysis::IntegerRangeAnalysis(Module* ir_module)
@@ -1001,4 +1055,8 @@
     return impl_->GetInfo(let);
 }
 
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Convert* convert) {
+    return impl_->GetInfo(convert);
+}
+
 }  // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index cfff143..b4be69c 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -36,6 +36,7 @@
 class Access;
 class Binary;
 class Constant;
+class Convert;
 class Function;
 class FunctionParam;
 class Let;
@@ -119,6 +120,10 @@
     /// it has a meaningful range. Returns nullptr otherwise.
     const IntegerRangeInfo* GetInfo(const Binary* binary);
 
+    /// Returns the integer range info of a given `Convert` variable if it is an integer variable
+    /// and it has a meaningful range. Returns nullptr otherwise.
+    const IntegerRangeInfo* GetInfo(const Convert* convert);
+
     /// Note: This function is only for tests.
     /// Returns the pointer of the loop control variable in the given loop when its initializer
     /// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index dd6d8bc..b697ce5 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -11028,5 +11028,328 @@
     ASSERT_EQ(nullptr, info);
 }
 
+TEST_F(IR_IntegerRangeAnalysisTest, Convert_Success_U32ToI32) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* localInvocationIndex = b.FunctionParam("localInvocationIndex", mod.Types().u32());
+    localInvocationIndex->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationIndex);
+    func->SetParams({localInvocationIndex});
+
+    Convert* convert = nullptr;
+    b.Append(func->Block(), [&] {
+        convert = b.Convert<i32>(localInvocationIndex);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%localInvocationIndex:u32 [@local_invocation_index]):void {
+  $B1: {
+    %3:i32 = convert %localInvocationIndex
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(&mod);
+    auto* info = analysis.GetInfo(convert);
+    ASSERT_NE(nullptr, info);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
+
+    const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
+    EXPECT_EQ(0, range.min_bound);
+    EXPECT_EQ(23, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, Convert_Success_I32ToU32) {
+    Var* idx = nullptr;
+    Convert* convert = nullptr;
+    auto* func = b.Function("func", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] {
+            // idx = 1
+            idx = b.Var("idx", 1_i);
+            b.NextIteration(loop);
+        });
+        b.Append(loop->Body(), [&] {
+            // idx < 10
+            auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+            auto* ifelse = b.If(binary);
+            b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+            b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+            convert = b.Convert<u32>(b.Load(idx));
+            b.Continue(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            // idx++
+            b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+            b.NextIteration(loop);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    loop [i: $B2, b: $B3, c: $B4] {  # loop_1
+      $B2: {  # initializer
+        %idx:ptr<function, i32, read_write> = var 1i
+        next_iteration  # -> $B3
+      }
+      $B3: {  # body
+        %3:i32 = load %idx
+        %4:bool = lt %3, 10i
+        if %4 [t: $B5, f: $B6] {  # if_1
+          $B5: {  # true
+            exit_if  # if_1
+          }
+          $B6: {  # false
+            exit_loop  # loop_1
+          }
+        }
+        %5:i32 = load %idx
+        %6:u32 = convert %5
+        continue  # -> $B4
+      }
+      $B4: {  # continuing
+        %7:i32 = load %idx
+        %8:i32 = add %7, 1i
+        store %idx, %8
+        next_iteration  # -> $B3
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(&mod);
+
+    auto* info = analysis.GetInfo(convert);
+    ASSERT_NE(nullptr, info);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+
+    const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+    EXPECT_EQ(1u, range.min_bound);
+    EXPECT_EQ(9u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, Convert_Failure_NegativeI32ToU32) {
+    Var* idx = nullptr;
+    Convert* convert = nullptr;
+    auto* func = b.Function("func", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] {
+            // idx = -1
+            idx = b.Var("idx", -1_i);
+            b.NextIteration(loop);
+        });
+        b.Append(loop->Body(), [&] {
+            // idx < 10
+            auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+            auto* ifelse = b.If(binary);
+            b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+            b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+            convert = b.Convert<u32>(b.Load(idx));
+            b.Continue(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            // idx++
+            b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+            b.NextIteration(loop);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    loop [i: $B2, b: $B3, c: $B4] {  # loop_1
+      $B2: {  # initializer
+        %idx:ptr<function, i32, read_write> = var -1i
+        next_iteration  # -> $B3
+      }
+      $B3: {  # body
+        %3:i32 = load %idx
+        %4:bool = lt %3, 10i
+        if %4 [t: $B5, f: $B6] {  # if_1
+          $B5: {  # true
+            exit_if  # if_1
+          }
+          $B6: {  # false
+            exit_loop  # loop_1
+          }
+        }
+        %5:i32 = load %idx
+        %6:u32 = convert %5
+        continue  # -> $B4
+      }
+      $B4: {  # continuing
+        %7:i32 = load %idx
+        %8:i32 = add %7, 1i
+        store %idx, %8
+        next_iteration  # -> $B3
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(&mod);
+
+    auto* info = analysis.GetInfo(convert);
+    ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, Convert_Failure_LargeU32ToI32) {
+    Var* idx = nullptr;
+    Convert* convert = nullptr;
+    auto* func = b.Function("func", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] {
+            // idx = 0u
+            idx = b.Var("idx", 0_u);
+            b.NextIteration(loop);
+        });
+        b.Append(loop->Body(), [&] {
+            // idx <= u32(i32::kHighestValue) + 1u
+            auto* binary = b.LessThanEqual<bool>(
+                b.Load(idx), u32(static_cast<uint32_t>(i32::kHighestValue) + 1u));
+            auto* ifelse = b.If(binary);
+            b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+            b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+            convert = b.Convert<i32>(b.Load(idx));
+            b.Continue(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            // idx++
+            b.Store(idx, b.Add<u32>(b.Load(idx), 1_u));
+            b.NextIteration(loop);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    loop [i: $B2, b: $B3, c: $B4] {  # loop_1
+      $B2: {  # initializer
+        %idx:ptr<function, u32, read_write> = var 0u
+        next_iteration  # -> $B3
+      }
+      $B3: {  # body
+        %3:u32 = load %idx
+        %4:bool = lte %3, 2147483648u
+        if %4 [t: $B5, f: $B6] {  # if_1
+          $B5: {  # true
+            exit_if  # if_1
+          }
+          $B6: {  # false
+            exit_loop  # loop_1
+          }
+        }
+        %5:u32 = load %idx
+        %6:i32 = convert %5
+        continue  # -> $B4
+      }
+      $B4: {  # continuing
+        %7:u32 = load %idx
+        %8:u32 = add %7, 1u
+        store %idx, %8
+        next_iteration  # -> $B3
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(&mod);
+
+    auto* info = analysis.GetInfo(convert);
+    ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, Convert_Failure_ConvertToNonInteger) {
+    Var* idx = nullptr;
+    Convert* convert = nullptr;
+    auto* func = b.Function("func", ty.void_());
+    b.Append(func->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] {
+            // idx = -1
+            idx = b.Var("idx", -1_i);
+            b.NextIteration(loop);
+        });
+        b.Append(loop->Body(), [&] {
+            // idx < 10
+            auto* binary = b.LessThan<bool>(b.Load(idx), 10_i);
+            auto* ifelse = b.If(binary);
+            b.Append(ifelse->True(), [&] { b.ExitIf(ifelse); });
+            b.Append(ifelse->False(), [&] { b.ExitLoop(loop); });
+            convert = b.Convert<f32>(b.Load(idx));
+            b.Continue(loop);
+        });
+        b.Append(loop->Continuing(), [&] {
+            // idx++
+            b.Store(idx, b.Add<i32>(b.Load(idx), 1_i));
+            b.NextIteration(loop);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    loop [i: $B2, b: $B3, c: $B4] {  # loop_1
+      $B2: {  # initializer
+        %idx:ptr<function, i32, read_write> = var -1i
+        next_iteration  # -> $B3
+      }
+      $B3: {  # body
+        %3:i32 = load %idx
+        %4:bool = lt %3, 10i
+        if %4 [t: $B5, f: $B6] {  # if_1
+          $B5: {  # true
+            exit_if  # if_1
+          }
+          $B6: {  # false
+            exit_loop  # loop_1
+          }
+        }
+        %5:i32 = load %idx
+        %6:f32 = convert %5
+        continue  # -> $B4
+      }
+      $B4: {  # continuing
+        %7:i32 = load %idx
+        %8:i32 = add %7, 1i
+        store %idx, %8
+        next_iteration  # -> $B3
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(&mod);
+
+    auto* info = analysis.GetInfo(convert);
+    ASSERT_EQ(nullptr, info);
+}
+
 }  // namespace
 }  // namespace tint::core::ir::analysis