Range Analysis: Compute range of `Access` from local invocation ID

This patch computes the range of all the `Access` instructions that
accesses the built-in function parameter local invocation ID and
returns nullptr when computing the range of any other `Access`
instructions.

Bug: 348701956
Test: tint_unittests
Change-Id: Idf0f40858b5a6b4542407792936277195616a39e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/239484
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 56306a3..b6401f1 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -30,6 +30,7 @@
 #include <limits>
 
 #include "src/tint/lang/core/constant/scalar.h"
+#include "src/tint/lang/core/ir/access.h"
 #include "src/tint/lang/core/ir/binary.h"
 #include "src/tint/lang/core/ir/exit_if.h"
 #include "src/tint/lang/core/ir/exit_loop.h"
@@ -105,41 +106,50 @@
             return nullptr;
         }
 
+        // Currently we only support the query of ranges of `local_invocation_id` or
+        // `local_invocation_index`.
+        if (!param->Builtin()) {
+            return nullptr;
+        }
+
+        switch (*param->Builtin()) {
+            case BuiltinValue::kLocalInvocationIndex:
+            case BuiltinValue::kLocalInvocationId:
+                break;
+            default:
+                return nullptr;
+        }
+
         const auto& info = integer_function_param_range_info_map_.GetOrAdd(
             param, [&]() -> Vector<IntegerRangeInfo, 3> {
-                if (param->Builtin() == core::BuiltinValue::kLocalInvocationIndex) {
-                    // We shouldn't be trying to use range analysis on a module that has
-                    // non-constant workgroup sizes, since we will always have replaced pipeline
-                    // overrides with constant values early in the pipeline.
-                    TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
-                    std::array<uint32_t, 3> workgroup_size =
-                        function_->WorkgroupSizeAsConst().value();
-                    uint64_t max_bound =
-                        workgroup_size[0] * workgroup_size[1] * workgroup_size[2] - 1u;
-                    constexpr uint64_t kMinBound = 0;
+                switch (*param->Builtin()) {
+                    case BuiltinValue::kLocalInvocationIndex: {
+                        // We shouldn't be trying to use range analysis on a module that has
+                        // non-constant workgroup sizes, since we will always have replaced pipeline
+                        // overrides with constant values early in the pipeline.
+                        TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
+                        std::array<uint32_t, 3> workgroup_size =
+                            function_->WorkgroupSizeAsConst().value();
+                        uint64_t max_bound =
+                            workgroup_size[0] * workgroup_size[1] * workgroup_size[2] - 1u;
+                        constexpr uint64_t kMinBound = 0;
 
-                    return {IntegerRangeInfo(kMinBound, max_bound)};
-                }
-
-                if (param->Builtin() == core::BuiltinValue::kLocalInvocationId) {
-                    TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
-                    std::array<uint32_t, 3> workgroup_size =
-                        function_->WorkgroupSizeAsConst().value();
-
-                    constexpr uint64_t kMinBound = 0;
-                    Vector<IntegerRangeInfo, 3> integerRanges;
-                    for (uint32_t size_x_y_z : workgroup_size) {
-                        integerRanges.Push({kMinBound, size_x_y_z - 1u});
+                        return {IntegerRangeInfo(kMinBound, max_bound)};
                     }
-                    return integerRanges;
-                }
+                    case BuiltinValue::kLocalInvocationId: {
+                        TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
+                        std::array<uint32_t, 3> workgroup_size =
+                            function_->WorkgroupSizeAsConst().value();
 
-                if (param->Type()->IsUnsignedIntegerScalar()) {
-                    return {IntegerRangeInfo(0, std::numeric_limits<uint64_t>::max())};
-                } else {
-                    TINT_ASSERT(param->Type()->IsSignedIntegerScalar());
-                    return {IntegerRangeInfo(std::numeric_limits<int64_t>::min(),
-                                             std::numeric_limits<int64_t>::max())};
+                        constexpr uint64_t kMinBound = 0;
+                        Vector<IntegerRangeInfo, 3> integerRanges;
+                        for (uint32_t size_x_y_z : workgroup_size) {
+                            integerRanges.Push({kMinBound, size_x_y_z - 1u});
+                        }
+                        return integerRanges;
+                    }
+                    default:
+                        TINT_UNREACHABLE();
                 }
             });
 
@@ -163,6 +173,26 @@
         return GetInfo(load_from_var);
     }
 
+    const IntegerRangeInfo* GetInfo(const Access* access) {
+        const Value* obj = access->Object();
+
+        // Currently we only support the access to `local_invocation_id` or `local_invocation_index`
+        // as a function parameter.
+        const FunctionParam* function_param = obj->As<FunctionParam>();
+        if (!function_param) {
+            return nullptr;
+        }
+        if (access->Indices().Length() > 1) {
+            return nullptr;
+        }
+        if (!access->Indices()[0]->As<Constant>()) {
+            return nullptr;
+        }
+        uint32_t index =
+            static_cast<uint32_t>(GetValueFromConstant(access->Indices()[0]->As<Constant>()));
+        return GetInfo(function_param, index);
+    }
+
     /// Analyze a loop to compute the range of the loop control variable if possible.
     void AnalyzeLoop(const Loop* loop) {
         const Var* index = GetLoopControlVariableFromConstantInitializer(loop);
@@ -682,4 +712,8 @@
     return impl_->GetInfo(load);
 }
 
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Access* access) {
+    return impl_->GetInfo(access);
+}
+
 }  // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index b6805a4..9c14509 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -33,6 +33,7 @@
 #include <variant>
 
 namespace tint::core::ir {
+class Access;
 class Binary;
 class Function;
 class FunctionParam;
@@ -93,6 +94,10 @@
     /// it has a meaningful range. Returns nullptr otherwise.
     const IntegerRangeInfo* GetInfo(const Load* load_var);
 
+    /// Returns the integer range info of a given `Access` variable if it is an integer variable and
+    /// it has a meaningful range. Returns nullptr otherwise.
+    const IntegerRangeInfo* GetInfo(const Access* access);
+
     /// Note: This function is only for tests.
     /// Returns the pointer of the loop control variable in the given loop when its initializer
     /// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 7b8e872..814193f 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -8616,5 +8616,184 @@
     EXPECT_EQ(nullptr, analysis.GetInfo(load_a));
 }
 
