Tint: Implement range analysis on `LocalInvocationId`

This patch adds the basic infrastructure of integer range analysis
and implements the computation of the range on `LocalInvocationId`.

Bug: chromium:348701956
Test: tint_unittests
Change-Id: I8536727f78e2d127b78a3db7f32c034aa6bb125e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/220454
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index ed339f5..0a0e635 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -46,13 +46,13 @@
 struct IntegerRangeAnalysisImpl {
     explicit IntegerRangeAnalysisImpl(Function* func) : function_(func) {}
 
-    const IntegerRangeInfo* GetInfo(const FunctionParam* param) {
-        if (!param->Type()->IsIntegerScalar()) {
+    const IntegerRangeInfo* GetInfo(const FunctionParam* param, uint32_t index) {
+        if (!param->Type()->IsIntegerScalarOrVector()) {
             return nullptr;
         }
 
-        const auto& info =
-            integer_function_param_range_info_map_.GetOrAdd(param, [&]() -> IntegerRangeInfo {
+        const auto& info = integer_function_param_range_info_map_.GetOrAdd(
+            param, [&]() -> Vector<IntegerRangeInfo, 3> {
                 if (param->Builtin() == core::BuiltinValue::kLocalInvocationIndex) {
                     // We shouldn't be trying to use range analysis on a module that has
                     // non-constant workgroup sizes, since we will always have replaced pipeline
@@ -63,31 +63,47 @@
                     uint64_t max_bound =
                         workgroup_size[0] * workgroup_size[1] * workgroup_size[2] - 1u;
                     constexpr uint64_t kMinBound = 0;
-                    return IntegerRangeInfo(kMinBound, max_bound);
+
+                    return {IntegerRangeInfo(kMinBound, max_bound)};
+                }
+
+                if (param->Builtin() == core::BuiltinValue::kLocalInvocationId) {
+                    TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
+                    std::array<uint32_t, 3> workgroup_size =
+                        function_->WorkgroupSizeAsConst().value();
+
+                    constexpr uint64_t kMinBound = 0;
+                    Vector<IntegerRangeInfo, 3> integerRanges;
+                    for (uint32_t size_x_y_z : workgroup_size) {
+                        integerRanges.Push({kMinBound, size_x_y_z - 1u});
+                    }
+                    return integerRanges;
                 }
 
                 if (param->Type()->IsUnsignedIntegerScalar()) {
-                    return IntegerRangeInfo(0, std::numeric_limits<uint64_t>::max());
+                    return {IntegerRangeInfo(0, std::numeric_limits<uint64_t>::max())};
                 } else {
                     TINT_ASSERT(param->Type()->IsSignedIntegerScalar());
-                    return IntegerRangeInfo(std::numeric_limits<int64_t>::min(),
-                                            std::numeric_limits<int64_t>::max());
+                    return {IntegerRangeInfo(std::numeric_limits<int64_t>::min(),
+                                             std::numeric_limits<int64_t>::max())};
                 }
             });
 
-        return &info;
+        TINT_ASSERT(info.Length() > index);
+        return &info[index];
     }
 
   private:
     Function* function_;
-    Hashmap<const FunctionParam*, IntegerRangeInfo, 4> integer_function_param_range_info_map_;
+    Hashmap<const FunctionParam*, Vector<IntegerRangeInfo, 3>, 4>
+        integer_function_param_range_info_map_;
 };
 
 IntegerRangeAnalysis::IntegerRangeAnalysis(Function* func)
     : impl_(new IntegerRangeAnalysisImpl(func)) {}
 IntegerRangeAnalysis::~IntegerRangeAnalysis() = default;
 
-const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const FunctionParam* param) {
-    return impl_->GetInfo(param);
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const FunctionParam* param, uint32_t index) {
+    return impl_->GetInfo(param, index);
 }
 }  // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index 03bd633..b8574b8 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -65,12 +65,16 @@
     explicit IntegerRangeAnalysis(ir::Function* func);
     ~IntegerRangeAnalysis();
 
-    /// Returns the integer range info of a given parameter, if it is an integer parameter.
+    /// Returns the integer range info of a given parameter with given index, if it is an integer
+    /// or an integer vector parameter. The index must not be over the maximum size of the vector
+    /// and must be 0 if the parameter is an integer.
     /// Otherwise is not analyzable and returns nullptr. If it is the first time to query the info,
     /// the result will also be stored into a cache for future queries.
     /// @param param the variable to get information about
+    /// @param index the vector component index when the parameter is a vector type. if the
+    /// parameter is a scalar, then `index` must be zero.
     /// @returns the integer range info
-    const IntegerRangeInfo* GetInfo(const FunctionParam* param);
+    const IntegerRangeInfo* GetInfo(const FunctionParam* param, uint32_t index = 0);
 
   private:
     IntegerRangeAnalysis(const IntegerRangeAnalysis&) = delete;
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index c01ba58..14d9022 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -184,5 +184,150 @@
     EXPECT_EQ(15u, range.max_bound);
 }
 
+TEST_F(IR_IntegerRangeAnalysisTest, LocalInvocationID_u32_XYZ) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* localInvocationId = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+    localInvocationId->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({localInvocationId});
+
+    b.Append(func->Block(), [&] {
+        auto* dst_x = b.Var(ty.ptr<function, array<u32, 4u>>());
+        auto* access_src_x = b.Access(ty.u32(), localInvocationId, 0_u);
+        auto* access_dst_x = b.Access(ty.ptr<function, u32>(), dst_x, access_src_x);
+        b.Store(access_dst_x, access_src_x);
+        auto* dst_y = b.Var(ty.ptr<function, array<u32, 3u>>());
+        auto* access_src_y = b.Access(ty.u32(), localInvocationId, 1_u);
+        auto* access_dst_y = b.Access(ty.ptr<function, u32>(), dst_y, access_src_y);
+        b.Store(access_dst_y, access_src_y);
+        auto* dst_z = b.Var(ty.ptr<function, array<u32, 2u>>());
+        auto* access_src_z = b.Access(ty.u32(), localInvocationId, 2_u);
+        auto* access_dst_z = b.Access(ty.ptr<function, u32>(), dst_z, access_src_z);
+        b.Store(access_dst_z, access_src_z);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:ptr<function, array<u32, 4>, read_write> = var
+    %4:u32 = access %localInvocationId, 0u
+    %5:ptr<function, u32, read_write> = access %3, %4
+    store %5, %4
+    %6:ptr<function, array<u32, 3>, read_write> = var
+    %7:u32 = access %localInvocationId, 1u
+    %8:ptr<function, u32, read_write> = access %6, %7
+    store %8, %7
+    %9:ptr<function, array<u32, 2>, read_write> = var
+    %10:u32 = access %localInvocationId, 2u
+    %11:ptr<function, u32, read_write> = access %9, %10
+    store %11, %10
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    std::array<uint32_t, 3> expected_max_bounds = {3u, 2u, 1u};
+    for (uint32_t i = 0; i < expected_max_bounds.size(); ++i) {
+        auto* info = analysis.GetInfo(localInvocationId, i);
+
+        ASSERT_NE(nullptr, info);
+        ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+
+        const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+        EXPECT_EQ(0u, range.min_bound);
+        EXPECT_EQ(expected_max_bounds[i], range.max_bound);
+    }
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LocalInvocationID_u32_1_Y_1) {
+    auto* func = b.ComputeFunction("my_func", 1_u, 8_u, 1_u);
+    auto* localInvocationId = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+    localInvocationId->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({localInvocationId});
+
+    b.Append(func->Block(), [&] {
+        auto* dst = b.Var(ty.ptr<function, array<u32, 8u>>());
+        auto* access_src = b.Access(ty.u32(), localInvocationId, 1_u);
+        auto* access_dst = b.Access(ty.ptr<function, u32>(), dst, access_src);
+        b.Store(access_dst, access_src);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(1u, 8u, 1u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:ptr<function, array<u32, 8>, read_write> = var
+    %4:u32 = access %localInvocationId, 1u
+    %5:ptr<function, u32, read_write> = access %3, %4
+    store %5, %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    std::array<uint32_t, 3> expected_max_bounds = {0u, 7u, 0u};
+    for (uint32_t i = 0; i < expected_max_bounds.size(); ++i) {
+        auto* info = analysis.GetInfo(localInvocationId, i);
+
+        ASSERT_NE(nullptr, info);
+        ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+
+        const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+        EXPECT_EQ(0u, range.min_bound);
+        EXPECT_EQ(expected_max_bounds[i], range.max_bound);
+    }
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, LocalInvocationID_u32_1_1_Z) {
+    auto* func = b.ComputeFunction("my_func", 1_u, 1_u, 16_u);
+    auto* localInvocationId = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+    localInvocationId->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({localInvocationId});
+
+    b.Append(func->Block(), [&] {
+        auto* dst = b.Var(ty.ptr<function, array<u32, 16u>>());
+        auto* access_src = b.Access(ty.u32(), localInvocationId, 2_u);
+        auto* access_dst = b.Access(ty.ptr<function, u32>(), dst, access_src);
+        b.Store(access_dst, access_src);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(1u, 1u, 16u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:ptr<function, array<u32, 16>, read_write> = var
+    %4:u32 = access %localInvocationId, 2u
+    %5:ptr<function, u32, read_write> = access %3, %4
+    store %5, %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    std::array<uint32_t, 3> expected_max_bounds = {0u, 0u, 15u};
+    for (uint32_t i = 0; i < expected_max_bounds.size(); ++i) {
+        auto* info = analysis.GetInfo(localInvocationId, i);
+
+        ASSERT_NE(nullptr, info);
+        ASSERT_TRUE(std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(info->range));
+
+        const auto& range = std::get<IntegerRangeInfo::UnsignedIntegerRange>(info->range);
+        EXPECT_EQ(0u, range.min_bound);
+        EXPECT_EQ(expected_max_bounds[i], range.max_bound);
+    }
+}
+
 }  // namespace
 }  // namespace tint::core::ir::analysis