Range Analysis: Compute range of a scalar integer `Constant`

This patch computes the range of an integer scalar `Constant` and
returns an `IntegerRangeInfo` object so that we can always compute
the valid integer ranges among `IntegerRangeInfo` objects.

Bug: 348701956
Test: tint_unittests
Change-Id: If4dd9902bdba7bec59e2674ecde36f563f2fb7d9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/240675
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index b6401f1..51d7f90 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -83,6 +83,16 @@
     const Binary* binary = nullptr;
 };
 
+IntegerRangeInfo ToIntegerRangeInfo(const Constant* constant,
+                                    int64_t min_value,
+                                    int64_t max_value) {
+    if (constant->Type()->IsSignedIntegerScalar()) {
+        return IntegerRangeInfo(min_value, max_value);
+    } else {
+        return IntegerRangeInfo(static_cast<uint64_t>(min_value), static_cast<uint64_t>(max_value));
+    }
+}
+
 }  // namespace
 
 IntegerRangeInfo::IntegerRangeInfo(int64_t min_bound, int64_t max_bound) {
@@ -193,6 +203,18 @@
         return GetInfo(function_param, index);
     }
 
+    const IntegerRangeInfo* GetInfo(const Constant* constant) {
+        if (!IsConstantInteger(constant)) {
+            return nullptr;
+        }
+        const IntegerRangeInfo& range_info =
+            integer_constant_range_info_map_.GetOrAdd(constant, [&]() -> IntegerRangeInfo {
+                int64_t const_value = GetValueFromConstant(constant);
+                return ToIntegerRangeInfo(constant, const_value, const_value);
+            });
+        return &range_info;
+    }
+
     /// Analyze a loop to compute the range of the loop control variable if possible.
     void AnalyzeLoop(const Loop* loop) {
         const Var* index = GetLoopControlVariableFromConstantInitializer(loop);
@@ -213,19 +235,12 @@
         TINT_ASSERT(index->Initializer()->As<Constant>());
 
         // for (var i = const_init; ...)
-        int64_t const_init = GetValueFromConstant(index->Initializer()->As<Constant>());
+        const Constant* constant_initializer = index->Initializer()->As<Constant>();
+        int64_t const_init = GetValueFromConstant(constant_initializer);
 
         // for (...; i++) or for(...; i--)
         bool index_is_increasing = update->Op() == BinaryOp::kAdd;
 
-        auto to_integer_range_info = [&](int64_t min_value, int64_t max_value) {
-            if (index->Initializer()->As<Constant>()->Type()->IsSignedIntegerScalar()) {
-                return IntegerRangeInfo(min_value, max_value);
-            } else {
-                return IntegerRangeInfo(static_cast<uint64_t>(min_value),
-                                        static_cast<uint64_t>(max_value));
-            }
-        };
         switch (compare_info.op) {
             case BinaryOp::kLessThanEqual: {
                 // for (var index = const_init; index <= const_rhs; index++)
@@ -238,8 +253,8 @@
                 // - `index <= const_rhs` can correctly exit when `const_init + 1` is the maximum
                 //   value of `i32` or `u32`.
                 if (index_is_increasing && const_init <= compare_info.const_rhs) {
-                    IntegerRangeInfo range_info =
-                        to_integer_range_info(const_init, compare_info.const_rhs);
+                    IntegerRangeInfo range_info = ToIntegerRangeInfo(
+                        constant_initializer, const_init, compare_info.const_rhs);
                     integer_var_range_info_map_.Add(index, range_info);
                 }
                 break;
@@ -254,8 +269,8 @@
                 // - `index >= const_rhs` can correctly exit when `const_init - 1` is the minimum
                 //   value of `i32` or `u32`.
                 if (!index_is_increasing && const_init >= compare_info.const_rhs) {
-                    IntegerRangeInfo range_info =
-                        to_integer_range_info(compare_info.const_rhs, const_init);
+                    IntegerRangeInfo range_info = ToIntegerRangeInfo(
+                        constant_initializer, compare_info.const_rhs, const_init);
                     integer_var_range_info_map_.Add(index, range_info);
                 }
                 break;
@@ -676,6 +691,7 @@
     Hashmap<const FunctionParam*, Vector<IntegerRangeInfo, 3>, 4>
         integer_function_param_range_info_map_;
     Hashmap<const Var*, IntegerRangeInfo, 8> integer_var_range_info_map_;
+    Hashmap<const Constant*, IntegerRangeInfo, 8> integer_constant_range_info_map_;
 };
 
 IntegerRangeAnalysis::IntegerRangeAnalysis(Function* func)
@@ -716,4 +732,8 @@
     return impl_->GetInfo(access);
 }
 
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Constant* constant) {
+    return impl_->GetInfo(constant);
+}
+
 }  // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index 9c14509..d92f77c 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -35,6 +35,7 @@
 namespace tint::core::ir {
 class Access;
 class Binary;
+class Constant;
 class Function;
 class FunctionParam;
 class Load;
@@ -98,6 +99,10 @@
     /// it has a meaningful range. Returns nullptr otherwise.
     const IntegerRangeInfo* GetInfo(const Access* access);
 
+    /// Returns the integer range info of a given `Constant` if it is an integer.
+    /// Returns nullptr otherwise.
+    const IntegerRangeInfo* GetInfo(const Constant* constant);
+
     /// Note: This function is only for tests.
     /// Returns the pointer of the loop control variable in the given loop when its initializer
     /// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 814193f..0979ecb 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -8795,5 +8795,88 @@
     ASSERT_EQ(nullptr, access_x_info);
 }
 
+TEST_F(IR_IntegerRangeAnalysisTest, SignedIntegerScalarConstant) {
+    auto* func = b.Function("func", ty.void_());
+    Constant* constant = nullptr;
+    b.Append(func->Block(), [&] {
+        constant = b.Constant(10_i);
+        b.Var("a", constant);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    %a:ptr<function, i32, read_write> = var 10i
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+    auto* info = analysis.GetInfo(constant);
+    ASSERT_NE(nullptr, info);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::SignedIntegerRange>(info->range));
+    const auto& range = std::get<IntegerRangeInfo::SignedIntegerRange>(info->range);
+    EXPECT_EQ(10, range.min_bound);
+    EXPECT_EQ(10, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, UnsignedIntegerScalarConstant) {
+    auto* func = b.Function("func", ty.void_());
+    Constant* constant = nullptr;
+    b.Append(func->Block(), [&] {
+        constant = b.Constant(20_u);
+        b.Var("a", constant);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    %a:ptr<function, u32, read_write> = var 20u
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+    auto* info = analysis.GetInfo(constant);
+    ASSERT_NE(nullptr, info);
+    ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+    const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+    EXPECT_EQ(20u, range.min_bound);
+    EXPECT_EQ(20u, range.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, NonIntegerConstant) {
+    auto* func = b.Function("func", ty.void_());
+    Constant* constant = nullptr;
+    b.Append(func->Block(), [&] {
+        constant = b.Constant(1.0_f);
+        b.Var("a", constant);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    %a:ptr<function, f32, read_write> = var 1.0f
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+    auto* info = analysis.GetInfo(constant);
+    ASSERT_EQ(nullptr, info);
+}
+
 }  // namespace
 }  // namespace tint::core::ir::analysis