Range Analysis: Compute range of `Access` from local invocation ID
This patch computes the range of all the `Access` instructions that
accesses the built-in function parameter local invocation ID and
returns nullptr when computing the range of any other `Access`
instructions.
Bug: 348701956
Test: tint_unittests
Change-Id: Idf0f40858b5a6b4542407792936277195616a39e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/239484
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
index 56306a3..b6401f1 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.cc
@@ -30,6 +30,7 @@
#include <limits>
#include "src/tint/lang/core/constant/scalar.h"
+#include "src/tint/lang/core/ir/access.h"
#include "src/tint/lang/core/ir/binary.h"
#include "src/tint/lang/core/ir/exit_if.h"
#include "src/tint/lang/core/ir/exit_loop.h"
@@ -105,41 +106,50 @@
return nullptr;
}
+ // Currently we only support the query of ranges of `local_invocation_id` or
+ // `local_invocation_index`.
+ if (!param->Builtin()) {
+ return nullptr;
+ }
+
+ switch (*param->Builtin()) {
+ case BuiltinValue::kLocalInvocationIndex:
+ case BuiltinValue::kLocalInvocationId:
+ break;
+ default:
+ return nullptr;
+ }
+
const auto& info = integer_function_param_range_info_map_.GetOrAdd(
param, [&]() -> Vector<IntegerRangeInfo, 3> {
- if (param->Builtin() == core::BuiltinValue::kLocalInvocationIndex) {
- // We shouldn't be trying to use range analysis on a module that has
- // non-constant workgroup sizes, since we will always have replaced pipeline
- // overrides with constant values early in the pipeline.
- TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
- std::array<uint32_t, 3> workgroup_size =
- function_->WorkgroupSizeAsConst().value();
- uint64_t max_bound =
- workgroup_size[0] * workgroup_size[1] * workgroup_size[2] - 1u;
- constexpr uint64_t kMinBound = 0;
+ switch (*param->Builtin()) {
+ case BuiltinValue::kLocalInvocationIndex: {
+ // We shouldn't be trying to use range analysis on a module that has
+ // non-constant workgroup sizes, since we will always have replaced pipeline
+ // overrides with constant values early in the pipeline.
+ TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
+ std::array<uint32_t, 3> workgroup_size =
+ function_->WorkgroupSizeAsConst().value();
+ uint64_t max_bound =
+ workgroup_size[0] * workgroup_size[1] * workgroup_size[2] - 1u;
+ constexpr uint64_t kMinBound = 0;
- return {IntegerRangeInfo(kMinBound, max_bound)};
- }
-
- if (param->Builtin() == core::BuiltinValue::kLocalInvocationId) {
- TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
- std::array<uint32_t, 3> workgroup_size =
- function_->WorkgroupSizeAsConst().value();
-
- constexpr uint64_t kMinBound = 0;
- Vector<IntegerRangeInfo, 3> integerRanges;
- for (uint32_t size_x_y_z : workgroup_size) {
- integerRanges.Push({kMinBound, size_x_y_z - 1u});
+ return {IntegerRangeInfo(kMinBound, max_bound)};
}
- return integerRanges;
- }
+ case BuiltinValue::kLocalInvocationId: {
+ TINT_ASSERT(function_->WorkgroupSizeAsConst().has_value());
+ std::array<uint32_t, 3> workgroup_size =
+ function_->WorkgroupSizeAsConst().value();
- if (param->Type()->IsUnsignedIntegerScalar()) {
- return {IntegerRangeInfo(0, std::numeric_limits<uint64_t>::max())};
- } else {
- TINT_ASSERT(param->Type()->IsSignedIntegerScalar());
- return {IntegerRangeInfo(std::numeric_limits<int64_t>::min(),
- std::numeric_limits<int64_t>::max())};
+ constexpr uint64_t kMinBound = 0;
+ Vector<IntegerRangeInfo, 3> integerRanges;
+ for (uint32_t size_x_y_z : workgroup_size) {
+ integerRanges.Push({kMinBound, size_x_y_z - 1u});
+ }
+ return integerRanges;
+ }
+ default:
+ TINT_UNREACHABLE();
}
});
@@ -163,6 +173,26 @@
return GetInfo(load_from_var);
}
+ const IntegerRangeInfo* GetInfo(const Access* access) {
+ const Value* obj = access->Object();
+
+ // Currently we only support the access to `local_invocation_id` or `local_invocation_index`
+ // as a function parameter.
+ const FunctionParam* function_param = obj->As<FunctionParam>();
+ if (!function_param) {
+ return nullptr;
+ }
+ if (access->Indices().Length() > 1) {
+ return nullptr;
+ }
+ if (!access->Indices()[0]->As<Constant>()) {
+ return nullptr;
+ }
+ uint32_t index =
+ static_cast<uint32_t>(GetValueFromConstant(access->Indices()[0]->As<Constant>()));
+ return GetInfo(function_param, index);
+ }
+
/// Analyze a loop to compute the range of the loop control variable if possible.
void AnalyzeLoop(const Loop* loop) {
const Var* index = GetLoopControlVariableFromConstantInitializer(loop);
@@ -682,4 +712,8 @@
return impl_->GetInfo(load);
}
+const IntegerRangeInfo* IntegerRangeAnalysis::GetInfo(const Access* access) {
+ return impl_->GetInfo(access);
+}
+
} // namespace tint::core::ir::analysis
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis.h b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
index b6805a4..9c14509 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis.h
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis.h
@@ -33,6 +33,7 @@
#include <variant>
namespace tint::core::ir {
+class Access;
class Binary;
class Function;
class FunctionParam;
@@ -93,6 +94,10 @@
/// it has a meaningful range. Returns nullptr otherwise.
const IntegerRangeInfo* GetInfo(const Load* load_var);
+ /// Returns the integer range info of a given `Access` variable if it is an integer variable and
+ /// it has a meaningful range. Returns nullptr otherwise.
+ const IntegerRangeInfo* GetInfo(const Access* access);
+
/// Note: This function is only for tests.
/// Returns the pointer of the loop control variable in the given loop when its initializer
/// meets the below requirements.
diff --git a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
index 7b8e872..814193f 100644
--- a/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
+++ b/src/tint/lang/core/ir/analysis/integer_range_analysis_test.cc
@@ -8616,5 +8616,184 @@
EXPECT_EQ(nullptr, analysis.GetInfo(load_a));
}
+TEST_F(IR_IntegerRangeAnalysisTest, AccessToLocalInvocationID) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* local_invocation_id = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Access* access_x = nullptr;
+ Access* access_y = nullptr;
+ Access* access_z = nullptr;
+ b.Append(func->Block(), [&] {
+ access_x = b.Access(ty.u32(), local_invocation_id, 0_u);
+ access_y = b.Access(ty.u32(), local_invocation_id, 1_u);
+ access_z = b.Access(ty.u32(), local_invocation_id, 2_u);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %localInvocationId, 0u
+ %4:u32 = access %localInvocationId, 1u
+ %5:u32 = access %localInvocationId, 2u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ auto* access_x_info = analysis.GetInfo(access_x);
+ ASSERT_NE(nullptr, access_x_info);
+ ASSERT_TRUE(
+ std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(access_x_info->range));
+ const auto& range_access_x =
+ std::get<IntegerRangeInfo::UnsignedIntegerRange>(access_x_info->range);
+ EXPECT_EQ(0u, range_access_x.min_bound);
+ EXPECT_EQ(3u, range_access_x.max_bound);
+
+ auto* access_y_info = analysis.GetInfo(access_y);
+ ASSERT_NE(nullptr, access_x_info);
+ ASSERT_TRUE(
+ std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(access_y_info->range));
+ const auto& range_access_y =
+ std::get<IntegerRangeInfo::UnsignedIntegerRange>(access_y_info->range);
+ EXPECT_EQ(0u, range_access_y.min_bound);
+ EXPECT_EQ(2u, range_access_y.max_bound);
+
+ auto* access_z_info = analysis.GetInfo(access_z);
+ ASSERT_NE(nullptr, access_x_info);
+ ASSERT_TRUE(
+ std::holds_alternative<IntegerRangeInfo::UnsignedIntegerRange>(access_z_info->range));
+ const auto& range_access_z =
+ std::get<IntegerRangeInfo::UnsignedIntegerRange>(access_z_info->range);
+ EXPECT_EQ(0u, range_access_z.min_bound);
+ EXPECT_EQ(1u, range_access_z.max_bound);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, NotAccessToFunctionParam) {
+ auto* func = b.Function("func", ty.void_());
+ Access* access = nullptr;
+ b.Append(func->Block(), [&] {
+ auto* dst = b.Var(ty.ptr<function, array<u32, 24u>>());
+ access = b.Access(ty.ptr<function, u32>(), dst, 0_u);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func():void {
+ $B1: {
+ %2:ptr<function, array<u32, 24>, read_write> = var undef
+ %3:ptr<function, u32, read_write> = access %2, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ auto* info = analysis.GetInfo(access);
+ ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, AccessToFunctionParamNoRange) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* global_invocation_id = b.FunctionParam("globalId", mod.Types().vec3<u32>());
+ global_invocation_id->SetBuiltin(tint::core::BuiltinValue::kGlobalInvocationId);
+ func->SetParams({global_invocation_id});
+
+ Access* access_x = nullptr;
+ Access* access_y = nullptr;
+ Access* access_z = nullptr;
+ b.Append(func->Block(), [&] {
+ access_x = b.Access(ty.u32(), global_invocation_id, 0_u);
+ access_y = b.Access(ty.u32(), global_invocation_id, 1_u);
+ access_z = b.Access(ty.u32(), global_invocation_id, 2_u);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%globalId:vec3<u32> [@global_invocation_id]):void {
+ $B1: {
+ %3:u32 = access %globalId, 0u
+ %4:u32 = access %globalId, 1u
+ %5:u32 = access %globalId, 2u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ ASSERT_EQ(nullptr, analysis.GetInfo(access_x));
+ ASSERT_EQ(nullptr, analysis.GetInfo(access_y));
+ ASSERT_EQ(nullptr, analysis.GetInfo(access_z));
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, AccessToNonIntegerFunctionParam) {
+ auto* func = b.Function("func", ty.void_());
+ Access* access = nullptr;
+ auto* param = b.FunctionParam("param", ty.vec4<f32>());
+ func->SetParams({param});
+ b.Append(func->Block(), [&] {
+ access = b.Access(ty.f32(), param, 0_u);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%func = func(%param:vec4<f32>):void {
+ $B1: {
+ %3:f32 = access %param, 0u
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+ auto* info = analysis.GetInfo(access);
+ ASSERT_EQ(nullptr, info);
+}
+
+TEST_F(IR_IntegerRangeAnalysisTest, NonConstantAccessIndex) {
+ auto* func = b.ComputeFunction("my_func", 4_u, 3_u, 2_u);
+ auto* local_invocation_id = b.FunctionParam("localInvocationId", mod.Types().vec3<u32>());
+ local_invocation_id->SetBuiltin(tint::core::BuiltinValue::kLocalInvocationId);
+ func->SetParams({local_invocation_id});
+
+ Access* access_x = nullptr;
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var(ty.ptr<function, u32>());
+ auto* index = b.Load(var);
+ access_x = b.Access(ty.u32(), local_invocation_id, index);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%my_func = @compute @workgroup_size(4u, 3u, 2u) func(%localInvocationId:vec3<u32> [@local_invocation_id]):void {
+ $B1: {
+ %3:ptr<function, u32, read_write> = var undef
+ %4:u32 = load %3
+ %5:u32 = access %localInvocationId, %4
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+ EXPECT_EQ(Validate(mod), Success);
+
+ IntegerRangeAnalysis analysis(func);
+
+ auto* access_x_info = analysis.GetInfo(access_x);
+ ASSERT_EQ(nullptr, access_x_info);
+}
+
} // namespace
} // namespace tint::core::ir::analysis