+TEST_F(IR_IntegerRangeAnalysisTest, AccessToLocalInvocationID) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* local_invocation_id = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+    local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({local_invocation_id});
+
+    Access* access_x = nullptr;
+    Access* access_y = nullptr;
+    Access* access_z = nullptr;
+    b.Append(func->Block(), [&] {
+        access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+        access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+        access_z = b.Access(ty.u32(), local_invocation_id, 2_u);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:u32 = access %localInvocationId, 0u
+    %4:u32 = access %localInvocationId, 1u
+    %5:u32 = access %localInvocationId, 2u
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    auto* access_x_info = analysis.GetInfo(access_x);
+    ASSERT_NE(nullptr, access_x_info);
+    ASSERT_TRUE(
+        std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(access_x_info->range));
+    const auto& range_access_x =
+        std::get<IntegerRangeInfo::UnsignedIntegerRange>(access_x_info->range);
+    EXPECT_EQ(0u, range_access_x.min_bound);
+    EXPECT_EQ(3u, range_access_x.max_bound);
+
+    auto* access_y_info = analysis.GetInfo(access_y);
+    ASSERT_NE(nullptr, access_x_info);
+    ASSERT_TRUE(
+        std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(access_y_info->range));
+    const auto& range_access_y =
+        std::get<IntegerRangeInfo::UnsignedIntegerRange>(access_y_info->range);
+    EXPECT_EQ(0u, range_access_y.min_bound);
+    EXPECT_EQ(2u, range_access_y.max_bound);
+
+    auto* access_z_info = analysis.GetInfo(access_z);
+    ASSERT_NE(nullptr, access_x_info);
+    ASSERT_TRUE(
+        std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(access_z_info->range));
+    const auto& range_access_z =
+        std::get<IntegerRangeInfo::UnsignedIntegerRange>(access_z_info->range);
+    EXPECT_EQ(0u, range_access_z.min_bound);
+    EXPECT_EQ(1u, range_access_z.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, NotAccessToFunctionParam) {
+    auto* func = b.Function("func", ty.void_());
+    Access* access = nullptr;
+    b.Append(func->Block(), [&] {
+        auto* dst = b.Var(ty.ptr<function, array<u32, 24u>>());
+        access = b.Access(ty.ptr<function, u32>(), dst, 0_u);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func():void {
+  $B1: {
+    %2:ptr<function, array<u32, 24>, read_write> = var undef
+    %3:ptr<function, u32, read_write> = access %2, 0u
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+    auto* info = analysis.GetInfo(access);
+    ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, AccessToFunctionParamNoRange) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* global_invocation_id = b.FunctionParam("globalId", mod.Types().vec3<u32>());
+    global_invocation_id->SetBuiltin(tint::core::BuiltinValue::kGlobalInvocationId);
+    func->SetParams({global_invocation_id});
+
+    Access* access_x = nullptr;
+    Access* access_y = nullptr;
+    Access* access_z = nullptr;
+    b.Append(func->Block(), [&] {
+        access_x = b.Access(ty.u32(), global_invocation_id, 0_u);
+        access_y = b.Access(ty.u32(), global_invocation_id, 1_u);
+        access_z = b.Access(ty.u32(), global_invocation_id, 2_u);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%globalId:vec3<u32> [@global_invocation_id]):void {
+  $B1: {
+    %3:u32 = access %globalId, 0u
+    %4:u32 = access %globalId, 1u
+    %5:u32 = access %globalId, 2u
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+    ASSERT_EQ(nullptr, analysis.GetInfo(access_x));
+    ASSERT_EQ(nullptr, analysis.GetInfo(access_y));
+    ASSERT_EQ(nullptr, analysis.GetInfo(access_z));
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, AccessToNonIntegerFunctionParam) {
+    auto* func = b.Function("func", ty.void_());
+    Access* access = nullptr;
+    auto* param = b.FunctionParam("param", ty.vec4<f32>());
+    func->SetParams({param});
+    b.Append(func->Block(), [&] {
+        access = b.Access(ty.f32(), param, 0_u);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%func = func(%param:vec4<f32>):void {
+  $B1: {
+    %3:f32 = access %param, 0u
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+    auto* info = analysis.GetInfo(access);
+    ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, NonConstantAccessIndex) {
+    auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+    auto* local_invocation_id = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+    local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+    func->SetParams({local_invocation_id});
+
+    Access* access_x = nullptr;
+    b.Append(func->Block(), [&] {
+        auto* var = b.Var(ty.ptr<function, u32>());
+        auto* index = b.Load(var);
+        access_x = b.Access(ty.u32(), local_invocation_id, index);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+  $B1: {
+    %3:ptr<function, u32, read_write> = var undef
+    %4:u32 = load %3
+    %5:u32 = access %localInvocationId, %4
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+    EXPECT_EQ(Validate(mod), Success);
+
+    IntegerRangeAnalysis analysis(func);
+
+    auto* access_x_info = analysis.GetInfo(access_x);
+    ASSERT_EQ(nullptr, access_x_info);
+}
+
 }  // namespace
 }  // namespace tint::core::ir::analysis